46#include "llvm/IR/IntrinsicsARM.h"
55#define DEBUG_TYPE "mve-tail-predication"
56#define DESC "Transform predicated vector loops to use MVE tail predication"
59 "tail-predication",
cl::desc(
"MVE tail-predication pass options"),
62 "Don't tail-predicate loops"),
64 "enabled-no-reductions",
65 "Enable tail-predication, but not for reduction loops"),
68 "Enable tail-predication, including reduction loops"),
70 "force-enabled-no-reductions",
71 "Enable tail-predication, but not for reduction loops, "
72 "and force this which might be unsafe"),
75 "Enable tail-predication, including reduction loops, "
76 "and force this which might be unsafe")));
81class MVETailPredication :
public LoopPass {
107 bool TryConvertActiveLaneMask(
Value *TripCount);
127 auto &TPC = getAnalysis<TargetPassConfig>();
130 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F);
131 SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
136 if (!
ST->hasMVEIntegerOps() || !
ST->hasV8_1MMainlineOps()) {
146 for (
auto &
I : *BB) {
147 auto *
Call = dyn_cast<IntrinsicInst>(&
I);
152 if (
ID == Intrinsic::start_loop_iterations ||
153 ID == Intrinsic::test_start_loop_iterations)
154 return cast<IntrinsicInst>(&
I);
171 LLVM_DEBUG(
dbgs() <<
"ARM TP: Running on Loop: " << *L << *Setup <<
"\n");
173 bool Changed = TryConvertActiveLaneMask(
Setup->getArgOperand(0));
197 bool ForceTailPredication =
202 bool Changed =
false;
203 if (!
L->makeLoopInvariant(ElemCount, Changed))
206 const SCEV *
EC = SE->getSCEV(ElemCount);
207 const SCEV *TC = SE->getSCEV(TripCount);
209 cast<FixedVectorType>(ActiveLaneMask->
getType())->getNumElements();
210 if (VectorWidth != 2 && VectorWidth != 4 && VectorWidth != 8 &&
218 if (!SE->isLoopInvariant(EC, L)) {
219 LLVM_DEBUG(
dbgs() <<
"ARM TP: element count must be loop invariant.\n");
229 const SCEV *IVExpr = SE->getSCEV(
IV);
230 auto *AddExpr = dyn_cast<SCEVAddRecExpr>(IVExpr);
237 if (AddExpr->getLoop() != L) {
241 auto *Step = dyn_cast<SCEVConstant>(AddExpr->getOperand(1));
243 LLVM_DEBUG(
dbgs() <<
"ARM TP: induction step is not a constant: ";
244 AddExpr->getOperand(1)->
dump());
247 auto StepValue = Step->getValue()->getSExtValue();
248 if (VectorWidth != StepValue) {
250 <<
" doesn't match vector width " << VectorWidth <<
"\n");
254 if ((ConstElemCount = dyn_cast<ConstantInt>(ElemCount))) {
255 ConstantInt *TC = dyn_cast<ConstantInt>(TripCount);
258 "set.loop.iterations\n");
268 (ConstElemCount->
getZExtValue() + VectorWidth - 1) / VectorWidth;
274 LLVM_DEBUG(
dbgs() <<
"ARM TP: inconsistent constant tripcount values: "
275 << TC1 <<
" from set.loop.iterations, and "
276 << TC2 <<
" from get.active.lane.mask\n");
279 }
else if (!ForceTailPredication) {
293 SE->getSCEV(ConstantInt::get(TripCount->
getType(), VectorWidth));
295 const SCEV *Start = AddExpr->getStart();
296 const SCEV *ECPlusVWMinus1 = SE->getAddExpr(
298 SE->getSCEV(ConstantInt::get(TripCount->
getType(), VectorWidth - 1)));
301 const SCEV *Ceil = SE->getUDivExpr(ECPlusVWMinus1, VW);
306 dbgs() <<
"ARM TP: Analysing overflow behaviour for:\n";
307 dbgs() <<
"ARM TP: - TripCount = " << *TC <<
"\n";
308 dbgs() <<
"ARM TP: - ElemCount = " << *
EC <<
"\n";
309 dbgs() <<
"ARM TP: - Start = " << *Start <<
"\n";
310 dbgs() <<
"ARM TP: - BETC = " << *SE->getBackedgeTakenCount(L) <<
"\n";
311 dbgs() <<
"ARM TP: - VecWidth = " << VectorWidth <<
"\n";
312 dbgs() <<
"ARM TP: - (ElemCount+VW-1) / VW = " << *Ceil <<
"\n";
327 const SCEV *Div = SE->getUDivExpr(
328 SE->getAddExpr(SE->getMulExpr(Ceil, VW), SE->getNegativeSCEV(VW),
329 SE->getNegativeSCEV(Start)),
331 const SCEV *Sub = SE->getMinusSCEV(SE->getBackedgeTakenCount(L), Div);
337 Sub = SE->applyLoopGuards(Sub, L);
341 LLVM_DEBUG(
dbgs() <<
"ARM TP: possible overflow in sub expression.\n");
350 if (
auto *BaseC = dyn_cast<SCEVConstant>(AddExpr->getStart())) {
351 if (BaseC->getAPInt().urem(VectorWidth) == 0)
352 return SE->getMinusSCEV(EC, BaseC);
353 }
else if (
auto *BaseV = dyn_cast<SCEVUnknown>(AddExpr->getStart())) {
354 Type *Ty = BaseV->getType();
358 L->getHeader()->getDataLayout()))
359 return SE->getMinusSCEV(EC, BaseV);
360 }
else if (
auto *BaseMul = dyn_cast<SCEVMulExpr>(AddExpr->getStart())) {
361 if (
auto *BaseC = dyn_cast<SCEVConstant>(BaseMul->getOperand(0)))
362 if (BaseC->getAPInt().urem(VectorWidth) == 0)
363 return SE->getMinusSCEV(EC, BaseC);
364 if (
auto *BaseC = dyn_cast<SCEVConstant>(BaseMul->getOperand(1)))
365 if (BaseC->getAPInt().urem(VectorWidth) == 0)
366 return SE->getMinusSCEV(EC, BaseC);
370 dbgs() <<
"ARM TP: induction base is not know to be a multiple of VF: "
371 << *AddExpr->getOperand(0) <<
"\n");
375void MVETailPredication::InsertVCTPIntrinsic(
IntrinsicInst *ActiveLaneMask,
377 IRBuilder<> Builder(
L->getLoopPreheader()->getTerminator());
378 Module *
M =
L->getHeader()->getModule();
380 unsigned VectorWidth =
381 cast<FixedVectorType>(ActiveLaneMask->
getType())->getNumElements();
384 Builder.SetInsertPoint(
L->getHeader(),
L->getHeader()->getFirstNonPHIIt());
385 PHINode *Processed = Builder.CreatePHI(Ty, 2);
390 Builder.SetInsertPoint(ActiveLaneMask);
391 ConstantInt *Factor = ConstantInt::get(cast<IntegerType>(Ty), VectorWidth);
394 switch (VectorWidth) {
397 case 2: VCTPID = Intrinsic::arm_mve_vctp64;
break;
398 case 4: VCTPID = Intrinsic::arm_mve_vctp32;
break;
399 case 8: VCTPID = Intrinsic::arm_mve_vctp16;
break;
400 case 16: VCTPID = Intrinsic::arm_mve_vctp8;
break;
402 Value *VCTPCall = Builder.CreateIntrinsic(VCTPID, {}, Processed);
407 Value *Remaining = Builder.CreateSub(Processed, Factor);
410 << *Processed <<
"\n"
411 <<
"ARM TP: Inserted VCTP: " << *VCTPCall <<
"\n");
414bool MVETailPredication::TryConvertActiveLaneMask(
Value *TripCount) {
416 for (
auto *BB :
L->getBlocks())
418 if (
auto *
Int = dyn_cast<IntrinsicInst>(&
I))
419 if (
Int->getIntrinsicID() == Intrinsic::get_active_lane_mask)
422 if (ActiveLaneMasks.
empty())
427 for (
auto *ActiveLaneMask : ActiveLaneMasks) {
429 << *ActiveLaneMask <<
"\n");
431 const SCEV *StartSCEV = IsSafeActiveMask(ActiveLaneMask, TripCount);
436 LLVM_DEBUG(
dbgs() <<
"ARM TP: Safe to insert VCTP. Start is " << *StartSCEV
441 Value *Start = Expander.expandCodeFor(StartSCEV, StartSCEV->
getType(), Ins);
442 LLVM_DEBUG(
dbgs() <<
"ARM TP: Created start value " << *Start <<
"\n");
443 InsertVCTPIntrinsic(ActiveLaneMask, Start);
447 for (
auto *
II : ActiveLaneMasks)
449 for (
auto *
I :
L->blocks())
455 return new MVETailPredication();
458char MVETailPredication::ID = 0;
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
cl::opt< TailPredication::Mode > EnableTailPredication("tail-predication", cl::desc("MVE tail-predication pass options"), cl::init(TailPredication::Enabled), cl::values(clEnumValN(TailPredication::Disabled, "disabled", "Don't tail-predicate loops"), clEnumValN(TailPredication::EnabledNoReductions, "enabled-no-reductions", "Enable tail-predication, but not for reduction loops"), clEnumValN(TailPredication::Enabled, "enabled", "Enable tail-predication, including reduction loops"), clEnumValN(TailPredication::ForceEnabledNoReductions, "force-enabled-no-reductions", "Enable tail-predication, but not for reduction loops, " "and force this which might be unsafe"), clEnumValN(TailPredication::ForceEnabled, "force-enabled", "Enable tail-predication, including reduction loops, " "and force this which might be unsafe")))
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Target-Independent Code Generator Pass Configuration Options pass.
static const uint32_t IV[8]
Class for arbitrary precision integers.
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
void setPreservesCFG()
This function should be called by the pass, iff they do not:
LLVM Basic Block Representation.
const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
This is the shared class of boolean and integer constants.
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
A wrapper class for inspecting calls to intrinsic functions.
The legacy pass manager's analysis pass to compute loop information.
virtual bool runOnLoop(Loop *L, LPPassManager &LPM)=0
Represents a single loop in the control flow graph.
A Module instance is used to store all the information related to an LLVM module.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
Pass interface - Implemented by all 'passes'.
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
This class uses information about analyze scalars to rewrite expressions in canonical form.
This class represents an analyzed expression in the program.
bool isZero() const
Return true if the expression is a constant zero.
void dump() const
This method is used for debugging.
Type * getType() const
Return the LLVM type of this SCEV expression.
The main scalar evolution driver.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Primary interface to the complete machine description for the target machine.
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ ForceEnabledNoReductions
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
This is an optimization pass for GlobalISel generic memory operations.
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
unsigned Log2_64(uint64_t Value)
Return the floor log base 2 of the specified value, -1 if the value is zero.
bool MaskedValueIsZero(const Value *V, const APInt &Mask, const SimplifyQuery &SQ, unsigned Depth=0)
Return true if 'V & Mask' is known to be zero.
bool DeleteDeadPHIs(BasicBlock *BB, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr)
Examine each PHI in the given block and delete it if it is dead.
Pass * createMVETailPredicationPass()
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.