Bug Summary

File:build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/mlir/include/mlir/IR/OpDefinition.h
Warning:line 104, column 5
Called C++ object pointer is null

Annotated Source Code

Press '?' to see keyboard shortcuts

clang -cc1 -cc1 -triple x86_64-pc-linux-gnu -analyze -disable-free -clear-ast-before-backend -disable-llvm-verifier -discard-value-names -main-file-name SuperVectorize.cpp -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -setup-static-analyzer -analyzer-config-compatibility-mode=true -mrelocation-model pic -pic-level 2 -mframe-pointer=none -fmath-errno -ffp-contract=on -fno-rounding-math -mconstructor-aliases -funwind-tables=2 -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -ffunction-sections -fdata-sections -fcoverage-compilation-dir=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/build-llvm/tools/clang/stage2-bins -resource-dir /usr/lib/llvm-16/lib/clang/16.0.0 -D MLIR_CUDA_CONVERSIONS_ENABLED=1 -D MLIR_ROCM_CONVERSIONS_ENABLED=1 -D _DEBUG -D _GNU_SOURCE -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -D __STDC_LIMIT_MACROS -I tools/mlir/lib/Dialect/Affine/Transforms -I /build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/mlir/lib/Dialect/Affine/Transforms -I include -I /build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/llvm/include -I /build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/mlir/include -I tools/mlir/include -D _FORTIFY_SOURCE=2 -D NDEBUG -U NDEBUG -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/x86_64-linux-gnu/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/backward -internal-isystem /usr/lib/llvm-16/lib/clang/16.0.0/include -internal-isystem /usr/local/include -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../x86_64-linux-gnu/include -internal-externc-isystem /usr/include/x86_64-linux-gnu -internal-externc-isystem /include -internal-externc-isystem /usr/include -fmacro-prefix-map=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fmacro-prefix-map=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/= -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/= -O2 -Wno-unused-command-line-argument -Wno-unused-parameter -Wwrite-strings -Wno-missing-field-initializers -Wno-long-long -Wno-maybe-uninitialized -Wno-class-memaccess -Wno-redundant-move -Wno-pessimizing-move -Wno-noexcept-type -Wno-comment -Wno-misleading-indentation -std=c++17 -fdeprecated-macro -fdebug-compilation-dir=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/= -ferror-limit 19 -fvisibility-inlines-hidden -stack-protector 2 -fgnuc-version=4.2.1 -fcolor-diagnostics -vectorize-loops -vectorize-slp -analyzer-output=html -analyzer-config stable-report-filename=true -faddrsig -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /tmp/scan-build-2022-10-03-140002-15933-1 -x c++ /build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp

/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp

1//===- SuperVectorize.cpp - Vectorize Pass Impl ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements vectorization of loops, operations and data types to
10// a target-independent, n-D super-vector abstraction.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/Passes.h"
15
16#include "mlir/Analysis/SliceAnalysis.h"
17#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
18#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
19#include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
20#include "mlir/Dialect/Affine/IR/AffineOps.h"
21#include "mlir/Dialect/Affine/Utils.h"
22#include "mlir/Dialect/Arith/IR/Arith.h"
23#include "mlir/Dialect/Func/IR/FuncOps.h"
24#include "mlir/Dialect/Vector/IR/VectorOps.h"
25#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
26#include "mlir/IR/BlockAndValueMapping.h"
27#include "mlir/Pass/Pass.h"
28#include "mlir/Support/LLVM.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/Support/Debug.h"
31
32namespace mlir {
33#define GEN_PASS_DEF_AFFINEVECTORIZE
34#include "mlir/Dialect/Affine/Passes.h.inc"
35} // namespace mlir
36
37using namespace mlir;
38using namespace vector;
39
40///
41/// Implements a high-level vectorization strategy on a Function.
42/// The abstraction used is that of super-vectors, which provide a single,
43/// compact, representation in the vector types, information that is expected
44/// to reduce the impact of the phase ordering problem
45///
46/// Vector granularity:
47/// ===================
48/// This pass is designed to perform vectorization at a super-vector
49/// granularity. A super-vector is loosely defined as a vector type that is a
50/// multiple of a "good" vector size so the HW can efficiently implement a set
51/// of high-level primitives. Multiple is understood along any dimension; e.g.
52/// both vector<16xf32> and vector<2x8xf32> are valid super-vectors for a
53/// vector<8xf32> HW vector. Note that a "good vector size so the HW can
54/// efficiently implement a set of high-level primitives" is not necessarily an
55/// integer multiple of actual hardware registers. We leave details of this
56/// distinction unspecified for now.
57///
58/// Some may prefer the terminology a "tile of HW vectors". In this case, one
59/// should note that super-vectors implement an "always full tile" abstraction.
60/// They guarantee no partial-tile separation is necessary by relying on a
61/// high-level copy-reshape abstraction that we call vector.transfer. This
62/// copy-reshape operations is also responsible for performing layout
63/// transposition if necessary. In the general case this will require a scoped
64/// allocation in some notional local memory.
65///
66/// Whatever the mental model one prefers to use for this abstraction, the key
67/// point is that we burn into a single, compact, representation in the vector
68/// types, information that is expected to reduce the impact of the phase
69/// ordering problem. Indeed, a vector type conveys information that:
70/// 1. the associated loops have dependency semantics that do not prevent
71/// vectorization;
72/// 2. the associate loops have been sliced in chunks of static sizes that are
73/// compatible with vector sizes (i.e. similar to unroll-and-jam);
74/// 3. the inner loops, in the unroll-and-jam analogy of 2, are captured by
75/// the
76/// vector type and no vectorization hampering transformations can be
77/// applied to them anymore;
78/// 4. the underlying memrefs are accessed in some notional contiguous way
79/// that allows loading into vectors with some amount of spatial locality;
80/// In other words, super-vectorization provides a level of separation of
81/// concern by way of opacity to subsequent passes. This has the effect of
82/// encapsulating and propagating vectorization constraints down the list of
83/// passes until we are ready to lower further.
84///
85/// For a particular target, a notion of minimal n-d vector size will be
86/// specified and vectorization targets a multiple of those. In the following
87/// paragraph, let "k ." represent "a multiple of", to be understood as a
88/// multiple in the same dimension (e.g. vector<16 x k . 128> summarizes
89/// vector<16 x 128>, vector<16 x 256>, vector<16 x 1024>, etc).
90///
91/// Some non-exhaustive notable super-vector sizes of interest include:
92/// - CPU: vector<k . HW_vector_size>,
93/// vector<k' . core_count x k . HW_vector_size>,
94/// vector<socket_count x k' . core_count x k . HW_vector_size>;
95/// - GPU: vector<k . warp_size>,
96/// vector<k . warp_size x float2>,
97/// vector<k . warp_size x float4>,
98/// vector<k . warp_size x 4 x 4x 4> (for tensor_core sizes).
99///
100/// Loops and operations are emitted that operate on those super-vector shapes.
101/// Subsequent lowering passes will materialize to actual HW vector sizes. These
102/// passes are expected to be (gradually) more target-specific.
103///
104/// At a high level, a vectorized load in a loop will resemble:
105/// ```mlir
106/// affine.for %i = ? to ? step ? {
107/// %v_a = vector.transfer_read A[%i] : memref<?xf32>, vector<128xf32>
108/// }
109/// ```
110/// It is the responsibility of the implementation of vector.transfer_read to
111/// materialize vector registers from the original scalar memrefs. A later (more
112/// target-dependent) lowering pass will materialize to actual HW vector sizes.
113/// This lowering may be occur at different times:
114/// 1. at the MLIR level into a combination of loops, unrolling, DmaStartOp +
115/// DmaWaitOp + vectorized operations for data transformations and shuffle;
116/// thus opening opportunities for unrolling and pipelining. This is an
117/// instance of library call "whiteboxing"; or
118/// 2. later in the a target-specific lowering pass or hand-written library
119/// call; achieving full separation of concerns. This is an instance of
120/// library call; or
121/// 3. a mix of both, e.g. based on a model.
122/// In the future, these operations will expose a contract to constrain the
123/// search on vectorization patterns and sizes.
124///
125/// Occurrence of super-vectorization in the compiler flow:
126/// =======================================================
127/// This is an active area of investigation. We start with 2 remarks to position
128/// super-vectorization in the context of existing ongoing work: LLVM VPLAN
129/// and LLVM SLP Vectorizer.
130///
131/// LLVM VPLAN:
132/// -----------
133/// The astute reader may have noticed that in the limit, super-vectorization
134/// can be applied at a similar time and with similar objectives than VPLAN.
135/// For instance, in the case of a traditional, polyhedral compilation-flow (for
136/// instance, the PPCG project uses ISL to provide dependence analysis,
137/// multi-level(scheduling + tiling), lifting footprint to fast memory,
138/// communication synthesis, mapping, register optimizations) and before
139/// unrolling. When vectorization is applied at this *late* level in a typical
140/// polyhedral flow, and is instantiated with actual hardware vector sizes,
141/// super-vectorization is expected to match (or subsume) the type of patterns
142/// that LLVM's VPLAN aims at targeting. The main difference here is that MLIR
143/// is higher level and our implementation should be significantly simpler. Also
144/// note that in this mode, recursive patterns are probably a bit of an overkill
145/// although it is reasonable to expect that mixing a bit of outer loop and
146/// inner loop vectorization + unrolling will provide interesting choices to
147/// MLIR.
148///
149/// LLVM SLP Vectorizer:
150/// --------------------
151/// Super-vectorization however is not meant to be usable in a similar fashion
152/// to the SLP vectorizer. The main difference lies in the information that
153/// both vectorizers use: super-vectorization examines contiguity of memory
154/// references along fastest varying dimensions and loops with recursive nested
155/// patterns capturing imperfectly-nested loop nests; the SLP vectorizer, on
156/// the other hand, performs flat pattern matching inside a single unrolled loop
157/// body and stitches together pieces of load and store operations into full
158/// 1-D vectors. We envision that the SLP vectorizer is a good way to capture
159/// innermost loop, control-flow dependent patterns that super-vectorization may
160/// not be able to capture easily. In other words, super-vectorization does not
161/// aim at replacing the SLP vectorizer and the two solutions are complementary.
162///
163/// Ongoing investigations:
164/// -----------------------
165/// We discuss the following *early* places where super-vectorization is
166/// applicable and touch on the expected benefits and risks . We list the
167/// opportunities in the context of the traditional polyhedral compiler flow
168/// described in PPCG. There are essentially 6 places in the MLIR pass pipeline
169/// we expect to experiment with super-vectorization:
170/// 1. Right after language lowering to MLIR: this is the earliest time where
171/// super-vectorization is expected to be applied. At this level, all the
172/// language/user/library-level annotations are available and can be fully
173/// exploited. Examples include loop-type annotations (such as parallel,
174/// reduction, scan, dependence distance vector, vectorizable) as well as
175/// memory access annotations (such as non-aliasing writes guaranteed,
176/// indirect accesses that are permutations by construction) accesses or
177/// that a particular operation is prescribed atomic by the user. At this
178/// level, anything that enriches what dependence analysis can do should be
179/// aggressively exploited. At this level we are close to having explicit
180/// vector types in the language, except we do not impose that burden on the
181/// programmer/library: we derive information from scalar code + annotations.
182/// 2. After dependence analysis and before polyhedral scheduling: the
183/// information that supports vectorization does not need to be supplied by a
184/// higher level of abstraction. Traditional dependence analysis is available
185/// in MLIR and will be used to drive vectorization and cost models.
186///
187/// Let's pause here and remark that applying super-vectorization as described
188/// in 1. and 2. presents clear opportunities and risks:
189/// - the opportunity is that vectorization is burned in the type system and
190/// is protected from the adverse effect of loop scheduling, tiling, loop
191/// interchange and all passes downstream. Provided that subsequent passes are
192/// able to operate on vector types; the vector shapes, associated loop
193/// iterator properties, alignment, and contiguity of fastest varying
194/// dimensions are preserved until we lower the super-vector types. We expect
195/// this to significantly rein in on the adverse effects of phase ordering.
196/// - the risks are that a. all passes after super-vectorization have to work
197/// on elemental vector types (not that this is always true, wherever
198/// vectorization is applied) and b. that imposing vectorization constraints
199/// too early may be overall detrimental to loop fusion, tiling and other
200/// transformations because the dependence distances are coarsened when
201/// operating on elemental vector types. For this reason, the pattern
202/// profitability analysis should include a component that also captures the
203/// maximal amount of fusion available under a particular pattern. This is
204/// still at the stage of rough ideas but in this context, search is our
205/// friend as the Tensor Comprehensions and auto-TVM contributions
206/// demonstrated previously.
207/// Bottom-line is we do not yet have good answers for the above but aim at
208/// making it easy to answer such questions.
209///
210/// Back to our listing, the last places where early super-vectorization makes
211/// sense are:
212/// 3. right after polyhedral-style scheduling: PLUTO-style algorithms are known
213/// to improve locality, parallelism and be configurable (e.g. max-fuse,
214/// smart-fuse etc). They can also have adverse effects on contiguity
215/// properties that are required for vectorization but the vector.transfer
216/// copy-reshape-pad-transpose abstraction is expected to help recapture
217/// these properties.
218/// 4. right after polyhedral-style scheduling+tiling;
219/// 5. right after scheduling+tiling+rescheduling: points 4 and 5 represent
220/// probably the most promising places because applying tiling achieves a
221/// separation of concerns that allows rescheduling to worry less about
222/// locality and more about parallelism and distribution (e.g. min-fuse).
223///
224/// At these levels the risk-reward looks different: on one hand we probably
225/// lost a good deal of language/user/library-level annotation; on the other
226/// hand we gained parallelism and locality through scheduling and tiling.
227/// However we probably want to ensure tiling is compatible with the
228/// full-tile-only abstraction used in super-vectorization or suffer the
229/// consequences. It is too early to place bets on what will win but we expect
230/// super-vectorization to be the right abstraction to allow exploring at all
231/// these levels. And again, search is our friend.
232///
233/// Lastly, we mention it again here:
234/// 6. as a MLIR-based alternative to VPLAN.
235///
236/// Lowering, unrolling, pipelining:
237/// ================================
238/// TODO: point to the proper places.
239///
240/// Algorithm:
241/// ==========
242/// The algorithm proceeds in a few steps:
243/// 1. defining super-vectorization patterns and matching them on the tree of
244/// AffineForOp. A super-vectorization pattern is defined as a recursive
245/// data structures that matches and captures nested, imperfectly-nested
246/// loops that have a. conformable loop annotations attached (e.g. parallel,
247/// reduction, vectorizable, ...) as well as b. all contiguous load/store
248/// operations along a specified minor dimension (not necessarily the
249/// fastest varying) ;
250/// 2. analyzing those patterns for profitability (TODO: and
251/// interference);
252/// 3. then, for each pattern in order:
253/// a. applying iterative rewriting of the loops and all their nested
254/// operations in topological order. Rewriting is implemented by
255/// coarsening the loops and converting operations and operands to their
256/// vector forms. Processing operations in topological order is relatively
257/// simple due to the structured nature of the control-flow
258/// representation. This order ensures that all the operands of a given
259/// operation have been vectorized before the operation itself in a single
260/// traversal, except for operands defined outside of the loop nest. The
261/// algorithm can convert the following operations to their vector form:
262/// * Affine load and store operations are converted to opaque vector
263/// transfer read and write operations.
264/// * Scalar constant operations/operands are converted to vector
265/// constant operations (splat).
266/// * Uniform operands (only induction variables of loops not mapped to
267/// a vector dimension, or operands defined outside of the loop nest
268/// for now) are broadcasted to a vector.
269/// TODO: Support more uniform cases.
270/// * Affine for operations with 'iter_args' are vectorized by
271/// vectorizing their 'iter_args' operands and results.
272/// TODO: Support more complex loops with divergent lbs and/or ubs.
273/// * The remaining operations in the loop nest are vectorized by
274/// widening their scalar types to vector types.
275/// b. if everything under the root AffineForOp in the current pattern
276/// is vectorized properly, we commit that loop to the IR and remove the
277/// scalar loop. Otherwise, we discard the vectorized loop and keep the
278/// original scalar loop.
279/// c. vectorization is applied on the next pattern in the list. Because
280/// pattern interference avoidance is not yet implemented and that we do
281/// not support further vectorizing an already vector load we need to
282/// re-verify that the pattern is still vectorizable. This is expected to
283/// make cost models more difficult to write and is subject to improvement
284/// in the future.
285///
286/// Choice of loop transformation to support the algorithm:
287/// =======================================================
288/// The choice of loop transformation to apply for coarsening vectorized loops
289/// is still subject to exploratory tradeoffs. In particular, say we want to
290/// vectorize by a factor 128, we want to transform the following input:
291/// ```mlir
292/// affine.for %i = %M to %N {
293/// %a = affine.load %A[%i] : memref<?xf32>
294/// }
295/// ```
296///
297/// Traditionally, one would vectorize late (after scheduling, tiling,
298/// memory promotion etc) say after stripmining (and potentially unrolling in
299/// the case of LLVM's SLP vectorizer):
300/// ```mlir
301/// affine.for %i = floor(%M, 128) to ceil(%N, 128) {
302/// affine.for %ii = max(%M, 128 * %i) to min(%N, 128*%i + 127) {
303/// %a = affine.load %A[%ii] : memref<?xf32>
304/// }
305/// }
306/// ```
307///
308/// Instead, we seek to vectorize early and freeze vector types before
309/// scheduling, so we want to generate a pattern that resembles:
310/// ```mlir
311/// affine.for %i = ? to ? step ? {
312/// %v_a = vector.transfer_read %A[%i] : memref<?xf32>, vector<128xf32>
313/// }
314/// ```
315///
316/// i. simply dividing the lower / upper bounds by 128 creates issues
317/// when representing expressions such as ii + 1 because now we only
318/// have access to original values that have been divided. Additional
319/// information is needed to specify accesses at below-128 granularity;
320/// ii. another alternative is to coarsen the loop step but this may have
321/// consequences on dependence analysis and fusability of loops: fusable
322/// loops probably need to have the same step (because we don't want to
323/// stripmine/unroll to enable fusion).
324/// As a consequence, we choose to represent the coarsening using the loop
325/// step for now and reevaluate in the future. Note that we can renormalize
326/// loop steps later if/when we have evidence that they are problematic.
327///
328/// For the simple strawman example above, vectorizing for a 1-D vector
329/// abstraction of size 128 returns code similar to:
330/// ```mlir
331/// affine.for %i = %M to %N step 128 {
332/// %v_a = vector.transfer_read %A[%i] : memref<?xf32>, vector<128xf32>
333/// }
334/// ```
335///
336/// Unsupported cases, extensions, and work in progress (help welcome :-) ):
337/// ========================================================================
338/// 1. lowering to concrete vector types for various HW;
339/// 2. reduction support for n-D vectorization and non-unit steps;
340/// 3. non-effecting padding during vector.transfer_read and filter during
341/// vector.transfer_write;
342/// 4. misalignment support vector.transfer_read / vector.transfer_write
343/// (hopefully without read-modify-writes);
344/// 5. control-flow support;
345/// 6. cost-models, heuristics and search;
346/// 7. Op implementation, extensions and implication on memref views;
347/// 8. many TODOs left around.
348///
349/// Examples:
350/// =========
351/// Consider the following Function:
352/// ```mlir
353/// func @vector_add_2d(%M : index, %N : index) -> f32 {
354/// %A = alloc (%M, %N) : memref<?x?xf32, 0>
355/// %B = alloc (%M, %N) : memref<?x?xf32, 0>
356/// %C = alloc (%M, %N) : memref<?x?xf32, 0>
357/// %f1 = arith.constant 1.0 : f32
358/// %f2 = arith.constant 2.0 : f32
359/// affine.for %i0 = 0 to %M {
360/// affine.for %i1 = 0 to %N {
361/// // non-scoped %f1
362/// affine.store %f1, %A[%i0, %i1] : memref<?x?xf32, 0>
363/// }
364/// }
365/// affine.for %i2 = 0 to %M {
366/// affine.for %i3 = 0 to %N {
367/// // non-scoped %f2
368/// affine.store %f2, %B[%i2, %i3] : memref<?x?xf32, 0>
369/// }
370/// }
371/// affine.for %i4 = 0 to %M {
372/// affine.for %i5 = 0 to %N {
373/// %a5 = affine.load %A[%i4, %i5] : memref<?x?xf32, 0>
374/// %b5 = affine.load %B[%i4, %i5] : memref<?x?xf32, 0>
375/// %s5 = arith.addf %a5, %b5 : f32
376/// // non-scoped %f1
377/// %s6 = arith.addf %s5, %f1 : f32
378/// // non-scoped %f2
379/// %s7 = arith.addf %s5, %f2 : f32
380/// // diamond dependency.
381/// %s8 = arith.addf %s7, %s6 : f32
382/// affine.store %s8, %C[%i4, %i5] : memref<?x?xf32, 0>
383/// }
384/// }
385/// %c7 = arith.constant 7 : index
386/// %c42 = arith.constant 42 : index
387/// %res = load %C[%c7, %c42] : memref<?x?xf32, 0>
388/// return %res : f32
389/// }
390/// ```
391///
392/// The -affine-super-vectorize pass with the following arguments:
393/// ```
394/// -affine-super-vectorize="virtual-vector-size=256 test-fastest-varying=0"
395/// ```
396///
397/// produces this standard innermost-loop vectorized code:
398/// ```mlir
399/// func @vector_add_2d(%arg0 : index, %arg1 : index) -> f32 {
400/// %0 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
401/// %1 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
402/// %2 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
403/// %cst = arith.constant 1.0 : f32
404/// %cst_0 = arith.constant 2.0 : f32
405/// affine.for %i0 = 0 to %arg0 {
406/// affine.for %i1 = 0 to %arg1 step 256 {
407/// %cst_1 = arith.constant dense<vector<256xf32>, 1.0> :
408/// vector<256xf32>
409/// vector.transfer_write %cst_1, %0[%i0, %i1] :
410/// vector<256xf32>, memref<?x?xf32>
411/// }
412/// }
413/// affine.for %i2 = 0 to %arg0 {
414/// affine.for %i3 = 0 to %arg1 step 256 {
415/// %cst_2 = arith.constant dense<vector<256xf32>, 2.0> :
416/// vector<256xf32>
417/// vector.transfer_write %cst_2, %1[%i2, %i3] :
418/// vector<256xf32>, memref<?x?xf32>
419/// }
420/// }
421/// affine.for %i4 = 0 to %arg0 {
422/// affine.for %i5 = 0 to %arg1 step 256 {
423/// %3 = vector.transfer_read %0[%i4, %i5] :
424/// memref<?x?xf32>, vector<256xf32>
425/// %4 = vector.transfer_read %1[%i4, %i5] :
426/// memref<?x?xf32>, vector<256xf32>
427/// %5 = arith.addf %3, %4 : vector<256xf32>
428/// %cst_3 = arith.constant dense<vector<256xf32>, 1.0> :
429/// vector<256xf32>
430/// %6 = arith.addf %5, %cst_3 : vector<256xf32>
431/// %cst_4 = arith.constant dense<vector<256xf32>, 2.0> :
432/// vector<256xf32>
433/// %7 = arith.addf %5, %cst_4 : vector<256xf32>
434/// %8 = arith.addf %7, %6 : vector<256xf32>
435/// vector.transfer_write %8, %2[%i4, %i5] :
436/// vector<256xf32>, memref<?x?xf32>
437/// }
438/// }
439/// %c7 = arith.constant 7 : index
440/// %c42 = arith.constant 42 : index
441/// %9 = load %2[%c7, %c42] : memref<?x?xf32>
442/// return %9 : f32
443/// }
444/// ```
445///
446/// The -affine-super-vectorize pass with the following arguments:
447/// ```
448/// -affine-super-vectorize="virtual-vector-size=32,256 \
449/// test-fastest-varying=1,0"
450/// ```
451///
452/// produces this more interesting mixed outer-innermost-loop vectorized code:
453/// ```mlir
454/// func @vector_add_2d(%arg0 : index, %arg1 : index) -> f32 {
455/// %0 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
456/// %1 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
457/// %2 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
458/// %cst = arith.constant 1.0 : f32
459/// %cst_0 = arith.constant 2.0 : f32
460/// affine.for %i0 = 0 to %arg0 step 32 {
461/// affine.for %i1 = 0 to %arg1 step 256 {
462/// %cst_1 = arith.constant dense<vector<32x256xf32>, 1.0> :
463/// vector<32x256xf32>
464/// vector.transfer_write %cst_1, %0[%i0, %i1] :
465/// vector<32x256xf32>, memref<?x?xf32>
466/// }
467/// }
468/// affine.for %i2 = 0 to %arg0 step 32 {
469/// affine.for %i3 = 0 to %arg1 step 256 {
470/// %cst_2 = arith.constant dense<vector<32x256xf32>, 2.0> :
471/// vector<32x256xf32>
472/// vector.transfer_write %cst_2, %1[%i2, %i3] :
473/// vector<32x256xf32>, memref<?x?xf32>
474/// }
475/// }
476/// affine.for %i4 = 0 to %arg0 step 32 {
477/// affine.for %i5 = 0 to %arg1 step 256 {
478/// %3 = vector.transfer_read %0[%i4, %i5] :
479/// memref<?x?xf32> vector<32x256xf32>
480/// %4 = vector.transfer_read %1[%i4, %i5] :
481/// memref<?x?xf32>, vector<32x256xf32>
482/// %5 = arith.addf %3, %4 : vector<32x256xf32>
483/// %cst_3 = arith.constant dense<vector<32x256xf32>, 1.0> :
484/// vector<32x256xf32>
485/// %6 = arith.addf %5, %cst_3 : vector<32x256xf32>
486/// %cst_4 = arith.constant dense<vector<32x256xf32>, 2.0> :
487/// vector<32x256xf32>
488/// %7 = arith.addf %5, %cst_4 : vector<32x256xf32>
489/// %8 = arith.addf %7, %6 : vector<32x256xf32>
490/// vector.transfer_write %8, %2[%i4, %i5] :
491/// vector<32x256xf32>, memref<?x?xf32>
492/// }
493/// }
494/// %c7 = arith.constant 7 : index
495/// %c42 = arith.constant 42 : index
496/// %9 = load %2[%c7, %c42] : memref<?x?xf32>
497/// return %9 : f32
498/// }
499/// ```
500///
501/// Of course, much more intricate n-D imperfectly-nested patterns can be
502/// vectorized too and specified in a fully declarative fashion.
503///
504/// Reduction:
505/// ==========
506/// Vectorizing reduction loops along the reduction dimension is supported if:
507/// - the reduction kind is supported,
508/// - the vectorization is 1-D, and
509/// - the step size of the loop equals to one.
510///
511/// Comparing to the non-vector-dimension case, two additional things are done
512/// during vectorization of such loops:
513/// - The resulting vector returned from the loop is reduced to a scalar using
514/// `vector.reduce`.
515/// - In some cases a mask is applied to the vector yielded at the end of the
516/// loop to prevent garbage values from being written to the accumulator.
517///
518/// Reduction vectorization is switched off by default, it can be enabled by
519/// passing a map from loops to reductions to utility functions, or by passing
520/// `vectorize-reductions=true` to the vectorization pass.
521///
522/// Consider the following example:
523/// ```mlir
524/// func @vecred(%in: memref<512xf32>) -> f32 {
525/// %cst = arith.constant 0.000000e+00 : f32
526/// %sum = affine.for %i = 0 to 500 iter_args(%part_sum = %cst) -> (f32) {
527/// %ld = affine.load %in[%i] : memref<512xf32>
528/// %cos = math.cos %ld : f32
529/// %add = arith.addf %part_sum, %cos : f32
530/// affine.yield %add : f32
531/// }
532/// return %sum : f32
533/// }
534/// ```
535///
536/// The -affine-super-vectorize pass with the following arguments:
537/// ```
538/// -affine-super-vectorize="virtual-vector-size=128 test-fastest-varying=0 \
539/// vectorize-reductions=true"
540/// ```
541/// produces the following output:
542/// ```mlir
543/// #map = affine_map<(d0) -> (-d0 + 500)>
544/// func @vecred(%arg0: memref<512xf32>) -> f32 {
545/// %cst = arith.constant 0.000000e+00 : f32
546/// %cst_0 = arith.constant dense<0.000000e+00> : vector<128xf32>
547/// %0 = affine.for %arg1 = 0 to 500 step 128 iter_args(%arg2 = %cst_0)
548/// -> (vector<128xf32>) {
549/// // %2 is the number of iterations left in the original loop.
550/// %2 = affine.apply #map(%arg1)
551/// %3 = vector.create_mask %2 : vector<128xi1>
552/// %cst_1 = arith.constant 0.000000e+00 : f32
553/// %4 = vector.transfer_read %arg0[%arg1], %cst_1 :
554/// memref<512xf32>, vector<128xf32>
555/// %5 = math.cos %4 : vector<128xf32>
556/// %6 = arith.addf %arg2, %5 : vector<128xf32>
557/// // We filter out the effect of last 12 elements using the mask.
558/// %7 = select %3, %6, %arg2 : vector<128xi1>, vector<128xf32>
559/// affine.yield %7 : vector<128xf32>
560/// }
561/// %1 = vector.reduction <add>, %0 : vector<128xf32> into f32
562/// return %1 : f32
563/// }
564/// ```
565///
566/// Note that because of loop misalignment we needed to apply a mask to prevent
567/// last 12 elements from affecting the final result. The mask is full of ones
568/// in every iteration except for the last one, in which it has the form
569/// `11...100...0` with 116 ones and 12 zeros.
570
571#define DEBUG_TYPE"early-vect" "early-vect"
572
573using llvm::dbgs;
574
575/// Forward declaration.
576static FilterFunctionType
577isVectorizableLoopPtrFactory(const DenseSet<Operation *> &parallelLoops,
578 int fastestVaryingMemRefDimension);
579
580/// Creates a vectorization pattern from the command line arguments.
581/// Up to 3-D patterns are supported.
582/// If the command line argument requests a pattern of higher order, returns an
583/// empty pattern list which will conservatively result in no vectorization.
584static Optional<NestedPattern>
585makePattern(const DenseSet<Operation *> &parallelLoops, int vectorRank,
586 ArrayRef<int64_t> fastestVaryingPattern) {
587 using matcher::For;
588 int64_t d0 = fastestVaryingPattern.empty() ? -1 : fastestVaryingPattern[0];
589 int64_t d1 = fastestVaryingPattern.size() < 2 ? -1 : fastestVaryingPattern[1];
590 int64_t d2 = fastestVaryingPattern.size() < 3 ? -1 : fastestVaryingPattern[2];
591 switch (vectorRank) {
592 case 1:
593 return For(isVectorizableLoopPtrFactory(parallelLoops, d0));
594 case 2:
595 return For(isVectorizableLoopPtrFactory(parallelLoops, d0),
596 For(isVectorizableLoopPtrFactory(parallelLoops, d1)));
597 case 3:
598 return For(isVectorizableLoopPtrFactory(parallelLoops, d0),
599 For(isVectorizableLoopPtrFactory(parallelLoops, d1),
600 For(isVectorizableLoopPtrFactory(parallelLoops, d2))));
601 default: {
602 return llvm::None;
603 }
604 }
605}
606
607static NestedPattern &vectorTransferPattern() {
608 static auto pattern = matcher::Op([](Operation &op) {
609 return isa<vector::TransferReadOp, vector::TransferWriteOp>(op);
610 });
611 return pattern;
612}
613
614namespace {
615
616/// Base state for the vectorize pass.
617/// Command line arguments are preempted by non-empty pass arguments.
618struct Vectorize : public impl::AffineVectorizeBase<Vectorize> {
619 using Base::Base;
620
621 void runOnOperation() override;
622};
623
624} // namespace
625
626static void vectorizeLoopIfProfitable(Operation *loop, unsigned depthInPattern,
627 unsigned patternDepth,
628 VectorizationStrategy *strategy) {
629 assert(patternDepth > depthInPattern &&(static_cast <bool> (patternDepth > depthInPattern &&
"patternDepth is greater than depthInPattern") ? void (0) : __assert_fail
("patternDepth > depthInPattern && \"patternDepth is greater than depthInPattern\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 630
, __extension__ __PRETTY_FUNCTION__))
630 "patternDepth is greater than depthInPattern")(static_cast <bool> (patternDepth > depthInPattern &&
"patternDepth is greater than depthInPattern") ? void (0) : __assert_fail
("patternDepth > depthInPattern && \"patternDepth is greater than depthInPattern\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 630
, __extension__ __PRETTY_FUNCTION__))
;
631 if (patternDepth - depthInPattern > strategy->vectorSizes.size()) {
632 // Don't vectorize this loop
633 return;
634 }
635 strategy->loopToVectorDim[loop] =
636 strategy->vectorSizes.size() - (patternDepth - depthInPattern);
637}
638
639/// Implements a simple strawman strategy for vectorization.
640/// Given a matched pattern `matches` of depth `patternDepth`, this strategy
641/// greedily assigns the fastest varying dimension ** of the vector ** to the
642/// innermost loop in the pattern.
643/// When coupled with a pattern that looks for the fastest varying dimension in
644/// load/store MemRefs, this creates a generic vectorization strategy that works
645/// for any loop in a hierarchy (outermost, innermost or intermediate).
646///
647/// TODO: In the future we should additionally increase the power of the
648/// profitability analysis along 3 directions:
649/// 1. account for loop extents (both static and parametric + annotations);
650/// 2. account for data layout permutations;
651/// 3. account for impact of vectorization on maximal loop fusion.
652/// Then we can quantify the above to build a cost model and search over
653/// strategies.
654static LogicalResult analyzeProfitability(ArrayRef<NestedMatch> matches,
655 unsigned depthInPattern,
656 unsigned patternDepth,
657 VectorizationStrategy *strategy) {
658 for (auto m : matches) {
659 if (failed(analyzeProfitability(m.getMatchedChildren(), depthInPattern + 1,
660 patternDepth, strategy))) {
661 return failure();
662 }
663 vectorizeLoopIfProfitable(m.getMatchedOperation(), depthInPattern,
664 patternDepth, strategy);
665 }
666 return success();
667}
668
669///// end TODO: Hoist to a VectorizationStrategy.cpp when appropriate /////
670
671namespace {
672
673struct VectorizationState {
674
675 VectorizationState(MLIRContext *context) : builder(context) {}
676
677 /// Registers the vector replacement of a scalar operation and its result
678 /// values. Both operations must have the same number of results.
679 ///
680 /// This utility is used to register the replacement for the vast majority of
681 /// the vectorized operations.
682 ///
683 /// Example:
684 /// * 'replaced': %0 = arith.addf %1, %2 : f32
685 /// * 'replacement': %0 = arith.addf %1, %2 : vector<128xf32>
686 void registerOpVectorReplacement(Operation *replaced, Operation *replacement);
687
688 /// Registers the vector replacement of a scalar value. The replacement
689 /// operation should have a single result, which replaces the scalar value.
690 ///
691 /// This utility is used to register the vector replacement of block arguments
692 /// and operation results which are not directly vectorized (i.e., their
693 /// scalar version still exists after vectorization), like uniforms.
694 ///
695 /// Example:
696 /// * 'replaced': block argument or operation outside of the vectorized
697 /// loop.
698 /// * 'replacement': %0 = vector.broadcast %1 : f32 to vector<128xf32>
699 void registerValueVectorReplacement(Value replaced, Operation *replacement);
700
701 /// Registers the vector replacement of a block argument (e.g., iter_args).
702 ///
703 /// Example:
704 /// * 'replaced': 'iter_arg' block argument.
705 /// * 'replacement': vectorized 'iter_arg' block argument.
706 void registerBlockArgVectorReplacement(BlockArgument replaced,
707 BlockArgument replacement);
708
709 /// Registers the scalar replacement of a scalar value. 'replacement' must be
710 /// scalar. Both values must be block arguments. Operation results should be
711 /// replaced using the 'registerOp*' utilitites.
712 ///
713 /// This utility is used to register the replacement of block arguments
714 /// that are within the loop to be vectorized and will continue being scalar
715 /// within the vector loop.
716 ///
717 /// Example:
718 /// * 'replaced': induction variable of a loop to be vectorized.
719 /// * 'replacement': new induction variable in the new vector loop.
720 void registerValueScalarReplacement(BlockArgument replaced,
721 BlockArgument replacement);
722
723 /// Registers the scalar replacement of a scalar result returned from a
724 /// reduction loop. 'replacement' must be scalar.
725 ///
726 /// This utility is used to register the replacement for scalar results of
727 /// vectorized reduction loops with iter_args.
728 ///
729 /// Example 2:
730 /// * 'replaced': %0 = affine.for %i = 0 to 512 iter_args(%x = ...) -> (f32)
731 /// * 'replacement': %1 = vector.reduction <add>, %0 : vector<4xf32> into
732 /// f32
733 void registerLoopResultScalarReplacement(Value replaced, Value replacement);
734
735 /// Returns in 'replacedVals' the scalar replacement for values in
736 /// 'inputVals'.
737 void getScalarValueReplacementsFor(ValueRange inputVals,
738 SmallVectorImpl<Value> &replacedVals);
739
740 /// Erases the scalar loop nest after its successful vectorization.
741 void finishVectorizationPattern(AffineForOp rootLoop);
742
743 // Used to build and insert all the new operations created. The insertion
744 // point is preserved and updated along the vectorization process.
745 OpBuilder builder;
746
747 // Maps input scalar operations to their vector counterparts.
748 DenseMap<Operation *, Operation *> opVectorReplacement;
749 // Maps input scalar values to their vector counterparts.
750 BlockAndValueMapping valueVectorReplacement;
751 // Maps input scalar values to their new scalar counterparts in the vector
752 // loop nest.
753 BlockAndValueMapping valueScalarReplacement;
754 // Maps results of reduction loops to their new scalar counterparts.
755 DenseMap<Value, Value> loopResultScalarReplacement;
756
757 // Maps the newly created vector loops to their vector dimension.
758 DenseMap<Operation *, unsigned> vecLoopToVecDim;
759
760 // Maps the new vectorized loops to the corresponding vector masks if it is
761 // required.
762 DenseMap<Operation *, Value> vecLoopToMask;
763
764 // The strategy drives which loop to vectorize by which amount.
765 const VectorizationStrategy *strategy = nullptr;
766
767private:
768 /// Internal implementation to map input scalar values to new vector or scalar
769 /// values.
770 void registerValueVectorReplacementImpl(Value replaced, Value replacement);
771 void registerValueScalarReplacementImpl(Value replaced, Value replacement);
772};
773
774} // namespace
775
776/// Registers the vector replacement of a scalar operation and its result
777/// values. Both operations must have the same number of results.
778///
779/// This utility is used to register the replacement for the vast majority of
780/// the vectorized operations.
781///
782/// Example:
783/// * 'replaced': %0 = arith.addf %1, %2 : f32
784/// * 'replacement': %0 = arith.addf %1, %2 : vector<128xf32>
785void VectorizationState::registerOpVectorReplacement(Operation *replaced,
786 Operation *replacement) {
787 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ commit vectorized op:\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ commit vectorized op:\n"
; } } while (false)
;
788 LLVM_DEBUG(dbgs() << *replaced << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << *replaced << "\n"; } }
while (false)
;
789 LLVM_DEBUG(dbgs() << "into\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "into\n"; } } while (false)
;
790 LLVM_DEBUG(dbgs() << *replacement << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << *replacement << "\n";
} } while (false)
;
791
792 assert(replaced->getNumResults() == replacement->getNumResults() &&(static_cast <bool> (replaced->getNumResults() == replacement
->getNumResults() && "Unexpected replaced and replacement results"
) ? void (0) : __assert_fail ("replaced->getNumResults() == replacement->getNumResults() && \"Unexpected replaced and replacement results\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 793
, __extension__ __PRETTY_FUNCTION__))
793 "Unexpected replaced and replacement results")(static_cast <bool> (replaced->getNumResults() == replacement
->getNumResults() && "Unexpected replaced and replacement results"
) ? void (0) : __assert_fail ("replaced->getNumResults() == replacement->getNumResults() && \"Unexpected replaced and replacement results\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 793
, __extension__ __PRETTY_FUNCTION__))
;
794 assert(opVectorReplacement.count(replaced) == 0 && "already registered")(static_cast <bool> (opVectorReplacement.count(replaced
) == 0 && "already registered") ? void (0) : __assert_fail
("opVectorReplacement.count(replaced) == 0 && \"already registered\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 794
, __extension__ __PRETTY_FUNCTION__))
;
795 opVectorReplacement[replaced] = replacement;
796
797 for (auto resultTuple :
798 llvm::zip(replaced->getResults(), replacement->getResults()))
799 registerValueVectorReplacementImpl(std::get<0>(resultTuple),
800 std::get<1>(resultTuple));
801}
802
803/// Registers the vector replacement of a scalar value. The replacement
804/// operation should have a single result, which replaces the scalar value.
805///
806/// This utility is used to register the vector replacement of block arguments
807/// and operation results which are not directly vectorized (i.e., their
808/// scalar version still exists after vectorization), like uniforms.
809///
810/// Example:
811/// * 'replaced': block argument or operation outside of the vectorized loop.
812/// * 'replacement': %0 = vector.broadcast %1 : f32 to vector<128xf32>
813void VectorizationState::registerValueVectorReplacement(
814 Value replaced, Operation *replacement) {
815 assert(replacement->getNumResults() == 1 &&(static_cast <bool> (replacement->getNumResults() ==
1 && "Expected single-result replacement") ? void (0
) : __assert_fail ("replacement->getNumResults() == 1 && \"Expected single-result replacement\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 816
, __extension__ __PRETTY_FUNCTION__))
816 "Expected single-result replacement")(static_cast <bool> (replacement->getNumResults() ==
1 && "Expected single-result replacement") ? void (0
) : __assert_fail ("replacement->getNumResults() == 1 && \"Expected single-result replacement\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 816
, __extension__ __PRETTY_FUNCTION__))
;
817 if (Operation *defOp = replaced.getDefiningOp())
818 registerOpVectorReplacement(defOp, replacement);
819 else
820 registerValueVectorReplacementImpl(replaced, replacement->getResult(0));
821}
822
823/// Registers the vector replacement of a block argument (e.g., iter_args).
824///
825/// Example:
826/// * 'replaced': 'iter_arg' block argument.
827/// * 'replacement': vectorized 'iter_arg' block argument.
828void VectorizationState::registerBlockArgVectorReplacement(
829 BlockArgument replaced, BlockArgument replacement) {
830 registerValueVectorReplacementImpl(replaced, replacement);
831}
832
833void VectorizationState::registerValueVectorReplacementImpl(Value replaced,
834 Value replacement) {
835 assert(!valueVectorReplacement.contains(replaced) &&(static_cast <bool> (!valueVectorReplacement.contains(replaced
) && "Vector replacement already registered") ? void (
0) : __assert_fail ("!valueVectorReplacement.contains(replaced) && \"Vector replacement already registered\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 836
, __extension__ __PRETTY_FUNCTION__))
836 "Vector replacement already registered")(static_cast <bool> (!valueVectorReplacement.contains(replaced
) && "Vector replacement already registered") ? void (
0) : __assert_fail ("!valueVectorReplacement.contains(replaced) && \"Vector replacement already registered\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 836
, __extension__ __PRETTY_FUNCTION__))
;
837 assert(replacement.getType().isa<VectorType>() &&(static_cast <bool> (replacement.getType().isa<VectorType
>() && "Expected vector type in vector replacement"
) ? void (0) : __assert_fail ("replacement.getType().isa<VectorType>() && \"Expected vector type in vector replacement\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 838
, __extension__ __PRETTY_FUNCTION__))
838 "Expected vector type in vector replacement")(static_cast <bool> (replacement.getType().isa<VectorType
>() && "Expected vector type in vector replacement"
) ? void (0) : __assert_fail ("replacement.getType().isa<VectorType>() && \"Expected vector type in vector replacement\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 838
, __extension__ __PRETTY_FUNCTION__))
;
839 valueVectorReplacement.map(replaced, replacement);
840}
841
842/// Registers the scalar replacement of a scalar value. 'replacement' must be
843/// scalar. Both values must be block arguments. Operation results should be
844/// replaced using the 'registerOp*' utilitites.
845///
846/// This utility is used to register the replacement of block arguments
847/// that are within the loop to be vectorized and will continue being scalar
848/// within the vector loop.
849///
850/// Example:
851/// * 'replaced': induction variable of a loop to be vectorized.
852/// * 'replacement': new induction variable in the new vector loop.
853void VectorizationState::registerValueScalarReplacement(
854 BlockArgument replaced, BlockArgument replacement) {
855 registerValueScalarReplacementImpl(replaced, replacement);
856}
857
858/// Registers the scalar replacement of a scalar result returned from a
859/// reduction loop. 'replacement' must be scalar.
860///
861/// This utility is used to register the replacement for scalar results of
862/// vectorized reduction loops with iter_args.
863///
864/// Example 2:
865/// * 'replaced': %0 = affine.for %i = 0 to 512 iter_args(%x = ...) -> (f32)
866/// * 'replacement': %1 = vector.reduction <add>, %0 : vector<4xf32> into f32
867void VectorizationState::registerLoopResultScalarReplacement(
868 Value replaced, Value replacement) {
869 assert(isa<AffineForOp>(replaced.getDefiningOp()))(static_cast <bool> (isa<AffineForOp>(replaced.getDefiningOp
())) ? void (0) : __assert_fail ("isa<AffineForOp>(replaced.getDefiningOp())"
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 869
, __extension__ __PRETTY_FUNCTION__))
;
870 assert(loopResultScalarReplacement.count(replaced) == 0 &&(static_cast <bool> (loopResultScalarReplacement.count(
replaced) == 0 && "already registered") ? void (0) : __assert_fail
("loopResultScalarReplacement.count(replaced) == 0 && \"already registered\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 871
, __extension__ __PRETTY_FUNCTION__))
871 "already registered")(static_cast <bool> (loopResultScalarReplacement.count(
replaced) == 0 && "already registered") ? void (0) : __assert_fail
("loopResultScalarReplacement.count(replaced) == 0 && \"already registered\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 871
, __extension__ __PRETTY_FUNCTION__))
;
872 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ will replace a result of the loop "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ will replace a result of the loop "
"with scalar: " << replacement; } } while (false)
873 "with scalar: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ will replace a result of the loop "
"with scalar: " << replacement; } } while (false)
874 << replacement)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ will replace a result of the loop "
"with scalar: " << replacement; } } while (false)
;
875 loopResultScalarReplacement[replaced] = replacement;
876}
877
878void VectorizationState::registerValueScalarReplacementImpl(Value replaced,
879 Value replacement) {
880 assert(!valueScalarReplacement.contains(replaced) &&(static_cast <bool> (!valueScalarReplacement.contains(replaced
) && "Scalar value replacement already registered") ?
void (0) : __assert_fail ("!valueScalarReplacement.contains(replaced) && \"Scalar value replacement already registered\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 881
, __extension__ __PRETTY_FUNCTION__))
881 "Scalar value replacement already registered")(static_cast <bool> (!valueScalarReplacement.contains(replaced
) && "Scalar value replacement already registered") ?
void (0) : __assert_fail ("!valueScalarReplacement.contains(replaced) && \"Scalar value replacement already registered\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 881
, __extension__ __PRETTY_FUNCTION__))
;
882 assert(!replacement.getType().isa<VectorType>() &&(static_cast <bool> (!replacement.getType().isa<VectorType
>() && "Expected scalar type in scalar replacement"
) ? void (0) : __assert_fail ("!replacement.getType().isa<VectorType>() && \"Expected scalar type in scalar replacement\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 883
, __extension__ __PRETTY_FUNCTION__))
883 "Expected scalar type in scalar replacement")(static_cast <bool> (!replacement.getType().isa<VectorType
>() && "Expected scalar type in scalar replacement"
) ? void (0) : __assert_fail ("!replacement.getType().isa<VectorType>() && \"Expected scalar type in scalar replacement\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 883
, __extension__ __PRETTY_FUNCTION__))
;
884 valueScalarReplacement.map(replaced, replacement);
885}
886
887/// Returns in 'replacedVals' the scalar replacement for values in 'inputVals'.
888void VectorizationState::getScalarValueReplacementsFor(
889 ValueRange inputVals, SmallVectorImpl<Value> &replacedVals) {
890 for (Value inputVal : inputVals)
891 replacedVals.push_back(valueScalarReplacement.lookupOrDefault(inputVal));
892}
893
894/// Erases a loop nest, including all its nested operations.
895static void eraseLoopNest(AffineForOp forOp) {
896 LLVM_DEBUG(dbgs() << "[early-vect]+++++ erasing:\n" << forOp << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "[early-vect]+++++ erasing:\n"
<< forOp << "\n"; } } while (false)
;
897 forOp.erase();
898}
899
900/// Erases the scalar loop nest after its successful vectorization.
901void VectorizationState::finishVectorizationPattern(AffineForOp rootLoop) {
902 LLVM_DEBUG(dbgs() << "\n[early-vect] Finalizing vectorization\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect] Finalizing vectorization\n"
; } } while (false)
;
903 eraseLoopNest(rootLoop);
904}
905
906// Apply 'map' with 'mapOperands' returning resulting values in 'results'.
907static void computeMemoryOpIndices(Operation *op, AffineMap map,
908 ValueRange mapOperands,
909 VectorizationState &state,
910 SmallVectorImpl<Value> &results) {
911 for (auto resultExpr : map.getResults()) {
912 auto singleResMap =
913 AffineMap::get(map.getNumDims(), map.getNumSymbols(), resultExpr);
914 auto afOp = state.builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
915 mapOperands);
916 results.push_back(afOp);
917 }
918}
919
920/// Returns a FilterFunctionType that can be used in NestedPattern to match a
921/// loop whose underlying load/store accesses are either invariant or all
922// varying along the `fastestVaryingMemRefDimension`.
923static FilterFunctionType
924isVectorizableLoopPtrFactory(const DenseSet<Operation *> &parallelLoops,
925 int fastestVaryingMemRefDimension) {
926 return [&parallelLoops, fastestVaryingMemRefDimension](Operation &forOp) {
927 auto loop = cast<AffineForOp>(forOp);
928 auto parallelIt = parallelLoops.find(loop);
929 if (parallelIt == parallelLoops.end())
930 return false;
931 int memRefDim = -1;
932 auto vectorizableBody =
933 isVectorizableLoopBody(loop, &memRefDim, vectorTransferPattern());
934 if (!vectorizableBody)
935 return false;
936 return memRefDim == -1 || fastestVaryingMemRefDimension == -1 ||
937 memRefDim == fastestVaryingMemRefDimension;
938 };
939}
940
941/// Returns the vector type resulting from applying the provided vectorization
942/// strategy on the scalar type.
943static VectorType getVectorType(Type scalarTy,
944 const VectorizationStrategy *strategy) {
945 assert(!scalarTy.isa<VectorType>() && "Expected scalar type")(static_cast <bool> (!scalarTy.isa<VectorType>() &&
"Expected scalar type") ? void (0) : __assert_fail ("!scalarTy.isa<VectorType>() && \"Expected scalar type\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 945
, __extension__ __PRETTY_FUNCTION__))
;
946 return VectorType::get(strategy->vectorSizes, scalarTy);
947}
948
949/// Tries to transform a scalar constant into a vector constant. Returns the
950/// vector constant if the scalar type is valid vector element type. Returns
951/// nullptr, otherwise.
952static arith::ConstantOp vectorizeConstant(arith::ConstantOp constOp,
953 VectorizationState &state) {
954 Type scalarTy = constOp.getType();
21
Calling 'Impl::getType'
25
Returning from 'Impl::getType'
955 if (!VectorType::isValidElementType(scalarTy))
26
Taking true branch
956 return nullptr;
27
Calling constructor for 'ConstantOp'
28
Calling constructor for 'Op<mlir::arith::ConstantOp, mlir::OpTrait::ZeroRegions, mlir::OpTrait::OneResult, mlir::OpTrait::OneTypedResult<mlir::Type>::Impl, mlir::OpTrait::ZeroSuccessors, mlir::OpTrait::ZeroOperands, mlir::OpTrait::OpInvariants, mlir::OpTrait::ConstantLike, mlir::MemoryEffectOpInterface::Trait, mlir::OpAsmOpInterface::Trait, mlir::InferIntRangeInterface::Trait, mlir::InferTypeOpInterface::Trait>'
33
Returning from constructor for 'Op<mlir::arith::ConstantOp, mlir::OpTrait::ZeroRegions, mlir::OpTrait::OneResult, mlir::OpTrait::OneTypedResult<mlir::Type>::Impl, mlir::OpTrait::ZeroSuccessors, mlir::OpTrait::ZeroOperands, mlir::OpTrait::OpInvariants, mlir::OpTrait::ConstantLike, mlir::MemoryEffectOpInterface::Trait, mlir::OpAsmOpInterface::Trait, mlir::InferIntRangeInterface::Trait, mlir::InferTypeOpInterface::Trait>'
34
Returning from constructor for 'ConstantOp'
957
958 auto vecTy = getVectorType(scalarTy, state.strategy);
959 auto vecAttr = DenseElementsAttr::get(vecTy, constOp.getValue());
960
961 OpBuilder::InsertionGuard guard(state.builder);
962 Operation *parentOp = state.builder.getInsertionBlock()->getParentOp();
963 // Find the innermost vectorized ancestor loop to insert the vector constant.
964 while (parentOp && !state.vecLoopToVecDim.count(parentOp))
965 parentOp = parentOp->getParentOp();
966 assert(parentOp && state.vecLoopToVecDim.count(parentOp) &&(static_cast <bool> (parentOp && state.vecLoopToVecDim
.count(parentOp) && isa<AffineForOp>(parentOp) &&
"Expected a vectorized for op") ? void (0) : __assert_fail (
"parentOp && state.vecLoopToVecDim.count(parentOp) && isa<AffineForOp>(parentOp) && \"Expected a vectorized for op\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 967
, __extension__ __PRETTY_FUNCTION__))
967 isa<AffineForOp>(parentOp) && "Expected a vectorized for op")(static_cast <bool> (parentOp && state.vecLoopToVecDim
.count(parentOp) && isa<AffineForOp>(parentOp) &&
"Expected a vectorized for op") ? void (0) : __assert_fail (
"parentOp && state.vecLoopToVecDim.count(parentOp) && isa<AffineForOp>(parentOp) && \"Expected a vectorized for op\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 967
, __extension__ __PRETTY_FUNCTION__))
;
968 auto vecForOp = cast<AffineForOp>(parentOp);
969 state.builder.setInsertionPointToStart(vecForOp.getBody());
970 auto newConstOp =
971 state.builder.create<arith::ConstantOp>(constOp.getLoc(), vecAttr);
972
973 // Register vector replacement for future uses in the scope.
974 state.registerOpVectorReplacement(constOp, newConstOp);
975 return newConstOp;
976}
977
978/// Creates a constant vector filled with the neutral elements of the given
979/// reduction. The scalar type of vector elements will be taken from
980/// `oldOperand`.
981static arith::ConstantOp createInitialVector(arith::AtomicRMWKind reductionKind,
982 Value oldOperand,
983 VectorizationState &state) {
984 Type scalarTy = oldOperand.getType();
985 if (!VectorType::isValidElementType(scalarTy))
986 return nullptr;
987
988 Attribute valueAttr = getIdentityValueAttr(
989 reductionKind, scalarTy, state.builder, oldOperand.getLoc());
990 auto vecTy = getVectorType(scalarTy, state.strategy);
991 auto vecAttr = DenseElementsAttr::get(vecTy, valueAttr);
992 auto newConstOp =
993 state.builder.create<arith::ConstantOp>(oldOperand.getLoc(), vecAttr);
994
995 return newConstOp;
996}
997
998/// Creates a mask used to filter out garbage elements in the last iteration
999/// of unaligned loops. If a mask is not required then `nullptr` is returned.
1000/// The mask will be a vector of booleans representing meaningful vector
1001/// elements in the current iteration. It is filled with ones for each iteration
1002/// except for the last one, where it has the form `11...100...0` with the
1003/// number of ones equal to the number of meaningful elements (i.e. the number
1004/// of iterations that would be left in the original loop).
1005static Value createMask(AffineForOp vecForOp, VectorizationState &state) {
1006 assert(state.strategy->vectorSizes.size() == 1 &&(static_cast <bool> (state.strategy->vectorSizes.size
() == 1 && "Creating a mask non-1-D vectors is not supported."
) ? void (0) : __assert_fail ("state.strategy->vectorSizes.size() == 1 && \"Creating a mask non-1-D vectors is not supported.\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1007
, __extension__ __PRETTY_FUNCTION__))
1007 "Creating a mask non-1-D vectors is not supported.")(static_cast <bool> (state.strategy->vectorSizes.size
() == 1 && "Creating a mask non-1-D vectors is not supported."
) ? void (0) : __assert_fail ("state.strategy->vectorSizes.size() == 1 && \"Creating a mask non-1-D vectors is not supported.\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1007
, __extension__ __PRETTY_FUNCTION__))
;
1008 assert(vecForOp.getStep() == state.strategy->vectorSizes[0] &&(static_cast <bool> (vecForOp.getStep() == state.strategy
->vectorSizes[0] && "Creating a mask for loops with non-unit original step size is not "
"supported.") ? void (0) : __assert_fail ("vecForOp.getStep() == state.strategy->vectorSizes[0] && \"Creating a mask for loops with non-unit original step size is not \" \"supported.\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1010
, __extension__ __PRETTY_FUNCTION__))
1009 "Creating a mask for loops with non-unit original step size is not "(static_cast <bool> (vecForOp.getStep() == state.strategy
->vectorSizes[0] && "Creating a mask for loops with non-unit original step size is not "
"supported.") ? void (0) : __assert_fail ("vecForOp.getStep() == state.strategy->vectorSizes[0] && \"Creating a mask for loops with non-unit original step size is not \" \"supported.\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1010
, __extension__ __PRETTY_FUNCTION__))
1010 "supported.")(static_cast <bool> (vecForOp.getStep() == state.strategy
->vectorSizes[0] && "Creating a mask for loops with non-unit original step size is not "
"supported.") ? void (0) : __assert_fail ("vecForOp.getStep() == state.strategy->vectorSizes[0] && \"Creating a mask for loops with non-unit original step size is not \" \"supported.\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1010
, __extension__ __PRETTY_FUNCTION__))
;
1011
1012 // Check if we have already created the mask.
1013 if (Value mask = state.vecLoopToMask.lookup(vecForOp))
1014 return mask;
1015
1016 // If the loop has constant bounds and the original number of iterations is
1017 // divisable by the vector size then we don't need a mask.
1018 if (vecForOp.hasConstantBounds()) {
1019 int64_t originalTripCount =
1020 vecForOp.getConstantUpperBound() - vecForOp.getConstantLowerBound();
1021 if (originalTripCount % vecForOp.getStep() == 0)
1022 return nullptr;
1023 }
1024
1025 OpBuilder::InsertionGuard guard(state.builder);
1026 state.builder.setInsertionPointToStart(vecForOp.getBody());
1027
1028 // We generate the mask using the `vector.create_mask` operation which accepts
1029 // the number of meaningful elements (i.e. the length of the prefix of 1s).
1030 // To compute the number of meaningful elements we subtract the current value
1031 // of the iteration variable from the upper bound of the loop. Example:
1032 //
1033 // // 500 is the upper bound of the loop
1034 // #map = affine_map<(d0) -> (500 - d0)>
1035 // %elems_left = affine.apply #map(%iv)
1036 // %mask = vector.create_mask %elems_left : vector<128xi1>
1037
1038 Location loc = vecForOp.getLoc();
1039
1040 // First we get the upper bound of the loop using `affine.apply` or
1041 // `affine.min`.
1042 AffineMap ubMap = vecForOp.getUpperBoundMap();
1043 Value ub;
1044 if (ubMap.getNumResults() == 1)
1045 ub = state.builder.create<AffineApplyOp>(loc, vecForOp.getUpperBoundMap(),
1046 vecForOp.getUpperBoundOperands());
1047 else
1048 ub = state.builder.create<AffineMinOp>(loc, vecForOp.getUpperBoundMap(),
1049 vecForOp.getUpperBoundOperands());
1050 // Then we compute the number of (original) iterations left in the loop.
1051 AffineExpr subExpr =
1052 state.builder.getAffineDimExpr(0) - state.builder.getAffineDimExpr(1);
1053 Value itersLeft =
1054 makeComposedAffineApply(state.builder, loc, AffineMap::get(2, 0, subExpr),
1055 {ub, vecForOp.getInductionVar()});
1056 // If the affine maps were successfully composed then `ub` is unneeded.
1057 if (ub.use_empty())
1058 ub.getDefiningOp()->erase();
1059 // Finally we create the mask.
1060 Type maskTy = VectorType::get(state.strategy->vectorSizes,
1061 state.builder.getIntegerType(1));
1062 Value mask =
1063 state.builder.create<vector::CreateMaskOp>(loc, maskTy, itersLeft);
1064
1065 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ creating a mask:\n"do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ creating a mask:\n"
<< itersLeft << "\n" << mask << "\n"
; } } while (false)
1066 << itersLeft << "\n"do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ creating a mask:\n"
<< itersLeft << "\n" << mask << "\n"
; } } while (false)
1067 << mask << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ creating a mask:\n"
<< itersLeft << "\n" << mask << "\n"
; } } while (false)
;
1068
1069 state.vecLoopToMask[vecForOp] = mask;
1070 return mask;
1071}
1072
1073/// Returns true if the provided value is vector uniform given the vectorization
1074/// strategy.
1075// TODO: For now, only values that are induction variables of loops not in
1076// `loopToVectorDim` or invariants to all the loops in the vectorization
1077// strategy are considered vector uniforms.
1078static bool isUniformDefinition(Value value,
1079 const VectorizationStrategy *strategy) {
1080 AffineForOp forOp = getForInductionVarOwner(value);
1081 if (forOp && strategy->loopToVectorDim.count(forOp) == 0)
1082 return true;
1083
1084 for (auto loopToDim : strategy->loopToVectorDim) {
1085 auto loop = cast<AffineForOp>(loopToDim.first);
1086 if (!loop.isDefinedOutsideOfLoop(value))
1087 return false;
1088 }
1089 return true;
1090}
1091
1092/// Generates a broadcast op for the provided uniform value using the
1093/// vectorization strategy in 'state'.
1094static Operation *vectorizeUniform(Value uniformVal,
1095 VectorizationState &state) {
1096 OpBuilder::InsertionGuard guard(state.builder);
1097 Value uniformScalarRepl =
1098 state.valueScalarReplacement.lookupOrDefault(uniformVal);
1099 state.builder.setInsertionPointAfterValue(uniformScalarRepl);
1100
1101 auto vectorTy = getVectorType(uniformVal.getType(), state.strategy);
1102 auto bcastOp = state.builder.create<BroadcastOp>(uniformVal.getLoc(),
1103 vectorTy, uniformScalarRepl);
1104 state.registerValueVectorReplacement(uniformVal, bcastOp);
1105 return bcastOp;
1106}
1107
1108/// Tries to vectorize a given `operand` by applying the following logic:
1109/// 1. if the defining operation has been already vectorized, `operand` is
1110/// already in the proper vector form;
1111/// 2. if the `operand` is a constant, returns the vectorized form of the
1112/// constant;
1113/// 3. if the `operand` is uniform, returns a vector broadcast of the `op`;
1114/// 4. otherwise, the vectorization of `operand` is not supported.
1115/// Newly created vector operations are registered in `state` as replacement
1116/// for their scalar counterparts.
1117/// In particular this logic captures some of the use cases where definitions
1118/// that are not scoped under the current pattern are needed to vectorize.
1119/// One such example is top level function constants that need to be splatted.
1120///
1121/// Returns an operand that has been vectorized to match `state`'s strategy if
1122/// vectorization is possible with the above logic. Returns nullptr otherwise.
1123///
1124/// TODO: handle more complex cases.
1125static Value vectorizeOperand(Value operand, VectorizationState &state) {
1126 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorize operand: " << operand)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ vectorize operand: "
<< operand; } } while (false)
;
15
Assuming 'DebugFlag' is false
16
Loop condition is false. Exiting loop
1127 // If this value is already vectorized, we are done.
1128 if (Value vecRepl = state.valueVectorReplacement.lookupOrNull(operand)) {
1129 LLVM_DEBUG(dbgs() << " -> already vectorized: " << vecRepl)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << " -> already vectorized: "
<< vecRepl; } } while (false)
;
1130 return vecRepl;
1131 }
1132
1133 // An vector operand that is not in the replacement map should never reach
1134 // this point. Reaching this point could mean that the code was already
1135 // vectorized and we shouldn't try to vectorize already vectorized code.
1136 assert(!operand.getType().isa<VectorType>() &&(static_cast <bool> (!operand.getType().isa<VectorType
>() && "Vector op not found in replacement map") ?
void (0) : __assert_fail ("!operand.getType().isa<VectorType>() && \"Vector op not found in replacement map\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1137
, __extension__ __PRETTY_FUNCTION__))
17
Taking false branch
18
'?' condition is true
1137 "Vector op not found in replacement map")(static_cast <bool> (!operand.getType().isa<VectorType
>() && "Vector op not found in replacement map") ?
void (0) : __assert_fail ("!operand.getType().isa<VectorType>() && \"Vector op not found in replacement map\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1137
, __extension__ __PRETTY_FUNCTION__))
;
1138
1139 // Vectorize constant.
1140 if (auto constOp = operand.getDefiningOp<arith::ConstantOp>()) {
19
Taking true branch
1141 auto vecConstant = vectorizeConstant(constOp, state);
20
Calling 'vectorizeConstant'
35
Returning from 'vectorizeConstant'
36
'vecConstant' initialized here
1142 LLVM_DEBUG(dbgs() << "-> constant: " << vecConstant)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "-> constant: " <<
vecConstant; } } while (false)
;
37
Assuming 'DebugFlag' is true
38
Assuming the condition is true
39
Taking true branch
40
Null pointer value stored to 'op.state'
41
Calling 'operator<<'
1143 return vecConstant.getResult();
1144 }
1145
1146 // Vectorize uniform values.
1147 if (isUniformDefinition(operand, state.strategy)) {
1148 Operation *vecUniform = vectorizeUniform(operand, state);
1149 LLVM_DEBUG(dbgs() << "-> uniform: " << *vecUniform)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "-> uniform: " << *
vecUniform; } } while (false)
;
1150 return vecUniform->getResult(0);
1151 }
1152
1153 // Check for unsupported block argument scenarios. A supported block argument
1154 // should have been vectorized already.
1155 if (!operand.getDefiningOp())
1156 LLVM_DEBUG(dbgs() << "-> unsupported block argument\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "-> unsupported block argument\n"
; } } while (false)
;
1157 else
1158 // Generic unsupported case.
1159 LLVM_DEBUG(dbgs() << "-> non-vectorizable\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "-> non-vectorizable\n";
} } while (false)
;
1160
1161 return nullptr;
1162}
1163
1164/// Vectorizes an affine load with the vectorization strategy in 'state' by
1165/// generating a 'vector.transfer_read' op with the proper permutation map
1166/// inferred from the indices of the load. The new 'vector.transfer_read' is
1167/// registered as replacement of the scalar load. Returns the newly created
1168/// 'vector.transfer_read' if vectorization was successful. Returns nullptr,
1169/// otherwise.
1170static Operation *vectorizeAffineLoad(AffineLoadOp loadOp,
1171 VectorizationState &state) {
1172 MemRefType memRefType = loadOp.getMemRefType();
1173 Type elementType = memRefType.getElementType();
1174 auto vectorType = VectorType::get(state.strategy->vectorSizes, elementType);
1175
1176 // Replace map operands with operands from the vector loop nest.
1177 SmallVector<Value, 8> mapOperands;
1178 state.getScalarValueReplacementsFor(loadOp.getMapOperands(), mapOperands);
1179
1180 // Compute indices for the transfer op. AffineApplyOp's may be generated.
1181 SmallVector<Value, 8> indices;
1182 indices.reserve(memRefType.getRank());
1183 if (loadOp.getAffineMap() !=
1184 state.builder.getMultiDimIdentityMap(memRefType.getRank()))
1185 computeMemoryOpIndices(loadOp, loadOp.getAffineMap(), mapOperands, state,
1186 indices);
1187 else
1188 indices.append(mapOperands.begin(), mapOperands.end());
1189
1190 // Compute permutation map using the information of new vector loops.
1191 auto permutationMap = makePermutationMap(state.builder.getInsertionBlock(),
1192 indices, state.vecLoopToVecDim);
1193 if (!permutationMap) {
1194 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ can't compute permutationMap\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ can't compute permutationMap\n"
; } } while (false)
;
1195 return nullptr;
1196 }
1197 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ permutationMap: "
; } } while (false)
;
1198 LLVM_DEBUG(permutationMap.print(dbgs()))do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { permutationMap.print(dbgs()); } } while (false
)
;
1199
1200 auto transfer = state.builder.create<vector::TransferReadOp>(
1201 loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, permutationMap);
1202
1203 // Register replacement for future uses in the scope.
1204 state.registerOpVectorReplacement(loadOp, transfer);
1205 return transfer;
1206}
1207
1208/// Vectorizes an affine store with the vectorization strategy in 'state' by
1209/// generating a 'vector.transfer_write' op with the proper permutation map
1210/// inferred from the indices of the store. The new 'vector.transfer_store' is
1211/// registered as replacement of the scalar load. Returns the newly created
1212/// 'vector.transfer_write' if vectorization was successful. Returns nullptr,
1213/// otherwise.
1214static Operation *vectorizeAffineStore(AffineStoreOp storeOp,
1215 VectorizationState &state) {
1216 MemRefType memRefType = storeOp.getMemRefType();
1217 Value vectorValue = vectorizeOperand(storeOp.getValueToStore(), state);
1218 if (!vectorValue)
1219 return nullptr;
1220
1221 // Replace map operands with operands from the vector loop nest.
1222 SmallVector<Value, 8> mapOperands;
1223 state.getScalarValueReplacementsFor(storeOp.getMapOperands(), mapOperands);
1224
1225 // Compute indices for the transfer op. AffineApplyOp's may be generated.
1226 SmallVector<Value, 8> indices;
1227 indices.reserve(memRefType.getRank());
1228 if (storeOp.getAffineMap() !=
1229 state.builder.getMultiDimIdentityMap(memRefType.getRank()))
1230 computeMemoryOpIndices(storeOp, storeOp.getAffineMap(), mapOperands, state,
1231 indices);
1232 else
1233 indices.append(mapOperands.begin(), mapOperands.end());
1234
1235 // Compute permutation map using the information of new vector loops.
1236 auto permutationMap = makePermutationMap(state.builder.getInsertionBlock(),
1237 indices, state.vecLoopToVecDim);
1238 if (!permutationMap)
1239 return nullptr;
1240 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ permutationMap: "
; } } while (false)
;
1241 LLVM_DEBUG(permutationMap.print(dbgs()))do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { permutationMap.print(dbgs()); } } while (false
)
;
1242
1243 auto transfer = state.builder.create<vector::TransferWriteOp>(
1244 storeOp.getLoc(), vectorValue, storeOp.getMemRef(), indices,
1245 permutationMap);
1246 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << transfer)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ vectorized store: "
<< transfer; } } while (false)
;
1247
1248 // Register replacement for future uses in the scope.
1249 state.registerOpVectorReplacement(storeOp, transfer);
1250 return transfer;
1251}
1252
1253/// Returns true if `value` is a constant equal to the neutral element of the
1254/// given vectorizable reduction.
1255static bool isNeutralElementConst(arith::AtomicRMWKind reductionKind,
1256 Value value, VectorizationState &state) {
1257 Type scalarTy = value.getType();
1258 if (!VectorType::isValidElementType(scalarTy))
1259 return false;
1260 Attribute valueAttr = getIdentityValueAttr(reductionKind, scalarTy,
1261 state.builder, value.getLoc());
1262 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(value.getDefiningOp()))
1263 return constOp.getValue() == valueAttr;
1264 return false;
1265}
1266
1267/// Vectorizes a loop with the vectorization strategy in 'state'. A new loop is
1268/// created and registered as replacement for the scalar loop. The builder's
1269/// insertion point is set to the new loop's body so that subsequent vectorized
1270/// operations are inserted into the new loop. If the loop is a vector
1271/// dimension, the step of the newly created loop will reflect the vectorization
1272/// factor used to vectorized that dimension.
1273static Operation *vectorizeAffineForOp(AffineForOp forOp,
1274 VectorizationState &state) {
1275 const VectorizationStrategy &strategy = *state.strategy;
1276 auto loopToVecDimIt = strategy.loopToVectorDim.find(forOp);
1277 bool isLoopVecDim = loopToVecDimIt != strategy.loopToVectorDim.end();
1278
1279 // TODO: Vectorization of reduction loops is not supported for non-unit steps.
1280 if (isLoopVecDim
11.1
'isLoopVecDim' is false
11.1
'isLoopVecDim' is false
11.1
'isLoopVecDim' is false
11.1
'isLoopVecDim' is false
&& forOp.getNumIterOperands() > 0 && forOp.getStep() != 1) {
1281 LLVM_DEBUG(do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ unsupported step size for reduction loop: "
<< forOp.getStep() << "\n"; } } while (false)
1282 dbgs()do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ unsupported step size for reduction loop: "
<< forOp.getStep() << "\n"; } } while (false)
1283 << "\n[early-vect]+++++ unsupported step size for reduction loop: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ unsupported step size for reduction loop: "
<< forOp.getStep() << "\n"; } } while (false)
1284 << forOp.getStep() << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ unsupported step size for reduction loop: "
<< forOp.getStep() << "\n"; } } while (false)
;
1285 return nullptr;
1286 }
1287
1288 // If we are vectorizing a vector dimension, compute a new step for the new
1289 // vectorized loop using the vectorization factor for the vector dimension.
1290 // Otherwise, propagate the step of the scalar loop.
1291 unsigned newStep;
1292 if (isLoopVecDim
11.2
'isLoopVecDim' is false
11.2
'isLoopVecDim' is false
11.2
'isLoopVecDim' is false
11.2
'isLoopVecDim' is false
) {
12
Taking false branch
1293 unsigned vectorDim = loopToVecDimIt->second;
1294 assert(vectorDim < strategy.vectorSizes.size() && "vector dim overflow")(static_cast <bool> (vectorDim < strategy.vectorSizes
.size() && "vector dim overflow") ? void (0) : __assert_fail
("vectorDim < strategy.vectorSizes.size() && \"vector dim overflow\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1294
, __extension__ __PRETTY_FUNCTION__))
;
1295 int64_t forOpVecFactor = strategy.vectorSizes[vectorDim];
1296 newStep = forOp.getStep() * forOpVecFactor;
1297 } else {
1298 newStep = forOp.getStep();
1299 }
1300
1301 // Get information about reduction kinds.
1302 ArrayRef<LoopReduction> reductions;
1303 if (isLoopVecDim
12.1
'isLoopVecDim' is false
12.1
'isLoopVecDim' is false
12.1
'isLoopVecDim' is false
12.1
'isLoopVecDim' is false
&& forOp.getNumIterOperands() > 0) {
1304 auto it = strategy.reductionLoops.find(forOp);
1305 assert(it != strategy.reductionLoops.end() &&(static_cast <bool> (it != strategy.reductionLoops.end(
) && "Reduction descriptors not found when vectorizing a reduction loop"
) ? void (0) : __assert_fail ("it != strategy.reductionLoops.end() && \"Reduction descriptors not found when vectorizing a reduction loop\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1306
, __extension__ __PRETTY_FUNCTION__))
1306 "Reduction descriptors not found when vectorizing a reduction loop")(static_cast <bool> (it != strategy.reductionLoops.end(
) && "Reduction descriptors not found when vectorizing a reduction loop"
) ? void (0) : __assert_fail ("it != strategy.reductionLoops.end() && \"Reduction descriptors not found when vectorizing a reduction loop\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1306
, __extension__ __PRETTY_FUNCTION__))
;
1307 reductions = it->second;
1308 assert(reductions.size() == forOp.getNumIterOperands() &&(static_cast <bool> (reductions.size() == forOp.getNumIterOperands
() && "The size of reductions array must match the number of iter_args"
) ? void (0) : __assert_fail ("reductions.size() == forOp.getNumIterOperands() && \"The size of reductions array must match the number of iter_args\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1309
, __extension__ __PRETTY_FUNCTION__))
1309 "The size of reductions array must match the number of iter_args")(static_cast <bool> (reductions.size() == forOp.getNumIterOperands
() && "The size of reductions array must match the number of iter_args"
) ? void (0) : __assert_fail ("reductions.size() == forOp.getNumIterOperands() && \"The size of reductions array must match the number of iter_args\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1309
, __extension__ __PRETTY_FUNCTION__))
;
1310 }
1311
1312 // Vectorize 'iter_args'.
1313 SmallVector<Value, 8> vecIterOperands;
1314 if (!isLoopVecDim
12.2
'isLoopVecDim' is false
12.2
'isLoopVecDim' is false
12.2
'isLoopVecDim' is false
12.2
'isLoopVecDim' is false
) {
13
Taking true branch
1315 for (auto operand : forOp.getIterOperands())
1316 vecIterOperands.push_back(vectorizeOperand(operand, state));
14
Calling 'vectorizeOperand'
1317 } else {
1318 // For reduction loops we need to pass a vector of neutral elements as an
1319 // initial value of the accumulator. We will add the original initial value
1320 // later.
1321 for (auto redAndOperand : llvm::zip(reductions, forOp.getIterOperands())) {
1322 vecIterOperands.push_back(createInitialVector(
1323 std::get<0>(redAndOperand).kind, std::get<1>(redAndOperand), state));
1324 }
1325 }
1326
1327 auto vecForOp = state.builder.create<AffineForOp>(
1328 forOp.getLoc(), forOp.getLowerBoundOperands(), forOp.getLowerBoundMap(),
1329 forOp.getUpperBoundOperands(), forOp.getUpperBoundMap(), newStep,
1330 vecIterOperands,
1331 /*bodyBuilder=*/[](OpBuilder &, Location, Value, ValueRange) {
1332 // Make sure we don't create a default terminator in the loop body as
1333 // the proper terminator will be added during vectorization.
1334 });
1335
1336 // Register loop-related replacements:
1337 // 1) The new vectorized loop is registered as vector replacement of the
1338 // scalar loop.
1339 // 2) The new iv of the vectorized loop is registered as scalar replacement
1340 // since a scalar copy of the iv will prevail in the vectorized loop.
1341 // TODO: A vector replacement will also be added in the future when
1342 // vectorization of linear ops is supported.
1343 // 3) The new 'iter_args' region arguments are registered as vector
1344 // replacements since they have been vectorized.
1345 // 4) If the loop performs a reduction along the vector dimension, a
1346 // `vector.reduction` or similar op is inserted for each resulting value
1347 // of the loop and its scalar value replaces the corresponding scalar
1348 // result of the loop.
1349 state.registerOpVectorReplacement(forOp, vecForOp);
1350 state.registerValueScalarReplacement(forOp.getInductionVar(),
1351 vecForOp.getInductionVar());
1352 for (auto iterTuple :
1353 llvm ::zip(forOp.getRegionIterArgs(), vecForOp.getRegionIterArgs()))
1354 state.registerBlockArgVectorReplacement(std::get<0>(iterTuple),
1355 std::get<1>(iterTuple));
1356
1357 if (isLoopVecDim) {
1358 for (unsigned i = 0; i < vecForOp.getNumIterOperands(); ++i) {
1359 // First, we reduce the vector returned from the loop into a scalar.
1360 Value reducedRes =
1361 getVectorReductionOp(reductions[i].kind, state.builder,
1362 vecForOp.getLoc(), vecForOp.getResult(i));
1363 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ creating a vector reduction: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ creating a vector reduction: "
<< reducedRes; } } while (false)
1364 << reducedRes)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ creating a vector reduction: "
<< reducedRes; } } while (false)
;
1365 // Then we combine it with the original (scalar) initial value unless it
1366 // is equal to the neutral element of the reduction.
1367 Value origInit = forOp.getOperand(forOp.getNumControlOperands() + i);
1368 Value finalRes = reducedRes;
1369 if (!isNeutralElementConst(reductions[i].kind, origInit, state))
1370 finalRes =
1371 arith::getReductionOp(reductions[i].kind, state.builder,
1372 reducedRes.getLoc(), reducedRes, origInit);
1373 state.registerLoopResultScalarReplacement(forOp.getResult(i), finalRes);
1374 }
1375 }
1376
1377 if (isLoopVecDim)
1378 state.vecLoopToVecDim[vecForOp] = loopToVecDimIt->second;
1379
1380 // Change insertion point so that upcoming vectorized instructions are
1381 // inserted into the vectorized loop's body.
1382 state.builder.setInsertionPointToStart(vecForOp.getBody());
1383
1384 // If this is a reduction loop then we may need to create a mask to filter out
1385 // garbage in the last iteration.
1386 if (isLoopVecDim && forOp.getNumIterOperands() > 0)
1387 createMask(vecForOp, state);
1388
1389 return vecForOp;
1390}
1391
1392/// Vectorizes arbitrary operation by plain widening. We apply generic type
1393/// widening of all its results and retrieve the vector counterparts for all its
1394/// operands.
1395static Operation *widenOp(Operation *op, VectorizationState &state) {
1396 SmallVector<Type, 8> vectorTypes;
1397 for (Value result : op->getResults())
1398 vectorTypes.push_back(
1399 VectorType::get(state.strategy->vectorSizes, result.getType()));
1400
1401 SmallVector<Value, 8> vectorOperands;
1402 for (Value operand : op->getOperands()) {
1403 Value vecOperand = vectorizeOperand(operand, state);
1404 if (!vecOperand) {
1405 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ an operand failed vectorize\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ an operand failed vectorize\n"
; } } while (false)
;
1406 return nullptr;
1407 }
1408 vectorOperands.push_back(vecOperand);
1409 }
1410
1411 // Create a clone of the op with the proper operands and return types.
1412 // TODO: The following assumes there is always an op with a fixed
1413 // name that works both in scalar mode and vector mode.
1414 // TODO: Is it worth considering an Operation.clone operation which
1415 // changes the type so we can promote an Operation with less boilerplate?
1416 Operation *vecOp =
1417 state.builder.create(op->getLoc(), op->getName().getIdentifier(),
1418 vectorOperands, vectorTypes, op->getAttrs());
1419 state.registerOpVectorReplacement(op, vecOp);
1420 return vecOp;
1421}
1422
1423/// Vectorizes a yield operation by widening its types. The builder's insertion
1424/// point is set after the vectorized parent op to continue vectorizing the
1425/// operations after the parent op. When vectorizing a reduction loop a mask may
1426/// be used to prevent adding garbage values to the accumulator.
1427static Operation *vectorizeAffineYieldOp(AffineYieldOp yieldOp,
1428 VectorizationState &state) {
1429 Operation *newYieldOp = widenOp(yieldOp, state);
1430 Operation *newParentOp = state.builder.getInsertionBlock()->getParentOp();
1431
1432 // If there is a mask for this loop then we must prevent garbage values from
1433 // being added to the accumulator by inserting `select` operations, for
1434 // example:
1435 //
1436 // %val_masked = select %mask, %val, %neutralCst : vector<128xi1>,
1437 // vector<128xf32>
1438 // %res = arith.addf %acc, %val_masked : vector<128xf32>
1439 // affine.yield %res : vector<128xf32>
1440 //
1441 if (Value mask = state.vecLoopToMask.lookup(newParentOp)) {
1442 state.builder.setInsertionPoint(newYieldOp);
1443 for (unsigned i = 0; i < newYieldOp->getNumOperands(); ++i) {
1444 SmallVector<Operation *> combinerOps;
1445 Value reducedVal = matchReduction(
1446 cast<AffineForOp>(newParentOp).getRegionIterArgs(), i, combinerOps);
1447 assert(reducedVal && "expect non-null value for parallel reduction loop")(static_cast <bool> (reducedVal && "expect non-null value for parallel reduction loop"
) ? void (0) : __assert_fail ("reducedVal && \"expect non-null value for parallel reduction loop\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1447
, __extension__ __PRETTY_FUNCTION__))
;
1448 assert(combinerOps.size() == 1 && "expect only one combiner op")(static_cast <bool> (combinerOps.size() == 1 &&
"expect only one combiner op") ? void (0) : __assert_fail ("combinerOps.size() == 1 && \"expect only one combiner op\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1448
, __extension__ __PRETTY_FUNCTION__))
;
1449 // IterOperands are neutral element vectors.
1450 Value neutralVal = cast<AffineForOp>(newParentOp).getIterOperands()[i];
1451 state.builder.setInsertionPoint(combinerOps.back());
1452 Value maskedReducedVal = state.builder.create<arith::SelectOp>(
1453 reducedVal.getLoc(), mask, reducedVal, neutralVal);
1454 LLVM_DEBUG(do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ masking an input to a binary op that"
"produces value for a yield Op: " << maskedReducedVal;
} } while (false)
1455 dbgs() << "\n[early-vect]+++++ masking an input to a binary op that"do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ masking an input to a binary op that"
"produces value for a yield Op: " << maskedReducedVal;
} } while (false)
1456 "produces value for a yield Op: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ masking an input to a binary op that"
"produces value for a yield Op: " << maskedReducedVal;
} } while (false)
1457 << maskedReducedVal)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ masking an input to a binary op that"
"produces value for a yield Op: " << maskedReducedVal;
} } while (false)
;
1458 combinerOps.back()->replaceUsesOfWith(reducedVal, maskedReducedVal);
1459 }
1460 }
1461
1462 state.builder.setInsertionPointAfter(newParentOp);
1463 return newYieldOp;
1464}
1465
1466/// Encodes Operation-specific behavior for vectorization. In general we
1467/// assume that all operands of an op must be vectorized but this is not
1468/// always true. In the future, it would be nice to have a trait that
1469/// describes how a particular operation vectorizes. For now we implement the
1470/// case distinction here. Returns a vectorized form of an operation or
1471/// nullptr if vectorization fails.
1472// TODO: consider adding a trait to Op to describe how it gets vectorized.
1473// Maybe some Ops are not vectorizable or require some tricky logic, we cannot
1474// do one-off logic here; ideally it would be TableGen'd.
1475static Operation *vectorizeOneOperation(Operation *op,
1476 VectorizationState &state) {
1477 // Sanity checks.
1478 assert(!isa<vector::TransferReadOp>(op) &&(static_cast <bool> (!isa<vector::TransferReadOp>
(op) && "vector.transfer_read cannot be further vectorized"
) ? void (0) : __assert_fail ("!isa<vector::TransferReadOp>(op) && \"vector.transfer_read cannot be further vectorized\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1479
, __extension__ __PRETTY_FUNCTION__))
4
Assuming 'op' is not a 'TransferReadOp'
5
'?' condition is true
1479 "vector.transfer_read cannot be further vectorized")(static_cast <bool> (!isa<vector::TransferReadOp>
(op) && "vector.transfer_read cannot be further vectorized"
) ? void (0) : __assert_fail ("!isa<vector::TransferReadOp>(op) && \"vector.transfer_read cannot be further vectorized\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1479
, __extension__ __PRETTY_FUNCTION__))
;
1480 assert(!isa<vector::TransferWriteOp>(op) &&(static_cast <bool> (!isa<vector::TransferWriteOp>
(op) && "vector.transfer_write cannot be further vectorized"
) ? void (0) : __assert_fail ("!isa<vector::TransferWriteOp>(op) && \"vector.transfer_write cannot be further vectorized\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1481
, __extension__ __PRETTY_FUNCTION__))
6
Assuming 'op' is not a 'TransferWriteOp'
7
'?' condition is true
1481 "vector.transfer_write cannot be further vectorized")(static_cast <bool> (!isa<vector::TransferWriteOp>
(op) && "vector.transfer_write cannot be further vectorized"
) ? void (0) : __assert_fail ("!isa<vector::TransferWriteOp>(op) && \"vector.transfer_write cannot be further vectorized\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1481
, __extension__ __PRETTY_FUNCTION__))
;
1482
1483 if (auto loadOp = dyn_cast<AffineLoadOp>(op))
8
Taking false branch
1484 return vectorizeAffineLoad(loadOp, state);
1485 if (auto storeOp = dyn_cast<AffineStoreOp>(op))
9
Taking false branch
1486 return vectorizeAffineStore(storeOp, state);
1487 if (auto forOp = dyn_cast<AffineForOp>(op))
10
Taking true branch
1488 return vectorizeAffineForOp(forOp, state);
11
Calling 'vectorizeAffineForOp'
1489 if (auto yieldOp = dyn_cast<AffineYieldOp>(op))
1490 return vectorizeAffineYieldOp(yieldOp, state);
1491 if (auto constant = dyn_cast<arith::ConstantOp>(op))
1492 return vectorizeConstant(constant, state);
1493
1494 // Other ops with regions are not supported.
1495 if (op->getNumRegions() != 0)
1496 return nullptr;
1497
1498 return widenOp(op, state);
1499}
1500
1501/// Recursive implementation to convert all the nested loops in 'match' to a 2D
1502/// vector container that preserves the relative nesting level of each loop with
1503/// respect to the others in 'match'. 'currentLevel' is the nesting level that
1504/// will be assigned to the loop in the current 'match'.
1505static void
1506getMatchedAffineLoopsRec(NestedMatch match, unsigned currentLevel,
1507 std::vector<SmallVector<AffineForOp, 2>> &loops) {
1508 // Add a new empty level to the output if it doesn't exist already.
1509 assert(currentLevel <= loops.size() && "Unexpected currentLevel")(static_cast <bool> (currentLevel <= loops.size() &&
"Unexpected currentLevel") ? void (0) : __assert_fail ("currentLevel <= loops.size() && \"Unexpected currentLevel\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1509
, __extension__ __PRETTY_FUNCTION__))
;
1510 if (currentLevel == loops.size())
1511 loops.emplace_back();
1512
1513 // Add current match and recursively visit its children.
1514 loops[currentLevel].push_back(cast<AffineForOp>(match.getMatchedOperation()));
1515 for (auto childMatch : match.getMatchedChildren()) {
1516 getMatchedAffineLoopsRec(childMatch, currentLevel + 1, loops);
1517 }
1518}
1519
1520/// Converts all the nested loops in 'match' to a 2D vector container that
1521/// preserves the relative nesting level of each loop with respect to the others
1522/// in 'match'. This means that every loop in 'loops[i]' will have a parent loop
1523/// in 'loops[i-1]'. A loop in 'loops[i]' may or may not have a child loop in
1524/// 'loops[i+1]'.
1525static void
1526getMatchedAffineLoops(NestedMatch match,
1527 std::vector<SmallVector<AffineForOp, 2>> &loops) {
1528 getMatchedAffineLoopsRec(match, /*currLoopDepth=*/0, loops);
1529}
1530
1531/// Internal implementation to vectorize affine loops from a single loop nest
1532/// using an n-D vectorization strategy.
1533static LogicalResult
1534vectorizeLoopNest(std::vector<SmallVector<AffineForOp, 2>> &loops,
1535 const VectorizationStrategy &strategy) {
1536 assert(loops[0].size() == 1 && "Expected single root loop")(static_cast <bool> (loops[0].size() == 1 && "Expected single root loop"
) ? void (0) : __assert_fail ("loops[0].size() == 1 && \"Expected single root loop\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1536
, __extension__ __PRETTY_FUNCTION__))
;
1537 AffineForOp rootLoop = loops[0][0];
1538 VectorizationState state(rootLoop.getContext());
1539 state.builder.setInsertionPointAfter(rootLoop);
1540 state.strategy = &strategy;
1541
1542 // Since patterns are recursive, they can very well intersect.
1543 // Since we do not want a fully greedy strategy in general, we decouple
1544 // pattern matching, from profitability analysis, from application.
1545 // As a consequence we must check that each root pattern is still
1546 // vectorizable. If a pattern is not vectorizable anymore, we just skip it.
1547 // TODO: implement a non-greedy profitability analysis that keeps only
1548 // non-intersecting patterns.
1549 if (!isVectorizableLoopBody(rootLoop, vectorTransferPattern())) {
1550 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ loop is not vectorizable")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ loop is not vectorizable"
; } } while (false)
;
1551 return failure();
1552 }
1553
1554 //////////////////////////////////////////////////////////////////////////////
1555 // Vectorize the scalar loop nest following a topological order. A new vector
1556 // loop nest with the vectorized operations is created along the process. If
1557 // vectorization succeeds, the scalar loop nest is erased. If vectorization
1558 // fails, the vector loop nest is erased and the scalar loop nest is not
1559 // modified.
1560 //////////////////////////////////////////////////////////////////////////////
1561
1562 auto opVecResult = rootLoop.walk<WalkOrder::PreOrder>([&](Operation *op) {
1563 LLVM_DEBUG(dbgs() << "[early-vect]+++++ Vectorizing: " << *op)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "[early-vect]+++++ Vectorizing: "
<< *op; } } while (false)
;
1
Assuming 'DebugFlag' is false
2
Loop condition is false. Exiting loop
1564 Operation *vectorOp = vectorizeOneOperation(op, state);
3
Calling 'vectorizeOneOperation'
1565 if (!vectorOp) {
1566 LLVM_DEBUG(do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "[early-vect]+++++ failed vectorizing the operation: "
<< *op << "\n"; } } while (false)
1567 dbgs() << "[early-vect]+++++ failed vectorizing the operation: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "[early-vect]+++++ failed vectorizing the operation: "
<< *op << "\n"; } } while (false)
1568 << *op << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "[early-vect]+++++ failed vectorizing the operation: "
<< *op << "\n"; } } while (false)
;
1569 return WalkResult::interrupt();
1570 }
1571
1572 return WalkResult::advance();
1573 });
1574
1575 if (opVecResult.wasInterrupted()) {
1576 LLVM_DEBUG(dbgs() << "[early-vect]+++++ failed vectorization for: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "[early-vect]+++++ failed vectorization for: "
<< rootLoop << "\n"; } } while (false)
1577 << rootLoop << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "[early-vect]+++++ failed vectorization for: "
<< rootLoop << "\n"; } } while (false)
;
1578 // Erase vector loop nest if it was created.
1579 auto vecRootLoopIt = state.opVectorReplacement.find(rootLoop);
1580 if (vecRootLoopIt != state.opVectorReplacement.end())
1581 eraseLoopNest(cast<AffineForOp>(vecRootLoopIt->second));
1582
1583 return failure();
1584 }
1585
1586 // Replace results of reduction loops with the scalar values computed using
1587 // `vector.reduce` or similar ops.
1588 for (auto resPair : state.loopResultScalarReplacement)
1589 resPair.first.replaceAllUsesWith(resPair.second);
1590
1591 assert(state.opVectorReplacement.count(rootLoop) == 1 &&(static_cast <bool> (state.opVectorReplacement.count(rootLoop
) == 1 && "Expected vector replacement for loop nest"
) ? void (0) : __assert_fail ("state.opVectorReplacement.count(rootLoop) == 1 && \"Expected vector replacement for loop nest\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1592
, __extension__ __PRETTY_FUNCTION__))
1592 "Expected vector replacement for loop nest")(static_cast <bool> (state.opVectorReplacement.count(rootLoop
) == 1 && "Expected vector replacement for loop nest"
) ? void (0) : __assert_fail ("state.opVectorReplacement.count(rootLoop) == 1 && \"Expected vector replacement for loop nest\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1592
, __extension__ __PRETTY_FUNCTION__))
;
1593 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ success vectorizing pattern")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ success vectorizing pattern"
; } } while (false)
;
1594 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorization result:\n"do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ vectorization result:\n"
<< *state.opVectorReplacement[rootLoop]; } } while (false
)
1595 << *state.opVectorReplacement[rootLoop])do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ vectorization result:\n"
<< *state.opVectorReplacement[rootLoop]; } } while (false
)
;
1596
1597 // Finish this vectorization pattern.
1598 state.finishVectorizationPattern(rootLoop);
1599 return success();
1600}
1601
1602/// Extracts the matched loops and vectorizes them following a topological
1603/// order. A new vector loop nest will be created if vectorization succeeds. The
1604/// original loop nest won't be modified in any case.
1605static LogicalResult vectorizeRootMatch(NestedMatch m,
1606 const VectorizationStrategy &strategy) {
1607 std::vector<SmallVector<AffineForOp, 2>> loopsToVectorize;
1608 getMatchedAffineLoops(m, loopsToVectorize);
1609 return vectorizeLoopNest(loopsToVectorize, strategy);
1610}
1611
1612/// Traverses all the loop matches and classifies them into intersection
1613/// buckets. Two matches intersect if any of them encloses the other one. A
1614/// match intersects with a bucket if the match intersects with the root
1615/// (outermost) loop in that bucket.
1616static void computeIntersectionBuckets(
1617 ArrayRef<NestedMatch> matches,
1618 std::vector<SmallVector<NestedMatch, 8>> &intersectionBuckets) {
1619 assert(intersectionBuckets.empty() && "Expected empty output")(static_cast <bool> (intersectionBuckets.empty() &&
"Expected empty output") ? void (0) : __assert_fail ("intersectionBuckets.empty() && \"Expected empty output\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1619
, __extension__ __PRETTY_FUNCTION__))
;
1620 // Keeps track of the root (outermost) loop of each bucket.
1621 SmallVector<AffineForOp, 8> bucketRoots;
1622
1623 for (const NestedMatch &match : matches) {
1624 AffineForOp matchRoot = cast<AffineForOp>(match.getMatchedOperation());
1625 bool intersects = false;
1626 for (int i = 0, end = intersectionBuckets.size(); i < end; ++i) {
1627 AffineForOp bucketRoot = bucketRoots[i];
1628 // Add match to the bucket if the bucket root encloses the match root.
1629 if (bucketRoot->isAncestor(matchRoot)) {
1630 intersectionBuckets[i].push_back(match);
1631 intersects = true;
1632 break;
1633 }
1634 // Add match to the bucket if the match root encloses the bucket root. The
1635 // match root becomes the new bucket root.
1636 if (matchRoot->isAncestor(bucketRoot)) {
1637 bucketRoots[i] = matchRoot;
1638 intersectionBuckets[i].push_back(match);
1639 intersects = true;
1640 break;
1641 }
1642 }
1643
1644 // Match doesn't intersect with any existing bucket. Create a new bucket for
1645 // it.
1646 if (!intersects) {
1647 bucketRoots.push_back(matchRoot);
1648 intersectionBuckets.emplace_back();
1649 intersectionBuckets.back().push_back(match);
1650 }
1651 }
1652}
1653
1654/// Internal implementation to vectorize affine loops in 'loops' using the n-D
1655/// vectorization factors in 'vectorSizes'. By default, each vectorization
1656/// factor is applied inner-to-outer to the loops of each loop nest.
1657/// 'fastestVaryingPattern' can be optionally used to provide a different loop
1658/// vectorization order. `reductionLoops` can be provided to specify loops which
1659/// can be vectorized along the reduction dimension.
1660static void vectorizeLoops(Operation *parentOp, DenseSet<Operation *> &loops,
1661 ArrayRef<int64_t> vectorSizes,
1662 ArrayRef<int64_t> fastestVaryingPattern,
1663 const ReductionLoopMap &reductionLoops) {
1664 assert((reductionLoops.empty() || vectorSizes.size() == 1) &&(static_cast <bool> ((reductionLoops.empty() || vectorSizes
.size() == 1) && "Vectorizing reductions is supported only for 1-D vectors"
) ? void (0) : __assert_fail ("(reductionLoops.empty() || vectorSizes.size() == 1) && \"Vectorizing reductions is supported only for 1-D vectors\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1665
, __extension__ __PRETTY_FUNCTION__))
1665 "Vectorizing reductions is supported only for 1-D vectors")(static_cast <bool> ((reductionLoops.empty() || vectorSizes
.size() == 1) && "Vectorizing reductions is supported only for 1-D vectors"
) ? void (0) : __assert_fail ("(reductionLoops.empty() || vectorSizes.size() == 1) && \"Vectorizing reductions is supported only for 1-D vectors\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1665
, __extension__ __PRETTY_FUNCTION__))
;
1666
1667 // Compute 1-D, 2-D or 3-D loop pattern to be matched on the target loops.
1668 Optional<NestedPattern> pattern =
1669 makePattern(loops, vectorSizes.size(), fastestVaryingPattern);
1670 if (!pattern) {
1671 LLVM_DEBUG(dbgs() << "\n[early-vect] pattern couldn't be computed\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect] pattern couldn't be computed\n"
; } } while (false)
;
1672 return;
1673 }
1674
1675 LLVM_DEBUG(dbgs() << "\n******************************************")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n******************************************"
; } } while (false)
;
1676 LLVM_DEBUG(dbgs() << "\n******************************************")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n******************************************"
; } } while (false)
;
1677 LLVM_DEBUG(dbgs() << "\n[early-vect] new pattern on parent op\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect] new pattern on parent op\n"
; } } while (false)
;
1678 LLVM_DEBUG(dbgs() << *parentOp << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << *parentOp << "\n"; } }
while (false)
;
1679
1680 unsigned patternDepth = pattern->getDepth();
1681
1682 // Compute all the pattern matches and classify them into buckets of
1683 // intersecting matches.
1684 SmallVector<NestedMatch, 32> allMatches;
1685 pattern->match(parentOp, &allMatches);
1686 std::vector<SmallVector<NestedMatch, 8>> intersectionBuckets;
1687 computeIntersectionBuckets(allMatches, intersectionBuckets);
1688
1689 // Iterate over all buckets and vectorize the matches eagerly. We can only
1690 // vectorize one match from each bucket since all the matches within a bucket
1691 // intersect.
1692 for (auto &intersectingMatches : intersectionBuckets) {
1693 for (NestedMatch &match : intersectingMatches) {
1694 VectorizationStrategy strategy;
1695 // TODO: depending on profitability, elect to reduce the vector size.
1696 strategy.vectorSizes.assign(vectorSizes.begin(), vectorSizes.end());
1697 strategy.reductionLoops = reductionLoops;
1698 if (failed(analyzeProfitability(match.getMatchedChildren(), 1,
1699 patternDepth, &strategy))) {
1700 continue;
1701 }
1702 vectorizeLoopIfProfitable(match.getMatchedOperation(), 0, patternDepth,
1703 &strategy);
1704 // Vectorize match. Skip the rest of intersecting matches in the bucket if
1705 // vectorization succeeded.
1706 // TODO: if pattern does not apply, report it; alter the cost/benefit.
1707 // TODO: some diagnostics if failure to vectorize occurs.
1708 if (succeeded(vectorizeRootMatch(match, strategy)))
1709 break;
1710 }
1711 }
1712
1713 LLVM_DEBUG(dbgs() << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n"; } } while (false)
;
1714}
1715
1716/// Applies vectorization to the current function by searching over a bunch of
1717/// predetermined patterns.
1718void Vectorize::runOnOperation() {
1719 func::FuncOp f = getOperation();
1720 if (!fastestVaryingPattern.empty() &&
1721 fastestVaryingPattern.size() != vectorSizes.size()) {
1722 f.emitRemark("Fastest varying pattern specified with different size than "
1723 "the vector size.");
1724 return signalPassFailure();
1725 }
1726
1727 if (vectorizeReductions && vectorSizes.size() != 1) {
1728 f.emitError("Vectorizing reductions is supported only for 1-D vectors.");
1729 return signalPassFailure();
1730 }
1731
1732 DenseSet<Operation *> parallelLoops;
1733 ReductionLoopMap reductionLoops;
1734
1735 // If 'vectorize-reduction=true' is provided, we also populate the
1736 // `reductionLoops` map.
1737 if (vectorizeReductions) {
1738 f.walk([&parallelLoops, &reductionLoops](AffineForOp loop) {
1739 SmallVector<LoopReduction, 2> reductions;
1740 if (isLoopParallel(loop, &reductions)) {
1741 parallelLoops.insert(loop);
1742 // If it's not a reduction loop, adding it to the map is not necessary.
1743 if (!reductions.empty())
1744 reductionLoops[loop] = reductions;
1745 }
1746 });
1747 } else {
1748 f.walk([&parallelLoops](AffineForOp loop) {
1749 if (isLoopParallel(loop))
1750 parallelLoops.insert(loop);
1751 });
1752 }
1753
1754 // Thread-safe RAII local context, BumpPtrAllocator freed on exit.
1755 NestedPatternContext mlContext;
1756 vectorizeLoops(f, parallelLoops, vectorSizes, fastestVaryingPattern,
1757 reductionLoops);
1758}
1759
1760/// Verify that affine loops in 'loops' meet the nesting criteria expected by
1761/// SuperVectorizer:
1762/// * There must be at least one loop.
1763/// * There must be a single root loop (nesting level 0).
1764/// * Each loop at a given nesting level must be nested in a loop from a
1765/// previous nesting level.
1766static LogicalResult
1767verifyLoopNesting(const std::vector<SmallVector<AffineForOp, 2>> &loops) {
1768 // Expected at least one loop.
1769 if (loops.empty())
1770 return failure();
1771
1772 // Expected only one root loop.
1773 if (loops[0].size() != 1)
1774 return failure();
1775
1776 // Traverse loops outer-to-inner to check some invariants.
1777 for (int i = 1, end = loops.size(); i < end; ++i) {
1778 for (AffineForOp loop : loops[i]) {
1779 // Check that each loop at this level is nested in one of the loops from
1780 // the previous level.
1781 if (none_of(loops[i - 1], [&](AffineForOp maybeParent) {
1782 return maybeParent->isProperAncestor(loop);
1783 }))
1784 return failure();
1785
1786 // Check that each loop at this level is not nested in another loop from
1787 // this level.
1788 for (AffineForOp sibling : loops[i]) {
1789 if (sibling->isProperAncestor(loop))
1790 return failure();
1791 }
1792 }
1793 }
1794
1795 return success();
1796}
1797
1798namespace mlir {
1799
1800/// External utility to vectorize affine loops in 'loops' using the n-D
1801/// vectorization factors in 'vectorSizes'. By default, each vectorization
1802/// factor is applied inner-to-outer to the loops of each loop nest.
1803/// 'fastestVaryingPattern' can be optionally used to provide a different loop
1804/// vectorization order.
1805/// If `reductionLoops` is not empty, the given reduction loops may be
1806/// vectorized along the reduction dimension.
1807/// TODO: Vectorizing reductions is supported only for 1-D vectorization.
1808void vectorizeAffineLoops(Operation *parentOp, DenseSet<Operation *> &loops,
1809 ArrayRef<int64_t> vectorSizes,
1810 ArrayRef<int64_t> fastestVaryingPattern,
1811 const ReductionLoopMap &reductionLoops) {
1812 // Thread-safe RAII local context, BumpPtrAllocator freed on exit.
1813 NestedPatternContext mlContext;
1814 vectorizeLoops(parentOp, loops, vectorSizes, fastestVaryingPattern,
1815 reductionLoops);
1816}
1817
1818/// External utility to vectorize affine loops from a single loop nest using an
1819/// n-D vectorization strategy (see doc in VectorizationStrategy definition).
1820/// Loops are provided in a 2D vector container. The first dimension represents
1821/// the nesting level relative to the loops to be vectorized. The second
1822/// dimension contains the loops. This means that:
1823/// a) every loop in 'loops[i]' must have a parent loop in 'loops[i-1]',
1824/// b) a loop in 'loops[i]' may or may not have a child loop in 'loops[i+1]'.
1825///
1826/// For example, for the following loop nest:
1827///
1828/// func @vec2d(%in0: memref<64x128x512xf32>, %in1: memref<64x128x128xf32>,
1829/// %out0: memref<64x128x512xf32>,
1830/// %out1: memref<64x128x128xf32>) {
1831/// affine.for %i0 = 0 to 64 {
1832/// affine.for %i1 = 0 to 128 {
1833/// affine.for %i2 = 0 to 512 {
1834/// %ld = affine.load %in0[%i0, %i1, %i2] : memref<64x128x512xf32>
1835/// affine.store %ld, %out0[%i0, %i1, %i2] : memref<64x128x512xf32>
1836/// }
1837/// affine.for %i3 = 0 to 128 {
1838/// %ld = affine.load %in1[%i0, %i1, %i3] : memref<64x128x128xf32>
1839/// affine.store %ld, %out1[%i0, %i1, %i3] : memref<64x128x128xf32>
1840/// }
1841/// }
1842/// }
1843/// return
1844/// }
1845///
1846/// loops = {{%i0}, {%i2, %i3}}, to vectorize the outermost and the two
1847/// innermost loops;
1848/// loops = {{%i1}, {%i2, %i3}}, to vectorize the middle and the two innermost
1849/// loops;
1850/// loops = {{%i2}}, to vectorize only the first innermost loop;
1851/// loops = {{%i3}}, to vectorize only the second innermost loop;
1852/// loops = {{%i1}}, to vectorize only the middle loop.
1853LogicalResult
1854vectorizeAffineLoopNest(std::vector<SmallVector<AffineForOp, 2>> &loops,
1855 const VectorizationStrategy &strategy) {
1856 // Thread-safe RAII local context, BumpPtrAllocator freed on exit.
1857 NestedPatternContext mlContext;
1858 if (failed(verifyLoopNesting(loops)))
1859 return failure();
1860 return vectorizeLoopNest(loops, strategy);
1861}
1862
1863} // namespace mlir

/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/mlir/include/mlir/IR/OpDefinition.h

1//===- OpDefinition.h - Classes for defining concrete Op types --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements helper classes for implementing the "Op" types. This
10// includes the Op type, which is the base class for Op class definitions,
11// as well as number of traits in the OpTrait namespace that provide a
12// declarative way to specify properties of Ops.
13//
14// The purpose of these types are to allow light-weight implementation of
15// concrete ops (like DimOp) with very little boilerplate.
16//
17//===----------------------------------------------------------------------===//
18
19#ifndef MLIR_IR_OPDEFINITION_H
20#define MLIR_IR_OPDEFINITION_H
21
22#include "mlir/IR/Dialect.h"
23#include "mlir/IR/Operation.h"
24#include "llvm/Support/PointerLikeTypeTraits.h"
25
26#include <type_traits>
27
28namespace mlir {
29class Builder;
30class OpBuilder;
31
32/// This class implements `Optional` functionality for ParseResult. We don't
33/// directly use Optional here, because it provides an implicit conversion
34/// to 'bool' which we want to avoid. This class is used to implement tri-state
35/// 'parseOptional' functions that may have a failure mode when parsing that
36/// shouldn't be attributed to "not present".
37class OptionalParseResult {
38public:
39 OptionalParseResult() = default;
40 OptionalParseResult(LogicalResult result) : impl(result) {}
41 OptionalParseResult(ParseResult result) : impl(result) {}
42 OptionalParseResult(const InFlightDiagnostic &)
43 : OptionalParseResult(failure()) {}
44 OptionalParseResult(llvm::NoneType) : impl(llvm::None) {}
45
46 /// Returns true if we contain a valid ParseResult value.
47 bool has_value() const { return impl.has_value(); }
48 LLVM_DEPRECATED("Use has_value instead", "has_value")__attribute__((deprecated("Use has_value instead", "has_value"
)))
bool hasValue() const {
49 return impl.has_value();
50 }
51
52 /// Access the internal ParseResult value.
53 ParseResult value() const { return impl.value(); }
54 LLVM_DEPRECATED("Use value instead", "value")__attribute__((deprecated("Use value instead", "value"))) ParseResult getValue() const {
55 return impl.value();
56 }
57 ParseResult operator*() const { return value(); }
58
59private:
60 Optional<ParseResult> impl;
61};
62
63// These functions are out-of-line utilities, which avoids them being template
64// instantiated/duplicated.
65namespace impl {
66/// Insert an operation, generated by `buildTerminatorOp`, at the end of the
67/// region's only block if it does not have a terminator already. If the region
68/// is empty, insert a new block first. `buildTerminatorOp` should return the
69/// terminator operation to insert.
70void ensureRegionTerminator(
71 Region &region, OpBuilder &builder, Location loc,
72 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp);
73void ensureRegionTerminator(
74 Region &region, Builder &builder, Location loc,
75 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp);
76
77} // namespace impl
78
79/// This is the concrete base class that holds the operation pointer and has
80/// non-generic methods that only depend on State (to avoid having them
81/// instantiated on template types that don't affect them.
82///
83/// This also has the fallback implementations of customization hooks for when
84/// they aren't customized.
85class OpState {
86public:
87 /// Ops are pointer-like, so we allow conversion to bool.
88 explicit operator bool() { return getOperation() != nullptr; }
89
90 /// This implicitly converts to Operation*.
91 operator Operation *() const { return state; }
92
93 /// Shortcut of `->` to access a member of Operation.
94 Operation *operator->() const { return state; }
95
96 /// Return the operation that this refers to.
97 Operation *getOperation() { return state; }
98
99 /// Return the context this operation belongs to.
100 MLIRContext *getContext() { return getOperation()->getContext(); }
101
102 /// Print the operation to the given stream.
103 void print(raw_ostream &os, OpPrintingFlags flags = llvm::None) {
104 state->print(os, flags);
43
Called C++ object pointer is null
105 }
106 void print(raw_ostream &os, AsmState &asmState) {
107 state->print(os, asmState);
108 }
109
110 /// Dump this operation.
111 void dump() { state->dump(); }
112
113 /// The source location the operation was defined or derived from.
114 Location getLoc() { return state->getLoc(); }
115
116 /// Return true if there are no users of any results of this operation.
117 bool use_empty() { return state->use_empty(); }
118
119 /// Remove this operation from its parent block and delete it.
120 void erase() { state->erase(); }
121
122 /// Emit an error with the op name prefixed, like "'dim' op " which is
123 /// convenient for verifiers.
124 InFlightDiagnostic emitOpError(const Twine &message = {});
125
126 /// Emit an error about fatal conditions with this operation, reporting up to
127 /// any diagnostic handlers that may be listening.
128 InFlightDiagnostic emitError(const Twine &message = {});
129
130 /// Emit a warning about this operation, reporting up to any diagnostic
131 /// handlers that may be listening.
132 InFlightDiagnostic emitWarning(const Twine &message = {});
133
134 /// Emit a remark about this operation, reporting up to any diagnostic
135 /// handlers that may be listening.
136 InFlightDiagnostic emitRemark(const Twine &message = {});
137
138 /// Walk the operation by calling the callback for each nested operation
139 /// (including this one), block or region, depending on the callback provided.
140 /// Regions, blocks and operations at the same nesting level are visited in
141 /// lexicographical order. The walk order for enclosing regions, blocks and
142 /// operations with respect to their nested ones is specified by 'Order'
143 /// (post-order by default). A callback on a block or operation is allowed to
144 /// erase that block or operation if either:
145 /// * the walk is in post-order, or
146 /// * the walk is in pre-order and the walk is skipped after the erasure.
147 /// See Operation::walk for more details.
148 template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
149 typename RetT = detail::walkResultType<FnT>>
150 std::enable_if_t<llvm::function_traits<std::decay_t<FnT>>::num_args == 1,
151 RetT>
152 walk(FnT &&callback) {
153 return state->walk<Order>(std::forward<FnT>(callback));
154 }
155
156 /// Generic walker with a stage aware callback. Walk the operation by calling
157 /// the callback for each nested operation (including this one) N+1 times,
158 /// where N is the number of regions attached to that operation.
159 ///
160 /// The callback method can take any of the following forms:
161 /// void(Operation *, const WalkStage &) : Walk all operation opaquely
162 /// * op.walk([](Operation *nestedOp, const WalkStage &stage) { ...});
163 /// void(OpT, const WalkStage &) : Walk all operations of the given derived
164 /// type.
165 /// * op.walk([](ReturnOp returnOp, const WalkStage &stage) { ...});
166 /// WalkResult(Operation*|OpT, const WalkStage &stage) : Walk operations,
167 /// but allow for interruption/skipping.
168 /// * op.walk([](... op, const WalkStage &stage) {
169 /// // Skip the walk of this op based on some invariant.
170 /// if (some_invariant)
171 /// return WalkResult::skip();
172 /// // Interrupt, i.e cancel, the walk based on some invariant.
173 /// if (another_invariant)
174 /// return WalkResult::interrupt();
175 /// return WalkResult::advance();
176 /// });
177 template <typename FnT, typename RetT = detail::walkResultType<FnT>>
178 std::enable_if_t<llvm::function_traits<std::decay_t<FnT>>::num_args == 2,
179 RetT>
180 walk(FnT &&callback) {
181 return state->walk(std::forward<FnT>(callback));
182 }
183
184 // These are default implementations of customization hooks.
185public:
186 /// This hook returns any canonicalization pattern rewrites that the operation
187 /// supports, for use by the canonicalization pass.
188 static void getCanonicalizationPatterns(RewritePatternSet &results,
189 MLIRContext *context) {}
190
191 /// This hook populates any unset default attrs.
192 static void populateDefaultAttrs(const RegisteredOperationName &,
193 NamedAttrList &) {}
194
195protected:
196 /// If the concrete type didn't implement a custom verifier hook, just fall
197 /// back to this one which accepts everything.
198 LogicalResult verify() { return success(); }
199 LogicalResult verifyRegions() { return success(); }
200
201 /// Parse the custom form of an operation. Unless overridden, this method will
202 /// first try to get an operation parser from the op's dialect. Otherwise the
203 /// custom assembly form of an op is always rejected. Op implementations
204 /// should implement this to return failure. On success, they should fill in
205 /// result with the fields to use.
206 static ParseResult parse(OpAsmParser &parser, OperationState &result);
207
208 /// Print the operation. Unless overridden, this method will first try to get
209 /// an operation printer from the dialect. Otherwise, it prints the operation
210 /// in generic form.
211 static void print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect);
212
213 /// Print an operation name, eliding the dialect prefix if necessary.
214 static void printOpName(Operation *op, OpAsmPrinter &p,
215 StringRef defaultDialect);
216
217 /// Mutability management is handled by the OpWrapper/OpConstWrapper classes,
218 /// so we can cast it away here.
219 explicit OpState(Operation *state) : state(state) {}
31
Null pointer value stored to 'vecConstant.state'
220
221private:
222 Operation *state;
223
224 /// Allow access to internal hook implementation methods.
225 friend RegisteredOperationName;
226};
227
228// Allow comparing operators.
229inline bool operator==(OpState lhs, OpState rhs) {
230 return lhs.getOperation() == rhs.getOperation();
231}
232inline bool operator!=(OpState lhs, OpState rhs) {
233 return lhs.getOperation() != rhs.getOperation();
234}
235
236raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr);
237
238/// This class represents a single result from folding an operation.
239class OpFoldResult : public PointerUnion<Attribute, Value> {
240 using PointerUnion<Attribute, Value>::PointerUnion;
241
242public:
243 void dump() { llvm::errs() << *this << "\n"; }
244};
245
246/// Allow printing to a stream.
247inline raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr) {
248 if (Value value = ofr.dyn_cast<Value>())
249 value.print(os);
250 else
251 ofr.dyn_cast<Attribute>().print(os);
252 return os;
253}
254
255/// Allow printing to a stream.
256inline raw_ostream &operator<<(raw_ostream &os, OpState op) {
257 op.print(os, OpPrintingFlags().useLocalScope());
42
Calling 'OpState::print'
258 return os;
259}
260
261//===----------------------------------------------------------------------===//
262// Operation Trait Types
263//===----------------------------------------------------------------------===//
264
265namespace OpTrait {
266
267// These functions are out-of-line implementations of the methods in the
268// corresponding trait classes. This avoids them being template
269// instantiated/duplicated.
270namespace impl {
271OpFoldResult foldIdempotent(Operation *op);
272OpFoldResult foldInvolution(Operation *op);
273LogicalResult verifyZeroOperands(Operation *op);
274LogicalResult verifyOneOperand(Operation *op);
275LogicalResult verifyNOperands(Operation *op, unsigned numOperands);
276LogicalResult verifyIsIdempotent(Operation *op);
277LogicalResult verifyIsInvolution(Operation *op);
278LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands);
279LogicalResult verifyOperandsAreFloatLike(Operation *op);
280LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op);
281LogicalResult verifySameTypeOperands(Operation *op);
282LogicalResult verifyZeroRegions(Operation *op);
283LogicalResult verifyOneRegion(Operation *op);
284LogicalResult verifyNRegions(Operation *op, unsigned numRegions);
285LogicalResult verifyAtLeastNRegions(Operation *op, unsigned numRegions);
286LogicalResult verifyZeroResults(Operation *op);
287LogicalResult verifyOneResult(Operation *op);
288LogicalResult verifyNResults(Operation *op, unsigned numOperands);
289LogicalResult verifyAtLeastNResults(Operation *op, unsigned numOperands);
290LogicalResult verifySameOperandsShape(Operation *op);
291LogicalResult verifySameOperandsAndResultShape(Operation *op);
292LogicalResult verifySameOperandsElementType(Operation *op);
293LogicalResult verifySameOperandsAndResultElementType(Operation *op);
294LogicalResult verifySameOperandsAndResultType(Operation *op);
295LogicalResult verifyResultsAreBoolLike(Operation *op);
296LogicalResult verifyResultsAreFloatLike(Operation *op);
297LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op);
298LogicalResult verifyIsTerminator(Operation *op);
299LogicalResult verifyZeroSuccessors(Operation *op);
300LogicalResult verifyOneSuccessor(Operation *op);
301LogicalResult verifyNSuccessors(Operation *op, unsigned numSuccessors);
302LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors);
303LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName,
304 StringRef valueGroupName,
305 size_t expectedCount);
306LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName);
307LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName);
308LogicalResult verifyNoRegionArguments(Operation *op);
309LogicalResult verifyElementwise(Operation *op);
310LogicalResult verifyIsIsolatedFromAbove(Operation *op);
311} // namespace impl
312
313/// Helper class for implementing traits. Clients are not expected to interact
314/// with this directly, so its members are all protected.
315template <typename ConcreteType, template <typename> class TraitType>
316class TraitBase {
317protected:
318 /// Return the ultimate Operation being worked on.
319 Operation *getOperation() {
320 auto *concrete = static_cast<ConcreteType *>(this);
321 return concrete->getOperation();
322 }
323};
324
325//===----------------------------------------------------------------------===//
326// Operand Traits
327
328namespace detail {
329/// Utility trait base that provides accessors for derived traits that have
330/// multiple operands.
331template <typename ConcreteType, template <typename> class TraitType>
332struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> {
333 using operand_iterator = Operation::operand_iterator;
334 using operand_range = Operation::operand_range;
335 using operand_type_iterator = Operation::operand_type_iterator;
336 using operand_type_range = Operation::operand_type_range;
337
338 /// Return the number of operands.
339 unsigned getNumOperands() { return this->getOperation()->getNumOperands(); }
340
341 /// Return the operand at index 'i'.
342 Value getOperand(unsigned i) { return this->getOperation()->getOperand(i); }
343
344 /// Set the operand at index 'i' to 'value'.
345 void setOperand(unsigned i, Value value) {
346 this->getOperation()->setOperand(i, value);
347 }
348
349 /// Operand iterator access.
350 operand_iterator operand_begin() {
351 return this->getOperation()->operand_begin();
352 }
353 operand_iterator operand_end() { return this->getOperation()->operand_end(); }
354 operand_range getOperands() { return this->getOperation()->getOperands(); }
355
356 /// Operand type access.
357 operand_type_iterator operand_type_begin() {
358 return this->getOperation()->operand_type_begin();
359 }
360 operand_type_iterator operand_type_end() {
361 return this->getOperation()->operand_type_end();
362 }
363 operand_type_range getOperandTypes() {
364 return this->getOperation()->getOperandTypes();
365 }
366};
367} // namespace detail
368
369/// `verifyInvariantsImpl` verifies the invariants like the types, attrs, .etc.
370/// It should be run after core traits and before any other user defined traits.
371/// In order to run it in the correct order, wrap it with OpInvariants trait so
372/// that tblgen will be able to put it in the right order.
373template <typename ConcreteType>
374class OpInvariants : public TraitBase<ConcreteType, OpInvariants> {
375public:
376 static LogicalResult verifyTrait(Operation *op) {
377 return cast<ConcreteType>(op).verifyInvariantsImpl();
378 }
379};
380
381/// This class provides the API for ops that are known to have no
382/// SSA operand.
383template <typename ConcreteType>
384class ZeroOperands : public TraitBase<ConcreteType, ZeroOperands> {
385public:
386 static LogicalResult verifyTrait(Operation *op) {
387 return impl::verifyZeroOperands(op);
388 }
389
390private:
391 // Disable these.
392 void getOperand() {}
393 void setOperand() {}
394};
395
396/// This class provides the API for ops that are known to have exactly one
397/// SSA operand.
398template <typename ConcreteType>
399class OneOperand : public TraitBase<ConcreteType, OneOperand> {
400public:
401 Value getOperand() { return this->getOperation()->getOperand(0); }
402
403 void setOperand(Value value) { this->getOperation()->setOperand(0, value); }
404
405 static LogicalResult verifyTrait(Operation *op) {
406 return impl::verifyOneOperand(op);
407 }
408};
409
410/// This class provides the API for ops that are known to have a specified
411/// number of operands. This is used as a trait like this:
412///
413/// class FooOp : public Op<FooOp, OpTrait::NOperands<2>::Impl> {
414///
415template <unsigned N>
416class NOperands {
417public:
418 static_assert(N > 1, "use ZeroOperands/OneOperand for N < 2");
419
420 template <typename ConcreteType>
421 class Impl
422 : public detail::MultiOperandTraitBase<ConcreteType, NOperands<N>::Impl> {
423 public:
424 static LogicalResult verifyTrait(Operation *op) {
425 return impl::verifyNOperands(op, N);
426 }
427 };
428};
429
430/// This class provides the API for ops that are known to have a at least a
431/// specified number of operands. This is used as a trait like this:
432///
433/// class FooOp : public Op<FooOp, OpTrait::AtLeastNOperands<2>::Impl> {
434///
435template <unsigned N>
436class AtLeastNOperands {
437public:
438 template <typename ConcreteType>
439 class Impl : public detail::MultiOperandTraitBase<ConcreteType,
440 AtLeastNOperands<N>::Impl> {
441 public:
442 static LogicalResult verifyTrait(Operation *op) {
443 return impl::verifyAtLeastNOperands(op, N);
444 }
445 };
446};
447
448/// This class provides the API for ops which have an unknown number of
449/// SSA operands.
450template <typename ConcreteType>
451class VariadicOperands
452 : public detail::MultiOperandTraitBase<ConcreteType, VariadicOperands> {};
453
454//===----------------------------------------------------------------------===//
455// Region Traits
456
457/// This class provides verification for ops that are known to have zero
458/// regions.
459template <typename ConcreteType>
460class ZeroRegions : public TraitBase<ConcreteType, ZeroRegions> {
461public:
462 static LogicalResult verifyTrait(Operation *op) {
463 return impl::verifyZeroRegions(op);
464 }
465};
466
467namespace detail {
468/// Utility trait base that provides accessors for derived traits that have
469/// multiple regions.
470template <typename ConcreteType, template <typename> class TraitType>
471struct MultiRegionTraitBase : public TraitBase<ConcreteType, TraitType> {
472 using region_iterator = MutableArrayRef<Region>;
473 using region_range = RegionRange;
474
475 /// Return the number of regions.
476 unsigned getNumRegions() { return this->getOperation()->getNumRegions(); }
477
478 /// Return the region at `index`.
479 Region &getRegion(unsigned i) { return this->getOperation()->getRegion(i); }
480
481 /// Region iterator access.
482 region_iterator region_begin() {
483 return this->getOperation()->region_begin();
484 }
485 region_iterator region_end() { return this->getOperation()->region_end(); }
486 region_range getRegions() { return this->getOperation()->getRegions(); }
487};
488} // namespace detail
489
490/// This class provides APIs for ops that are known to have a single region.
491template <typename ConcreteType>
492class OneRegion : public TraitBase<ConcreteType, OneRegion> {
493public:
494 Region &getRegion() { return this->getOperation()->getRegion(0); }
495
496 /// Returns a range of operations within the region of this operation.
497 auto getOps() { return getRegion().getOps(); }
498 template <typename OpT>
499 auto getOps() {
500 return getRegion().template getOps<OpT>();
501 }
502
503 static LogicalResult verifyTrait(Operation *op) {
504 return impl::verifyOneRegion(op);
505 }
506};
507
508/// This class provides the API for ops that are known to have a specified
509/// number of regions.
510template <unsigned N>
511class NRegions {
512public:
513 static_assert(N > 1, "use ZeroRegions/OneRegion for N < 2");
514
515 template <typename ConcreteType>
516 class Impl
517 : public detail::MultiRegionTraitBase<ConcreteType, NRegions<N>::Impl> {
518 public:
519 static LogicalResult verifyTrait(Operation *op) {
520 return impl::verifyNRegions(op, N);
521 }
522 };
523};
524
525/// This class provides APIs for ops that are known to have at least a specified
526/// number of regions.
527template <unsigned N>
528class AtLeastNRegions {
529public:
530 template <typename ConcreteType>
531 class Impl : public detail::MultiRegionTraitBase<ConcreteType,
532 AtLeastNRegions<N>::Impl> {
533 public:
534 static LogicalResult verifyTrait(Operation *op) {
535 return impl::verifyAtLeastNRegions(op, N);
536 }
537 };
538};
539
540/// This class provides the API for ops which have an unknown number of
541/// regions.
542template <typename ConcreteType>
543class VariadicRegions
544 : public detail::MultiRegionTraitBase<ConcreteType, VariadicRegions> {};
545
546//===----------------------------------------------------------------------===//
547// Result Traits
548
549/// This class provides return value APIs for ops that are known to have
550/// zero results.
551template <typename ConcreteType>
552class ZeroResults : public TraitBase<ConcreteType, ZeroResults> {
553public:
554 static LogicalResult verifyTrait(Operation *op) {
555 return impl::verifyZeroResults(op);
556 }
557};
558
559namespace detail {
560/// Utility trait base that provides accessors for derived traits that have
561/// multiple results.
562template <typename ConcreteType, template <typename> class TraitType>
563struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
564 using result_iterator = Operation::result_iterator;
565 using result_range = Operation::result_range;
566 using result_type_iterator = Operation::result_type_iterator;
567 using result_type_range = Operation::result_type_range;
568
569 /// Return the number of results.
570 unsigned getNumResults() { return this->getOperation()->getNumResults(); }
571
572 /// Return the result at index 'i'.
573 Value getResult(unsigned i) { return this->getOperation()->getResult(i); }
574
575 /// Replace all uses of results of this operation with the provided 'values'.
576 /// 'values' may correspond to an existing operation, or a range of 'Value'.
577 template <typename ValuesT>
578 void replaceAllUsesWith(ValuesT &&values) {
579 this->getOperation()->replaceAllUsesWith(std::forward<ValuesT>(values));
580 }
581
582 /// Return the type of the `i`-th result.
583 Type getType(unsigned i) { return getResult(i).getType(); }
584
585 /// Result iterator access.
586 result_iterator result_begin() {
587 return this->getOperation()->result_begin();
588 }
589 result_iterator result_end() { return this->getOperation()->result_end(); }
590 result_range getResults() { return this->getOperation()->getResults(); }
591
592 /// Result type access.
593 result_type_iterator result_type_begin() {
594 return this->getOperation()->result_type_begin();
595 }
596 result_type_iterator result_type_end() {
597 return this->getOperation()->result_type_end();
598 }
599 result_type_range getResultTypes() {
600 return this->getOperation()->getResultTypes();
601 }
602};
603} // namespace detail
604
605/// This class provides return value APIs for ops that are known to have a
606/// single result. ResultType is the concrete type returned by getType().
607template <typename ConcreteType>
608class OneResult : public TraitBase<ConcreteType, OneResult> {
609public:
610 Value getResult() { return this->getOperation()->getResult(0); }
611
612 /// If the operation returns a single value, then the Op can be implicitly
613 /// converted to an Value. This yields the value of the only result.
614 operator Value() { return getResult(); }
615
616 /// Replace all uses of 'this' value with the new value, updating anything
617 /// in the IR that uses 'this' to use the other value instead. When this
618 /// returns there are zero uses of 'this'.
619 void replaceAllUsesWith(Value newValue) {
620 getResult().replaceAllUsesWith(newValue);
621 }
622
623 /// Replace all uses of 'this' value with the result of 'op'.
624 void replaceAllUsesWith(Operation *op) {
625 this->getOperation()->replaceAllUsesWith(op);
626 }
627
628 static LogicalResult verifyTrait(Operation *op) {
629 return impl::verifyOneResult(op);
630 }
631};
632
633/// This trait is used for return value APIs for ops that are known to have a
634/// specific type other than `Type`. This allows the "getType()" member to be
635/// more specific for an op. This should be used in conjunction with OneResult,
636/// and occur in the trait list before OneResult.
637template <typename ResultType>
638class OneTypedResult {
639public:
640 /// This class provides return value APIs for ops that are known to have a
641 /// single result. ResultType is the concrete type returned by getType().
642 template <typename ConcreteType>
643 class Impl
644 : public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> {
645 public:
646 ResultType getType() {
647 auto resultTy = this->getOperation()->getResult(0).getType();
648 return resultTy.template cast<ResultType>();
22
Calling 'Type::cast'
24
Returning from 'Type::cast'
649 }
650 };
651};
652
653/// This class provides the API for ops that are known to have a specified
654/// number of results. This is used as a trait like this:
655///
656/// class FooOp : public Op<FooOp, OpTrait::NResults<2>::Impl> {
657///
658template <unsigned N>
659class NResults {
660public:
661 static_assert(N > 1, "use ZeroResults/OneResult for N < 2");
662
663 template <typename ConcreteType>
664 class Impl
665 : public detail::MultiResultTraitBase<ConcreteType, NResults<N>::Impl> {
666 public:
667 static LogicalResult verifyTrait(Operation *op) {
668 return impl::verifyNResults(op, N);
669 }
670 };
671};
672
673/// This class provides the API for ops that are known to have at least a
674/// specified number of results. This is used as a trait like this:
675///
676/// class FooOp : public Op<FooOp, OpTrait::AtLeastNResults<2>::Impl> {
677///
678template <unsigned N>
679class AtLeastNResults {
680public:
681 template <typename ConcreteType>
682 class Impl : public detail::MultiResultTraitBase<ConcreteType,
683 AtLeastNResults<N>::Impl> {
684 public:
685 static LogicalResult verifyTrait(Operation *op) {
686 return impl::verifyAtLeastNResults(op, N);
687 }
688 };
689};
690
691/// This class provides the API for ops which have an unknown number of
692/// results.
693template <typename ConcreteType>
694class VariadicResults
695 : public detail::MultiResultTraitBase<ConcreteType, VariadicResults> {};
696
697//===----------------------------------------------------------------------===//
698// Terminator Traits
699
700/// This class indicates that the regions associated with this op don't have
701/// terminators.
702template <typename ConcreteType>
703class NoTerminator : public TraitBase<ConcreteType, NoTerminator> {};
704
705/// This class provides the API for ops that are known to be terminators.
706template <typename ConcreteType>
707class IsTerminator : public TraitBase<ConcreteType, IsTerminator> {
708public:
709 static LogicalResult verifyTrait(Operation *op) {
710 return impl::verifyIsTerminator(op);
711 }
712};
713
714/// This class provides verification for ops that are known to have zero
715/// successors.
716template <typename ConcreteType>
717class ZeroSuccessors : public TraitBase<ConcreteType, ZeroSuccessors> {
718public:
719 static LogicalResult verifyTrait(Operation *op) {
720 return impl::verifyZeroSuccessors(op);
721 }
722};
723
724namespace detail {
725/// Utility trait base that provides accessors for derived traits that have
726/// multiple successors.
727template <typename ConcreteType, template <typename> class TraitType>
728struct MultiSuccessorTraitBase : public TraitBase<ConcreteType, TraitType> {
729 using succ_iterator = Operation::succ_iterator;
730 using succ_range = SuccessorRange;
731
732 /// Return the number of successors.
733 unsigned getNumSuccessors() {
734 return this->getOperation()->getNumSuccessors();
735 }
736
737 /// Return the successor at `index`.
738 Block *getSuccessor(unsigned i) {
739 return this->getOperation()->getSuccessor(i);
740 }
741
742 /// Set the successor at `index`.
743 void setSuccessor(Block *block, unsigned i) {
744 return this->getOperation()->setSuccessor(block, i);
745 }
746
747 /// Successor iterator access.
748 succ_iterator succ_begin() { return this->getOperation()->succ_begin(); }
749 succ_iterator succ_end() { return this->getOperation()->succ_end(); }
750 succ_range getSuccessors() { return this->getOperation()->getSuccessors(); }
751};
752} // namespace detail
753
754/// This class provides APIs for ops that are known to have a single successor.
755template <typename ConcreteType>
756class OneSuccessor : public TraitBase<ConcreteType, OneSuccessor> {
757public:
758 Block *getSuccessor() { return this->getOperation()->getSuccessor(0); }
759 void setSuccessor(Block *succ) {
760 this->getOperation()->setSuccessor(succ, 0);
761 }
762
763 static LogicalResult verifyTrait(Operation *op) {
764 return impl::verifyOneSuccessor(op);
765 }
766};
767
768/// This class provides the API for ops that are known to have a specified
769/// number of successors.
770template <unsigned N>
771class NSuccessors {
772public:
773 static_assert(N > 1, "use ZeroSuccessors/OneSuccessor for N < 2");
774
775 template <typename ConcreteType>
776 class Impl : public detail::MultiSuccessorTraitBase<ConcreteType,
777 NSuccessors<N>::Impl> {
778 public:
779 static LogicalResult verifyTrait(Operation *op) {
780 return impl::verifyNSuccessors(op, N);
781 }
782 };
783};
784
785/// This class provides APIs for ops that are known to have at least a specified
786/// number of successors.
787template <unsigned N>
788class AtLeastNSuccessors {
789public:
790 template <typename ConcreteType>
791 class Impl
792 : public detail::MultiSuccessorTraitBase<ConcreteType,
793 AtLeastNSuccessors<N>::Impl> {
794 public:
795 static LogicalResult verifyTrait(Operation *op) {
796 return impl::verifyAtLeastNSuccessors(op, N);
797 }
798 };
799};
800
801/// This class provides the API for ops which have an unknown number of
802/// successors.
803template <typename ConcreteType>
804class VariadicSuccessors
805 : public detail::MultiSuccessorTraitBase<ConcreteType, VariadicSuccessors> {
806};
807
808//===----------------------------------------------------------------------===//
809// SingleBlock
810
811/// This class provides APIs and verifiers for ops with regions having a single
812/// block.
813template <typename ConcreteType>
814struct SingleBlock : public TraitBase<ConcreteType, SingleBlock> {
815public:
816 static LogicalResult verifyTrait(Operation *op) {
817 for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) {
818 Region &region = op->getRegion(i);
819
820 // Empty regions are fine.
821 if (region.empty())
822 continue;
823
824 // Non-empty regions must contain a single basic block.
825 if (!llvm::hasSingleElement(region))
826 return op->emitOpError("expects region #")
827 << i << " to have 0 or 1 blocks";
828
829 if (!ConcreteType::template hasTrait<NoTerminator>()) {
830 Block &block = region.front();
831 if (block.empty())
832 return op->emitOpError() << "expects a non-empty block";
833 }
834 }
835 return success();
836 }
837
838 Block *getBody(unsigned idx = 0) {
839 Region &region = this->getOperation()->getRegion(idx);
840 assert(!region.empty() && "unexpected empty region")(static_cast <bool> (!region.empty() && "unexpected empty region"
) ? void (0) : __assert_fail ("!region.empty() && \"unexpected empty region\""
, "mlir/include/mlir/IR/OpDefinition.h", 840, __extension__ __PRETTY_FUNCTION__
))
;
841 return &region.front();
842 }
843 Region &getBodyRegion(unsigned idx = 0) {
844 return this->getOperation()->getRegion(idx);
845 }
846
847 //===------------------------------------------------------------------===//
848 // Single Region Utilities
849 //===------------------------------------------------------------------===//
850
851 /// The following are a set of methods only enabled when the parent
852 /// operation has a single region. Each of these methods take an additional
853 /// template parameter that represents the concrete operation so that we
854 /// can use SFINAE to disable the methods for non-single region operations.
855 template <typename OpT, typename T = void>
856 using enable_if_single_region =
857 typename std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>;
858
859 template <typename OpT = ConcreteType>
860 enable_if_single_region<OpT, Block::iterator> begin() {
861 return getBody()->begin();
862 }
863 template <typename OpT = ConcreteType>
864 enable_if_single_region<OpT, Block::iterator> end() {
865 return getBody()->end();
866 }
867 template <typename OpT = ConcreteType>
868 enable_if_single_region<OpT, Operation &> front() {
869 return *begin();
870 }
871
872 /// Insert the operation into the back of the body.
873 template <typename OpT = ConcreteType>
874 enable_if_single_region<OpT> push_back(Operation *op) {
875 insert(Block::iterator(getBody()->end()), op);
876 }
877
878 /// Insert the operation at the given insertion point.
879 template <typename OpT = ConcreteType>
880 enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) {
881 insert(Block::iterator(insertPt), op);
882 }
883 template <typename OpT = ConcreteType>
884 enable_if_single_region<OpT> insert(Block::iterator insertPt, Operation *op) {
885 getBody()->getOperations().insert(insertPt, op);
886 }
887};
888
889//===----------------------------------------------------------------------===//
890// SingleBlockImplicitTerminator
891
892/// This class provides APIs and verifiers for ops with regions having a single
893/// block that must terminate with `TerminatorOpType`.
894template <typename TerminatorOpType>
895struct SingleBlockImplicitTerminator {
896 template <typename ConcreteType>
897 class Impl : public SingleBlock<ConcreteType> {
898 private:
899 using Base = SingleBlock<ConcreteType>;
900 /// Builds a terminator operation without relying on OpBuilder APIs to avoid
901 /// cyclic header inclusion.
902 static Operation *buildTerminator(OpBuilder &builder, Location loc) {
903 OperationState state(loc, TerminatorOpType::getOperationName());
904 TerminatorOpType::build(builder, state);
905 return Operation::create(state);
906 }
907
908 public:
909 /// The type of the operation used as the implicit terminator type.
910 using ImplicitTerminatorOpT = TerminatorOpType;
911
912 static LogicalResult verifyRegionTrait(Operation *op) {
913 if (failed(Base::verifyTrait(op)))
914 return failure();
915 for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) {
916 Region &region = op->getRegion(i);
917 // Empty regions are fine.
918 if (region.empty())
919 continue;
920 Operation &terminator = region.front().back();
921 if (isa<TerminatorOpType>(terminator))
922 continue;
923
924 return op->emitOpError("expects regions to end with '" +
925 TerminatorOpType::getOperationName() +
926 "', found '" +
927 terminator.getName().getStringRef() + "'")
928 .attachNote()
929 << "in custom textual format, the absence of terminator implies "
930 "'"
931 << TerminatorOpType::getOperationName() << '\'';
932 }
933
934 return success();
935 }
936
937 /// Ensure that the given region has the terminator required by this trait.
938 /// If OpBuilder is provided, use it to build the terminator and notify the
939 /// OpBuilder listeners accordingly. If only a Builder is provided, locally
940 /// construct an OpBuilder with no listeners; this should only be used if no
941 /// OpBuilder is available at the call site, e.g., in the parser.
942 static void ensureTerminator(Region &region, Builder &builder,
943 Location loc) {
944 ::mlir::impl::ensureRegionTerminator(region, builder, loc,
945 buildTerminator);
946 }
947 static void ensureTerminator(Region &region, OpBuilder &builder,
948 Location loc) {
949 ::mlir::impl::ensureRegionTerminator(region, builder, loc,
950 buildTerminator);
951 }
952
953 //===------------------------------------------------------------------===//
954 // Single Region Utilities
955 //===------------------------------------------------------------------===//
956 using Base::getBody;
957
958 template <typename OpT, typename T = void>
959 using enable_if_single_region =
960 typename std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>;
961
962 /// Insert the operation into the back of the body, before the terminator.
963 template <typename OpT = ConcreteType>
964 enable_if_single_region<OpT> push_back(Operation *op) {
965 insert(Block::iterator(getBody()->getTerminator()), op);
966 }
967
968 /// Insert the operation at the given insertion point. Note: The operation
969 /// is never inserted after the terminator, even if the insertion point is
970 /// end().
971 template <typename OpT = ConcreteType>
972 enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) {
973 insert(Block::iterator(insertPt), op);
974 }
975 template <typename OpT = ConcreteType>
976 enable_if_single_region<OpT> insert(Block::iterator insertPt,
977 Operation *op) {
978 auto *body = getBody();
979 if (insertPt == body->end())
980 insertPt = Block::iterator(body->getTerminator());
981 body->getOperations().insert(insertPt, op);
982 }
983 };
984};
985
986/// Check is an op defines the `ImplicitTerminatorOpT` member. This is intended
987/// to be used with `llvm::is_detected`.
988template <class T>
989using has_implicit_terminator_t = typename T::ImplicitTerminatorOpT;
990
991/// Support to check if an operation has the SingleBlockImplicitTerminator
992/// trait. We can't just use `hasTrait` because this class is templated on a
993/// specific terminator op.
994template <class Op, bool hasTerminator =
995 llvm::is_detected<has_implicit_terminator_t, Op>::value>
996struct hasSingleBlockImplicitTerminator {
997 static constexpr bool value = std::is_base_of<
998 typename OpTrait::SingleBlockImplicitTerminator<
999 typename Op::ImplicitTerminatorOpT>::template Impl<Op>,
1000 Op>::value;
1001};
1002template <class Op>
1003struct hasSingleBlockImplicitTerminator<Op, false> {
1004 static constexpr bool value = false;
1005};
1006
1007//===----------------------------------------------------------------------===//
1008// Misc Traits
1009
1010/// This class provides verification for ops that are known to have the same
1011/// operand shape: all operands are scalars, vectors/tensors of the same
1012/// shape.
1013template <typename ConcreteType>
1014class SameOperandsShape : public TraitBase<ConcreteType, SameOperandsShape> {
1015public:
1016 static LogicalResult verifyTrait(Operation *op) {
1017 return impl::verifySameOperandsShape(op);
1018 }
1019};
1020
1021/// This class provides verification for ops that are known to have the same
1022/// operand and result shape: both are scalars, vectors/tensors of the same
1023/// shape.
1024template <typename ConcreteType>
1025class SameOperandsAndResultShape
1026 : public TraitBase<ConcreteType, SameOperandsAndResultShape> {
1027public:
1028 static LogicalResult verifyTrait(Operation *op) {
1029 return impl::verifySameOperandsAndResultShape(op);
1030 }
1031};
1032
1033/// This class provides verification for ops that are known to have the same
1034/// operand element type (or the type itself if it is scalar).
1035///
1036template <typename ConcreteType>
1037class SameOperandsElementType
1038 : public TraitBase<ConcreteType, SameOperandsElementType> {
1039public:
1040 static LogicalResult verifyTrait(Operation *op) {
1041 return impl::verifySameOperandsElementType(op);
1042 }
1043};
1044
1045/// This class provides verification for ops that are known to have the same
1046/// operand and result element type (or the type itself if it is scalar).
1047///
1048template <typename ConcreteType>
1049class SameOperandsAndResultElementType
1050 : public TraitBase<ConcreteType, SameOperandsAndResultElementType> {
1051public:
1052 static LogicalResult verifyTrait(Operation *op) {
1053 return impl::verifySameOperandsAndResultElementType(op);
1054 }
1055};
1056
1057/// This class provides verification for ops that are known to have the same
1058/// operand and result type.
1059///
1060/// Note: this trait subsumes the SameOperandsAndResultShape and
1061/// SameOperandsAndResultElementType traits.
1062template <typename ConcreteType>
1063class SameOperandsAndResultType
1064 : public TraitBase<ConcreteType, SameOperandsAndResultType> {
1065public:
1066 static LogicalResult verifyTrait(Operation *op) {
1067 return impl::verifySameOperandsAndResultType(op);
1068 }
1069};
1070
1071/// This class verifies that any results of the specified op have a boolean
1072/// type, a vector thereof, or a tensor thereof.
1073template <typename ConcreteType>
1074class ResultsAreBoolLike : public TraitBase<ConcreteType, ResultsAreBoolLike> {
1075public:
1076 static LogicalResult verifyTrait(Operation *op) {
1077 return impl::verifyResultsAreBoolLike(op);
1078 }
1079};
1080
1081/// This class verifies that any results of the specified op have a floating
1082/// point type, a vector thereof, or a tensor thereof.
1083template <typename ConcreteType>
1084class ResultsAreFloatLike
1085 : public TraitBase<ConcreteType, ResultsAreFloatLike> {
1086public:
1087 static LogicalResult verifyTrait(Operation *op) {
1088 return impl::verifyResultsAreFloatLike(op);
1089 }
1090};
1091
1092/// This class verifies that any results of the specified op have a signless
1093/// integer or index type, a vector thereof, or a tensor thereof.
1094template <typename ConcreteType>
1095class ResultsAreSignlessIntegerLike
1096 : public TraitBase<ConcreteType, ResultsAreSignlessIntegerLike> {
1097public:
1098 static LogicalResult verifyTrait(Operation *op) {
1099 return impl::verifyResultsAreSignlessIntegerLike(op);
1100 }
1101};
1102
1103/// This class adds property that the operation is commutative.
1104template <typename ConcreteType>
1105class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {};
1106
1107/// This class adds property that the operation is an involution.
1108/// This means a unary to unary operation "f" that satisfies f(f(x)) = x
1109template <typename ConcreteType>
1110class IsInvolution : public TraitBase<ConcreteType, IsInvolution> {
1111public:
1112 static LogicalResult verifyTrait(Operation *op) {
1113 static_assert(ConcreteType::template hasTrait<OneResult>(),
1114 "expected operation to produce one result");
1115 static_assert(ConcreteType::template hasTrait<OneOperand>(),
1116 "expected operation to take one operand");
1117 static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
1118 "expected operation to preserve type");
1119 // Involution requires the operation to be side effect free as well
1120 // but currently this check is under a FIXME and is not actually done.
1121 return impl::verifyIsInvolution(op);
1122 }
1123
1124 static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
1125 return impl::foldInvolution(op);
1126 }
1127};
1128
1129/// This class adds property that the operation is idempotent.
1130/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x),
1131/// or a binary operation "g" that satisfies g(x, x) = x.
1132template <typename ConcreteType>
1133class IsIdempotent : public TraitBase<ConcreteType, IsIdempotent> {
1134public:
1135 static LogicalResult verifyTrait(Operation *op) {
1136 static_assert(ConcreteType::template hasTrait<OneResult>(),
1137 "expected operation to produce one result");
1138 static_assert(ConcreteType::template hasTrait<OneOperand>() ||
1139 ConcreteType::template hasTrait<NOperands<2>::Impl>(),
1140 "expected operation to take one or two operands");
1141 static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
1142 "expected operation to preserve type");
1143 // Idempotent requires the operation to be side effect free as well
1144 // but currently this check is under a FIXME and is not actually done.
1145 return impl::verifyIsIdempotent(op);
1146 }
1147
1148 static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
1149 return impl::foldIdempotent(op);
1150 }
1151};
1152
1153/// This class verifies that all operands of the specified op have a float type,
1154/// a vector thereof, or a tensor thereof.
1155template <typename ConcreteType>
1156class OperandsAreFloatLike
1157 : public TraitBase<ConcreteType, OperandsAreFloatLike> {
1158public:
1159 static LogicalResult verifyTrait(Operation *op) {
1160 return impl::verifyOperandsAreFloatLike(op);
1161 }
1162};
1163
1164/// This class verifies that all operands of the specified op have a signless
1165/// integer or index type, a vector thereof, or a tensor thereof.
1166template <typename ConcreteType>
1167class OperandsAreSignlessIntegerLike
1168 : public TraitBase<ConcreteType, OperandsAreSignlessIntegerLike> {
1169public:
1170 static LogicalResult verifyTrait(Operation *op) {
1171 return impl::verifyOperandsAreSignlessIntegerLike(op);
1172 }
1173};
1174
1175/// This class verifies that all operands of the specified op have the same
1176/// type.
1177template <typename ConcreteType>
1178class SameTypeOperands : public TraitBase<ConcreteType, SameTypeOperands> {
1179public:
1180 static LogicalResult verifyTrait(Operation *op) {
1181 return impl::verifySameTypeOperands(op);
1182 }
1183};
1184
1185/// This class provides the API for a sub-set of ops that are known to be
1186/// constant-like. These are non-side effecting operations with one result and
1187/// zero operands that can always be folded to a specific attribute value.
1188template <typename ConcreteType>
1189class ConstantLike : public TraitBase<ConcreteType, ConstantLike> {
1190public:
1191 static LogicalResult verifyTrait(Operation *op) {
1192 static_assert(ConcreteType::template hasTrait<OneResult>(),
1193 "expected operation to produce one result");
1194 static_assert(ConcreteType::template hasTrait<ZeroOperands>(),
1195 "expected operation to take zero operands");
1196 // TODO: We should verify that the operation can always be folded, but this
1197 // requires that the attributes of the op already be verified. We should add
1198 // support for verifying traits "after" the operation to enable this use
1199 // case.
1200 return success();
1201 }
1202};
1203
1204/// This class provides the API for ops that are known to be isolated from
1205/// above.
1206template <typename ConcreteType>
1207class IsIsolatedFromAbove
1208 : public TraitBase<ConcreteType, IsIsolatedFromAbove> {
1209public:
1210 static LogicalResult verifyRegionTrait(Operation *op) {
1211 return impl::verifyIsIsolatedFromAbove(op);
1212 }
1213};
1214
1215/// A trait of region holding operations that defines a new scope for polyhedral
1216/// optimization purposes. Any SSA values of 'index' type that either dominate
1217/// such an operation or are used at the top-level of such an operation
1218/// automatically become valid symbols for the polyhedral scope defined by that
1219/// operation. For more details, see `Traits.md#AffineScope`.
1220template <typename ConcreteType>
1221class AffineScope : public TraitBase<ConcreteType, AffineScope> {
1222public:
1223 static LogicalResult verifyTrait(Operation *op) {
1224 static_assert(!ConcreteType::template hasTrait<ZeroRegions>(),
1225 "expected operation to have one or more regions");
1226 return success();
1227 }
1228};
1229
1230/// A trait of region holding operations that define a new scope for automatic
1231/// allocations, i.e., allocations that are freed when control is transferred
1232/// back from the operation's region. Any operations performing such allocations
1233/// (for eg. memref.alloca) will have their allocations automatically freed at
1234/// their closest enclosing operation with this trait.
1235template <typename ConcreteType>
1236class AutomaticAllocationScope
1237 : public TraitBase<ConcreteType, AutomaticAllocationScope> {
1238public:
1239 static LogicalResult verifyTrait(Operation *op) {
1240 static_assert(!ConcreteType::template hasTrait<ZeroRegions>(),
1241 "expected operation to have one or more regions");
1242 return success();
1243 }
1244};
1245
1246/// This class provides a verifier for ops that are expecting their parent
1247/// to be one of the given parent ops
1248template <typename... ParentOpTypes>
1249struct HasParent {
1250 template <typename ConcreteType>
1251 class Impl : public TraitBase<ConcreteType, Impl> {
1252 public:
1253 static LogicalResult verifyTrait(Operation *op) {
1254 if (llvm::isa_and_nonnull<ParentOpTypes...>(op->getParentOp()))
1255 return success();
1256
1257 return op->emitOpError()
1258 << "expects parent op "
1259 << (sizeof...(ParentOpTypes) != 1 ? "to be one of '" : "'")
1260 << llvm::makeArrayRef({ParentOpTypes::getOperationName()...})
1261 << "'";
1262 }
1263 };
1264};
1265
1266/// A trait for operations that have an attribute specifying operand segments.
1267///
1268/// Certain operations can have multiple variadic operands and their size
1269/// relationship is not always known statically. For such cases, we need
1270/// a per-op-instance specification to divide the operands into logical groups
1271/// or segments. This can be modeled by attributes. The attribute will be named
1272/// as `operand_segment_sizes`.
1273///
1274/// This trait verifies the attribute for specifying operand segments has
1275/// the correct type (1D vector) and values (non-negative), etc.
1276template <typename ConcreteType>
1277class AttrSizedOperandSegments
1278 : public TraitBase<ConcreteType, AttrSizedOperandSegments> {
1279public:
1280 static StringRef getOperandSegmentSizeAttr() {
1281 return "operand_segment_sizes";
1282 }
1283
1284 static LogicalResult verifyTrait(Operation *op) {
1285 return ::mlir::OpTrait::impl::verifyOperandSizeAttr(
1286 op, getOperandSegmentSizeAttr());
1287 }
1288};
1289
1290/// Similar to AttrSizedOperandSegments but used for results.
1291template <typename ConcreteType>
1292class AttrSizedResultSegments
1293 : public TraitBase<ConcreteType, AttrSizedResultSegments> {
1294public:
1295 static StringRef getResultSegmentSizeAttr() { return "result_segment_sizes"; }
1296
1297 static LogicalResult verifyTrait(Operation *op) {
1298 return ::mlir::OpTrait::impl::verifyResultSizeAttr(
1299 op, getResultSegmentSizeAttr());
1300 }
1301};
1302
1303/// This trait provides a verifier for ops that are expecting their regions to
1304/// not have any arguments
1305template <typename ConcrentType>
1306struct NoRegionArguments : public TraitBase<ConcrentType, NoRegionArguments> {
1307 static LogicalResult verifyTrait(Operation *op) {
1308 return ::mlir::OpTrait::impl::verifyNoRegionArguments(op);
1309 }
1310};
1311
1312// This trait is used to flag operations that consume or produce
1313// values of `MemRef` type where those references can be 'normalized'.
1314// TODO: Right now, the operands of an operation are either all normalizable,
1315// or not. In the future, we may want to allow some of the operands to be
1316// normalizable.
1317template <typename ConcrentType>
1318struct MemRefsNormalizable
1319 : public TraitBase<ConcrentType, MemRefsNormalizable> {};
1320
1321/// This trait tags element-wise ops on vectors or tensors.
1322///
1323/// NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this
1324/// trait. In particular, broadcasting behavior is not allowed.
1325///
1326/// An `Elementwise` op must satisfy the following properties:
1327///
1328/// 1. If any result is a vector/tensor then at least one operand must also be a
1329/// vector/tensor.
1330/// 2. If any operand is a vector/tensor then there must be at least one result
1331/// and all results must be vectors/tensors.
1332/// 3. All operand and result vector/tensor types must be of the same shape. The
1333/// shape may be dynamic in which case the op's behaviour is undefined for
1334/// non-matching shapes.
1335/// 4. The operation must be elementwise on its vector/tensor operands and
1336/// results. When applied to single-element vectors/tensors, the result must
1337/// be the same per elememnt.
1338///
1339/// TODO: Avoid hardcoding vector/tensor, and generalize this trait to a new
1340/// interface `ElementwiseTypeInterface` that describes the container types for
1341/// which the operation is elementwise.
1342///
1343/// Rationale:
1344/// - 1. and 2. guarantee a well-defined iteration space and exclude the cases
1345/// of 0 non-scalar operands or 0 non-scalar results, which complicate a
1346/// generic definition of the iteration space.
1347/// - 3. guarantees that folding can be done across scalars/vectors/tensors with
1348/// the same pattern, as otherwise lots of special handling for type
1349/// mismatches would be needed.
1350/// - 4. guarantees that no error handling is needed. Higher-level dialects
1351/// should reify any needed guards or error handling code before lowering to
1352/// an `Elementwise` op.
1353template <typename ConcreteType>
1354struct Elementwise : public TraitBase<ConcreteType, Elementwise> {
1355 static LogicalResult verifyTrait(Operation *op) {
1356 return ::mlir::OpTrait::impl::verifyElementwise(op);
1357 }
1358};
1359
1360/// This trait tags `Elementwise` operatons that can be systematically
1361/// scalarized. All vector/tensor operands and results are then replaced by
1362/// scalars of the respective element type. Semantically, this is the operation
1363/// on a single element of the vector/tensor.
1364///
1365/// Rationale:
1366/// Allow to define the vector/tensor semantics of elementwise operations based
1367/// on the same op's behavior on scalars. This provides a constructive procedure
1368/// for IR transformations to, e.g., create scalar loop bodies from tensor ops.
1369///
1370/// Example:
1371/// ```
1372/// %tensor_select = "arith.select"(%pred_tensor, %true_val, %false_val)
1373/// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
1374/// -> tensor<?xf32>
1375/// ```
1376/// can be scalarized to
1377///
1378/// ```
1379/// %scalar_select = "arith.select"(%pred, %true_val_scalar, %false_val_scalar)
1380/// : (i1, f32, f32) -> f32
1381/// ```
1382template <typename ConcreteType>
1383struct Scalarizable : public TraitBase<ConcreteType, Scalarizable> {
1384 static LogicalResult verifyTrait(Operation *op) {
1385 static_assert(
1386 ConcreteType::template hasTrait<Elementwise>(),
1387 "`Scalarizable` trait is only applicable to `Elementwise` ops.");
1388 return success();
1389 }
1390};
1391
1392/// This trait tags `Elementwise` operatons that can be systematically
1393/// vectorized. All scalar operands and results are then replaced by vectors
1394/// with the respective element type. Semantically, this is the operation on
1395/// multiple elements simultaneously. See also `Tensorizable`.
1396///
1397/// Rationale:
1398/// Provide the reverse to `Scalarizable` which, when chained together, allows
1399/// reasoning about the relationship between the tensor and vector case.
1400/// Additionally, it permits reasoning about promoting scalars to vectors via
1401/// broadcasting in cases like `%select_scalar_pred` below.
1402template <typename ConcreteType>
1403struct Vectorizable : public TraitBase<ConcreteType, Vectorizable> {
1404 static LogicalResult verifyTrait(Operation *op) {
1405 static_assert(
1406 ConcreteType::template hasTrait<Elementwise>(),
1407 "`Vectorizable` trait is only applicable to `Elementwise` ops.");
1408 return success();
1409 }
1410};
1411
1412/// This trait tags `Elementwise` operatons that can be systematically
1413/// tensorized. All scalar operands and results are then replaced by tensors
1414/// with the respective element type. Semantically, this is the operation on
1415/// multiple elements simultaneously. See also `Vectorizable`.
1416///
1417/// Rationale:
1418/// Provide the reverse to `Scalarizable` which, when chained together, allows
1419/// reasoning about the relationship between the tensor and vector case.
1420/// Additionally, it permits reasoning about promoting scalars to tensors via
1421/// broadcasting in cases like `%select_scalar_pred` below.
1422///
1423/// Examples:
1424/// ```
1425/// %scalar = "arith.addf"(%a, %b) : (f32, f32) -> f32
1426/// ```
1427/// can be tensorized to
1428/// ```
1429/// %tensor = "arith.addf"(%a, %b) : (tensor<?xf32>, tensor<?xf32>)
1430/// -> tensor<?xf32>
1431/// ```
1432///
1433/// ```
1434/// %scalar_pred = "arith.select"(%pred, %true_val, %false_val)
1435/// : (i1, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
1436/// ```
1437/// can be tensorized to
1438/// ```
1439/// %tensor_pred = "arith.select"(%pred, %true_val, %false_val)
1440/// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
1441/// -> tensor<?xf32>
1442/// ```
1443template <typename ConcreteType>
1444struct Tensorizable : public TraitBase<ConcreteType, Tensorizable> {
1445 static LogicalResult verifyTrait(Operation *op) {
1446 static_assert(
1447 ConcreteType::template hasTrait<Elementwise>(),
1448 "`Tensorizable` trait is only applicable to `Elementwise` ops.");
1449 return success();
1450 }
1451};
1452
1453/// Together, `Elementwise`, `Scalarizable`, `Vectorizable`, and `Tensorizable`
1454/// provide an easy way for scalar operations to conveniently generalize their
1455/// behavior to vectors/tensors, and systematize conversion between these forms.
1456bool hasElementwiseMappableTraits(Operation *op);
1457
1458} // namespace OpTrait
1459
1460//===----------------------------------------------------------------------===//
1461// Internal Trait Utilities
1462//===----------------------------------------------------------------------===//
1463
1464namespace op_definition_impl {
1465//===----------------------------------------------------------------------===//
1466// Trait Existence
1467
1468/// Returns true if this given Trait ID matches the IDs of any of the provided
1469/// trait types `Traits`.
1470template <template <typename T> class... Traits>
1471static bool hasTrait(TypeID traitID) {
1472 TypeID traitIDs[] = {TypeID::get<Traits>()...};
1473 for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i)
1474 if (traitIDs[i] == traitID)
1475 return true;
1476 return false;
1477}
1478
1479//===----------------------------------------------------------------------===//
1480// Trait Folding
1481
1482/// Trait to check if T provides a 'foldTrait' method for single result
1483/// operations.
1484template <typename T, typename... Args>
1485using has_single_result_fold_trait = decltype(T::foldTrait(
1486 std::declval<Operation *>(), std::declval<ArrayRef<Attribute>>()));
1487template <typename T>
1488using detect_has_single_result_fold_trait =
1489 llvm::is_detected<has_single_result_fold_trait, T>;
1490/// Trait to check if T provides a general 'foldTrait' method.
1491template <typename T, typename... Args>
1492using has_fold_trait =
1493 decltype(T::foldTrait(std::declval<Operation *>(),
1494 std::declval<ArrayRef<Attribute>>(),
1495 std::declval<SmallVectorImpl<OpFoldResult> &>()));
1496template <typename T>
1497using detect_has_fold_trait = llvm::is_detected<has_fold_trait, T>;
1498/// Trait to check if T provides any `foldTrait` method.
1499template <typename T>
1500using detect_has_any_fold_trait =
1501 std::disjunction<detect_has_fold_trait<T>,
1502 detect_has_single_result_fold_trait<T>>;
1503
1504/// Returns the result of folding a trait that implements a `foldTrait` function
1505/// that is specialized for operations that have a single result.
1506template <typename Trait>
1507static std::enable_if_t<detect_has_single_result_fold_trait<Trait>::value,
1508 LogicalResult>
1509foldTrait(Operation *op, ArrayRef<Attribute> operands,
1510 SmallVectorImpl<OpFoldResult> &results) {
1511 assert(op->hasTrait<OpTrait::OneResult>() &&(static_cast <bool> (op->hasTrait<OpTrait::OneResult
>() && "expected trait on non single-result operation to implement the "
"general `foldTrait` method") ? void (0) : __assert_fail ("op->hasTrait<OpTrait::OneResult>() && \"expected trait on non single-result operation to implement the \" \"general `foldTrait` method\""
, "mlir/include/mlir/IR/OpDefinition.h", 1513, __extension__ __PRETTY_FUNCTION__
))
1512 "expected trait on non single-result operation to implement the "(static_cast <bool> (op->hasTrait<OpTrait::OneResult
>() && "expected trait on non single-result operation to implement the "
"general `foldTrait` method") ? void (0) : __assert_fail ("op->hasTrait<OpTrait::OneResult>() && \"expected trait on non single-result operation to implement the \" \"general `foldTrait` method\""
, "mlir/include/mlir/IR/OpDefinition.h", 1513, __extension__ __PRETTY_FUNCTION__
))
1513 "general `foldTrait` method")(static_cast <bool> (op->hasTrait<OpTrait::OneResult
>() && "expected trait on non single-result operation to implement the "
"general `foldTrait` method") ? void (0) : __assert_fail ("op->hasTrait<OpTrait::OneResult>() && \"expected trait on non single-result operation to implement the \" \"general `foldTrait` method\""
, "mlir/include/mlir/IR/OpDefinition.h", 1513, __extension__ __PRETTY_FUNCTION__
))
;
1514 // If a previous trait has already been folded and replaced this operation, we
1515 // fail to fold this trait.
1516 if (!results.empty())
1517 return failure();
1518
1519 if (OpFoldResult result = Trait::foldTrait(op, operands)) {
1520 if (result.template dyn_cast<Value>() != op->getResult(0))
1521 results.push_back(result);
1522 return success();
1523 }
1524 return failure();
1525}
1526/// Returns the result of folding a trait that implements a generalized
1527/// `foldTrait` function that is supports any operation type.
1528template <typename Trait>
1529static std::enable_if_t<detect_has_fold_trait<Trait>::value, LogicalResult>
1530foldTrait(Operation *op, ArrayRef<Attribute> operands,
1531 SmallVectorImpl<OpFoldResult> &results) {
1532 // If a previous trait has already been folded and replaced this operation, we
1533 // fail to fold this trait.
1534 return results.empty() ? Trait::foldTrait(op, operands, results) : failure();
1535}
1536template <typename Trait>
1537static inline std::enable_if_t<!detect_has_any_fold_trait<Trait>::value,
1538 LogicalResult>
1539foldTrait(Operation *, ArrayRef<Attribute>, SmallVectorImpl<OpFoldResult> &) {
1540 return failure();
1541}
1542
1543/// Given a tuple type containing a set of traits, return the result of folding
1544/// the given operation.
1545template <typename... Ts>
1546static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands,
1547 SmallVectorImpl<OpFoldResult> &results) {
1548 return success((succeeded(foldTrait<Ts>(op, operands, results)) || ...));
1549}
1550
1551//===----------------------------------------------------------------------===//
1552// Trait Verification
1553
1554/// Trait to check if T provides a `verifyTrait` method.
1555template <typename T, typename... Args>
1556using has_verify_trait = decltype(T::verifyTrait(std::declval<Operation *>()));
1557template <typename T>
1558using detect_has_verify_trait = llvm::is_detected<has_verify_trait, T>;
1559
1560/// Trait to check if T provides a `verifyTrait` method.
1561template <typename T, typename... Args>
1562using has_verify_region_trait =
1563 decltype(T::verifyRegionTrait(std::declval<Operation *>()));
1564template <typename T>
1565using detect_has_verify_region_trait =
1566 llvm::is_detected<has_verify_region_trait, T>;
1567
1568/// Verify the given trait if it provides a verifier.
1569template <typename T>
1570std::enable_if_t<detect_has_verify_trait<T>::value, LogicalResult>
1571verifyTrait(Operation *op) {
1572 return T::verifyTrait(op);
1573}
1574template <typename T>
1575inline std::enable_if_t<!detect_has_verify_trait<T>::value, LogicalResult>
1576verifyTrait(Operation *) {
1577 return success();
1578}
1579
1580/// Given a set of traits, return the result of verifying the given operation.
1581template <typename... Ts>
1582LogicalResult verifyTraits(Operation *op) {
1583 return success((succeeded(verifyTrait<Ts>(op)) && ...));
1584}
1585
1586/// Verify the given trait if it provides a region verifier.
1587template <typename T>
1588std::enable_if_t<detect_has_verify_region_trait<T>::value, LogicalResult>
1589verifyRegionTrait(Operation *op) {
1590 return T::verifyRegionTrait(op);
1591}
1592template <typename T>
1593inline std::enable_if_t<!detect_has_verify_region_trait<T>::value,
1594 LogicalResult>
1595verifyRegionTrait(Operation *) {
1596 return success();
1597}
1598
1599/// Given a set of traits, return the result of verifying the regions of the
1600/// given operation.
1601template <typename... Ts>
1602LogicalResult verifyRegionTraits(Operation *op) {
1603 return success((succeeded(verifyRegionTrait<Ts>(op)) && ...));
1604}
1605} // namespace op_definition_impl
1606
1607//===----------------------------------------------------------------------===//
1608// Operation Definition classes
1609//===----------------------------------------------------------------------===//
1610
1611/// This provides public APIs that all operations should have. The template
1612/// argument 'ConcreteType' should be the concrete type by CRTP and the others
1613/// are base classes by the policy pattern.
1614template <typename ConcreteType, template <typename T> class... Traits>
1615class Op : public OpState, public Traits<ConcreteType>... {
1616public:
1617 /// Inherit getOperation from `OpState`.
1618 using OpState::getOperation;
1619 using OpState::verify;
1620 using OpState::verifyRegions;
1621
1622 /// Return if this operation contains the provided trait.
1623 template <template <typename T> class Trait>
1624 static constexpr bool hasTrait() {
1625 return llvm::is_one_of<Trait<ConcreteType>, Traits<ConcreteType>...>::value;
1626 }
1627
1628 /// Create a deep copy of this operation.
1629 ConcreteType clone() { return cast<ConcreteType>(getOperation()->clone()); }
1630
1631 /// Create a partial copy of this operation without traversing into attached
1632 /// regions. The new operation will have the same number of regions as the
1633 /// original one, but they will be left empty.
1634 ConcreteType cloneWithoutRegions() {
1635 return cast<ConcreteType>(getOperation()->cloneWithoutRegions());
1636 }
1637
1638 /// Return true if this "op class" can match against the specified operation.
1639 static bool classof(Operation *op) {
1640 if (auto info = op->getRegisteredInfo())
1641 return TypeID::get<ConcreteType>() == info->getTypeID();
1642#ifndef NDEBUG
1643 if (op->getName().getStringRef() == ConcreteType::getOperationName())
1644 llvm::report_fatal_error(
1645 "classof on '" + ConcreteType::getOperationName() +
1646 "' failed due to the operation not being registered");
1647#endif
1648 return false;
1649 }
1650 /// Provide `classof` support for other OpBase derived classes, such as
1651 /// Interfaces.
1652 template <typename T>
1653 static std::enable_if_t<std::is_base_of<OpState, T>::value, bool>
1654 classof(const T *op) {
1655 return classof(const_cast<T *>(op)->getOperation());
1656 }
1657
1658 /// Expose the type we are instantiated on to template machinery that may want
1659 /// to introspect traits on this operation.
1660 using ConcreteOpType = ConcreteType;
1661
1662 /// This is a public constructor. Any op can be initialized to null.
1663 explicit Op() : OpState(nullptr) {}
1664 Op(std::nullptr_t) : OpState(nullptr) {}
29
Passing null pointer value via 1st parameter 'state'
30
Calling constructor for 'OpState'
32
Returning from constructor for 'OpState'
1665
1666 /// This is a public constructor to enable access via the llvm::cast family of
1667 /// methods. This should not be used directly.
1668 explicit Op(Operation *state) : OpState(state) {}
1669
1670 /// Methods for supporting PointerLikeTypeTraits.
1671 const void *getAsOpaquePointer() const {
1672 return static_cast<const void *>((Operation *)*this);
1673 }
1674 static ConcreteOpType getFromOpaquePointer(const void *pointer) {
1675 return ConcreteOpType(
1676 reinterpret_cast<Operation *>(const_cast<void *>(pointer)));
1677 }
1678
1679 /// Attach the given models as implementations of the corresponding
1680 /// interfaces for the concrete operation.
1681 template <typename... Models>
1682 static void attachInterface(MLIRContext &context) {
1683 Optional<RegisteredOperationName> info = RegisteredOperationName::lookup(
1684 ConcreteType::getOperationName(), &context);
1685 if (!info)
1686 llvm::report_fatal_error(
1687 "Attempting to attach an interface to an unregistered operation " +
1688 ConcreteType::getOperationName() + ".");
1689 (checkInterfaceTarget<Models>(), ...);
1690 info->attachInterface<Models...>();
1691 }
1692
1693private:
1694 /// Trait to check if T provides a 'fold' method for a single result op.
1695 template <typename T, typename... Args>
1696 using has_single_result_fold =
1697 decltype(std::declval<T>().fold(std::declval<ArrayRef<Attribute>>()));
1698 template <typename T>
1699 using detect_has_single_result_fold =
1700 llvm::is_detected<has_single_result_fold, T>;
1701 /// Trait to check if T provides a general 'fold' method.
1702 template <typename T, typename... Args>
1703 using has_fold = decltype(std::declval<T>().fold(
1704 std::declval<ArrayRef<Attribute>>(),
1705 std::declval<SmallVectorImpl<OpFoldResult> &>()));
1706 template <typename T>
1707 using detect_has_fold = llvm::is_detected<has_fold, T>;
1708 /// Trait to check if T provides a 'print' method.
1709 template <typename T, typename... Args>
1710 using has_print =
1711 decltype(std::declval<T>().print(std::declval<OpAsmPrinter &>()));
1712 template <typename T>
1713 using detect_has_print = llvm::is_detected<has_print, T>;
1714
1715 /// Trait to check if T provides a 'ConcreteEntity' type alias.
1716 template <typename T>
1717 using has_concrete_entity_t = typename T::ConcreteEntity;
1718
1719 /// A struct-wrapped type alias to T::ConcreteEntity if provided and to
1720 /// ConcreteType otherwise. This is akin to std::conditional but doesn't fail
1721 /// on the missing typedef. Useful for checking if the interface is targeting
1722 /// the right class.
1723 template <typename T,
1724 bool = llvm::is_detected<has_concrete_entity_t, T>::value>
1725 struct InterfaceTargetOrOpT {
1726 using type = typename T::ConcreteEntity;
1727 };
1728 template <typename T>
1729 struct InterfaceTargetOrOpT<T, false> {
1730 using type = ConcreteType;
1731 };
1732
1733 /// A hook for static assertion that the external interface model T is
1734 /// targeting the concrete type of this op. The model can also be a fallback
1735 /// model that works for every op.
1736 template <typename T>
1737 static void checkInterfaceTarget() {
1738 static_assert(std::is_same<typename InterfaceTargetOrOpT<T>::type,
1739 ConcreteType>::value,
1740 "attaching an interface to the wrong op kind");
1741 }
1742
1743 /// Returns an interface map containing the interfaces registered to this
1744 /// operation.
1745 static detail::InterfaceMap getInterfaceMap() {
1746 return detail::InterfaceMap::template get<Traits<ConcreteType>...>();
1747 }
1748
1749 /// Return the internal implementations of each of the OperationName
1750 /// hooks.
1751 /// Implementation of `FoldHookFn` OperationName hook.
1752 static OperationName::FoldHookFn getFoldHookFn() {
1753 return getFoldHookFnImpl<ConcreteType>();
1754 }
1755 /// The internal implementation of `getFoldHookFn` above that is invoked if
1756 /// the operation is single result and defines a `fold` method.
1757 template <typename ConcreteOpT>
1758 static std::enable_if_t<llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
1759 Traits<ConcreteOpT>...>::value &&
1760 detect_has_single_result_fold<ConcreteOpT>::value,
1761 OperationName::FoldHookFn>
1762 getFoldHookFnImpl() {
1763 return [](Operation *op, ArrayRef<Attribute> operands,
1764 SmallVectorImpl<OpFoldResult> &results) {
1765 return foldSingleResultHook<ConcreteOpT>(op, operands, results);
1766 };
1767 }
1768 /// The internal implementation of `getFoldHookFn` above that is invoked if
1769 /// the operation is not single result and defines a `fold` method.
1770 template <typename ConcreteOpT>
1771 static std::enable_if_t<!llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
1772 Traits<ConcreteOpT>...>::value &&
1773 detect_has_fold<ConcreteOpT>::value,
1774 OperationName::FoldHookFn>
1775 getFoldHookFnImpl() {
1776 return [](Operation *op, ArrayRef<Attribute> operands,
1777 SmallVectorImpl<OpFoldResult> &results) {
1778 return foldHook<ConcreteOpT>(op, operands, results);
1779 };
1780 }
1781 /// The internal implementation of `getFoldHookFn` above that is invoked if
1782 /// the operation does not define a `fold` method.
1783 template <typename ConcreteOpT>
1784 static std::enable_if_t<!detect_has_single_result_fold<ConcreteOpT>::value &&
1785 !detect_has_fold<ConcreteOpT>::value,
1786 OperationName::FoldHookFn>
1787 getFoldHookFnImpl() {
1788 return [](Operation *op, ArrayRef<Attribute> operands,
1789 SmallVectorImpl<OpFoldResult> &results) {
1790 // In this case, we only need to fold the traits of the operation.
1791 return op_definition_impl::foldTraits<Traits<ConcreteType>...>(
1792 op, operands, results);
1793 };
1794 }
1795 /// Return the result of folding a single result operation that defines a
1796 /// `fold` method.
1797 template <typename ConcreteOpT>
1798 static LogicalResult
1799 foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands,
1800 SmallVectorImpl<OpFoldResult> &results) {
1801 OpFoldResult result = cast<ConcreteOpT>(op).fold(operands);
1802
1803 // If the fold failed or was in-place, try to fold the traits of the
1804 // operation.
1805 if (!result || result.template dyn_cast<Value>() == op->getResult(0)) {
1806 if (succeeded(op_definition_impl::foldTraits<Traits<ConcreteType>...>(
1807 op, operands, results)))
1808 return success();
1809 return success(static_cast<bool>(result));
1810 }
1811 results.push_back(result);
1812 return success();
1813 }
1814 /// Return the result of folding an operation that defines a `fold` method.
1815 template <typename ConcreteOpT>
1816 static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
1817 SmallVectorImpl<OpFoldResult> &results) {
1818 LogicalResult result = cast<ConcreteOpT>(op).fold(operands, results);
1819
1820 // If the fold failed or was in-place, try to fold the traits of the
1821 // operation.
1822 if (failed(result) || results.empty()) {
1823 if (succeeded(op_definition_impl::foldTraits<Traits<ConcreteType>...>(
1824 op, operands, results)))
1825 return success();
1826 }
1827 return result;
1828 }
1829
1830 /// Implementation of `GetCanonicalizationPatternsFn` OperationName hook.
1831 static OperationName::GetCanonicalizationPatternsFn
1832 getGetCanonicalizationPatternsFn() {
1833 return &ConcreteType::getCanonicalizationPatterns;
1834 }
1835 /// Implementation of `GetHasTraitFn`
1836 static OperationName::HasTraitFn getHasTraitFn() {
1837 return
1838 [](TypeID id) { return op_definition_impl::hasTrait<Traits...>(id); };
1839 }
1840 /// Implementation of `ParseAssemblyFn` OperationName hook.
1841 static OperationName::ParseAssemblyFn getParseAssemblyFn() {
1842 return &ConcreteType::parse;
1843 }
1844 /// Implementation of `PrintAssemblyFn` OperationName hook.
1845 static OperationName::PrintAssemblyFn getPrintAssemblyFn() {
1846 return getPrintAssemblyFnImpl<ConcreteType>();
1847 }
1848 /// The internal implementation of `getPrintAssemblyFn` that is invoked when
1849 /// the concrete operation does not define a `print` method.
1850 template <typename ConcreteOpT>
1851 static std::enable_if_t<!detect_has_print<ConcreteOpT>::value,
1852 OperationName::PrintAssemblyFn>
1853 getPrintAssemblyFnImpl() {
1854 return [](Operation *op, OpAsmPrinter &printer, StringRef defaultDialect) {
1855 return OpState::print(op, printer, defaultDialect);
1856 };
1857 }
1858 /// The internal implementation of `getPrintAssemblyFn` that is invoked when
1859 /// the concrete operation defines a `print` method.
1860 template <typename ConcreteOpT>
1861 static std::enable_if_t<detect_has_print<ConcreteOpT>::value,
1862 OperationName::PrintAssemblyFn>
1863 getPrintAssemblyFnImpl() {
1864 return &printAssembly;
1865 }
1866 static void printAssembly(Operation *op, OpAsmPrinter &p,
1867 StringRef defaultDialect) {
1868 OpState::printOpName(op, p, defaultDialect);
1869 return cast<ConcreteType>(op).print(p);
1870 }
1871 /// Implementation of `PopulateDefaultAttrsFn` OperationName hook.
1872 static OperationName::PopulateDefaultAttrsFn getPopulateDefaultAttrsFn() {
1873 return ConcreteType::populateDefaultAttrs;
1874 }
1875 /// Implementation of `VerifyInvariantsFn` OperationName hook.
1876 static LogicalResult verifyInvariants(Operation *op) {
1877 static_assert(hasNoDataMembers(),
1878 "Op class shouldn't define new data members");
1879 return failure(
1880 failed(op_definition_impl::verifyTraits<Traits<ConcreteType>...>(op)) ||
1881 failed(cast<ConcreteType>(op).verify()));
1882 }
1883 static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() {
1884 return static_cast<LogicalResult (*)(Operation *)>(&verifyInvariants);
1885 }
1886 /// Implementation of `VerifyRegionInvariantsFn` OperationName hook.
1887 static LogicalResult verifyRegionInvariants(Operation *op) {
1888 static_assert(hasNoDataMembers(),
1889 "Op class shouldn't define new data members");
1890 return failure(
1891 failed(op_definition_impl::verifyRegionTraits<Traits<ConcreteType>...>(
1892 op)) ||
1893 failed(cast<ConcreteType>(op).verifyRegions()));
1894 }
1895 static OperationName::VerifyRegionInvariantsFn getVerifyRegionInvariantsFn() {
1896 return static_cast<LogicalResult (*)(Operation *)>(&verifyRegionInvariants);
1897 }
1898
1899 static constexpr bool hasNoDataMembers() {
1900 // Checking that the derived class does not define any member by comparing
1901 // its size to an ad-hoc EmptyOp.
1902 class EmptyOp : public Op<EmptyOp, Traits...> {};
1903 return sizeof(ConcreteType) == sizeof(EmptyOp);
1904 }
1905
1906 /// Allow access to internal implementation methods.
1907 friend RegisteredOperationName;
1908};
1909
1910/// This class represents the base of an operation interface. See the definition
1911/// of `detail::Interface` for requirements on the `Traits` type.
1912template <typename ConcreteType, typename Traits>
1913class OpInterface
1914 : public detail::Interface<ConcreteType, Operation *, Traits,
1915 Op<ConcreteType>, OpTrait::TraitBase> {
1916public:
1917 using Base = OpInterface<ConcreteType, Traits>;
1918 using InterfaceBase = detail::Interface<ConcreteType, Operation *, Traits,
1919 Op<ConcreteType>, OpTrait::TraitBase>;
1920
1921 /// Inherit the base class constructor.
1922 using InterfaceBase::InterfaceBase;
1923
1924protected:
1925 /// Returns the impl interface instance for the given operation.
1926 static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) {
1927 OperationName name = op->getName();
1928
1929 // Access the raw interface from the operation info.
1930 if (Optional<RegisteredOperationName> rInfo = name.getRegisteredInfo()) {
1931 if (auto *opIface = rInfo->getInterface<ConcreteType>())
1932 return opIface;
1933 // Fallback to the dialect to provide it with a chance to implement this
1934 // interface for this operation.
1935 return rInfo->getDialect().getRegisteredInterfaceForOp<ConcreteType>(
1936 op->getName());
1937 }
1938 // Fallback to the dialect to provide it with a chance to implement this
1939 // interface for this operation.
1940 if (Dialect *dialect = name.getDialect())
1941 return dialect->getRegisteredInterfaceForOp<ConcreteType>(name);
1942 return nullptr;
1943 }
1944
1945 /// Allow access to `getInterfaceFor`.
1946 friend InterfaceBase;
1947};
1948
1949//===----------------------------------------------------------------------===//
1950// CastOpInterface utilities
1951//===----------------------------------------------------------------------===//
1952
1953// These functions are out-of-line implementations of the methods in
1954// CastOpInterface, which avoids them being template instantiated/duplicated.
1955namespace impl {
1956/// Attempt to fold the given cast operation.
1957LogicalResult foldCastInterfaceOp(Operation *op,
1958 ArrayRef<Attribute> attrOperands,
1959 SmallVectorImpl<OpFoldResult> &foldResults);
1960/// Attempt to verify the given cast operation.
1961LogicalResult verifyCastInterfaceOp(
1962 Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible);
1963} // namespace impl
1964} // namespace mlir
1965
1966namespace llvm {
1967
1968template <typename T>
1969struct DenseMapInfo<T,
1970 std::enable_if_t<std::is_base_of<mlir::OpState, T>::value &&
1971 !mlir::detail::IsInterface<T>::value>> {
1972 static inline T getEmptyKey() {
1973 auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
1974 return T::getFromOpaquePointer(pointer);
1975 }
1976 static inline T getTombstoneKey() {
1977 auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
1978 return T::getFromOpaquePointer(pointer);
1979 }
1980 static unsigned getHashValue(T val) {
1981 return hash_value(val.getAsOpaquePointer());
1982 }
1983 static bool isEqual(T lhs, T rhs) { return lhs == rhs; }
1984};
1985
1986} // namespace llvm
1987
1988#endif

/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/mlir/include/mlir/IR/Types.h

1//===- Types.h - MLIR Type Classes ------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_IR_TYPES_H
10#define MLIR_IR_TYPES_H
11
12#include "mlir/IR/TypeSupport.h"
13#include "llvm/ADT/ArrayRef.h"
14#include "llvm/ADT/DenseMapInfo.h"
15#include "llvm/Support/PointerLikeTypeTraits.h"
16
17namespace mlir {
18class AsmState;
19
20/// Instances of the Type class are uniqued, have an immutable identifier and an
21/// optional mutable component. They wrap a pointer to the storage object owned
22/// by MLIRContext. Therefore, instances of Type are passed around by value.
23///
24/// Some types are "primitives" meaning they do not have any parameters, for
25/// example the Index type. Parametric types have additional information that
26/// differentiates the types of the same class, for example the Integer type has
27/// bitwidth, making i8 and i16 belong to the same kind by be different
28/// instances of the IntegerType. Type parameters are part of the unique
29/// immutable key. The mutable component of the type can be modified after the
30/// type is created, but cannot affect the identity of the type.
31///
32/// Types are constructed and uniqued via the 'detail::TypeUniquer' class.
33///
34/// Derived type classes are expected to implement several required
35/// implementation hooks:
36/// * Optional:
37/// - static LogicalResult verify(
38/// function_ref<InFlightDiagnostic()> emitError,
39/// Args... args)
40/// * This method is invoked when calling the 'TypeBase::get/getChecked'
41/// methods to ensure that the arguments passed in are valid to construct
42/// a type instance with.
43/// * This method is expected to return failure if a type cannot be
44/// constructed with 'args', success otherwise.
45/// * 'args' must correspond with the arguments passed into the
46/// 'TypeBase::get' call.
47///
48///
49/// Type storage objects inherit from TypeStorage and contain the following:
50/// - The dialect that defined the type.
51/// - Any parameters of the type.
52/// - An optional mutable component.
53/// For non-parametric types, a convenience DefaultTypeStorage is provided.
54/// Parametric storage types must derive TypeStorage and respect the following:
55/// - Define a type alias, KeyTy, to a type that uniquely identifies the
56/// instance of the type.
57/// * The key type must be constructible from the values passed into the
58/// detail::TypeUniquer::get call.
59/// * If the KeyTy does not have an llvm::DenseMapInfo specialization, the
60/// storage class must define a hashing method:
61/// 'static unsigned hashKey(const KeyTy &)'
62///
63/// - Provide a method, 'bool operator==(const KeyTy &) const', to
64/// compare the storage instance against an instance of the key type.
65///
66/// - Provide a static construction method:
67/// 'DerivedStorage *construct(TypeStorageAllocator &, const KeyTy &key)'
68/// that builds a unique instance of the derived storage. The arguments to
69/// this function are an allocator to store any uniqued data within the
70/// context and the key type for this storage.
71///
72/// - If they have a mutable component, this component must not be a part of
73/// the key.
74class Type {
75public:
76 /// Utility class for implementing types.
77 template <typename ConcreteType, typename BaseType, typename StorageType,
78 template <typename T> class... Traits>
79 using TypeBase = detail::StorageUserBase<ConcreteType, BaseType, StorageType,
80 detail::TypeUniquer, Traits...>;
81
82 using ImplType = TypeStorage;
83
84 using AbstractTy = AbstractType;
85
86 constexpr Type() {}
87 /* implicit */ Type(const ImplType *impl)
88 : impl(const_cast<ImplType *>(impl)) {}
89
90 Type(const Type &other) = default;
91 Type &operator=(const Type &other) = default;
92
93 bool operator==(Type other) const { return impl == other.impl; }
94 bool operator!=(Type other) const { return !(*this == other); }
95 explicit operator bool() const { return impl; }
96
97 bool operator!() const { return impl == nullptr; }
98
99 template <typename... Tys>
100 bool isa() const;
101 template <typename... Tys>
102 bool isa_and_nonnull() const;
103 template <typename U>
104 U dyn_cast() const;
105 template <typename U>
106 U dyn_cast_or_null() const;
107 template <typename U>
108 U cast() const;
109
110 // Support type casting Type to itself.
111 static bool classof(Type) { return true; }
112
113 /// Return a unique identifier for the concrete type. This is used to support
114 /// dynamic type casting.
115 TypeID getTypeID() { return impl->getAbstractType().getTypeID(); }
116
117 /// Return the MLIRContext in which this type was uniqued.
118 MLIRContext *getContext() const;
119
120 /// Get the dialect this type is registered to.
121 Dialect &getDialect() const { return impl->getAbstractType().getDialect(); }
122
123 // Convenience predicates. This is only for floating point types,
124 // derived types should use isa/dyn_cast.
125 bool isIndex() const;
126 bool isBF16() const;
127 bool isF16() const;
128 bool isF32() const;
129 bool isF64() const;
130 bool isF80() const;
131 bool isF128() const;
132
133 /// Return true if this is an integer type with the specified width.
134 bool isInteger(unsigned width) const;
135 /// Return true if this is a signless integer type (with the specified width).
136 bool isSignlessInteger() const;
137 bool isSignlessInteger(unsigned width) const;
138 /// Return true if this is a signed integer type (with the specified width).
139 bool isSignedInteger() const;
140 bool isSignedInteger(unsigned width) const;
141 /// Return true if this is an unsigned integer type (with the specified
142 /// width).
143 bool isUnsignedInteger() const;
144 bool isUnsignedInteger(unsigned width) const;
145
146 /// Return the bit width of an integer or a float type, assert failure on
147 /// other types.
148 unsigned getIntOrFloatBitWidth() const;
149
150 /// Return true if this is a signless integer or index type.
151 bool isSignlessIntOrIndex() const;
152 /// Return true if this is a signless integer, index, or float type.
153 bool isSignlessIntOrIndexOrFloat() const;
154 /// Return true of this is a signless integer or a float type.
155 bool isSignlessIntOrFloat() const;
156
157 /// Return true if this is an integer (of any signedness) or an index type.
158 bool isIntOrIndex() const;
159 /// Return true if this is an integer (of any signedness) or a float type.
160 bool isIntOrFloat() const;
161 /// Return true if this is an integer (of any signedness), index, or float
162 /// type.
163 bool isIntOrIndexOrFloat() const;
164
165 /// Print the current type.
166 void print(raw_ostream &os) const;
167 void print(raw_ostream &os, AsmState &state) const;
168 void dump() const;
169
170 friend ::llvm::hash_code hash_value(Type arg);
171
172 /// Methods for supporting PointerLikeTypeTraits.
173 const void *getAsOpaquePointer() const {
174 return static_cast<const void *>(impl);
175 }
176 static Type getFromOpaquePointer(const void *pointer) {
177 return Type(reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
178 }
179
180 /// Returns true if the type was registered with a particular trait.
181 template <template <typename T> class Trait>
182 bool hasTrait() {
183 return getAbstractType().hasTrait<Trait>();
184 }
185
186 /// Return the abstract type descriptor for this type.
187 const AbstractTy &getAbstractType() { return impl->getAbstractType(); }
188
189 /// Return the Type implementation.
190 ImplType *getImpl() const { return impl; }
191
192protected:
193 ImplType *impl{nullptr};
194};
195
196inline raw_ostream &operator<<(raw_ostream &os, Type type) {
197 type.print(os);
198 return os;
199}
200
201//===----------------------------------------------------------------------===//
202// TypeTraitBase
203//===----------------------------------------------------------------------===//
204
205namespace TypeTrait {
206/// This class represents the base of a type trait.
207template <typename ConcreteType, template <typename> class TraitType>
208using TraitBase = detail::StorageUserTraitBase<ConcreteType, TraitType>;
209} // namespace TypeTrait
210
211//===----------------------------------------------------------------------===//
212// TypeInterface
213//===----------------------------------------------------------------------===//
214
215/// This class represents the base of a type interface. See the definition of
216/// `detail::Interface` for requirements on the `Traits` type.
217template <typename ConcreteType, typename Traits>
218class TypeInterface : public detail::Interface<ConcreteType, Type, Traits, Type,
219 TypeTrait::TraitBase> {
220public:
221 using Base = TypeInterface<ConcreteType, Traits>;
222 using InterfaceBase =
223 detail::Interface<ConcreteType, Type, Traits, Type, TypeTrait::TraitBase>;
224 using InterfaceBase::InterfaceBase;
225
226private:
227 /// Returns the impl interface instance for the given type.
228 static typename InterfaceBase::Concept *getInterfaceFor(Type type) {
229 return type.getAbstractType().getInterface<ConcreteType>();
230 }
231
232 /// Allow access to 'getInterfaceFor'.
233 friend InterfaceBase;
234};
235
236//===----------------------------------------------------------------------===//
237// Core TypeTrait
238//===----------------------------------------------------------------------===//
239
240/// This trait is used to determine if a type is mutable or not. It is attached
241/// on a type if the corresponding ImplType defines a `mutate` function with
242/// a proper signature.
243namespace TypeTrait {
244template <typename ConcreteType>
245using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>;
246} // namespace TypeTrait
247
248//===----------------------------------------------------------------------===//
249// Type Utils
250//===----------------------------------------------------------------------===//
251
252// Make Type hashable.
253inline ::llvm::hash_code hash_value(Type arg) {
254 return DenseMapInfo<const Type::ImplType *>::getHashValue(arg.impl);
255}
256
257template <typename... Tys>
258bool Type::isa() const {
259 return llvm::isa<Tys...>(*this);
260}
261
262template <typename... Tys>
263bool Type::isa_and_nonnull() const {
264 return llvm::isa_and_present<Tys...>(*this);
265}
266
267template <typename U>
268U Type::dyn_cast() const {
269 return llvm::dyn_cast<U>(*this);
270}
271
272template <typename U>
273U Type::dyn_cast_or_null() const {
274 return llvm::dyn_cast_or_null<U>(*this);
275}
276
277template <typename U>
278U Type::cast() const {
279 return llvm::cast<U>(*this);
23
Value assigned to 'DebugFlag', which participates in a condition later
280}
281
282} // namespace mlir
283
284namespace llvm {
285
286// Type hash just like pointers.
287template <>
288struct DenseMapInfo<mlir::Type> {
289 static mlir::Type getEmptyKey() {
290 auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
291 return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer));
292 }
293 static mlir::Type getTombstoneKey() {
294 auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
295 return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer));
296 }
297 static unsigned getHashValue(mlir::Type val) { return mlir::hash_value(val); }
298 static bool isEqual(mlir::Type LHS, mlir::Type RHS) { return LHS == RHS; }
299};
300template <typename T>
301struct DenseMapInfo<T, std::enable_if_t<std::is_base_of<mlir::Type, T>::value &&
302 !mlir::detail::IsInterface<T>::value>>
303 : public DenseMapInfo<mlir::Type> {
304 static T getEmptyKey() {
305 const void *pointer = llvm::DenseMapInfo<const void *>::getEmptyKey();
306 return T::getFromOpaquePointer(pointer);
307 }
308 static T getTombstoneKey() {
309 const void *pointer = llvm::DenseMapInfo<const void *>::getTombstoneKey();
310 return T::getFromOpaquePointer(pointer);
311 }
312};
313
314/// We align TypeStorage by 8, so allow LLVM to steal the low bits.
315template <>
316struct PointerLikeTypeTraits<mlir::Type> {
317public:
318 static inline void *getAsVoidPointer(mlir::Type I) {
319 return const_cast<void *>(I.getAsOpaquePointer());
320 }
321 static inline mlir::Type getFromVoidPointer(void *P) {
322 return mlir::Type::getFromOpaquePointer(P);
323 }
324 static constexpr int NumLowBitsAvailable = 3;
325};
326
327/// Add support for llvm style casts.
328/// We provide a cast between To and From if From is mlir::Type or derives from
329/// it
330template <typename To, typename From>
331struct CastInfo<
332 To, From,
333 std::enable_if_t<std::is_same_v<mlir::Type, std::remove_const_t<From>> ||
334 std::is_base_of_v<mlir::Type, From>>>
335 : NullableValueCastFailed<To>,
336 DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
337 /// Arguments are taken as mlir::Type here and not as `From`, because when
338 /// casting from an intermediate type of the hierarchy to one of its children,
339 /// the val.getTypeID() inside T::classof will use the static getTypeID of the
340 /// parent instead of the non-static Type::getTypeID that returns the dynamic
341 /// ID. This means that T::classof would end up comparing the static TypeID of
342 /// the children to the static TypeID of its parent, making it impossible to
343 /// downcast from the parent to the child.
344 static inline bool isPossible(mlir::Type ty) {
345 /// Return a constant true instead of a dynamic true when casting to self or
346 /// up the hierarchy.
347 return std::is_same_v<To, std::remove_const_t<From>> ||
348 std::is_base_of_v<To, From> || To::classof(ty);
349 }
350 static inline To doCast(mlir::Type ty) { return To(ty.getImpl()); }
351};
352
353} // namespace llvm
354
355#endif // MLIR_IR_TYPES_H

tools/mlir/include/mlir/Dialect/Arith/IR/ArithOps.h.inc

1/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\
2|* *|
3|* Op Declarations *|
4|* *|
5|* Automatically generated file, do not edit! *|
6|* *|
7\*===----------------------------------------------------------------------===*/
8
9#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)
10#undef GET_OP_FWD_DEFINES
11namespace mlir {
12namespace arith {
13class AddFOp;
14} // namespace arith
15} // namespace mlir
16namespace mlir {
17namespace arith {
18class AddIOp;
19} // namespace arith
20} // namespace mlir
21namespace mlir {
22namespace arith {
23class AddUICarryOp;
24} // namespace arith
25} // namespace mlir
26namespace mlir {
27namespace arith {
28class AndIOp;
29} // namespace arith
30} // namespace mlir
31namespace mlir {
32namespace arith {
33class BitcastOp;
34} // namespace arith
35} // namespace mlir
36namespace mlir {
37namespace arith {
38class CeilDivSIOp;
39} // namespace arith
40} // namespace mlir
41namespace mlir {
42namespace arith {
43class CeilDivUIOp;
44} // namespace arith
45} // namespace mlir
46namespace mlir {
47namespace arith {
48class CmpFOp;
49} // namespace arith
50} // namespace mlir
51namespace mlir {
52namespace arith {
53class CmpIOp;
54} // namespace arith
55} // namespace mlir
56namespace mlir {
57namespace arith {
58class ConstantOp;
59} // namespace arith
60} // namespace mlir
61namespace mlir {
62namespace arith {
63class DivFOp;
64} // namespace arith
65} // namespace mlir
66namespace mlir {
67namespace arith {
68class DivSIOp;
69} // namespace arith
70} // namespace mlir
71namespace mlir {
72namespace arith {
73class DivUIOp;
74} // namespace arith
75} // namespace mlir
76namespace mlir {
77namespace arith {
78class ExtFOp;
79} // namespace arith
80} // namespace mlir
81namespace mlir {
82namespace arith {
83class ExtSIOp;
84} // namespace arith
85} // namespace mlir
86namespace mlir {
87namespace arith {
88class ExtUIOp;
89} // namespace arith
90} // namespace mlir
91namespace mlir {
92namespace arith {
93class FPToSIOp;
94} // namespace arith
95} // namespace mlir
96namespace mlir {
97namespace arith {
98class FPToUIOp;
99} // namespace arith
100} // namespace mlir
101namespace mlir {
102namespace arith {
103class FloorDivSIOp;
104} // namespace arith
105} // namespace mlir
106namespace mlir {
107namespace arith {
108class IndexCastOp;
109} // namespace arith
110} // namespace mlir
111namespace mlir {
112namespace arith {
113class MaxFOp;
114} // namespace arith
115} // namespace mlir
116namespace mlir {
117namespace arith {
118class MaxSIOp;
119} // namespace arith
120} // namespace mlir
121namespace mlir {
122namespace arith {
123class MaxUIOp;
124} // namespace arith
125} // namespace mlir
126namespace mlir {
127namespace arith {
128class MinFOp;
129} // namespace arith
130} // namespace mlir
131namespace mlir {
132namespace arith {
133class MinSIOp;
134} // namespace arith
135} // namespace mlir
136namespace mlir {
137namespace arith {
138class MinUIOp;
139} // namespace arith
140} // namespace mlir
141namespace mlir {
142namespace arith {
143class MulFOp;
144} // namespace arith
145} // namespace mlir
146namespace mlir {
147namespace arith {
148class MulIOp;
149} // namespace arith
150} // namespace mlir
151namespace mlir {
152namespace arith {
153class NegFOp;
154} // namespace arith
155} // namespace mlir
156namespace mlir {
157namespace arith {
158class OrIOp;
159} // namespace arith
160} // namespace mlir
161namespace mlir {
162namespace arith {
163class RemFOp;
164} // namespace arith
165} // namespace mlir
166namespace mlir {
167namespace arith {
168class RemSIOp;
169} // namespace arith
170} // namespace mlir
171namespace mlir {
172namespace arith {
173class RemUIOp;
174} // namespace arith
175} // namespace mlir
176namespace mlir {
177namespace arith {
178class SIToFPOp;
179} // namespace arith
180} // namespace mlir
181namespace mlir {
182namespace arith {
183class ShLIOp;
184} // namespace arith
185} // namespace mlir
186namespace mlir {
187namespace arith {
188class ShRSIOp;
189} // namespace arith
190} // namespace mlir
191namespace mlir {
192namespace arith {
193class ShRUIOp;
194} // namespace arith
195} // namespace mlir
196namespace mlir {
197namespace arith {
198class SubFOp;
199} // namespace arith
200} // namespace mlir
201namespace mlir {
202namespace arith {
203class SubIOp;
204} // namespace arith
205} // namespace mlir
206namespace mlir {
207namespace arith {
208class TruncFOp;
209} // namespace arith
210} // namespace mlir
211namespace mlir {
212namespace arith {
213class TruncIOp;
214} // namespace arith
215} // namespace mlir
216namespace mlir {
217namespace arith {
218class UIToFPOp;
219} // namespace arith
220} // namespace mlir
221namespace mlir {
222namespace arith {
223class XOrIOp;
224} // namespace arith
225} // namespace mlir
226namespace mlir {
227namespace arith {
228class SelectOp;
229} // namespace arith
230} // namespace mlir
231#endif
232
233#ifdef GET_OP_CLASSES
234#undef GET_OP_CLASSES
235
236
237//===----------------------------------------------------------------------===//
238// Local Utility Method Definitions
239//===----------------------------------------------------------------------===//
240
241namespace mlir {
242namespace arith {
243
244//===----------------------------------------------------------------------===//
245// ::mlir::arith::AddFOp declarations
246//===----------------------------------------------------------------------===//
247
248class AddFOpAdaptor {
249public:
250 AddFOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs = nullptr, ::mlir::RegionRange regions = {});
251
252 AddFOpAdaptor(AddFOp op);
253
254 ::mlir::ValueRange getOperands();
255 std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
256 ::mlir::ValueRange getODSOperands(unsigned index);
257 ::mlir::Value getLhs();
258 ::mlir::Value getRhs();
259 ::mlir::DictionaryAttr getAttributes();
260 ::mlir::LogicalResult verify(::mlir::Location loc);
261private:
262 ::mlir::ValueRange odsOperands;
263 ::mlir::DictionaryAttr odsAttrs;
264 ::mlir::RegionRange odsRegions;
265 ::llvm::Optional<::mlir::OperationName> odsOpName;
266};
267class AddFOp : public ::mlir::Op<AddFOp, ::mlir::OpTrait::ZeroRegions, ::mlir::OpTrait::OneResult, ::mlir::OpTrait::OneTypedResult<::mlir::Type>::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::OpTrait::IsCommutative, ::mlir::OpTrait::SameOperandsAndResultType, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::VectorUnrollOpInterface::Trait, ::mlir::OpTrait::Elementwise, ::mlir::OpTrait::Scalarizable, ::mlir::OpTrait::Vectorizable, ::mlir::OpTrait::Tensorizable, ::mlir::InferTypeOpInterface::Trait> {
268public:
269 using Op::Op;
270 using Op::print;
271 using Adaptor = AddFOpAdaptor;
272public:
273 static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() {
274 return {};
275 }
276
277 static constexpr ::llvm::StringLiteral getOperationName() {
278 return ::llvm::StringLiteral("arith.addf");
279 }
280
281 std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
282 ::mlir::Operation::operand_range getODSOperands(unsigned index);
283 ::mlir::Value getLhs();
284 ::mlir::Value getRhs();
285 ::mlir::MutableOperandRange getLhsMutable();
286 ::mlir::MutableOperandRange getRhsMutable();
287 std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index);
288 ::mlir::Operation::result_range getODSResults(unsigned index);
289 ::mlir::Value getResult();
290 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type result, ::mlir::Value lhs, ::mlir::Value rhs);
291 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value lhs, ::mlir::Value rhs);
292 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs);
293 static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
294 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
295 ::mlir::LogicalResult verifyInvariantsImpl();
296 ::mlir::LogicalResult verifyInvariants();
297 ::mlir::OpFoldResult fold(::llvm::ArrayRef<::mlir::Attribute> operands);
298 static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes);
299 static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
300 void print(::mlir::OpAsmPrinter &_odsPrinter);
301 void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects);
302public:
303};
304} // namespace arith
305} // namespace mlir
306MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::arith::AddFOp)namespace mlir { namespace detail { template <> class TypeIDResolver
< ::mlir::arith::AddFOp> { public: static TypeID resolveTypeID
() { return id; } private: static SelfOwningTypeID id; }; } }
307
308namespace mlir {
309namespace arith {
310
311//===----------------------------------------------------------------------===//
312// ::mlir::arith::AddIOp declarations
313//===----------------------------------------------------------------------===//
314
315class AddIOpAdaptor {
316public:
317 AddIOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs = nullptr, ::mlir::RegionRange regions = {});
318
319 AddIOpAdaptor(AddIOp op);
320
321 ::mlir::ValueRange getOperands();
322 std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
323 ::mlir::ValueRange getODSOperands(unsigned index);
324 ::mlir::Value getLhs();
325 ::mlir::Value getRhs();
326 ::mlir::DictionaryAttr getAttributes();
327 ::mlir::LogicalResult verify(::mlir::Location loc);
328private:
329 ::mlir::ValueRange odsOperands;
330 ::mlir::DictionaryAttr odsAttrs;
331 ::mlir::RegionRange odsRegions;
332 ::llvm::Optional<::mlir::OperationName> odsOpName;
333};
334class AddIOp : public ::mlir::Op<AddIOp, ::mlir::OpTrait::ZeroRegions, ::mlir::OpTrait::OneResult, ::mlir::OpTrait::OneTypedResult<::mlir::Type>::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::OpTrait::IsCommutative, ::mlir::InferIntRangeInterface::Trait, ::mlir::OpTrait::SameOperandsAndResultType, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::VectorUnrollOpInterface::Trait, ::mlir::OpTrait::Elementwise, ::mlir::OpTrait::Scalarizable, ::mlir::OpTrait::Vectorizable, ::mlir::OpTrait::Tensorizable, ::mlir::InferTypeOpInterface::Trait> {
335public:
336 using Op::Op;
337 using Op::print;
338 using Adaptor = AddIOpAdaptor;
339public:
340 static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() {
341 return {};
342 }
343
344 static constexpr ::llvm::StringLiteral getOperationName() {
345 return ::llvm::StringLiteral("arith.addi");
346 }
347
348 std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
349 ::mlir::Operation::operand_range getODSOperands(unsigned index);
350 ::mlir::Value getLhs();
351 ::mlir::Value getRhs();
352 ::mlir::MutableOperandRange getLhsMutable();
353 ::mlir::MutableOperandRange getRhsMutable();
354 std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index);
355 ::mlir::Operation::result_range getODSResults(unsigned index);
356 ::mlir::Value getResult();
357 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type result, ::mlir::Value lhs, ::mlir::Value rhs);
358 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value lhs, ::mlir::Value rhs);
359 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs);
360 static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
361 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
362 ::mlir::LogicalResult verifyInvariantsImpl();
363 ::mlir::LogicalResult verifyInvariants();
364 static void getCanonicalizationPatterns(::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context);
365 ::mlir::OpFoldResult fold(::llvm::ArrayRef<::mlir::Attribute> operands);
366 static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes);
367 void inferResultRanges(::llvm::ArrayRef<::mlir::ConstantIntRanges> argRanges, ::mlir::SetIntRangeFn setResultRanges);
368 static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
369 void print(::mlir::OpAsmPrinter &_odsPrinter);
370 void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects);
371public:
372};
373} // namespace arith
374} // namespace mlir
375MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::arith::AddIOp)namespace mlir { namespace detail { template <> class TypeIDResolver
< ::mlir::arith::AddIOp> { public: static TypeID resolveTypeID
() { return id; } private: static SelfOwningTypeID id; }; } }
376
377namespace mlir {
378namespace arith {
379
380//===----------------------------------------------------------------------===//
381// ::mlir::arith::AddUICarryOp declarations
382//===----------------------------------------------------------------------===//
383
384class AddUICarryOpAdaptor {
385public:
386 AddUICarryOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs = nullptr, ::mlir::RegionRange regions = {});
387
388 AddUICarryOpAdaptor(AddUICarryOp op);
389
390 ::mlir::ValueRange getOperands();
391 std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
392 ::mlir::ValueRange getODSOperands(unsigned index);
393 ::mlir::Value getLhs();
394 ::mlir::Value getRhs();
395 ::mlir::DictionaryAttr getAttributes();
396 ::mlir::LogicalResult verify(::mlir::Location loc);
397private:
398 ::mlir::ValueRange odsOperands;
399 ::mlir::DictionaryAttr odsAttrs;
400 ::mlir::RegionRange odsRegions;
401 ::llvm::Optional<::mlir::OperationName> odsOpName;
402};
403class AddUICarryOp : public ::mlir::Op<AddUICarryOp, ::mlir::OpTrait::ZeroRegions, ::mlir::OpTrait::NResults<2>::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::OpTrait::IsCommutative, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::VectorUnrollOpInterface::Trait, ::mlir::OpTrait::Elementwise, ::mlir::OpTrait::Scalarizable, ::mlir::OpTrait::Vectorizable, ::mlir::OpTrait::Tensorizable, ::mlir::OpAsmOpInterface::Trait> {
404public:
405 using Op::Op;
406 using Op::print;
407 using Adaptor = AddUICarryOpAdaptor;
408public:
409 static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() {
410 return {};
411 }
412
413 void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn);
414 static constexpr ::llvm::StringLiteral getOperationName() {
415 return ::llvm::StringLiteral("arith.addui_carry");
416 }
417
418 std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
419 ::mlir::Operation::operand_range getODSOperands(unsigned index);
420 ::mlir::Value getLhs();
421 ::mlir::Value getRhs();
422 ::mlir::MutableOperandRange getLhsMutable();
423 ::mlir::MutableOperandRange getRhsMutable();
424 std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index);
425 ::mlir::Operation::result_range getODSResults(unsigned index);
426 ::mlir::Value getSum();
427 ::mlir::Value getCarry();
428 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value lhs, Value rhs);
429 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type sum, ::mlir::Type carry, ::mlir::Value lhs, ::mlir::Value rhs);
430 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs);
431 static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
432 ::mlir::LogicalResult verifyInvariantsImpl();
433 ::mlir::LogicalResult verifyInvariants();
434 ::mlir::LogicalResult fold(::llvm::ArrayRef<::mlir::Attribute> operands, ::llvm::SmallVectorImpl<::mlir::OpFoldResult> &results);
435 static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
436 void print(::mlir::OpAsmPrinter &_odsPrinter);
437 void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects);
438public:
439 ::llvm::Optional<::llvm::SmallVector<int64_t, 4>> getShapeForUnroll();
440};
441} // namespace arith
442} // namespace mlir
443MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::arith::AddUICarryOp)namespace mlir { namespace detail { template <> class TypeIDResolver
< ::mlir::arith::AddUICarryOp> { public: static TypeID resolveTypeID
() { return id; } private: static SelfOwningTypeID id; }; } }
444
445namespace mlir {
446namespace arith {
447
448//===----------------------------------------------------------------------===//
449// ::mlir::arith::AndIOp declarations
450//===----------------------------------------------------------------------===//
451
452class AndIOpAdaptor {
453public:
454 AndIOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs = nullptr, ::mlir::RegionRange regions = {});
455
456 AndIOpAdaptor(AndIOp op);
457
458 ::mlir::ValueRange getOperands();
459 std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
460 ::mlir::ValueRange getODSOperands(unsigned index);
461 ::mlir::Value getLhs();
462 ::mlir::Value getRhs();
463 ::mlir::DictionaryAttr getAttributes();
464 ::mlir::LogicalResult verify(::mlir::Location loc);
465private:
466 ::mlir::ValueRange odsOperands;
467 ::mlir::DictionaryAttr odsAttrs;
468 ::mlir::RegionRange odsRegions;
469 ::llvm::Optional<::mlir::OperationName> odsOpName;
470};
471class AndIOp : public ::mlir::Op<AndIOp, ::mlir::OpTrait::ZeroRegions, ::mlir::OpTrait::OneResult, ::mlir::OpTrait::OneTypedResult<::mlir::Type>::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::OpTrait::IsCommutative, ::mlir::OpTrait::IsIdempotent, ::mlir::InferIntRangeInterface::Trait, ::mlir::OpTrait::SameOperandsAndResultType, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::VectorUnrollOpInterface::Trait, ::mlir::OpTrait::Elementwise, ::mlir::OpTrait::Scalarizable, ::mlir::OpTrait::Vectorizable, ::mlir::OpTrait::Tensorizable, ::mlir::InferTypeOpInterface::Trait> {
472public:
473 using Op::Op;
474 using Op::print;
475 using Adaptor = AndIOpAdaptor;
476public:
477 static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() {
478 return {};
479 }
480
481 static constexpr ::llvm::StringLiteral getOperationName() {
482 return ::llvm::StringLiteral("arith.andi");
483 }
484
485 std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
486 ::mlir::Operation::operand_range getODSOperands(unsigned index);
487 ::mlir::Value getLhs();
488 ::mlir::Value getRhs();
489 ::mlir::MutableOperandRange getLhsMutable();
490 ::mlir::MutableOperandRange getRhsMutable();
491 std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index);
492 ::mlir::Operation::result_range getODSResults(unsigned index);
493 ::mlir::Value getResult();
494 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type result, ::mlir::Value lhs, ::mlir::Value rhs);
495 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value lhs, ::mlir::Value rhs);
496 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs);
497 static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
498 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
499 ::mlir::LogicalResult verifyInvariantsImpl();
500 ::mlir::LogicalResult verifyInvariants();
501 static void getCanonicalizationPatterns(::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context);
502 ::mlir::OpFoldResult fold(::llvm::ArrayRef<::mlir::Attribute> operands);
503 static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes);
504 void inferResultRanges(::llvm::ArrayRef<::mlir::ConstantIntRanges> argRanges, ::mlir::SetIntRangeFn setResultRanges);
505 static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
506 void print(::mlir::OpAsmPrinter &_odsPrinter);
507 void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects);
508public:
509};
510} // namespace arith
511} // namespace mlir
512MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::arith::AndIOp)namespace mlir { namespace detail { template <> class TypeIDResolver
< ::mlir::arith::AndIOp> { public: static TypeID resolveTypeID
() { return id; } private: static SelfOwningTypeID id; }; } }
513
514namespace mlir {
515namespace arith {
516
517//===----------------------------------------------------------------------===//
518// ::mlir::arith::BitcastOp declarations
519//===----------------------------------------------------------------------===//
520
521class BitcastOpAdaptor {
522public:
523 BitcastOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs = nullptr, ::mlir::RegionRange regions = {});
524
525 BitcastOpAdaptor(BitcastOp op);
526
527 ::mlir::ValueRange getOperands();
528 std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
529 ::mlir::ValueRange getODSOperands(unsigned index);
530 ::mlir::Value getIn();
531 ::mlir::DictionaryAttr getAttributes();
532 ::mlir::LogicalResult verify(::mlir::Location loc);
533private:
534 ::mlir::ValueRange odsOperands;
535 ::mlir::DictionaryAttr odsAttrs;
536 ::mlir::RegionRange odsRegions;
537 ::llvm::Optional<::mlir::OperationName> odsOpName;
538};
539class BitcastOp : public ::mlir::Op<BitcastOp, ::mlir::OpTrait::ZeroRegions, ::mlir::OpTrait::OneResult, ::mlir::OpTrait::OneTypedResult<::mlir::Type>::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::OneOperand, ::mlir::OpTrait::OpInvariants, ::mlir::OpTrait::SameOperandsAndResultShape, ::mlir::CastOpInterface::Trait, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::VectorUnrollOpInterface::Trait, ::mlir::OpTrait::Elementwise, ::mlir::OpTrait::Scalarizable, ::mlir::OpTrait::Vectorizable, ::mlir::OpTrait::Tensorizable> {
540public:
541 using Op::Op;
542 using Op::print;
543 using Adaptor = BitcastOpAdaptor;
544public:
545 static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() {
546 return {};
547 }
548
549 static constexpr ::llvm::StringLiteral getOperationName() {
550 return ::llvm::StringLiteral("arith.bitcast");
551 }
552
553 std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
554 ::mlir::Operation::operand_range getODSOperands(unsigned index);
555 ::mlir::Value getIn();
556 ::mlir::MutableOperandRange getInMutable();
557 std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index);
558 ::mlir::Operation::result_range getODSResults(unsigned index);
559 ::mlir::Value getOut();
560 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type out, ::mlir::Value in);
561 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value in);
562 static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
563 ::mlir::LogicalResult verifyInvariantsImpl();
564 ::mlir::LogicalResult verifyInvariants();
565 static void getCanonicalizationPatterns(::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context);
566 ::mlir::OpFoldResult fold(::llvm::ArrayRef<::mlir::Attribute> operands);
567 static bool areCastCompatible(::mlir::TypeRange inputs, ::mlir::TypeRange outputs);
568 static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
569 void print(::mlir::OpAsmPrinter &_odsPrinter);
570 void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects);
571public:
572};
573} // namespace arith
574} // namespace mlir
575MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::arith::BitcastOp)namespace mlir { namespace detail { template <> class TypeIDResolver
< ::mlir::arith::BitcastOp> { public: static TypeID resolveTypeID
() { return id; } private: static SelfOwningTypeID id; }; } }
576
577namespace mlir {
578namespace arith {
579
580//===----------------------------------------------------------------------===//
581// ::mlir::arith::CeilDivSIOp declarations
582//===----------------------------------------------------------------------===//
583
584class CeilDivSIOpAdaptor {
585public:
586 CeilDivSIOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs = nullptr, ::mlir::RegionRange regions = {});
587
588 CeilDivSIOpAdaptor(CeilDivSIOp op);
589
590 ::mlir::ValueRange getOperands();
591 std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
592 ::mlir::ValueRange getODSOperands(unsigned index);
593 ::mlir::Value getLhs();
594 ::mlir::Value getRhs();
595 ::mlir::DictionaryAttr getAttributes();
596 ::mlir::LogicalResult verify(::mlir::Location loc);
597private:
598 ::mlir::ValueRange odsOperands;
599 ::mlir::DictionaryAttr odsAttrs;
600 ::mlir::RegionRange odsRegions;
601 ::llvm::Optional<::mlir::OperationName> odsOpName;
602};
603class CeilDivSIOp : public ::mlir::Op<CeilDivSIOp, ::mlir::OpTrait::ZeroRegions, ::mlir::OpTrait::OneResult, ::mlir::OpTrait::OneTypedResult<::mlir::Type>::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::InferIntRangeInterface::Trait, ::mlir::OpTrait::SameOperandsAndResultType, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::VectorUnrollOpInterface::Trait, ::mlir::OpTrait::Elementwise, ::mlir::OpTrait::Scalarizable, ::mlir::OpTrait::Vectorizable, ::mlir::OpTrait::Tensorizable, ::mlir::InferTypeOpInterface::Trait> {
604public:
605 using Op::Op;
606 using Op::print;
607 using Adaptor = CeilDivSIOpAdaptor;
608public:
609 static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() {
610 return {};
611 }
612
613 static constexpr ::llvm::StringLiteral getOperationName() {
614 return ::llvm::StringLiteral("arith.ceildivsi");
615 }
616
617 std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
618 ::mlir::Operation::operand_range getODSOperands(unsigned index);
619 ::mlir::Value getLhs();
620 ::mlir::Value getRhs();
621 ::mlir::MutableOperandRange getLhsMutable();
622 ::mlir::MutableOperandRange getRhsMutable();
623 std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index);
624 ::mlir::Operation::result_range getODSResults(unsigned index);
625 ::mlir::Value getResult();
626 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type result, ::mlir::Value lhs, ::mlir::Value rhs);
627 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value lhs, ::mlir::Value rhs);
628 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs);
629 static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
630 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
631 ::mlir::LogicalResult verifyInvariantsImpl();
632 ::mlir::LogicalResult verifyInvariants();
633 ::mlir::OpFoldResult fold(::llvm::ArrayRef<::mlir::Attribute> operands);
634 static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes);
635 void inferResultRanges(::llvm::ArrayRef<::mlir::ConstantIntRanges> argRanges, ::mlir::SetIntRangeFn setResultRanges);
636 static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
637 void print(::mlir::OpAsmPrinter &_odsPrinter);
638 void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects);
639public:
640};
641} // namespace arith
642} // namespace mlir
643MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::arith::CeilDivSIOp)namespace mlir { namespace detail { template <> class TypeIDResolver
< ::mlir::arith::CeilDivSIOp> { public: static TypeID resolveTypeID
() { return id; } private: static SelfOwningTypeID id; }; } }
644
645namespace mlir {
646namespace arith {
647
648//===----------------------------------------------------------------------===//
649// ::mlir::arith::CeilDivUIOp declarations
650//===----------------------------------------------------------------------===//
651
652class CeilDivUIOpAdaptor {
653public:
654 CeilDivUIOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs = nullptr, ::mlir::RegionRange regions = {});
655
656 CeilDivUIOpAdaptor(CeilDivUIOp op);
657
658 ::mlir::ValueRange getOperands();
659 std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
660 ::mlir::ValueRange getODSOperands(unsigned index);
661 ::mlir::Value getLhs();
662 ::mlir::Value getRhs();
663 ::mlir::DictionaryAttr getAttributes();
664 ::mlir::LogicalResult verify(::mlir::Location loc);
665private:
666 ::mlir::ValueRange odsOperands;
667 ::mlir::DictionaryAttr odsAttrs;
668 ::mlir::RegionRange odsRegions;
669 ::llvm::Optional<::mlir::OperationName> odsOpName;
670};
671class CeilDivUIOp : public ::mlir::Op<CeilDivUIOp, ::mlir::OpTrait::ZeroRegions, ::mlir::OpTrait::OneResult, ::mlir::OpTrait::OneTypedResult<::mlir::Type>::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::InferIntRangeInterface::Trait, ::mlir::OpTrait::SameOperandsAndResultType, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::VectorUnrollOpInterface::Trait, ::mlir::OpTrait::Elementwise, ::mlir::OpTrait::Scalarizable, ::mlir::OpTrait::Vectorizable, ::mlir::OpTrait::Tensorizable, ::mlir::InferTypeOpInterface::Trait> {
672public:
673 using Op::Op;
674 using Op::print;
675 using Adaptor = CeilDivUIOpAdaptor;
676public:
677 static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() {
678 return {};
679 }
680
681 static constexpr ::llvm::StringLiteral getOperationName() {
682 return ::llvm::StringLiteral("arith.ceildivui");
683 }
684
685 std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
686 ::mlir::Operation::operand_range getODSOperands(unsigned index);
687 ::mlir::Value getLhs();
688 ::mlir::Value getRhs();
689 ::mlir::MutableOperandRange getLhsMutable();
690 ::mlir::MutableOperandRange getRhsMutable();
691 std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index);
692 ::mlir::Operation::result_range getODSResults(unsigned index);
693 ::mlir::Value getResult();
694 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type result, ::mlir::Value lhs, ::mlir::Value rhs);
695 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value lhs, ::mlir::Value rhs);
696 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs);
697 static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
698 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
699 ::mlir::LogicalResult verifyInvariantsImpl();
700 ::mlir::LogicalResult verifyInvariants();
701 ::mlir::OpFoldResult fold(::llvm::ArrayRef<::mlir::Attribute> operands);
702 static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes);
703 void inferResultRanges(::llvm::ArrayRef<::mlir::ConstantIntRanges> argRanges, ::mlir::SetIntRangeFn setResultRanges);
704 static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
705 void print(::mlir::OpAsmPrinter &_odsPrinter);
706 void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects);
707public:
708};
709} // namespace arith
710} // namespace mlir
711MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::arith::CeilDivUIOp)namespace mlir { namespace detail { template <> class TypeIDResolver
< ::mlir::arith::CeilDivUIOp> { public: static TypeID resolveTypeID
() { return id; } private: static SelfOwningTypeID id; }; } }
712
713namespace mlir {
714namespace arith {
715
716//===----------------------------------------------------------------------===//
717// ::mlir::arith::CmpFOp declarations
718//===----------------------------------------------------------------------===//
719
720class CmpFOpAdaptor {
721public:
722 CmpFOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs = nullptr, ::mlir::RegionRange regions = {});
723
724 CmpFOpAdaptor(CmpFOp op);
725
726 ::mlir::ValueRange getOperands();
727 std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
728 ::mlir::ValueRange getODSOperands(unsigned index);
729 ::mlir::Value getLhs();
730 ::mlir::Value getRhs();
731 ::mlir::DictionaryAttr getAttributes();
732 ::mlir::arith::CmpFPredicateAttr getPredicateAttr();
733 ::mlir::arith::CmpFPredicate getPredicate();
734 ::mlir::LogicalResult verify(::mlir::Location loc);
735private:
736 ::mlir::ValueRange odsOperands;
737 ::mlir::DictionaryAttr odsAttrs;
738 ::mlir::RegionRange odsRegions;
739 ::llvm::Optional<::mlir::OperationName> odsOpName;
740};
741class CmpFOp : public ::mlir::Op<CmpFOp, ::mlir::OpTrait::ZeroRegions, ::mlir::OpTrait::OneResult, ::mlir::OpTrait::OneTypedResult<::mlir::Type>::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::OpTrait::SameTypeOperands, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::VectorUnrollOpInterface::Trait, ::mlir::OpTrait::Elementwise, ::mlir::OpTrait::Scalarizable, ::mlir::OpTrait::Vectorizable, ::mlir::OpTrait::Tensorizable> {
742public:
743 using Op::Op;
744 using Op::print;
745 using Adaptor = CmpFOpAdaptor;
746public:
747 static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() {
748 static ::llvm::StringRef attrNames[] = {::llvm::StringRef("predicate")};
749 return ::llvm::makeArrayRef(attrNames);
750 }
751
752 ::mlir::StringAttr getPredicateAttrName() {
753 return getAttributeNameForIndex(0);
754 }
755
756 static ::mlir::StringAttr getPredicateAttrName(::mlir::OperationName name) {
757 return getAttributeNameForIndex(name, 0);
758 }
759
760 static constexpr ::llvm::StringLiteral getOperationName() {
761 return ::llvm::StringLiteral("arith.cmpf");
762 }
763
764 std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
765 ::mlir::Operation::operand_range getODSOperands(unsigned index);
766 ::mlir::Value getLhs();
767 ::mlir::Value getRhs();
768 ::mlir::MutableOperandRange getLhsMutable();
769 ::mlir::MutableOperandRange getRhsMutable();
770 std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index);
771 ::mlir::Operation::result_range getODSResults(unsigned index);
772 ::mlir::Value getResult();
773 ::mlir::arith::CmpFPredicateAttr getPredicateAttr();
774 ::mlir::arith::CmpFPredicate getPredicate();
775 void setPredicateAttr(::mlir::arith::CmpFPredicateAttr attr);
776 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, CmpFPredicate predicate, Value lhs, Value rhs);
777 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type result, ::mlir::arith::CmpFPredicateAttr predicate, ::mlir::Value lhs, ::mlir::Value rhs);
778 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::arith::CmpFPredicateAttr predicate, ::mlir::Value lhs, ::mlir::Value rhs);
779 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type result, ::mlir::arith::CmpFPredicate predicate, ::mlir::Value lhs, ::mlir::Value rhs);
780 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::arith::CmpFPredicate predicate, ::mlir::Value lhs, ::mlir::Value rhs);
781 static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
782 ::mlir::LogicalResult verifyInvariantsImpl();
783 ::mlir::LogicalResult verifyInvariants();
784 static void getCanonicalizationPatterns(::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context);
785 ::mlir::OpFoldResult fold(::llvm::ArrayRef<::mlir::Attribute> operands);
786 static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
787 void print(::mlir::OpAsmPrinter &_odsPrinter);
788 void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects);
789private:
790 ::mlir::StringAttr getAttributeNameForIndex(unsigned index) {
791 return getAttributeNameForIndex((*this)->getName(), index);
792 }
793
794 static ::mlir::StringAttr getAttributeNameForIndex(::mlir::OperationName name, unsigned index) {
795 assert(index < 1 && "invalid attribute index")(static_cast <bool> (index < 1 && "invalid attribute index"
) ? void (0) : __assert_fail ("index < 1 && \"invalid attribute index\""
, "tools/mlir/include/mlir/Dialect/Arith/IR/ArithOps.h.inc", 795
, __extension__ __PRETTY_FUNCTION__))
;
796 return name.getRegisteredInfo()->getAttributeNames()[index];
797 }
798
799public:
800 static arith::CmpFPredicate getPredicateByName(StringRef name);
801};
802} // namespace arith
803} // namespace mlir
804MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::arith::CmpFOp)namespace mlir { namespace detail { template <> class TypeIDResolver
< ::mlir::arith::CmpFOp> { public: static TypeID resolveTypeID
() { return id; } private: static SelfOwningTypeID id; }; } }
805
806namespace mlir {
807namespace arith {
808
809//===----------------------------------------------------------------------===//
810// ::mlir::arith::CmpIOp declarations
811//===----------------------------------------------------------------------===//
812
813class CmpIOpAdaptor {
814public:
815 CmpIOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs = nullptr, ::mlir::RegionRange regions = {});
816
817 CmpIOpAdaptor(CmpIOp op);
818
819 ::mlir::ValueRange getOperands();
820 std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
821 ::mlir::ValueRange getODSOperands(unsigned index);
822 ::mlir::Value getLhs();
823 ::mlir::Value getRhs();
824 ::mlir::DictionaryAttr getAttributes();
825 ::mlir::arith::CmpIPredicateAttr getPredicateAttr();
826 ::mlir::arith::CmpIPredicate getPredicate();
827 ::mlir::LogicalResult verify(::mlir::Location loc);
828private:
829 ::mlir::ValueRange odsOperands;
830 ::mlir::DictionaryAttr odsAttrs;
831 ::mlir::RegionRange odsRegions;
832 ::llvm::Optional<::mlir::OperationName> odsOpName;
833};
834class CmpIOp : public ::mlir::Op<CmpIOp, ::mlir::OpTrait::ZeroRegions, ::mlir::OpTrait::OneResult, ::mlir::OpTrait::OneTypedResult<::mlir::Type>::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::InferIntRangeInterface::Trait, ::mlir::OpTrait::SameTypeOperands, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::VectorUnrollOpInterface::Trait, ::mlir::OpTrait::Elementwise, ::mlir::OpTrait::Scalarizable, ::mlir::OpTrait::Vectorizable, ::mlir::OpTrait::Tensorizable> {
835public:
836 using Op::Op;
837 using Op::print;
838 using Adaptor = CmpIOpAdaptor;
839public:
840 static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() {
841 static ::llvm::StringRef attrNames[] = {::llvm::StringRef("predicate")};
842 return ::llvm::makeArrayRef(attrNames);
843 }
844
845 ::mlir::StringAttr getPredicateAttrName() {
846 return getAttributeNameForIndex(0);
847 }
848
849 static ::mlir::StringAttr getPredicateAttrName(::mlir::OperationName name) {
850 return getAttributeNameForIndex(name, 0);
851 }
852
853 static constexpr ::llvm::StringLiteral getOperationName() {
854 return ::llvm::StringLiteral("arith.cmpi");
855 }
856
857 std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
858 ::mlir::Operation::operand_range getODSOperands(unsigned index);
859 ::mlir::Value getLhs();
860 ::mlir::Value getRhs();
861 ::mlir::MutableOperandRange getLhsMutable();
862 ::mlir::MutableOperandRange getRhsMutable();
863 std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index);
864 ::mlir::Operation::result_range getODSResults(unsigned index);
865 ::mlir::Value getResult();
866 ::mlir::arith::CmpIPredicateAttr getPredicateAttr();
867 ::mlir::arith::CmpIPredicate getPredicate();
868 void setPredicateAttr(::mlir::arith::CmpIPredicateAttr attr);
869 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, CmpIPredicate predicate, Value lhs, Value rhs);
870 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type result, ::mlir::arith::CmpIPredicateAttr predicate, ::mlir::Value lhs, ::mlir::Value rhs);
871 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::arith::CmpIPredicateAttr predicate, ::mlir::Value lhs, ::mlir::Value rhs);
872 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type result, ::mlir::arith::CmpIPredicate predicate, ::mlir::Value lhs, ::mlir::Value rhs);
873 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::arith::CmpIPredicate predicate, ::mlir::Value lhs, ::mlir::Value rhs);
874 static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
875 ::mlir::LogicalResult verifyInvariantsImpl();
876 ::mlir::LogicalResult verifyInvariants();
877 static void getCanonicalizationPatterns(::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context);
878 ::mlir::OpFoldResult fold(::llvm::ArrayRef<::mlir::Attribute> operands);
879 void inferResultRanges(::llvm::ArrayRef<::mlir::ConstantIntRanges> argRanges, ::mlir::SetIntRangeFn setResultRanges);
880 static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
881 void print(::mlir::OpAsmPrinter &_odsPrinter);
882 void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects);
883private:
884 ::mlir::StringAttr getAttributeNameForIndex(unsigned index) {
885 return getAttributeNameForIndex((*this)->getName(), index);
886 }
887
888 static ::mlir::StringAttr getAttributeNameForIndex(::mlir::OperationName name, unsigned index) {
889 assert(index < 1 && "invalid attribute index")(static_cast <bool> (index < 1 && "invalid attribute index"
) ? void (0) : __assert_fail ("index < 1 && \"invalid attribute index\""
, "tools/mlir/include/mlir/Dialect/Arith/IR/ArithOps.h.inc", 889
, __extension__ __PRETTY_FUNCTION__))
;
890 return name.getRegisteredInfo()->getAttributeNames()[index];
891 }
892
893public:
894 static arith::CmpIPredicate getPredicateByName(StringRef name);
895};
896} // namespace arith
897} // namespace mlir
898MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::arith::CmpIOp)namespace mlir { namespace detail { template <> class TypeIDResolver
< ::mlir::arith::CmpIOp> { public: static TypeID resolveTypeID
() { return id; } private: static SelfOwningTypeID id; }; } }
899
900namespace mlir {
901namespace arith {
902
903//===----------------------------------------------------------------------===//
904// ::mlir::arith::ConstantOp declarations
905//===----------------------------------------------------------------------===//
906
907class ConstantOpAdaptor {
908public:
909 ConstantOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs = nullptr, ::mlir::RegionRange regions = {});
910
911 ConstantOpAdaptor(ConstantOp op);
912
913 ::mlir::ValueRange getOperands();
914 std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
915 ::mlir::ValueRange getODSOperands(unsigned index);
916 ::mlir::DictionaryAttr getAttributes();
917 ::mlir::TypedAttr getValueAttr();
918 ::mlir::TypedAttr getValue();
919 ::mlir::LogicalResult verify(::mlir::Location loc);
920private:
921 ::mlir::ValueRange odsOperands;
922 ::mlir::DictionaryAttr odsAttrs;
923 ::mlir::RegionRange odsRegions;
924 ::llvm::Optional<::mlir::OperationName> odsOpName;
925};
926class ConstantOp : public ::mlir::Op<ConstantOp, ::mlir::OpTrait::ZeroRegions, ::mlir::OpTrait::OneResult, ::mlir::OpTrait::OneTypedResult<::mlir::Type>::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::ZeroOperands, ::mlir::OpTrait::OpInvariants, ::mlir::OpTrait::ConstantLike, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::OpAsmOpInterface::Trait, ::mlir::InferIntRangeInterface::Trait, ::mlir::InferTypeOpInterface::Trait> {
927public:
928 using Op::Op;
929 using Op::print;
930 using Adaptor = ConstantOpAdaptor;
931public:
932 static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() {
933 static ::llvm::StringRef attrNames[] = {::llvm::StringRef("value")};
934 return ::llvm::makeArrayRef(attrNames);
935 }
936
937 ::mlir::StringAttr getValueAttrName() {
938 return getAttributeNameForIndex(0);
939 }
940
941 static ::mlir::StringAttr getValueAttrName(::mlir::OperationName name) {
942 return getAttributeNameForIndex(name, 0);
943 }
944
945 static constexpr ::llvm::StringLiteral getOperationName() {
946 return ::llvm::StringLiteral("arith.constant");
947 }
948
949 std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
950 ::mlir::Operation::operand_range getODSOperands(unsigned index);
951 std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index);
952 ::mlir::Operation::result_range getODSResults(unsigned index);
953 ::mlir::Value getResult();
954 ::mlir::TypedAttr getValueAttr();
955 ::mlir::TypedAttr getValue();
956 void setValueAttr(::mlir::TypedAttr attr);
957 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Attribute value, Type type);
958 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type result, ::mlir::TypedAttr value);
959 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypedAttr value);
960 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::TypedAttr value);
961 static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
962 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
963 ::mlir::LogicalResult verifyInvariantsImpl();
964 ::mlir::LogicalResult verifyInvariants();
965 ::mlir::LogicalResult verify();
966 ::mlir::OpFoldResult fold(::llvm::ArrayRef<::mlir::Attribute> operands);
967 static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes);
968 void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn);
969 void inferResultRanges(::llvm::ArrayRef<::mlir::ConstantIntRanges> argRanges, ::mlir::SetIntRangeFn setResultRanges);
970 static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
971 void print(::mlir::OpAsmPrinter &_odsPrinter);
972 void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects);
973private:
974 ::mlir::StringAttr getAttributeNameForIndex(unsigned index) {
975 return getAttributeNameForIndex((*this)->getName(), index);
976 }
977
978 static ::mlir::StringAttr getAttributeNameForIndex(::mlir::OperationName name, unsigned index) {
979 assert(index < 1 && "invalid attribute index")(static_cast <bool> (index < 1 && "invalid attribute index"
) ? void (0) : __assert_fail ("index < 1 && \"invalid attribute index\""
, "tools/mlir/include/mlir/Dialect/Arith/IR/ArithOps.h.inc", 979
, __extension__ __PRETTY_FUNCTION__))
;
980 return name.getRegisteredInfo()->getAttributeNames()[index];
981 }
982
983public:
984 /// Whether the constant op can be constructed with a particular value and
985 /// type.
986 static bool isBuildableWith(Attribute value, Type type);
987};
988} // namespace arith
989} // namespace mlir
990MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::arith::ConstantOp)namespace mlir { namespace detail { template <> class TypeIDResolver
< ::mlir::arith::ConstantOp> { public: static TypeID resolveTypeID
() { return id; } private: static SelfOwningTypeID id; }; } }
991
992namespace mlir {
993namespace arith {
994
995//===----------------------------------------------------------------------===//
996// ::mlir::arith::DivFOp declarations
997//===----------------------------------------------------------------------===//
998
999class DivFOpAdaptor {
1000public:
1001 DivFOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs = nullptr, ::mlir::RegionRange regions = {});
1002
1003 DivFOpAdaptor(DivFOp op);
1004
1005 ::mlir::ValueRange getOperands();
1006 std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
1007 ::mlir::ValueRange getODSOperands(unsigned index);
1008 ::mlir::Value getLhs();
1009 ::mlir::Value getRhs();
1010 ::mlir::DictionaryAttr getAttributes();
1011 ::mlir::LogicalResult verify(::mlir::Location loc);
1012private:
1013 ::mlir::ValueRange odsOperands;
1014 ::mlir::DictionaryAttr odsAttrs;
1015 ::mlir::RegionRange odsRegions;
1016 ::llvm::Optional<::mlir::OperationName> odsOpName;
1017};
1018class DivFOp : public ::mlir::Op<DivFOp, ::mlir::OpTrait::ZeroRegions, ::mlir::OpTrait::OneResult, ::mlir::OpTrait::OneTypedResult<::mlir::Type>::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::OpTrait::SameOperandsAndResultType, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::VectorUnrollOpInterface::Trait, ::mlir::OpTrait::Elementwise, ::mlir::OpTrait::Scalarizable, ::mlir::OpTrait::Vectorizable, ::mlir::OpTrait::Tensorizable, ::mlir::InferTypeOpInterface::Trait> {
1019public:
1020 using Op::Op;
1021 using Op::print;
1022 using Adaptor = DivFOpAdaptor;
1023public:
1024 static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() {
1025 return {};
1026 }
1027
1028 static constexpr ::llvm::StringLiteral getOperationName() {
1029 return ::llvm::StringLiteral("arith.divf");
1030 }
1031
1032 std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
1033 ::mlir::Operation::operand_range getODSOperands(unsigned index);
1034 ::mlir::Value getLhs();
1035 ::mlir::Value getRhs();
1036 ::mlir::MutableOperandRange getLhsMutable();
1037 ::mlir::MutableOperandRange getRhsMutable();
1038 std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index);
1039 ::mlir::Operation::result_range getODSResults(unsigned index);
1040 ::mlir::Value getResult();
1041 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type result, ::mlir::Value lhs, ::mlir::Value rhs);
1042 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value lhs, ::mlir::Value rhs);
1043 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs);
1044 static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
1045 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
1046 ::mlir::LogicalResult verifyInvariantsImpl();
1047 ::mlir::LogicalResult verifyInvariants();
1048 static void getCanonicalizationPatterns(::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context);
1049 ::mlir::OpFoldResult fold(::llvm::ArrayRef<::mlir::Attribute> operands);
1050 static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes);
1051 static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
1052 void print(::mlir::OpAsmPrinter &_odsPrinter);
1053 void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects);
1054public:
1055};
1056} // namespace arith
1057} // namespace mlir
1058MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::arith::DivFOp)namespace mlir { namespace detail { template <> class TypeIDResolver
< ::mlir::arith::DivFOp> { public: static TypeID resolveTypeID
() { return id; } private: static SelfOwningTypeID id; }; } }
1059
1060namespace mlir {
1061namespace arith {
1062
1063//===----------------------------------------------------------------------===//
1064// ::mlir::arith::DivSIOp declarations
1065//===----------------------------------------------------------------------===//
1066
1067class DivSIOpAdaptor {
1068public:
1069 DivSIOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs = nullptr, ::mlir::RegionRange regions = {});
1070
1071 DivSIOpAdaptor(DivSIOp op);
1072
1073 ::mlir::ValueRange getOperands();
1074 std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
1075 ::mlir::ValueRange getODSOperands(unsigned index);
1076 ::mlir::Value getLhs();
1077 ::mlir::Value getRhs();
1078 ::mlir::DictionaryAttr getAttributes();
1079 ::mlir::LogicalResult verify(::mlir::Location loc);
1080private:
1081 ::mlir::ValueRange odsOperands;
1082 ::mlir::DictionaryAttr odsAttrs;
1083 ::mlir::RegionRange odsRegions;
1084 ::llvm::Optional<::mlir::OperationName> odsOpName;
1085};
1086class DivSIOp : public ::mlir::Op<DivSIOp, ::mlir::OpTrait::ZeroRegions, ::mlir::OpTrait::OneResult, ::mlir::OpTrait::OneTypedResult<::mlir::Type>::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::InferIntRangeInterface::Trait, ::mlir::OpTrait::SameOperandsAndResultType, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::VectorUnrollOpInterface::Trait, ::mlir::OpTrait::Elementwise, ::mlir::OpTrait::Scalarizable, ::mlir::OpTrait::Vectorizable, ::mlir::OpTrait::Tensorizable, ::mlir::InferTypeOpInterface::Trait> {
1087public:
1088 using Op::Op;
1089 using Op::print;
1090 using Adaptor = DivSIOpAdaptor;
1091public:
1092 static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() {
1093 return {};
1094 }
1095
1096 static constexpr ::llvm::StringLiteral getOperationName() {
1097 return ::llvm::StringLiteral("arith.divsi");
1098 }
1099
1100 std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
1101 ::mlir::Operation::operand_range getODSOperands(unsigned index);
1102 ::mlir::Value getLhs();
1103 ::mlir::Value getRhs();
1104 ::mlir::MutableOperandRange getLhsMutable();
1105 ::mlir::MutableOperandRange getRhsMutable();
1106 std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index);
1107 ::mlir::Operation::result_range getODSResults(unsigned index);
1108 ::mlir::Value getResult();
1109 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type result, ::mlir::Value lhs, ::mlir::Value rhs);
1110 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value lhs, ::mlir::Value rhs);
1111 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs);
1112 static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
1113 static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
1114 ::mlir::LogicalResult verifyInvariantsImpl();
1115 ::mlir::LogicalResult verifyInvariants();
1116 ::mlir::OpFoldResult fold(::llvm::ArrayRef<::mlir::Attribute> operands);
1117 static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes);
1118 void inferResultRanges(::llvm::ArrayRef<::mlir::ConstantIntRanges> argRanges, ::mlir::SetIntRangeFn setResultRanges);
1119 static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
1120 void print(::mlir::OpAsmPrinter &_odsPrinter);
1121 void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects);
1122public:
1123};
1124} // namespace arith
1125} // namespace mlir
1126MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::arith::DivSIOp)namespace mlir { namespace detail { template <> class TypeIDResolver
< ::mlir::arith::DivSIOp> { public: static TypeID resolveTypeID
() { return id; } private: static SelfOwningTypeID id; }; } }
1127
1128namespace mlir {
1129namespace arith {
1130
1131//===----------------------------------------------------------------------===//
1132// ::mlir::arith::DivUIOp declarations
1133//===----------------------------------------------------------------------===//
1134
1135class DivUIOpAdaptor {
1136public:
1137 DivUIOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs =