51using namespace PatternMatch;
53#define DEBUG_TYPE "aarch64-loop-idiom-transform"
57 cl::desc(
"Disable AArch64 Loop Idiom Transform Pass."));
61 cl::desc(
"Proceed with AArch64 Loop Idiom Transform Pass, but do "
62 "not convert byte-compare loop(s)."));
66 cl::desc(
"Verify loops generated AArch64 Loop Idiom Transform Pass."));
77class AArch64LoopIdiomTransform {
78 Loop *CurLoop =
nullptr;
96 bool runOnCountableLoop();
100 bool recognizeByteCompare();
111class AArch64LoopIdiomTransformLegacyPass :
public LoopPass {
115 explicit AArch64LoopIdiomTransformLegacyPass() :
LoopPass(
ID) {
121 return "Transform AArch64-specific loop idioms";
133bool AArch64LoopIdiomTransformLegacyPass::runOnLoop(
Loop *L,
139 auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
140 auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
141 auto &
TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
142 *
L->getHeader()->getParent());
143 return AArch64LoopIdiomTransform(
144 DT, LI, &
TTI, &
L->getHeader()->getModule()->getDataLayout())
150char AArch64LoopIdiomTransformLegacyPass::ID = 0;
153 AArch64LoopIdiomTransformLegacyPass,
"aarch64-lit",
154 "Transform specific loop idioms into optimized vector forms",
false,
false)
161 AArch64LoopIdiomTransformLegacyPass, "aarch64-
lit",
165 return new AArch64LoopIdiomTransformLegacyPass();
175 const auto *
DL = &L.getHeader()->getModule()->getDataLayout();
177 AArch64LoopIdiomTransform LIT(&AR.
DT, &AR.
LI, &AR.
TTI,
DL);
190bool AArch64LoopIdiomTransform::run(
Loop *L) {
193 Function &
F = *L->getHeader()->getParent();
197 if (
F.hasFnAttribute(Attribute::NoImplicitFloat)) {
199 <<
" due to its NoImplicitFloat attribute");
205 if (!L->getLoopPreheader())
209 << CurLoop->getHeader()->getName() <<
"\n");
211 return recognizeByteCompare();
214bool AArch64LoopIdiomTransform::recognizeByteCompare() {
229 if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2)
232 PHINode *PN = dyn_cast<PHINode>(&Header->front());
236 auto LoopBlocks = CurLoop->getBlocks();
245 auto CondBBInsts = LoopBlocks[0]->instructionsWithoutDebug();
246 if (std::distance(CondBBInsts.begin(), CondBBInsts.end()) > 4)
260 auto LoopBBInsts = LoopBlocks[1]->instructionsWithoutDebug();
261 if (std::distance(LoopBBInsts.begin(), LoopBBInsts.end()) > 7)
265 Value *StartIdx =
nullptr;
276 if (!
Index || !
Index->getType()->isIntegerTy(32) ||
286 for (
User *U :
I.users())
287 if (!CurLoop->contains(cast<Instruction>(U)))
294 if (!
match(Header->getTerminator(),
305 Value *LoadA, *LoadB;
316 LoadInst *LoadAI = cast<LoadInst>(LoadA);
317 LoadInst *LoadBI = cast<LoadInst>(LoadB);
331 if (!CurLoop->isLoopInvariant(PtrA) || !CurLoop->isLoopInvariant(PtrB) ||
370 if (FoundBB == EndBB) {
372 Value *WhileCondVal = EndPN.getIncomingValueForBlock(Header);
373 Value *WhileBodyVal = EndPN.getIncomingValueForBlock(WhileBB);
379 if (WhileCondVal != WhileBodyVal &&
380 ((WhileCondVal !=
Index && WhileCondVal != MaxLen) ||
381 (WhileBodyVal !=
Index)))
391 transformByteCompare(GEPA, GEPB, PN, MaxLen,
Index, StartIdx,
true,
396Value *AArch64LoopIdiomTransform::expandFindMismatch(
403 BasicBlock *Preheader = CurLoop->getLoopPreheader();
411 SplitBlock(Preheader, PHBranch, DT, LI,
nullptr,
"mismatch_end");
427 Ctx,
"mismatch_min_it_check", EndBlock->
getParent(), EndBlock);
433 Ctx,
"mismatch_mem_check", EndBlock->
getParent(), EndBlock);
436 Ctx,
"mismatch_sve_loop_preheader", EndBlock->
getParent(), EndBlock);
439 Ctx,
"mismatch_sve_loop", EndBlock->
getParent(), EndBlock);
442 Ctx,
"mismatch_sve_loop_inc", EndBlock->
getParent(), EndBlock);
445 Ctx,
"mismatch_sve_loop_found", EndBlock->
getParent(), EndBlock);
448 Ctx,
"mismatch_loop_pre", EndBlock->
getParent(), EndBlock);
454 Ctx,
"mismatch_loop_inc", EndBlock->
getParent(), EndBlock);
460 auto SVELoop = LI->AllocateLoop();
461 auto ScalarLoop = LI->AllocateLoop();
463 if (CurLoop->getParentLoop()) {
464 CurLoop->getParentLoop()->addBasicBlockToLoop(MinItCheckBlock, *LI);
465 CurLoop->getParentLoop()->addBasicBlockToLoop(MemCheckBlock, *LI);
466 CurLoop->getParentLoop()->addBasicBlockToLoop(SVELoopPreheaderBlock, *LI);
467 CurLoop->getParentLoop()->addChildLoop(SVELoop);
468 CurLoop->getParentLoop()->addBasicBlockToLoop(SVELoopMismatchBlock, *LI);
469 CurLoop->getParentLoop()->addBasicBlockToLoop(LoopPreHeaderBlock, *LI);
470 CurLoop->getParentLoop()->addChildLoop(ScalarLoop);
472 LI->addTopLevelLoop(SVELoop);
473 LI->addTopLevelLoop(ScalarLoop);
477 SVELoop->addBasicBlockToLoop(SVELoopStartBlock, *LI);
478 SVELoop->addBasicBlockToLoop(SVELoopIncBlock, *LI);
480 ScalarLoop->addBasicBlockToLoop(LoopStartBlock, *LI);
481 ScalarLoop->addBasicBlockToLoop(LoopIncBlock, *LI);
496 LLVMContext::MD_prof,
498 Builder.
Insert(MinItCheckBr);
538 Value *CombinedPageCmp = Builder.
CreateOr(LhsPageCmp, RhsPageCmp);
540 LoopPreHeaderBlock, SVELoopPreheaderBlock, CombinedPageCmp);
544 Builder.
Insert(CombinedPageCmpCmpBr);
564 Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
567 VecLen = Builder.
CreateMul(VecLen, ConstantInt::get(I64Type, 16),
"",
574 Builder.
Insert(JumpToSVELoop);
582 PHINode *LoopPred = Builder.
CreatePHI(PredVTy, 2,
"mismatch_sve_loop_pred");
583 LoopPred->
addIncoming(InitialPred, SVELoopPreheaderBlock);
585 SVEIndexPhi->
addIncoming(ExtStart, SVELoopPreheaderBlock);
589 Value *SVELhsGep = Builder.
CreateGEP(LoadType, PtrA, SVEIndexPhi);
591 cast<GetElementPtrInst>(SVELhsGep)->setIsInBounds(
true);
595 Value *SVERhsGep = Builder.
CreateGEP(LoadType, PtrB, SVEIndexPhi);
597 cast<GetElementPtrInst>(SVERhsGep)->setIsInBounds(
true);
602 SVEMatchCmp = Builder.
CreateSelect(LoopPred, SVEMatchCmp, PFalse);
605 SVELoopMismatchBlock, SVELoopIncBlock, SVEMatchHasActiveLanes);
606 Builder.
Insert(SVEEarlyExit);
616 Value *NewSVEIndexPhi = Builder.
CreateAdd(SVEIndexPhi, VecLen,
"",
618 SVEIndexPhi->
addIncoming(NewSVEIndexPhi, SVELoopIncBlock);
621 {PredVTy, I64Type}, {NewSVEIndexPhi, ExtEnd});
624 Value *PredHasActiveLanes =
628 Builder.
Insert(SVELoopBranchBack);
636 PHINode *FoundPred = Builder.
CreatePHI(PredVTy, 1,
"mismatch_sve_found_pred");
637 FoundPred->
addIncoming(SVEMatchCmp, SVELoopStartBlock);
639 Builder.
CreatePHI(PredVTy, 1,
"mismatch_sve_last_loop_pred");
640 LastLoopPred->
addIncoming(LoopPred, SVELoopStartBlock);
642 Builder.
CreatePHI(I64Type, 1,
"mismatch_sve_found_index");
643 SVEFoundIndex->
addIncoming(SVEIndexPhi, SVELoopStartBlock);
647 Intrinsic::experimental_cttz_elts, {ResType, PredMatchCmp->
getType()},
648 {PredMatchCmp, Builder.
getInt1(
true)});
675 cast<GetElementPtrInst>(LhsGep)->setIsInBounds(
true);
680 cast<GetElementPtrInst>(RhsGep)->setIsInBounds(
true);
686 Builder.
Insert(MatchCmpBr);
693 Value *PhiInc = Builder.
CreateAdd(IndexPhi, ConstantInt::get(ResType, 1),
"",
694 Index->hasNoUnsignedWrap(),
695 Index->hasNoSignedWrap());
716 ResPhi->
addIncoming(SVELoopRes, SVELoopMismatchBlock);
721 ScalarLoop->verifyLoop();
722 SVELoop->verifyLoop();
723 if (!SVELoop->isRecursivelyLCSSAForm(*DT, *LI))
725 if (!ScalarLoop->isRecursivelyLCSSAForm(*DT, *LI))
732void AArch64LoopIdiomTransform::transformByteCompare(
738 BasicBlock *Preheader = CurLoop->getLoopPreheader();
747 Start = Builder.
CreateAdd(Start, ConstantInt::get(Start->getType(), 1));
750 expandFindMismatch(Builder, DTU, GEPA, GEPB,
Index, Start, MaxLen);
754 assert(IndPhi->
hasOneUse() &&
"Index phi node has more than one use!");
755 Index->replaceAllUsesWith(ByteCmpRes);
758 "Expected preheader to terminate with an unconditional branch.");
764 CmpBB->moveBefore(EndBB);
777 if (FoundBB != EndBB) {
788 auto fixSuccessorPhis = [&](
BasicBlock *SuccBB) {
789 for (
PHINode &PN : SuccBB->phis()) {
795 if (
Op == ByteCmpRes) {
812 if (CurLoop->contains(BB)) {
821 fixSuccessorPhis(EndBB);
822 if (EndBB != FoundBB)
823 fixSuccessorPhis(FoundBB);
827 if (!CurLoop->isOutermost())
828 CurLoop->getParentLoop()->addBasicBlockToLoop(CmpBB, *LI);
831 CurLoop->getParentLoop()->verifyLoop();
832 if (!CurLoop->getParentLoop()->isRecursivelyLCSSAForm(*DT, *LI))
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
hexagon loop Recognize Hexagon specific loop idioms
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
A container for analyses that lazily runs them and caches their results.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
LLVM Basic Block Representation.
iterator_range< const_phi_iterator > phis() const
Returns a range that iterates over the phis in the basic block.
const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
const Function * getParent() const
Return the enclosing method, or null if none.
LLVMContext & getContext() const
Get the context in which this basic block lives.
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Conditional or Unconditional Branch instruction.
static BranchInst * Create(BasicBlock *IfTrue, BasicBlock::iterator InsertBefore)
bool isUnconditional() const
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
static Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
void applyUpdates(ArrayRef< DominatorTree::UpdateType > Updates)
Submit updates to all available trees.
static constexpr UpdateKind Delete
static constexpr UpdateKind Insert
Legacy analysis pass which computes a DominatorTree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
an instruction for type-safe pointer arithmetic to access elements of arrays and structs
bool isInBounds() const
Determine whether the GEP has the inbounds flag.
Value * getPointerOperand()
Type * getResultElementType() const
unsigned getNumIndices() const
ConstantInt * getInt1(bool V)
Get a constant value representing either true or false.
IntegerType * getInt1Ty()
Fetch the type representing a single bit.
Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")
Value * CreateVectorSplat(unsigned NumElts, Value *V, const Twine &Name="")
Return a vector value that contains.
ConstantInt * getTrue()
Get the constant value for i1 true.
CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, Instruction *FMFSource=nullptr, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
CallInst * CreateMaskedLoad(Type *Ty, Value *Ptr, Align Alignment, Value *Mask, Value *PassThru=nullptr, const Twine &Name="")
Create a call to Masked Load intrinsic.
Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
Value * CreateLShr(Value *LHS, Value *RHS, const Twine &Name="", bool isExact=false)
IntegerType * getInt32Ty()
Fetch the type representing a 32-bit integer.
void SetCurrentDebugLocation(DebugLoc L)
Set location information used by debugging information.
IntegerType * getInt64Ty()
Fetch the type representing a 64-bit integer.
Value * CreateICmpNE(Value *LHS, Value *RHS, const Twine &Name="")
CallInst * CreateOrReduce(Value *Src)
Create a vector int OR reduction intrinsic of the source vector.
PHINode * CreatePHI(Type *Ty, unsigned NumReservedValues, const Twine &Name="")
Value * CreateICmpEQ(Value *LHS, Value *RHS, const Twine &Name="")
InstTy * Insert(InstTy *I, const Twine &Name="") const
Insert and return the specified instruction.
BranchInst * CreateCondBr(Value *Cond, BasicBlock *True, BasicBlock *False, MDNode *BranchWeights=nullptr, MDNode *Unpredictable=nullptr)
Create a conditional 'br Cond, TrueDest, FalseDest' instruction.
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
Value * CreateZExt(Value *V, Type *DestTy, const Twine &Name="", bool IsNonNeg=false)
Value * CreateAnd(Value *LHS, Value *RHS, const Twine &Name="")
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreatePtrToInt(Value *V, Type *DestTy, const Twine &Name="")
Value * CreateTrunc(Value *V, Type *DestTy, const Twine &Name="", bool IsNUW=false, bool IsNSW=false)
Value * CreateOr(Value *LHS, Value *RHS, const Twine &Name="")
BranchInst * CreateBr(BasicBlock *Dest)
Create an unconditional 'br label X' instruction.
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Value * CreateICmpULE(Value *LHS, Value *RHS, const Twine &Name="")
Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", bool IsInBounds=false)
IntegerType * getInt8Ty()
Fetch the type representing an 8-bit integer.
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
void setMetadata(unsigned KindID, MDNode *Node)
Set the metadata of the specified kind to the specified node.
void setSuccessor(unsigned Idx, BasicBlock *BB)
Update the specified successor to point at the provided block.
This is an important class for using LLVM in a threaded context.
This class provides an interface for updating the loop pass manager based on mutations to the loop ne...
An instruction for reading from memory.
The legacy pass manager's analysis pass to compute loop information.
virtual bool runOnLoop(Loop *L, LPPassManager &LPM)=0
Represents a single loop in the control flow graph.
MDNode * createBranchWeights(uint32_t TrueWeight, uint32_t FalseWeight)
Return metadata containing two branch weights.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
PassRegistry - This class manages the registration and intitialization of the pass subsystem as appli...
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Pass interface - Implemented by all 'passes'.
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
This class represents an analyzed expression in the program.
Class to represent scalable SIMD vectors.
static ScalableVectorType * get(Type *ElementType, unsigned MinNumElts)
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
StringRef - Represent a constant reference to a string, i.e.
The instances of the Type class are immutable: once they are created, they are never changed.
static IntegerType * getInt8Ty(LLVMContext &C)
bool isIntegerTy() const
True if this is an instance of IntegerType.
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
bool hasOneUse() const
Return true if there is exactly one use of this value.
LLVMContext & getContext() const
All values hold a context through their type.
ElementCount getElementCount() const
Return an ElementCount instance to represent the (possibly scalable) number of elements in the vector...
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
bool match(Val *V, const Pattern &P)
specificval_ty m_Specific(const Value *V)
Match if we have a specific specified value.
cst_pred_ty< is_one > m_One()
Match an integer 1 or a vector with all elements equal to 1.
CmpClass_match< LHS, RHS, ICmpInst, ICmpInst::Predicate > m_ICmp(ICmpInst::Predicate &Pred, const LHS &L, const RHS &R)
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
CastInst_match< OpTy, ZExtInst > m_ZExt(const OpTy &Op)
Matches ZExt.
brc_match< Cond_t, bind_ty< BasicBlock >, bind_ty< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
BinaryOp_match< LHS, RHS, Instruction::Add, true > m_c_Add(const LHS &L, const RHS &R)
Matches a Add with LHS and RHS in either order.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
class_match< BasicBlock > m_BasicBlock()
Match an arbitrary basic block value and ignore it.
initializer< Ty > init(const Ty &Val)
PointerTypeMap run(const Module &M)
Compute the PointerTypeMap for the module M.
This is an optimization pass for GlobalISel generic memory operations.
unsigned Log2_64(uint64_t Value)
Return the floor log base 2 of the specified value, -1 if the value is zero.
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
Pass * createAArch64LoopIdiomTransformPass()
void initializeAArch64LoopIdiomTransformLegacyPassPass(PassRegistry &)
BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
This struct is a compact representation of a valid (non-zero power of two) alignment.
The adaptor from a function pass to a loop pass computes these analyses and makes them available to t...
TargetTransformInfo & TTI