46#include "llvm/IR/IntrinsicsARM.h"
57#define DEBUG_TYPE "mve-tail-predication"
58#define DESC "Transform predicated vector loops to use MVE tail predication"
61 "tail-predication",
cl::desc(
"MVE tail-predication pass options"),
64 "Don't tail-predicate loops"),
66 "enabled-no-reductions",
67 "Enable tail-predication, but not for reduction loops"),
70 "Enable tail-predication, including reduction loops"),
72 "force-enabled-no-reductions",
73 "Enable tail-predication, but not for reduction loops, "
74 "and force this which might be unsafe"),
77 "Enable tail-predication, including reduction loops, "
78 "and force this which might be unsafe")));
83class MVETailPredication :
public LoopPass {
109 bool TryConvertActiveLaneMask(
Value *TripCount);
129 auto &TPC = getAnalysis<TargetPassConfig>();
132 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F);
133 SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
138 if (!
ST->hasMVEIntegerOps() || !
ST->hasV8_1MMainlineOps()) {
148 for (
auto &
I : *BB) {
149 auto *
Call = dyn_cast<IntrinsicInst>(&
I);
154 if (
ID == Intrinsic::start_loop_iterations ||
155 ID == Intrinsic::test_start_loop_iterations)
156 return cast<IntrinsicInst>(&
I);
173 LLVM_DEBUG(
dbgs() <<
"ARM TP: Running on Loop: " << *L << *Setup <<
"\n");
175 bool Changed = TryConvertActiveLaneMask(
Setup->getArgOperand(0));
199 bool ForceTailPredication =
204 bool Changed =
false;
205 if (!
L->makeLoopInvariant(ElemCount, Changed))
208 const SCEV *
EC = SE->getSCEV(ElemCount);
209 const SCEV *TC = SE->getSCEV(TripCount);
211 cast<FixedVectorType>(ActiveLaneMask->
getType())->getNumElements();
212 if (VectorWidth != 2 && VectorWidth != 4 && VectorWidth != 8 &&
220 if (!SE->isLoopInvariant(EC, L)) {
221 LLVM_DEBUG(
dbgs() <<
"ARM TP: element count must be loop invariant.\n");
231 const SCEV *IVExpr = SE->getSCEV(
IV);
232 auto *AddExpr = dyn_cast<SCEVAddRecExpr>(IVExpr);
239 if (AddExpr->getLoop() != L) {
243 auto *Step = dyn_cast<SCEVConstant>(AddExpr->getOperand(1));
245 LLVM_DEBUG(
dbgs() <<
"ARM TP: induction step is not a constant: ";
246 AddExpr->getOperand(1)->
dump());
249 auto StepValue = Step->getValue()->getSExtValue();
250 if (VectorWidth != StepValue) {
252 <<
" doesn't match vector width " << VectorWidth <<
"\n");
256 if ((ConstElemCount = dyn_cast<ConstantInt>(ElemCount))) {
257 ConstantInt *TC = dyn_cast<ConstantInt>(TripCount);
260 "set.loop.iterations\n");
270 (ConstElemCount->
getZExtValue() + VectorWidth - 1) / VectorWidth;
276 LLVM_DEBUG(
dbgs() <<
"ARM TP: inconsistent constant tripcount values: "
277 << TC1 <<
" from set.loop.iterations, and "
278 << TC2 <<
" from get.active.lane.mask\n");
281 }
else if (!ForceTailPredication) {
295 SE->getSCEV(ConstantInt::get(TripCount->
getType(), VectorWidth));
297 const SCEV *Start = AddExpr->getStart();
298 const SCEV *ECPlusVWMinus1 = SE->getAddExpr(
300 SE->getSCEV(ConstantInt::get(TripCount->
getType(), VectorWidth - 1)));
303 const SCEV *Ceil = SE->getUDivExpr(ECPlusVWMinus1, VW);
308 dbgs() <<
"ARM TP: Analysing overflow behaviour for:\n";
309 dbgs() <<
"ARM TP: - TripCount = " << *TC <<
"\n";
310 dbgs() <<
"ARM TP: - ElemCount = " << *
EC <<
"\n";
311 dbgs() <<
"ARM TP: - Start = " << *Start <<
"\n";
312 dbgs() <<
"ARM TP: - BETC = " << *SE->getBackedgeTakenCount(L) <<
"\n";
313 dbgs() <<
"ARM TP: - VecWidth = " << VectorWidth <<
"\n";
314 dbgs() <<
"ARM TP: - (ElemCount+VW-1) / VW = " << *Ceil <<
"\n";
329 const SCEV *Div = SE->getUDivExpr(
330 SE->getAddExpr(SE->getMulExpr(Ceil, VW), SE->getNegativeSCEV(VW),
331 SE->getNegativeSCEV(Start)),
333 const SCEV *Sub = SE->getMinusSCEV(SE->getBackedgeTakenCount(L), Div);
339 Sub = SE->applyLoopGuards(Sub, L);
343 LLVM_DEBUG(
dbgs() <<
"ARM TP: possible overflow in sub expression.\n");
352 if (
auto *BaseC = dyn_cast<SCEVConstant>(AddExpr->getStart())) {
353 if (BaseC->getAPInt().urem(VectorWidth) == 0)
354 return SE->getMinusSCEV(EC, BaseC);
355 }
else if (
auto *BaseV = dyn_cast<SCEVUnknown>(AddExpr->getStart())) {
356 Type *Ty = BaseV->getType();
360 L->getHeader()->getDataLayout()))
361 return SE->getMinusSCEV(EC, BaseV);
362 }
else if (
auto *BaseMul = dyn_cast<SCEVMulExpr>(AddExpr->getStart())) {
363 if (
auto *BaseC = dyn_cast<SCEVConstant>(BaseMul->getOperand(0)))
364 if (BaseC->getAPInt().urem(VectorWidth) == 0)
365 return SE->getMinusSCEV(EC, BaseC);
366 if (
auto *BaseC = dyn_cast<SCEVConstant>(BaseMul->getOperand(1)))
367 if (BaseC->getAPInt().urem(VectorWidth) == 0)
368 return SE->getMinusSCEV(EC, BaseC);
372 dbgs() <<
"ARM TP: induction base is not know to be a multiple of VF: "
373 << *AddExpr->getOperand(0) <<
"\n");
377void MVETailPredication::InsertVCTPIntrinsic(
IntrinsicInst *ActiveLaneMask,
379 IRBuilder<> Builder(
L->getLoopPreheader()->getTerminator());
380 Module *
M =
L->getHeader()->getModule();
382 unsigned VectorWidth =
383 cast<FixedVectorType>(ActiveLaneMask->
getType())->getNumElements();
386 Builder.SetInsertPoint(
L->getHeader(),
L->getHeader()->getFirstNonPHIIt());
387 PHINode *Processed = Builder.CreatePHI(Ty, 2);
392 Builder.SetInsertPoint(ActiveLaneMask);
393 ConstantInt *Factor = ConstantInt::get(cast<IntegerType>(Ty), VectorWidth);
396 switch (VectorWidth) {
399 case 2: VCTPID = Intrinsic::arm_mve_vctp64;
break;
400 case 4: VCTPID = Intrinsic::arm_mve_vctp32;
break;
401 case 8: VCTPID = Intrinsic::arm_mve_vctp16;
break;
402 case 16: VCTPID = Intrinsic::arm_mve_vctp8;
break;
405 Value *VCTPCall = Builder.CreateCall(VCTP, Processed);
410 Value *Remaining = Builder.CreateSub(Processed, Factor);
413 << *Processed <<
"\n"
414 <<
"ARM TP: Inserted VCTP: " << *VCTPCall <<
"\n");
417bool MVETailPredication::TryConvertActiveLaneMask(
Value *TripCount) {
419 for (
auto *BB :
L->getBlocks())
421 if (
auto *
Int = dyn_cast<IntrinsicInst>(&
I))
422 if (
Int->getIntrinsicID() == Intrinsic::get_active_lane_mask)
425 if (ActiveLaneMasks.
empty())
430 for (
auto *ActiveLaneMask : ActiveLaneMasks) {
432 << *ActiveLaneMask <<
"\n");
434 const SCEV *StartSCEV = IsSafeActiveMask(ActiveLaneMask, TripCount);
439 LLVM_DEBUG(
dbgs() <<
"ARM TP: Safe to insert VCTP. Start is " << *StartSCEV
444 Value *Start = Expander.expandCodeFor(StartSCEV, StartSCEV->
getType(), Ins);
445 LLVM_DEBUG(
dbgs() <<
"ARM TP: Created start value " << *Start <<
"\n");
446 InsertVCTPIntrinsic(ActiveLaneMask, Start);
450 for (
auto *
II : ActiveLaneMasks)
452 for (
auto *
I :
L->blocks())
458 return new MVETailPredication();
461char MVETailPredication::ID = 0;
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
cl::opt< TailPredication::Mode > EnableTailPredication("tail-predication", cl::desc("MVE tail-predication pass options"), cl::init(TailPredication::Enabled), cl::values(clEnumValN(TailPredication::Disabled, "disabled", "Don't tail-predicate loops"), clEnumValN(TailPredication::EnabledNoReductions, "enabled-no-reductions", "Enable tail-predication, but not for reduction loops"), clEnumValN(TailPredication::Enabled, "enabled", "Enable tail-predication, including reduction loops"), clEnumValN(TailPredication::ForceEnabledNoReductions, "force-enabled-no-reductions", "Enable tail-predication, but not for reduction loops, " "and force this which might be unsafe"), clEnumValN(TailPredication::ForceEnabled, "force-enabled", "Enable tail-predication, including reduction loops, " "and force this which might be unsafe")))
uint64_t IntrinsicInst * II
const char LLVMTargetMachineRef TM
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Target-Independent Code Generator Pass Configuration Options pass.
static const uint32_t IV[8]
Class for arbitrary precision integers.
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
void setPreservesCFG()
This function should be called by the pass, iff they do not:
LLVM Basic Block Representation.
const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
This is the shared class of boolean and integer constants.
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
A wrapper class for inspecting calls to intrinsic functions.
The legacy pass manager's analysis pass to compute loop information.
virtual bool runOnLoop(Loop *L, LPPassManager &LPM)=0
Represents a single loop in the control flow graph.
A Module instance is used to store all the information related to an LLVM module.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
Pass interface - Implemented by all 'passes'.
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
This class uses information about analyze scalars to rewrite expressions in canonical form.
This class represents an analyzed expression in the program.
bool isZero() const
Return true if the expression is a constant zero.
void dump() const
This method is used for debugging.
Type * getType() const
Return the LLVM type of this SCEV expression.
The main scalar evolution driver.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Primary interface to the complete machine description for the target machine.
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Function * getDeclaration(Module *M, ID id, ArrayRef< Type * > Tys=std::nullopt)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
@ ForceEnabledNoReductions
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
This is an optimization pass for GlobalISel generic memory operations.
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
bool MaskedValueIsZero(const Value *V, const APInt &Mask, const SimplifyQuery &DL, unsigned Depth=0)
Return true if 'V & Mask' is known to be zero.
unsigned Log2_64(uint64_t Value)
Return the floor log base 2 of the specified value, -1 if the value is zero.
bool DeleteDeadPHIs(BasicBlock *BB, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr)
Examine each PHI in the given block and delete it if it is dead.
Pass * createMVETailPredicationPass()
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.