32#include "llvm/IR/IntrinsicsAArch64.h"
43#define DEBUG_TYPE "aarch64-sve-intrinsic-opts"
56 bool coalescePTrueIntrinsicCalls(
BasicBlock &BB,
70void SVEIntrinsicOpts::getAnalysisUsage(
AnalysisUsage &AU)
const {
75char SVEIntrinsicOpts::ID = 0;
76static const char *
name =
"SVE intrinsics optimizations";
82 return new SVEIntrinsicOpts();
102 if (
match(
User, m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>())) {
108 if (ConvertToUses.
empty())
114 const auto *PTrueVTy = cast<ScalableVectorType>(PTrue->
getType());
117 auto *IntrUser = dyn_cast<IntrinsicInst>(
User);
118 if (IntrUser && IntrUser->getIntrinsicID() ==
119 Intrinsic::aarch64_sve_convert_from_svbool) {
120 const auto *IntrUserVTy = cast<ScalableVectorType>(IntrUser->getType());
123 if (IntrUserVTy->getElementCount().getKnownMinValue() >
124 PTrueVTy->getElementCount().getKnownMinValue())
136bool SVEIntrinsicOpts::coalescePTrueIntrinsicCalls(
138 if (PTrues.
size() <= 1)
142 auto *MostEncompassingPTrue =
144 auto *PTrue1VTy = cast<ScalableVectorType>(PTrue1->getType());
145 auto *PTrue2VTy = cast<ScalableVectorType>(PTrue2->getType());
146 return PTrue1VTy->getElementCount().getKnownMinValue() <
147 PTrue2VTy->getElementCount().getKnownMinValue();
152 PTrues.
remove(MostEncompassingPTrue);
162 Builder.SetInsertPoint(&BB, ++MostEncompassingPTrue->getIterator());
164 auto *MostEncompassingPTrueVTy =
165 cast<VectorType>(MostEncompassingPTrue->getType());
166 auto *ConvertToSVBool = Builder.CreateIntrinsic(
167 Intrinsic::aarch64_sve_convert_to_svbool, {MostEncompassingPTrueVTy},
168 {MostEncompassingPTrue});
170 bool ConvertFromCreated =
false;
171 for (
auto *PTrue : PTrues) {
172 auto *PTrueVTy = cast<VectorType>(PTrue->getType());
176 if (MostEncompassingPTrueVTy != PTrueVTy) {
177 ConvertFromCreated =
true;
179 Builder.SetInsertPoint(&BB, ++ConvertToSVBool->getIterator());
180 auto *ConvertFromSVBool =
181 Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
182 {PTrueVTy}, {ConvertToSVBool});
183 PTrue->replaceAllUsesWith(ConvertFromSVBool);
185 PTrue->replaceAllUsesWith(MostEncompassingPTrue);
187 PTrue->eraseFromParent();
191 if (!ConvertFromCreated)
192 ConvertToSVBool->eraseFromParent();
245bool SVEIntrinsicOpts::optimizePTrueIntrinsicCalls(
247 bool Changed =
false;
249 for (
auto *
F : Functions) {
250 for (
auto &BB : *
F) {
259 auto *IntrI = dyn_cast<IntrinsicInst>(&
I);
260 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
263 const auto PTruePattern =
264 cast<ConstantInt>(IntrI->getOperand(0))->getZExtValue();
266 if (PTruePattern == AArch64SVEPredPattern::all)
267 SVAllPTrues.
insert(IntrI);
268 if (PTruePattern == AArch64SVEPredPattern::pow2)
269 SVPow2PTrues.
insert(IntrI);
272 Changed |= coalescePTrueIntrinsicCalls(BB, SVAllPTrues);
273 Changed |= coalescePTrueIntrinsicCalls(BB, SVPow2PTrues);
282bool SVEIntrinsicOpts::optimizePredicateStore(
Instruction *
I) {
283 auto *
F =
I->getFunction();
284 auto Attr =
F->getFnAttribute(Attribute::VScaleRange);
288 unsigned MinVScale = Attr.getVScaleRangeMin();
289 std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax();
291 if (!MaxVScale || MinVScale != MaxVScale)
296 auto *FixedPredType =
300 auto *
Store = dyn_cast<StoreInst>(
I);
301 if (!Store || !
Store->isSimple())
305 if (
Store->getOperand(0)->getType() != FixedPredType)
309 auto *IntrI = dyn_cast<IntrinsicInst>(
Store->getOperand(0));
310 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_extract)
314 if (!cast<ConstantInt>(IntrI->getOperand(1))->isZero())
318 auto *BitCast = dyn_cast<BitCastInst>(IntrI->getOperand(0));
323 if (BitCast->getOperand(0)->getType() != PredType)
327 Builder.SetInsertPoint(
I);
329 Builder.CreateStore(BitCast->getOperand(0),
Store->getPointerOperand());
331 Store->eraseFromParent();
332 if (IntrI->getNumUses() == 0)
333 IntrI->eraseFromParent();
334 if (BitCast->getNumUses() == 0)
335 BitCast->eraseFromParent();
342bool SVEIntrinsicOpts::optimizePredicateLoad(
Instruction *
I) {
343 auto *
F =
I->getFunction();
344 auto Attr =
F->getFnAttribute(Attribute::VScaleRange);
348 unsigned MinVScale = Attr.getVScaleRangeMin();
349 std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax();
351 if (!MaxVScale || MinVScale != MaxVScale)
356 auto *FixedPredType =
360 auto *BitCast = dyn_cast<BitCastInst>(
I);
361 if (!BitCast || BitCast->getType() != PredType)
365 auto *IntrI = dyn_cast<IntrinsicInst>(BitCast->getOperand(0));
366 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_insert)
370 if (!isa<UndefValue>(IntrI->getOperand(0)) ||
371 !cast<ConstantInt>(IntrI->getOperand(2))->isZero())
375 auto *
Load = dyn_cast<LoadInst>(IntrI->getOperand(1));
376 if (!Load || !
Load->isSimple())
380 if (
Load->getType() != FixedPredType)
384 Builder.SetInsertPoint(Load);
386 auto *LoadPred = Builder.CreateLoad(PredType,
Load->getPointerOperand());
388 BitCast->replaceAllUsesWith(LoadPred);
389 BitCast->eraseFromParent();
390 if (IntrI->getNumUses() == 0)
391 IntrI->eraseFromParent();
392 if (
Load->getNumUses() == 0)
393 Load->eraseFromParent();
398bool SVEIntrinsicOpts::optimizeInstructions(
400 bool Changed =
false;
402 for (
auto *
F : Functions) {
403 DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>(*F).getDomTree();
409 for (
auto *BB : RPOT) {
411 switch (
I.getOpcode()) {
412 case Instruction::Store:
413 Changed |= optimizePredicateStore(&
I);
415 case Instruction::BitCast:
416 Changed |= optimizePredicateLoad(&
I);
426bool SVEIntrinsicOpts::optimizeFunctions(
428 bool Changed =
false;
430 Changed |= optimizePTrueIntrinsicCalls(Functions);
431 Changed |= optimizeInstructions(Functions);
436bool SVEIntrinsicOpts::runOnModule(
Module &M) {
437 bool Changed =
false;
443 for (
auto &
F :
M.getFunctionList()) {
444 if (!
F.isDeclaration())
447 switch (
F.getIntrinsicID()) {
448 case Intrinsic::vector_extract:
449 case Intrinsic::vector_insert:
450 case Intrinsic::aarch64_sve_ptrue:
451 for (
User *U :
F.users())
459 if (!Functions.
empty())
460 Changed |= optimizeFunctions(Functions);
This file contains the declarations for the subclasses of Constant, which represent the different fla...
static Function * getFunction(Constant *C)
Module.h This file contains the declarations for the Module class.
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
static bool isPTruePromoted(IntrinsicInst *PTrue)
Checks if a ptrue intrinsic call is promoted.
This file implements a set that has insertion order iteration characteristics.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesCFG()
This function should be called by the pass, iff they do not:
LLVM Basic Block Representation.
const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
LLVMContext & getContext() const
Get the context in which this basic block lives.
Legacy analysis pass which computes a DominatorTree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
A wrapper class for inspecting calls to intrinsic functions.
This is an important class for using LLVM in a threaded context.
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
virtual bool runOnModule(Module &M)=0
runOnModule - Virtual method overriden by subclasses to process the module being operated on.
A Module instance is used to store all the information related to an LLVM module.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
static ScalableVectorType * get(Type *ElementType, unsigned MinNumElts)
bool remove(const value_type &X)
Remove an item from the set vector.
bool remove_if(UnaryPredicate P)
Remove items from the set vector based on a predicate function.
size_type size() const
Determine the number of elements in the SetVector.
bool empty() const
Determine if the SetVector is empty or not.
bool insert(const value_type &X)
Insert a new element into the SetVector.
A SetVector that performs no allocations if smaller than a certain size.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
static IntegerType * getInt1Ty(LLVMContext &C)
static IntegerType * getInt8Ty(LLVMContext &C)
Type * getType() const
All values are typed, get the type of this value.
iterator_range< user_iterator > users()
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
bool match(Val *V, const Pattern &P)
This is an optimization pass for GlobalISel generic memory operations.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
ModulePass * createSVEIntrinsicOptsPass()
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
void initializeSVEIntrinsicOptsPass(PassRegistry &)