LLVM 19.0.0git
MVELaneInterleavingPass.cpp
Go to the documentation of this file.
1//===- MVELaneInterleaving.cpp - Inverleave for MVE instructions ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass interleaves around sext/zext/trunc instructions. MVE does not have
10// a single sext/zext or trunc instruction that takes the bottom half of a
11// vector and extends to a full width, like NEON has with MOVL. Instead it is
12// expected that this happens through top/bottom instructions. So the MVE
13// equivalent VMOVLT/B instructions take either the even or odd elements of the
14// input and extend them to the larger type, producing a vector with half the
15// number of elements each of double the bitwidth. As there is no simple
16// instruction, we often have to turn sext/zext/trunc into a series of lane
17// moves (or stack loads/stores, which we do not do yet).
18//
19// This pass takes vector code that starts at truncs, looks for interconnected
20// blobs of operations that end with sext/zext (or constants/splats) of the
21// form:
22// %sa = sext v8i16 %a to v8i32
23// %sb = sext v8i16 %b to v8i32
24// %add = add v8i32 %sa, %sb
25// %r = trunc %add to v8i16
26// And adds shuffles to allow the use of VMOVL/VMOVN instrctions:
27// %sha = shuffle v8i16 %a, undef, <0, 2, 4, 6, 1, 3, 5, 7>
28// %sa = sext v8i16 %sha to v8i32
29// %shb = shuffle v8i16 %b, undef, <0, 2, 4, 6, 1, 3, 5, 7>
30// %sb = sext v8i16 %shb to v8i32
31// %add = add v8i32 %sa, %sb
32// %r = trunc %add to v8i16
33// %shr = shuffle v8i16 %r, undef, <0, 4, 1, 5, 2, 6, 3, 7>
34// Which can then be split and lowered to MVE instructions efficiently:
35// %sa_b = VMOVLB.s16 %a
36// %sa_t = VMOVLT.s16 %a
37// %sb_b = VMOVLB.s16 %b
38// %sb_t = VMOVLT.s16 %b
39// %add_b = VADD.i32 %sa_b, %sb_b
40// %add_t = VADD.i32 %sa_t, %sb_t
41// %r = VMOVNT.i16 %add_b, %add_t
42//
43//===----------------------------------------------------------------------===//
44
45#include "ARM.h"
46#include "ARMBaseInstrInfo.h"
47#include "ARMSubtarget.h"
48#include "llvm/ADT/SetVector.h"
53#include "llvm/IR/BasicBlock.h"
54#include "llvm/IR/Constant.h"
55#include "llvm/IR/Constants.h"
57#include "llvm/IR/Function.h"
58#include "llvm/IR/IRBuilder.h"
60#include "llvm/IR/InstrTypes.h"
61#include "llvm/IR/Instruction.h"
64#include "llvm/IR/Intrinsics.h"
65#include "llvm/IR/IntrinsicsARM.h"
67#include "llvm/IR/Type.h"
68#include "llvm/IR/Value.h"
70#include "llvm/Pass.h"
72#include <algorithm>
73#include <cassert>
74
75using namespace llvm;
76
77#define DEBUG_TYPE "mve-laneinterleave"
78
80 "enable-mve-interleave", cl::Hidden, cl::init(true),
81 cl::desc("Enable interleave MVE vector operation lowering"));
82
83namespace {
84
85class MVELaneInterleaving : public FunctionPass {
86public:
87 static char ID; // Pass identification, replacement for typeid
88
89 explicit MVELaneInterleaving() : FunctionPass(ID) {
91 }
92
93 bool runOnFunction(Function &F) override;
94
95 StringRef getPassName() const override { return "MVE lane interleaving"; }
96
97 void getAnalysisUsage(AnalysisUsage &AU) const override {
98 AU.setPreservesCFG();
101 }
102};
103
104} // end anonymous namespace
105
106char MVELaneInterleaving::ID = 0;
107
108INITIALIZE_PASS(MVELaneInterleaving, DEBUG_TYPE, "MVE lane interleaving", false,
109 false)
110
112 return new MVELaneInterleaving();
113}
114
117 // This is not always beneficial to transform. Exts can be incorporated into
118 // loads, Truncs can be folded into stores.
119 // Truncs are usually the same number of instructions,
120 // VSTRH.32(A);VSTRH.32(B) vs VSTRH.16(VMOVNT A, B) with interleaving
121 // Exts are unfortunately more instructions in the general case:
122 // A=VLDRH.32; B=VLDRH.32;
123 // vs with interleaving:
124 // T=VLDRH.16; A=VMOVNB T; B=VMOVNT T
125 // But those VMOVL may be folded into a VMULL.
126
127 // But expensive extends/truncs are always good to remove. FPExts always
128 // involve extra VCVT's so are always considered to be beneficial to convert.
129 for (auto *E : Exts) {
130 if (isa<FPExtInst>(E) || !isa<LoadInst>(E->getOperand(0))) {
131 LLVM_DEBUG(dbgs() << "Beneficial due to " << *E << "\n");
132 return true;
133 }
134 }
135 for (auto *T : Truncs) {
136 if (T->hasOneUse() && !isa<StoreInst>(*T->user_begin())) {
137 LLVM_DEBUG(dbgs() << "Beneficial due to " << *T << "\n");
138 return true;
139 }
140 }
141
142 // Otherwise, we know we have a load(ext), see if any of the Extends are a
143 // vmull. This is a simple heuristic and certainly not perfect.
144 for (auto *E : Exts) {
145 if (!E->hasOneUse() ||
146 cast<Instruction>(*E->user_begin())->getOpcode() != Instruction::Mul) {
147 LLVM_DEBUG(dbgs() << "Not beneficial due to " << *E << "\n");
148 return false;
149 }
150 }
151 return true;
152}
153
154static bool tryInterleave(Instruction *Start,
156 LLVM_DEBUG(dbgs() << "tryInterleave from " << *Start << "\n");
157
158 if (!isa<Instruction>(Start->getOperand(0)))
159 return false;
160
161 // Look for connected operations starting from Ext's, terminating at Truncs.
162 std::vector<Instruction *> Worklist;
163 Worklist.push_back(Start);
164 Worklist.push_back(cast<Instruction>(Start->getOperand(0)));
165
169 SmallSetVector<Use *, 4> OtherLeafs;
171
172 while (!Worklist.empty()) {
173 Instruction *I = Worklist.back();
174 Worklist.pop_back();
175
176 switch (I->getOpcode()) {
177 // Truncs
178 case Instruction::Trunc:
179 case Instruction::FPTrunc:
180 if (!Truncs.insert(I))
181 continue;
182 Visited.insert(I);
183 break;
184
185 // Extend leafs
186 case Instruction::SExt:
187 case Instruction::ZExt:
188 case Instruction::FPExt:
189 if (Exts.count(I))
190 continue;
191 for (auto *Use : I->users())
192 Worklist.push_back(cast<Instruction>(Use));
193 Exts.insert(I);
194 break;
195
196 case Instruction::Call: {
197 IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
198 if (!II)
199 return false;
200
201 if (II->getIntrinsicID() == Intrinsic::vector_reduce_add) {
202 if (!Reducts.insert(I))
203 continue;
204 Visited.insert(I);
205 break;
206 }
207
208 switch (II->getIntrinsicID()) {
209 case Intrinsic::abs:
210 case Intrinsic::smin:
211 case Intrinsic::smax:
212 case Intrinsic::umin:
213 case Intrinsic::umax:
214 case Intrinsic::sadd_sat:
215 case Intrinsic::ssub_sat:
216 case Intrinsic::uadd_sat:
217 case Intrinsic::usub_sat:
218 case Intrinsic::minnum:
219 case Intrinsic::maxnum:
220 case Intrinsic::fabs:
221 case Intrinsic::fma:
222 case Intrinsic::ceil:
223 case Intrinsic::floor:
224 case Intrinsic::rint:
225 case Intrinsic::round:
226 case Intrinsic::trunc:
227 break;
228 default:
229 return false;
230 }
231 [[fallthrough]]; // Fall through to treating these like an operator below.
232 }
233 // Binary/tertiary ops
234 case Instruction::Add:
235 case Instruction::Sub:
236 case Instruction::Mul:
237 case Instruction::AShr:
238 case Instruction::LShr:
239 case Instruction::Shl:
240 case Instruction::ICmp:
241 case Instruction::FCmp:
242 case Instruction::FAdd:
243 case Instruction::FMul:
244 case Instruction::Select:
245 if (!Ops.insert(I))
246 continue;
247
248 for (Use &Op : I->operands()) {
249 if (!isa<FixedVectorType>(Op->getType()))
250 continue;
251 if (isa<Instruction>(Op))
252 Worklist.push_back(cast<Instruction>(&Op));
253 else
254 OtherLeafs.insert(&Op);
255 }
256
257 for (auto *Use : I->users())
258 Worklist.push_back(cast<Instruction>(Use));
259 break;
260
261 case Instruction::ShuffleVector:
262 // A shuffle of a splat is a splat.
263 if (cast<ShuffleVectorInst>(I)->isZeroEltSplat())
264 continue;
265 [[fallthrough]];
266
267 default:
268 LLVM_DEBUG(dbgs() << " Unhandled instruction: " << *I << "\n");
269 return false;
270 }
271 }
272
273 if (Exts.empty() && OtherLeafs.empty())
274 return false;
275
276 LLVM_DEBUG({
277 dbgs() << "Found group:\n Exts:\n";
278 for (auto *I : Exts)
279 dbgs() << " " << *I << "\n";
280 dbgs() << " Ops:\n";
281 for (auto *I : Ops)
282 dbgs() << " " << *I << "\n";
283 dbgs() << " OtherLeafs:\n";
284 for (auto *I : OtherLeafs)
285 dbgs() << " " << *I->get() << " of " << *I->getUser() << "\n";
286 dbgs() << " Truncs:\n";
287 for (auto *I : Truncs)
288 dbgs() << " " << *I << "\n";
289 dbgs() << " Reducts:\n";
290 for (auto *I : Reducts)
291 dbgs() << " " << *I << "\n";
292 });
293
294 assert((!Truncs.empty() || !Reducts.empty()) &&
295 "Expected some truncs or reductions");
296 if (Truncs.empty() && Exts.empty())
297 return false;
298
299 auto *VT = !Truncs.empty()
300 ? cast<FixedVectorType>(Truncs[0]->getType())
301 : cast<FixedVectorType>(Exts[0]->getOperand(0)->getType());
302 LLVM_DEBUG(dbgs() << "Using VT:" << *VT << "\n");
303
304 // Check types
305 unsigned NumElts = VT->getNumElements();
306 unsigned BaseElts = VT->getScalarSizeInBits() == 16
307 ? 8
308 : (VT->getScalarSizeInBits() == 8 ? 16 : 0);
309 if (BaseElts == 0 || NumElts % BaseElts != 0) {
310 LLVM_DEBUG(dbgs() << " Type is unsupported\n");
311 return false;
312 }
313 if (Start->getOperand(0)->getType()->getScalarSizeInBits() !=
314 VT->getScalarSizeInBits() * 2) {
315 LLVM_DEBUG(dbgs() << " Type not double sized\n");
316 return false;
317 }
318 for (Instruction *I : Exts)
319 if (I->getOperand(0)->getType() != VT) {
320 LLVM_DEBUG(dbgs() << " Wrong type on " << *I << "\n");
321 return false;
322 }
323 for (Instruction *I : Truncs)
324 if (I->getType() != VT) {
325 LLVM_DEBUG(dbgs() << " Wrong type on " << *I << "\n");
326 return false;
327 }
328
329 // Check that it looks beneficial
330 if (!isProfitableToInterleave(Exts, Truncs))
331 return false;
332 if (!Reducts.empty() && (Ops.empty() || all_of(Ops, [](Instruction *I) {
333 return I->getOpcode() == Instruction::Mul ||
334 I->getOpcode() == Instruction::Select ||
335 I->getOpcode() == Instruction::ICmp;
336 }))) {
337 LLVM_DEBUG(dbgs() << "Reduction does not look profitable\n");
338 return false;
339 }
340
341 // Create new shuffles around the extends / truncs / other leaves.
342 IRBuilder<> Builder(Start);
343
344 SmallVector<int, 16> LeafMask;
345 SmallVector<int, 16> TruncMask;
346 // LeafMask : 0, 2, 4, 6, 1, 3, 5, 7 8, 10, 12, 14, 9, 11, 13, 15
347 // TruncMask: 0, 4, 1, 5, 2, 6, 3, 7 8, 12, 9, 13, 10, 14, 11, 15
348 for (unsigned Base = 0; Base < NumElts; Base += BaseElts) {
349 for (unsigned i = 0; i < BaseElts / 2; i++)
350 LeafMask.push_back(Base + i * 2);
351 for (unsigned i = 0; i < BaseElts / 2; i++)
352 LeafMask.push_back(Base + i * 2 + 1);
353 }
354 for (unsigned Base = 0; Base < NumElts; Base += BaseElts) {
355 for (unsigned i = 0; i < BaseElts / 2; i++) {
356 TruncMask.push_back(Base + i);
357 TruncMask.push_back(Base + i + BaseElts / 2);
358 }
359 }
360
361 for (Instruction *I : Exts) {
362 LLVM_DEBUG(dbgs() << "Replacing ext " << *I << "\n");
363 Builder.SetInsertPoint(I);
364 Value *Shuffle = Builder.CreateShuffleVector(I->getOperand(0), LeafMask);
365 bool FPext = isa<FPExtInst>(I);
366 bool Sext = isa<SExtInst>(I);
367 Value *Ext = FPext ? Builder.CreateFPExt(Shuffle, I->getType())
368 : Sext ? Builder.CreateSExt(Shuffle, I->getType())
369 : Builder.CreateZExt(Shuffle, I->getType());
370 I->replaceAllUsesWith(Ext);
371 LLVM_DEBUG(dbgs() << " with " << *Shuffle << "\n");
372 }
373
374 for (Use *I : OtherLeafs) {
375 LLVM_DEBUG(dbgs() << "Replacing leaf " << *I << "\n");
376 Builder.SetInsertPoint(cast<Instruction>(I->getUser()));
377 Value *Shuffle = Builder.CreateShuffleVector(I->get(), LeafMask);
378 I->getUser()->setOperand(I->getOperandNo(), Shuffle);
379 LLVM_DEBUG(dbgs() << " with " << *Shuffle << "\n");
380 }
381
382 for (Instruction *I : Truncs) {
383 LLVM_DEBUG(dbgs() << "Replacing trunc " << *I << "\n");
384
385 Builder.SetInsertPoint(I->getParent(), ++I->getIterator());
386 Value *Shuf = Builder.CreateShuffleVector(I, TruncMask);
387 I->replaceAllUsesWith(Shuf);
388 cast<Instruction>(Shuf)->setOperand(0, I);
389
390 LLVM_DEBUG(dbgs() << " with " << *Shuf << "\n");
391 }
392
393 return true;
394}
395
396// Add reductions are fairly common and associative, meaning we can start the
397// interleaving from them and don't need to emit a shuffle.
399 if (auto *II = dyn_cast<IntrinsicInst>(&I))
400 return II->getIntrinsicID() == Intrinsic::vector_reduce_add;
401 return false;
402}
403
404bool MVELaneInterleaving::runOnFunction(Function &F) {
405 if (!EnableInterleave)
406 return false;
407 auto &TPC = getAnalysis<TargetPassConfig>();
408 auto &TM = TPC.getTM<TargetMachine>();
409 auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
410 if (!ST->hasMVEIntegerOps())
411 return false;
412
413 bool Changed = false;
414
416 for (Instruction &I : reverse(instructions(F))) {
417 if (((I.getType()->isVectorTy() &&
418 (isa<TruncInst>(I) || isa<FPTruncInst>(I))) ||
419 isAddReduction(I)) &&
420 !Visited.count(&I))
421 Changed |= tryInterleave(&I, Visited);
422 }
423
424 return Changed;
425}
Expand Atomic instructions
This file contains the declarations for the subclasses of Constant, which represent the different fla...
#define LLVM_DEBUG(X)
Definition: Debug.h:101
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
static bool isProfitableToInterleave(SmallSetVector< Instruction *, 4 > &Exts, SmallSetVector< Instruction *, 4 > &Truncs)
static bool tryInterleave(Instruction *Start, SmallPtrSetImpl< Instruction * > &Visited)
cl::opt< bool > EnableInterleave("enable-mve-interleave", cl::Hidden, cl::init(true), cl::desc("Enable interleave MVE vector operation lowering"))
#define DEBUG_TYPE
static bool isAddReduction(Instruction &I)
const char LLVMTargetMachineRef TM
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:38
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file implements a set that has insertion order iteration characteristics.
static SymbolRef::Type getType(const Symbol *Sym)
Definition: TapiFile.cpp:40
This file describes how to lower LLVM code to machine code.
Target-Independent Code Generator Pass Configuration Options pass.
This pass exposes codegen information to IR-level passes.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:269
This class represents an Operation in the Expression.
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:311
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
Value * CreateSExt(Value *V, Type *DestTy, const Twine &Name="")
Definition: IRBuilder.h:2033
Value * CreateZExt(Value *V, Type *DestTy, const Twine &Name="", bool IsNonNeg=false)
Definition: IRBuilder.h:2021
Value * CreateShuffleVector(Value *V1, Value *V2, Value *Mask, const Twine &Name="")
Definition: IRBuilder.h:2494
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition: IRBuilder.h:180
Value * CreateFPExt(Value *V, Type *DestTy, const Twine &Name="")
Definition: IRBuilder.h:2110
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2666
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:47
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
Definition: IntrinsicInst.h:54
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Pass interface - Implemented by all 'passes'.
Definition: Pass.h:94
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:98
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
Definition: Pass.cpp:81
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
Definition: SetVector.h:264
bool empty() const
Determine if the SetVector is empty or not.
Definition: SetVector.h:93
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:162
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:321
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
Definition: SmallPtrSet.h:360
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:342
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:427
A SetVector that performs no allocations if smaller than a certain size.
Definition: SetVector.h:370
void push_back(const T &Elt)
Definition: SmallVector.h:426
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1209
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
Primary interface to the complete machine description for the target machine.
Definition: TargetMachine.h:76
Target-Independent Code Generator Pass Configuration Options.
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
LLVM Value Representation.
Definition: Value.h:74
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:450
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1722
Pass * createMVELaneInterleavingPass()
void initializeMVELaneInterleavingPass(PassRegistry &)
auto reverse(ContainerTy &&C)
Definition: STLExtras.h:419
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163