32#include "llvm/IR/IntrinsicsAArch64.h"
42#define DEBUG_TYPE "aarch64-sve-intrinsic-opts"
55 bool coalescePTrueIntrinsicCalls(
BasicBlock &BB,
69void SVEIntrinsicOpts::getAnalysisUsage(
AnalysisUsage &AU)
const {
74char SVEIntrinsicOpts::ID = 0;
75static const char *
name =
"SVE intrinsics optimizations";
81 return new SVEIntrinsicOpts();
101 if (
match(
User, m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>())) {
107 if (ConvertToUses.
empty())
113 const auto *PTrueVTy = cast<ScalableVectorType>(PTrue->
getType());
116 auto *IntrUser = dyn_cast<IntrinsicInst>(
User);
117 if (IntrUser && IntrUser->getIntrinsicID() ==
118 Intrinsic::aarch64_sve_convert_from_svbool) {
119 const auto *IntrUserVTy = cast<ScalableVectorType>(IntrUser->getType());
122 if (IntrUserVTy->getElementCount().getKnownMinValue() >
123 PTrueVTy->getElementCount().getKnownMinValue())
135bool SVEIntrinsicOpts::coalescePTrueIntrinsicCalls(
137 if (PTrues.
size() <= 1)
141 auto *MostEncompassingPTrue =
143 auto *PTrue1VTy = cast<ScalableVectorType>(PTrue1->getType());
144 auto *PTrue2VTy = cast<ScalableVectorType>(PTrue2->getType());
145 return PTrue1VTy->getElementCount().getKnownMinValue() <
146 PTrue2VTy->getElementCount().getKnownMinValue();
151 PTrues.
remove(MostEncompassingPTrue);
161 Builder.SetInsertPoint(&BB, ++MostEncompassingPTrue->getIterator());
163 auto *MostEncompassingPTrueVTy =
164 cast<VectorType>(MostEncompassingPTrue->getType());
165 auto *ConvertToSVBool = Builder.CreateIntrinsic(
166 Intrinsic::aarch64_sve_convert_to_svbool, {MostEncompassingPTrueVTy},
167 {MostEncompassingPTrue});
169 bool ConvertFromCreated =
false;
170 for (
auto *PTrue : PTrues) {
171 auto *PTrueVTy = cast<VectorType>(PTrue->getType());
175 if (MostEncompassingPTrueVTy != PTrueVTy) {
176 ConvertFromCreated =
true;
178 Builder.SetInsertPoint(&BB, ++ConvertToSVBool->getIterator());
179 auto *ConvertFromSVBool =
180 Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
181 {PTrueVTy}, {ConvertToSVBool});
182 PTrue->replaceAllUsesWith(ConvertFromSVBool);
184 PTrue->replaceAllUsesWith(MostEncompassingPTrue);
186 PTrue->eraseFromParent();
190 if (!ConvertFromCreated)
191 ConvertToSVBool->eraseFromParent();
244bool SVEIntrinsicOpts::optimizePTrueIntrinsicCalls(
246 bool Changed =
false;
248 for (
auto *
F : Functions) {
249 for (
auto &BB : *
F) {
258 auto *IntrI = dyn_cast<IntrinsicInst>(&
I);
259 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
262 const auto PTruePattern =
263 cast<ConstantInt>(IntrI->getOperand(0))->getZExtValue();
265 if (PTruePattern == AArch64SVEPredPattern::all)
266 SVAllPTrues.
insert(IntrI);
267 if (PTruePattern == AArch64SVEPredPattern::pow2)
268 SVPow2PTrues.
insert(IntrI);
271 Changed |= coalescePTrueIntrinsicCalls(BB, SVAllPTrues);
272 Changed |= coalescePTrueIntrinsicCalls(BB, SVPow2PTrues);
281bool SVEIntrinsicOpts::optimizePredicateStore(
Instruction *
I) {
282 auto *
F =
I->getFunction();
283 auto Attr =
F->getFnAttribute(Attribute::VScaleRange);
287 unsigned MinVScale = Attr.getVScaleRangeMin();
288 std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax();
290 if (!MaxVScale || MinVScale != MaxVScale)
295 auto *FixedPredType =
299 auto *
Store = dyn_cast<StoreInst>(
I);
300 if (!Store || !
Store->isSimple())
304 if (
Store->getOperand(0)->getType() != FixedPredType)
308 auto *IntrI = dyn_cast<IntrinsicInst>(
Store->getOperand(0));
309 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_extract)
313 if (!cast<ConstantInt>(IntrI->getOperand(1))->isZero())
317 auto *BitCast = dyn_cast<BitCastInst>(IntrI->getOperand(0));
322 if (BitCast->getOperand(0)->getType() != PredType)
326 Builder.SetInsertPoint(
I);
328 Builder.CreateStore(BitCast->getOperand(0),
Store->getPointerOperand());
330 Store->eraseFromParent();
331 if (IntrI->getNumUses() == 0)
332 IntrI->eraseFromParent();
333 if (BitCast->getNumUses() == 0)
334 BitCast->eraseFromParent();
341bool SVEIntrinsicOpts::optimizePredicateLoad(
Instruction *
I) {
342 auto *
F =
I->getFunction();
343 auto Attr =
F->getFnAttribute(Attribute::VScaleRange);
347 unsigned MinVScale = Attr.getVScaleRangeMin();
348 std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax();
350 if (!MaxVScale || MinVScale != MaxVScale)
355 auto *FixedPredType =
359 auto *BitCast = dyn_cast<BitCastInst>(
I);
360 if (!BitCast || BitCast->getType() != PredType)
364 auto *IntrI = dyn_cast<IntrinsicInst>(BitCast->getOperand(0));
365 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_insert)
369 if (!isa<UndefValue>(IntrI->getOperand(0)) ||
370 !cast<ConstantInt>(IntrI->getOperand(2))->isZero())
374 auto *
Load = dyn_cast<LoadInst>(IntrI->getOperand(1));
375 if (!Load || !
Load->isSimple())
379 if (
Load->getType() != FixedPredType)
383 Builder.SetInsertPoint(Load);
385 auto *LoadPred = Builder.CreateLoad(PredType,
Load->getPointerOperand());
387 BitCast->replaceAllUsesWith(LoadPred);
388 BitCast->eraseFromParent();
389 if (IntrI->getNumUses() == 0)
390 IntrI->eraseFromParent();
391 if (
Load->getNumUses() == 0)
392 Load->eraseFromParent();
397bool SVEIntrinsicOpts::optimizeInstructions(
399 bool Changed =
false;
401 for (
auto *
F : Functions) {
402 DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>(*F).getDomTree();
408 for (
auto *BB : RPOT) {
410 switch (
I.getOpcode()) {
411 case Instruction::Store:
412 Changed |= optimizePredicateStore(&
I);
414 case Instruction::BitCast:
415 Changed |= optimizePredicateLoad(&
I);
425bool SVEIntrinsicOpts::optimizeFunctions(
427 bool Changed =
false;
429 Changed |= optimizePTrueIntrinsicCalls(Functions);
430 Changed |= optimizeInstructions(Functions);
435bool SVEIntrinsicOpts::runOnModule(
Module &M) {
436 bool Changed =
false;
442 for (
auto &
F :
M.getFunctionList()) {
443 if (!
F.isDeclaration())
446 switch (
F.getIntrinsicID()) {
447 case Intrinsic::vector_extract:
448 case Intrinsic::vector_insert:
449 case Intrinsic::aarch64_sve_ptrue:
450 for (
User *U :
F.users())
458 if (!Functions.
empty())
459 Changed |= optimizeFunctions(Functions);
This file contains the declarations for the subclasses of Constant, which represent the different fla...
static Function * getFunction(Constant *C)
Module.h This file contains the declarations for the Module class.
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
static bool isPTruePromoted(IntrinsicInst *PTrue)
Checks if a ptrue intrinsic call is promoted.
This file implements a set that has insertion order iteration characteristics.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesCFG()
This function should be called by the pass, iff they do not:
LLVM Basic Block Representation.
const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
LLVMContext & getContext() const
Get the context in which this basic block lives.
Legacy analysis pass which computes a DominatorTree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
A wrapper class for inspecting calls to intrinsic functions.
This is an important class for using LLVM in a threaded context.
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
virtual bool runOnModule(Module &M)=0
runOnModule - Virtual method overriden by subclasses to process the module being operated on.
A Module instance is used to store all the information related to an LLVM module.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
static ScalableVectorType * get(Type *ElementType, unsigned MinNumElts)
bool remove(const value_type &X)
Remove an item from the set vector.
bool remove_if(UnaryPredicate P)
Remove items from the set vector based on a predicate function.
size_type size() const
Determine the number of elements in the SetVector.
bool empty() const
Determine if the SetVector is empty or not.
bool insert(const value_type &X)
Insert a new element into the SetVector.
A SetVector that performs no allocations if smaller than a certain size.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
static IntegerType * getInt1Ty(LLVMContext &C)
static IntegerType * getInt8Ty(LLVMContext &C)
Type * getType() const
All values are typed, get the type of this value.
iterator_range< user_iterator > users()
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
bool match(Val *V, const Pattern &P)
This is an optimization pass for GlobalISel generic memory operations.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
ModulePass * createSVEIntrinsicOptsPass()
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
void initializeSVEIntrinsicOptsPass(PassRegistry &)