46#include "llvm/IR/IntrinsicsARM.h"
55#define DEBUG_TYPE "mve-tail-predication"
56#define DESC "Transform predicated vector loops to use MVE tail predication"
59 "tail-predication",
cl::desc(
"MVE tail-predication pass options"),
62 "Don't tail-predicate loops"),
64 "enabled-no-reductions",
65 "Enable tail-predication, but not for reduction loops"),
68 "Enable tail-predication, including reduction loops"),
70 "force-enabled-no-reductions",
71 "Enable tail-predication, but not for reduction loops, "
72 "and force this which might be unsafe"),
75 "Enable tail-predication, including reduction loops, "
76 "and force this which might be unsafe")));
81class MVETailPredication :
public LoopPass {
93 void getAnalysisUsage(AnalysisUsage &AU)
const override {
102 bool runOnLoop(Loop *L, LPPassManager&)
override;
107 bool TryConvertActiveLaneMask(
Value *TripCount);
113 const SCEV *IsSafeActiveMask(IntrinsicInst *ActiveLaneMask,
Value *TripCount);
116 void InsertVCTPIntrinsic(IntrinsicInst *ActiveLaneMask,
Value *Start);
127 auto &TPC = getAnalysis<TargetPassConfig>();
128 auto &
TM = TPC.getTM<TargetMachine>();
129 ST = &
TM.getSubtarget<ARMSubtarget>(
F);
130 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F);
131 SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
136 if (!
ST->hasMVEIntegerOps() || !
ST->hasV8_1MMainlineOps()) {
145 auto FindLoopIterations = [](
BasicBlock *BB) -> IntrinsicInst* {
146 for (
auto &
I : *BB) {
152 if (
ID == Intrinsic::start_loop_iterations ||
153 ID == Intrinsic::test_start_loop_iterations)
160 IntrinsicInst *
Setup = FindLoopIterations(Preheader);
171 LLVM_DEBUG(
dbgs() <<
"ARM TP: Running on Loop: " << *L << *Setup <<
"\n");
173 bool Changed = TryConvertActiveLaneMask(
Setup->getArgOperand(0));
195const SCEV *MVETailPredication::IsSafeActiveMask(IntrinsicInst *ActiveLaneMask,
197 bool ForceTailPredication =
203 if (!
L->makeLoopInvariant(ElemCount,
Changed))
207 const SCEV *TC = SE->
getSCEV(TripCount);
210 if (VectorWidth != 2 && VectorWidth != 4 && VectorWidth != 8 &&
213 ConstantInt *ConstElemCount =
nullptr;
219 LLVM_DEBUG(
dbgs() <<
"ARM TP: element count must be loop invariant.\n");
237 if (AddExpr->getLoop() != L) {
243 LLVM_DEBUG(
dbgs() <<
"ARM TP: induction step is not a constant: ";
244 AddExpr->getOperand(1)->
dump());
247 auto StepValue = Step->getValue()->getSExtValue();
248 if (VectorWidth != StepValue) {
250 <<
" doesn't match vector width " << VectorWidth <<
"\n");
258 "set.loop.iterations\n");
268 (ConstElemCount->
getZExtValue() + VectorWidth - 1) / VectorWidth;
274 LLVM_DEBUG(
dbgs() <<
"ARM TP: inconsistent constant tripcount values: "
275 << TC1 <<
" from set.loop.iterations, and "
276 << TC2 <<
" from get.active.lane.mask\n");
279 }
else if (!ForceTailPredication) {
295 const SCEV *
Start = AddExpr->getStart();
298 SE->
getSCEV(ConstantInt::get(TripCount->
getType(), VectorWidth - 1)));
301 const SCEV *Ceil = SE->
getUDivExpr(ECPlusVWMinus1, VW);
306 dbgs() <<
"ARM TP: Analysing overflow behaviour for:\n";
307 dbgs() <<
"ARM TP: - TripCount = " << *TC <<
"\n";
308 dbgs() <<
"ARM TP: - ElemCount = " << *
EC <<
"\n";
309 dbgs() <<
"ARM TP: - Start = " << *
Start <<
"\n";
311 dbgs() <<
"ARM TP: - VecWidth = " << VectorWidth <<
"\n";
312 dbgs() <<
"ARM TP: - (ElemCount+VW-1) / VW = " << *Ceil <<
"\n";
340 if (!
Sub->isZero()) {
341 LLVM_DEBUG(
dbgs() <<
"ARM TP: possible overflow in sub expression.\n");
351 if (BaseC->getAPInt().urem(VectorWidth) == 0)
354 Type *Ty = BaseV->getType();
358 L->getHeader()->getDataLayout()))
362 if (BaseC->getAPInt().urem(VectorWidth) == 0)
365 if (BaseC->getAPInt().urem(VectorWidth) == 0)
370 dbgs() <<
"ARM TP: induction base is not know to be a multiple of VF: "
371 << *AddExpr->getOperand(0) <<
"\n");
375void MVETailPredication::InsertVCTPIntrinsic(IntrinsicInst *ActiveLaneMask,
377 IRBuilder<> Builder(
L->getLoopPreheader()->getTerminator());
378 Module *
M =
L->getHeader()->getModule();
380 unsigned VectorWidth =
384 Builder.SetInsertPoint(
L->getHeader(),
L->getHeader()->getFirstNonPHIIt());
385 PHINode *Processed = Builder.CreatePHI(Ty, 2);
390 Builder.SetInsertPoint(ActiveLaneMask);
394 switch (VectorWidth) {
397 case 2: VCTPID = Intrinsic::arm_mve_vctp64;
break;
398 case 4: VCTPID = Intrinsic::arm_mve_vctp32;
break;
399 case 8: VCTPID = Intrinsic::arm_mve_vctp16;
break;
400 case 16: VCTPID = Intrinsic::arm_mve_vctp8;
break;
402 Value *VCTPCall = Builder.CreateIntrinsic(VCTPID, Processed);
407 Value *Remaining = Builder.CreateSub(Processed, Factor);
410 << *Processed <<
"\n"
411 <<
"ARM TP: Inserted VCTP: " << *VCTPCall <<
"\n");
414bool MVETailPredication::TryConvertActiveLaneMask(
Value *TripCount) {
416 for (
auto *BB :
L->getBlocks())
419 if (
Int->getIntrinsicID() == Intrinsic::get_active_lane_mask)
422 if (ActiveLaneMasks.
empty())
427 for (
auto *ActiveLaneMask : ActiveLaneMasks) {
429 << *ActiveLaneMask <<
"\n");
431 const SCEV *StartSCEV = IsSafeActiveMask(ActiveLaneMask, TripCount);
436 LLVM_DEBUG(
dbgs() <<
"ARM TP: Safe to insert VCTP. Start is " << *StartSCEV
438 SCEVExpander Expander(*SE,
L->getHeader()->getDataLayout(),
442 LLVM_DEBUG(
dbgs() <<
"ARM TP: Created start value " << *Start <<
"\n");
443 InsertVCTPIntrinsic(ActiveLaneMask, Start);
447 for (
auto *
II : ActiveLaneMasks)
449 for (
auto *
I :
L->blocks())
455 return new MVETailPredication();
458char MVETailPredication::ID = 0;
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
cl::opt< TailPredication::Mode > EnableTailPredication("tail-predication", cl::desc("MVE tail-predication pass options"), cl::init(TailPredication::Enabled), cl::values(clEnumValN(TailPredication::Disabled, "disabled", "Don't tail-predicate loops"), clEnumValN(TailPredication::EnabledNoReductions, "enabled-no-reductions", "Enable tail-predication, but not for reduction loops"), clEnumValN(TailPredication::Enabled, "enabled", "Enable tail-predication, including reduction loops"), clEnumValN(TailPredication::ForceEnabledNoReductions, "force-enabled-no-reductions", "Enable tail-predication, but not for reduction loops, " "and force this which might be unsafe"), clEnumValN(TailPredication::ForceEnabled, "force-enabled", "Enable tail-predication, including reduction loops, " "and force this which might be unsafe")))
Machine Check Debug Module
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Target-Independent Code Generator Pass Configuration Options pass.
static const uint32_t IV[8]
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
LLVM_ABI const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
LLVM_ABI Intrinsic::ID getIntrinsicID() const
Returns the intrinsic ID of the intrinsic called or Intrinsic::not_intrinsic if the called function i...
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Represents a single loop in the control flow graph.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
Pass interface - Implemented by all 'passes'.
LLVM_ABI void dump() const
This method is used for debugging.
LLVM_ABI Type * getType() const
Return the LLVM type of this SCEV expression.
The main scalar evolution driver.
LLVM_ABI const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
LLVM_ABI const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
LLVM_ABI const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
LLVM_ABI bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
LLVM_ABI const SCEV * getUDivExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
LLVM_ABI const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
LLVM_ABI const SCEV * getMulExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
LLVM_ABI const SCEV * getAddExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Value * getOperand(unsigned i) const
Type * getType() const
All values are typed, get the type of this value.
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ BasicBlock
Various leaf nodes.
@ ForceEnabledNoReductions
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
friend class Instruction
Iterator for Instructions in a `BasicBlock.
This is an optimization pass for GlobalISel generic memory operations.
void dump(const SparseBitVector< ElementSize > &LHS, raw_ostream &out)
FunctionAddr VTableAddr Value
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
unsigned Log2_64(uint64_t Value)
Return the floor log base 2 of the specified value, -1 if the value is zero.
LLVM_ABI bool MaskedValueIsZero(const Value *V, const APInt &Mask, const SimplifyQuery &SQ, unsigned Depth=0)
Return true if 'V & Mask' is known to be zero.
LLVM_ABI bool DeleteDeadPHIs(BasicBlock *BB, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr)
Examine each PHI in the given block and delete it if it is dead.
Pass * createMVETailPredicationPass()
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
@ Sub
Subtraction of integers.
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.