LLVM 23.0.0git
SVEShuffleOpts.cpp
Go to the documentation of this file.
1//===------- SVEShuffleOpts - SVE Shuffle Optimization --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Tries to pattern match and combine scalable vector shuffles that could
10// be more efficiently performed by tbl instructions.
11//
12// An example would be a loop with 4 multiply-accumulate reductions, where the
13// new data in each vector iterations comes from a 4-way deinterleaving of
14// smaller datatypes loaded from memory which are then zero extended.
15//
16// Something like the following:
17// %bgra = call ... @llvm.masked.load
18// %deinterleave = call ... @llvm.vector.deinterleave4(%bgra)
19// If the load was of a <vscale x 8 x i16>, we now have 4 deinterleaved
20// <vscale x 2 x i16> values.
21// %b.i16 = extractvalue %deinterleave, 0
22// %b.i64 = zext <vscale x 2 x i16> %b.i16 to <vscale x 2 x i64>
23// %acc.b.next = add <vscale x 2 x i64> %acc.b, %b.i64
24// <repeat for the other 3 subvectors>
25//
26// If the initial load is a legal vector rather than 4x the size (generating a
27// structured ld4 instead), we would see multiple uunpkhi/lo instructions for
28// the extensions, followed by uzp1/2 instructions for the deinterleave.
29// Instead, we can replace all of those with 4 tbl instructions. The tradeoff,
30// of course, is that we now have 4 mask values to maintain which may increase
31// register pressure.
32//
33// This basic transform could be performed in CodeGenPrepare (as the equivalent
34// for NEON is), or in a DAG Combine. However, we hope to extend it to detect
35// other shuffles that we can fold into the tbl. Extending the above example,
36// if instead of directly adding to the accumulator we multiplied it by a
37// common term for all 4 components that had been reversed:
38// %common.load = call @llvm.masked.load
39// %common.reverse = call @llvm.vector.reverse
40// These would be loaded at the extended size, <vscale x 2 x i64> in our
41// example.
42// %b.mul = mul <vscale x 2 x i64> %b.i64, %common.reverse
43// %acc.b.next = add <vscale x 2 x i64> %acc.b, %b.mul
44// <repeat for the other 3 subvectors, using %common.reverse for each)
45//
46// In this case, the reverse isn't applied to the deinterleaved data in the
47// original IR, but to the common term multiplied by the individual bgra
48// elements. If the order of the elements in the accumulator is important, we
49// cannot change that. If, however, we know that the accumulator is reduced to
50// a single scalar after the loop and the data is either integers or floating
51// point with reassociation allowed, we could instead choose a different mask
52// for the tbls to reverse the individual bgra elements instead, removing an
53// additional instruction from the loop. This does require looking beyond the
54// blocks in the loop, so DAGCombine won't help.
55//
56// We should also be able to introduce new shuffles in order to balance out
57// SVE's bottom/top instruction pairs, which act on even/odd lanes instead of
58// the high or low half of a register.
59//
60// This pass may end up being a temporary solution that is removed if we can
61// create a generic vector shuffle intrinsic and move this feature to
62// LoopVectorize itself, as that would allow for better cost modelling.
63//
64//===----------------------------------------------------------------------===//
65
66#include "AArch64.h"
67#include "AArch64Subtarget.h"
78#include "llvm/IR/Constants.h"
79#include "llvm/IR/IRBuilder.h"
82#include "llvm/IR/IntrinsicsAArch64.h"
83#include "llvm/IR/LLVMContext.h"
84#include "llvm/IR/PassManager.h"
87#include <array>
88
89using namespace llvm;
90using namespace llvm::PatternMatch;
91
92#define DEBUG_TYPE "aarch64-sve-shuffle-opts"
93
94/// A mapping between a vector_deinterleaveN intrinsic and extending cast
95/// instructions used on the resulting subvectors.
97
98/// Evaluate a deinterleave and see what the uses are. If we find other
99/// operations that we can combine into a tbl shuffle, add the deinterleave and
100/// the operations (currently only zext or uitofp) to the candidates map.
102 Loop &L, const AArch64TargetLowering &TL,
103 const DataLayout DL) {
104 assert(I->getIntrinsicID() == Intrinsic::vector_deinterleave4 &&
105 "Only deinterleave4 supported currently");
106
107 ConstantRange VScaleRange = getVScaleRange(I->getFunction(), 64);
108 // TBL zeroes elements with an out-of-bounds index, but for the largest
109 // possible SVE vector (2048b) the maximum value for i8 elements (255) is not
110 // large enough to encode an 'out of bounds' value. So we can only perform
111 // this optimization for i8 elements if we know vscale is < 16.
112 EVT InputVT = TL.getValueType(DL, I->getOperand(0)->getType());
113 if (!InputVT.isScalableVector() ||
114 (InputVT.getScalarSizeInBits() < 16 &&
115 (!VScaleRange.getUpper().ult(16) || VScaleRange.isUpperWrapped())) ||
116 TL.getTypeConversion(I->getContext(), InputVT).first !=
118 return;
119
120 std::array<CastInst *, 4> Extends = {};
121 unsigned Opcode = 0;
122 Type *DestTy = nullptr;
123 for (User *U : I->users()) {
124 auto *Extract = dyn_cast<ExtractValueInst>(U);
125 if (!Extract || !Extract->hasOneUse())
126 return;
127
128 // We expect only a single cast instruction as a user for the extract.
129 auto *Extend = dyn_cast_if_present<CastInst>(*Extract->users().begin());
130 if (!Extend || (!isa<ZExtInst>(Extend) && !isa<UIToFPInst>(Extend)))
131 return;
132
133 // We're only interested if the uses are in the loop. This is almost
134 // certainly the case.
135 if (!L.contains(Extend))
136 return;
137
138 Opcode = Extend->getOpcode();
139 DestTy = Extend->getDestTy();
140
141 // Make sure DestTy matches the input size.
142 if (DestTy->getPrimitiveSizeInBits() != InputVT.getSizeInBits())
143 return;
144
145 Extends[Extract->getIndices().front()] = Extend;
146 }
147
148 // Check that all extracted values are being extended the same way, and that
149 // we have the expected number of extensions.
150 if (!all_of(Extends, [DestTy, Opcode](CastInst *CI) {
151 return !CI || (CI->getDestTy() == DestTy && CI->getOpcode() == Opcode);
152 }))
153 return;
154
155 Candidates.try_emplace(I, Extends);
156}
157
158/// Given a map of deinterleaves to zext or uitofp casts, remove the operations
159/// and replace them with tbl shuffles.
161 for (auto &[Deinterleave, Extends] : Deinterleaves) {
162 VectorType *DestTy = cast<VectorType>(Extends[0]->getDestTy());
163 VectorType *SrcTy = cast<VectorType>(Extends[0]->getSrcTy());
164 unsigned DstBits = DestTy->getScalarSizeInBits();
165 unsigned SrcBits = SrcTy->getScalarSizeInBits();
166 bool IsUIToFP = isa<UIToFPInst>(Extends[0]);
167 VectorType *StepVecTy = VectorType::getInteger(DestTy);
168 Value *Input = Deinterleave->getOperand(0);
169 Type *InputTy = Input->getType();
170
172 for (auto [Idx, Extend] : enumerate(Extends)) {
173 // If not all lanes were extracted, we can have gaps. Skip over them.
174 if (!Extend)
175 continue;
176 // Build the mask using stepvectors and casting.
177 // We want to select the Idx'th element, and every 4 elements after that.
178 // Each element needs to be zero extended; we can do that by providing
179 // tbl index values that are out of range. We can't do that nicely with
180 // a stepvector of the same element type as the input type, but we can
181 // do it with elements the size of the output type.
182 // E.g. for element 0 of a 16b -> 64b zext, we would start with a mask of
183 // 0xFFFF_FFFF_FFFF_0000 + Idx for the start of the stepvector, and use a
184 // step of 4. We then cast that back to an element size of 16b, yielding
185 // <0x0000 + Idx, 0xFFFF, 0xFFFF, 0xFFFF, 0x0004 + Idx, 0xFFFF...>.
186 APInt StartIdx = Invalid << SrcBits;
187 StartIdx += Idx;
188 IRBuilder<> Builder(Extend);
189 Value *StepVector = Builder.CreateStepVector(StepVecTy);
190 Value *ScaledSteps =
191 Builder.CreateNUWMul(StepVector, ConstantInt::get(StepVecTy, 4));
192 Value *ZextTbl = Builder.CreateNUWAdd(
193 ScaledSteps, ConstantInt::get(StepVecTy, StartIdx));
194 Value *FinalMask = Builder.CreateBitCast(ZextTbl, InputTy);
195
196 // Replace the deinterleave, extractvalue, and extension chain with
197 // a tbl directly on the input value.
198 Value *Tbl = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_tbl,
199 {InputTy}, {Input, FinalMask});
200 Value *Widen = Builder.CreateBitCast(Tbl, StepVecTy);
201 if (IsUIToFP)
202 Widen = Builder.CreateUIToFP(Widen, DestTy);
203 LLVM_DEBUG(dbgs() << "SVETBLOPT: Replaced " << *Extend << " with "
204 << *Widen << "\n");
205 Extend->replaceAllUsesWith(Widen);
206 Extend->eraseFromParent();
207 }
208
209 // Delete the unused extracts and deinterleave.
210 for (User *U : make_early_inc_range(Deinterleave->users()))
211 cast<Instruction>(U)->eraseFromParent();
212 Deinterleave->eraseFromParent();
213 }
214}
215
216static bool processLoop(Loop &L, const AArch64Subtarget &ST, DataLayout DL) {
217 // At present, we only want to do this for innermost loops when SVE
218 // is available.
219 if (!L.isInnermost() || !ST.isSVEorStreamingSVEAvailable())
220 return false;
221
222 // TODO: Pull other shuffles into the tbl where possible.
223 // TODO: Add more advanced cases, such as introducing shuffles so that
224 // the SVE odd/even BT narrowing instructions can be used.
225 // TODO: Support other deinterleaves.
226 const AArch64TargetLowering &TL = *ST.getTargetLowering();
227 assert(DL.isLittleEndian() &&
228 "Shuffle optimizations unsupported for big endian targets.");
229 DeinterleaveMap Candidates;
230 for (auto *BB : L.blocks())
231 for (auto &I : *BB)
233 evaluateDeinterleave(cast<IntrinsicInst>(&I), Candidates, L, TL, DL);
234
235 if (Candidates.empty())
236 return false;
237
239 return true;
240}
241
242namespace {
243struct SVEShuffleOpts : public LoopPass {
244 static char ID; // Pass identification, replacement for typeid
245 SVEShuffleOpts() : LoopPass(ID) {}
246
247 bool runOnLoop(Loop *L, LPPassManager &PM) override {
248 if (skipLoop(L))
249 return false;
250
251 TargetPassConfig &TPC = getAnalysis<TargetPassConfig>();
252 const AArch64TargetMachine &TM = TPC.getTM<AArch64TargetMachine>();
253 const AArch64Subtarget &ST =
254 *TM.getSubtargetImpl(*L->getHeader()->getParent());
255
256 return processLoop(*L, ST, TM.createDataLayout());
257 }
258
259 void getAnalysisUsage(AnalysisUsage &AU) const override {
260 AU.addRequired<TargetPassConfig>();
261 AU.setPreservesCFG();
262 }
263
264 StringRef getPassName() const override { return "SVE Shuffle Optimizations"; }
265};
266} // end anonymous namespace
267
268char SVEShuffleOpts::ID = 0;
269static const char *name = "SVE Shuffle Optimizations";
270INITIALIZE_PASS_BEGIN(SVEShuffleOpts, DEBUG_TYPE, name, false, false)
273
274Pass *llvm::createSVEShuffleOptsPass() { return new SVEShuffleOpts(); }
275
278 LPMUpdater &U) {
279 const AArch64Subtarget &ST =
280 *TM.getSubtargetImpl(*L.getHeader()->getParent());
281
282 if (processLoop(L, ST, TM.createDataLayout())) {
288 return PA;
289 }
290
291 return PreservedAnalyses::all();
292}
static SDValue Widen(SelectionDAG *CurDAG, SDValue N)
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
This file contains the declarations for the subclasses of Constant, which represent the different fla...
#define DEBUG_TYPE
This header defines various interfaces for pass management in LLVM.
#define I(x, y, z)
Definition MD5.cpp:57
This file exposes an interface to building/using memory SSA to walk memory instructions using a use/d...
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
static const char * name
SmallDenseMap< CallInst *, std::array< CastInst *, 4 > > DeinterleaveMap
A mapping between a vector_deinterleaveN intrinsic and extending cast instructions used on the result...
static bool processLoop(Loop &L, const AArch64Subtarget &ST, DataLayout DL)
static void optimizeSVEDeinterleavedExtends(DeinterleaveMap Deinterleaves)
Given a map of deinterleaves to zext or uitofp casts, remove the operations and replace them with tbl...
static void evaluateDeinterleave(IntrinsicInst *I, DeinterleaveMap &Candidates, Loop &L, const AArch64TargetLowering &TL, const DataLayout DL)
Evaluate a deinterleave and see what the uses are.
#define LLVM_DEBUG(...)
Definition Debug.h:119
This file describes how to lower LLVM code to machine code.
Target-Independent Code Generator Pass Configuration Options pass.
This pass exposes codegen information to IR-level passes.
The Input class is used to parse a yaml document into in-memory structs and vectors.
const AArch64Subtarget * getSubtargetImpl(const Function &F) const override
Virtual method implemented by subclasses that returns a reference to that target's TargetSubtargetInf...
Class for arbitrary precision integers.
Definition APInt.h:78
static APInt getAllOnes(unsigned numBits)
Return an APInt of a specified width with all bits set.
Definition APInt.h:235
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition APInt.h:1118
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition Pass.cpp:275
A function analysis which provides an AssumptionCache.
Represents analyses that only rely on functions' control flow.
Definition Analysis.h:73
This is the base class for all instructions that perform data casts.
Definition InstrTypes.h:512
Instruction::CastOps getOpcode() const
Return the opcode of this CastInst.
Definition InstrTypes.h:674
Type * getDestTy() const
Return the destination type, as a convenience.
Definition InstrTypes.h:681
This class represents a range of values.
const APInt & getUpper() const
Return the upper value for this range.
LLVM_ABI bool isUpperWrapped() const
Return true if the exclusive upper bound wraps around the unsigned domain.
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:64
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition IRBuilder.h:2900
A wrapper class for inspecting calls to intrinsic functions.
This class provides an interface for updating the loop pass manager based on mutations to the loop ne...
Represents a single loop in the control flow graph.
Definition LoopInfo.h:40
An analysis that produces MemorySSA for a function.
Definition MemorySSA.h:922
Pass interface - Implemented by all 'passes'.
Definition Pass.h:99
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalyses & preserveSet()
Mark an analysis set as preserved.
Definition Analysis.h:151
PreservedAnalyses & preserve()
Mark an analysis as preserved.
Definition Analysis.h:132
LLVM_ABI PreservedAnalyses run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &U)
Analysis pass providing the TargetTransformInfo.
EVT getValueType(const DataLayout &DL, Type *Ty, bool AllowUnknown=false) const
Return the EVT corresponding to this LLVM type.
LegalizeKind getTypeConversion(LLVMContext &Context, EVT VT) const
Return pair that represents the legalization kind (first) that needs to happen to EVT (second) in ord...
const DataLayout createDataLayout() const
Create a DataLayout.
Target-Independent Code Generator Pass Configuration Options.
TMC & getTM() const
Get the right type of TargetMachine for this target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:46
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:197
LLVM Value Representation.
Definition Value.h:75
static VectorType * getInteger(VectorType *VTy)
This static method gets a VectorType with the same number of elements as the input type,...
bool match(Val *V, const Pattern &P)
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
auto m_Value()
Match an arbitrary value and ignore it.
This is an optimization pass for GlobalISel generic memory operations.
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1738
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are tuples (A, B,...
Definition STLExtras.h:2553
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
auto dyn_cast_if_present(const Y &Val)
dyn_cast_if_present<X> - Functionally identical to dyn_cast, except that a null (or none in the case ...
Definition Casting.h:732
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition STLExtras.h:633
AnalysisManager< Loop, LoopStandardAnalysisResults & > LoopAnalysisManager
The loop analysis manager.
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:209
LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
Pass * createSVEShuffleOptsPass()
Extended Value Type.
Definition ValueTypes.h:35
TypeSize getSizeInBits() const
Return the size of the specified value type in bits.
Definition ValueTypes.h:396
uint64_t getScalarSizeInBits() const
Definition ValueTypes.h:408
bool isScalableVector() const
Return true if this is a vector type where the runtime length is machine dependent.
Definition ValueTypes.h:187
The adaptor from a function pass to a loop pass computes these analyses and makes them available to t...