43#define DEBUG_TYPE "scalarize-masked-mem-intrin"
47class ScalarizeMaskedMemIntrinLegacyPass :
public FunctionPass {
59 return "Scalarize Masked Memory Intrinsics";
77char ScalarizeMaskedMemIntrinLegacyPass::ID = 0;
80 "Scalarize unsupported masked memory intrinsics",
false,
89 return new ScalarizeMaskedMemIntrinLegacyPass();
97 unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements();
98 for (
unsigned i = 0; i != NumElts; ++i) {
99 Constant *CElt =
C->getAggregateElement(i);
100 if (!CElt || !isa<ConstantInt>(CElt))
109 return DL.isBigEndian() ? VectorWidth - 1 -
Idx :
Idx;
151 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
154 Type *EltTy = VecType->getElementType();
164 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
174 const Align AdjustedAlignVal =
176 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
179 Value *VResult = Src0;
182 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
183 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
199 Mask->getName() +
".first");
205 CondBlock->
setName(
"cond.load");
209 Load->copyMetadata(*CI);
214 Phi->addIncoming(Load, CondBlock);
215 Phi->addIncoming(Src0, IfBlock);
228 if (VectorWidth != 1) {
230 SclrMask = Builder.
CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
233 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
242 if (VectorWidth != 1) {
246 Builder.
getIntN(VectorWidth, 0));
262 CondBlock->
setName(
"cond.load");
273 IfBlock = NewIfBlock;
278 Phi->addIncoming(NewVResult, CondBlock);
279 Phi->addIncoming(VResult, PrevIfBlock);
322 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
323 auto *VecType = cast<VectorType>(Src->getType());
325 Type *EltTy = VecType->getElementType();
333 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
336 Store->copyMetadata(*CI);
342 const Align AdjustedAlignVal =
344 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
347 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
348 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
363 Mask->getName() +
".first");
368 CondBlock->
setName(
"cond.store");
373 Store->copyMetadata(*CI);
384 if (VectorWidth != 1) {
386 SclrMask = Builder.
CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
389 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
397 if (VectorWidth != 1) {
401 Builder.
getIntN(VectorWidth, 0));
417 CondBlock->
setName(
"cond.store");
471 auto *VecType = cast<FixedVectorType>(CI->
getType());
472 Type *EltTy = VecType->getElementType();
478 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
483 Value *VResult = Src0;
484 unsigned VectorWidth = VecType->getNumElements();
488 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
489 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
505 if (VectorWidth != 1) {
507 SclrMask = Builder.
CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
510 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
519 if (VectorWidth != 1) {
523 Builder.
getIntN(VectorWidth, 0));
539 CondBlock->
setName(
"cond.load");
552 IfBlock = NewIfBlock;
557 Phi->addIncoming(NewVResult, CondBlock);
558 Phi->addIncoming(VResult, PrevIfBlock);
601 auto *SrcFVTy = cast<FixedVectorType>(Src->getType());
604 isa<VectorType>(Ptrs->
getType()) &&
605 isa<PointerType>(cast<VectorType>(Ptrs->
getType())->getElementType()) &&
606 "Vector of pointers is expected in masked scatter intrinsic");
613 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
614 unsigned VectorWidth = SrcFVTy->getNumElements();
618 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
619 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
633 if (VectorWidth != 1) {
635 SclrMask = Builder.
CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
638 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
646 if (VectorWidth != 1) {
650 Builder.
getIntN(VectorWidth, 0));
666 CondBlock->
setName(
"cond.store");
691 auto *VecType = cast<FixedVectorType>(CI->
getType());
693 Type *EltTy = VecType->getElementType();
702 unsigned VectorWidth = VecType->getNumElements();
705 Value *VResult = PassThru;
708 const Align AdjustedAlignment =
715 unsigned MemIndex = 0;
718 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
720 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue()) {
722 ShuffleMask[
Idx] =
Idx + VectorWidth;
743 if (VectorWidth != 1) {
745 SclrMask = Builder.
CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
748 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
757 if (VectorWidth != 1) {
761 Builder.
getIntN(VectorWidth, 0));
777 CondBlock->
setName(
"cond.load");
785 if ((
Idx + 1) != VectorWidth)
792 IfBlock = NewIfBlock;
802 if ((
Idx + 1) != VectorWidth) {
824 auto *VecType = cast<FixedVectorType>(Src->getType());
833 Type *EltTy = VecType->getElementType();
836 const Align AdjustedAlignment =
839 unsigned VectorWidth = VecType->getNumElements();
843 unsigned MemIndex = 0;
844 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
845 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
860 if (VectorWidth != 1) {
862 SclrMask = Builder.
CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
865 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
872 if (VectorWidth != 1) {
876 Builder.
getIntN(VectorWidth, 0));
892 CondBlock->
setName(
"cond.store");
900 if ((
Idx + 1) != VectorWidth)
907 IfBlock = NewIfBlock;
912 if ((
Idx + 1) != VectorWidth) {
934 auto *AddrType = cast<FixedVectorType>(Ptrs->
getType());
944 unsigned VectorWidth = AddrType->getNumElements();
948 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
949 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
960 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
969 CondBlock->
setName(
"cond.histogram.update");
989 std::optional<DomTreeUpdater> DTU;
991 DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy);
993 bool EverMadeChange =
false;
994 bool MadeChange =
true;
995 auto &
DL =
F.getDataLayout();
999 bool ModifiedDTOnIteration =
false;
1001 DTU ? &*DTU :
nullptr);
1004 if (ModifiedDTOnIteration)
1008 EverMadeChange |= MadeChange;
1010 return EverMadeChange;
1013bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(
Function &
F) {
1014 auto &
TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F);
1016 if (
auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>())
1017 DT = &DTWP->getDomTree();
1036 bool MadeChange =
false;
1039 while (CurInstIterator != BB.
end()) {
1040 if (
CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
1055 if (isa<ScalableVectorType>(
II->getType()) ||
1057 [](
Value *V) { return isa<ScalableVectorType>(V->getType()); }))
1059 switch (
II->getIntrinsicID()) {
1062 case Intrinsic::experimental_vector_histogram_add:
1068 case Intrinsic::masked_load:
1076 case Intrinsic::masked_store:
1083 case Intrinsic::masked_gather: {
1085 cast<ConstantInt>(CI->
getArgOperand(1))->getMaybeAlignValue();
1087 Align Alignment =
DL.getValueOrABITypeAlignment(MA,
1095 case Intrinsic::masked_scatter: {
1097 cast<ConstantInt>(CI->
getArgOperand(2))->getMaybeAlignValue();
1099 Align Alignment =
DL.getValueOrABITypeAlignment(MA,
1108 case Intrinsic::masked_expandload:
1115 case Intrinsic::masked_compressstore:
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
static bool runImpl(Function &F, const TargetLowering &TLI)
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static void scalarizeMaskedVectorHistogram(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth, unsigned Idx)
static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, const TargetTransformInfo &TTI, const DataLayout &DL, DomTreeUpdater *DTU)
static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, const TargetTransformInfo &TTI, const DataLayout &DL, DomTreeUpdater *DTU)
static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static bool runImpl(Function &F, const TargetTransformInfo &TTI, DominatorTree *DT)
static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static bool isConstantIntVector(Value *Mask)
Scalarize unsupported masked memory intrinsics
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
A container for analyses that lazily runs them and caches their results.
PassT::Result * getCachedResult(IRUnitT &IR) const
Get the cached result of an analysis pass for a given IR unit.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
AttributeSet getParamAttrs(unsigned ArgNo) const
The attributes for the argument or parameter at the given index are returned.
MaybeAlign getAlignment() const
LLVM Basic Block Representation.
iterator begin()
Instruction iterator methods.
InstListType::iterator iterator
Instruction iterators...
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Value * getArgOperand(unsigned i) const
AttributeList getAttributes() const
Return the parameter attributes for this call.
This class represents a function call, abstracting a target machine's calling convention.
This is an important base class in LLVM.
A parsed version of the target data layout string in and methods for querying it.
Analysis pass which computes a DominatorTree.
Legacy analysis pass which computes a DominatorTree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
Value * CreateInsertElement(Type *VecTy, Value *NewElt, Value *Idx, const Twine &Name="")
Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")
IntegerType * getIntNTy(unsigned N)
Fetch the type representing an N-bit integer.
LoadInst * CreateAlignedLoad(Type *Ty, Value *Ptr, MaybeAlign Align, const char *Name)
Value * CreateConstInBoundsGEP1_32(Type *Ty, Value *Ptr, unsigned Idx0, const Twine &Name="")
void SetCurrentDebugLocation(DebugLoc L)
Set location information used by debugging information.
Value * CreateICmpNE(Value *LHS, Value *RHS, const Twine &Name="")
PHINode * CreatePHI(Type *Ty, unsigned NumReservedValues, const Twine &Name="")
Value * CreateBitCast(Value *V, Type *DestTy, const Twine &Name="")
ConstantInt * getIntN(unsigned N, uint64_t C)
Get a constant N-bit value, zero extended or truncated from a 64-bit value.
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
Value * CreateShuffleVector(Value *V1, Value *V2, Value *Mask, const Twine &Name="")
Value * CreateAnd(Value *LHS, Value *RHS, const Twine &Name="")
StoreInst * CreateStore(Value *Val, Value *Ptr, bool isVolatile=false)
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
StoreInst * CreateAlignedStore(Value *Val, Value *Ptr, MaybeAlign Align, bool isVolatile=false)
ConstantInt * getInt(const APInt &AI)
Get a constant integer value.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
BasicBlock * getSuccessor(unsigned Idx) const LLVM_READONLY
Return the specified successor. This instruction must be a terminator.
void copyMetadata(const Instruction &SrcInst, ArrayRef< unsigned > WL=ArrayRef< unsigned >())
Copy metadata from SrcInst to this instruction.
A wrapper class for inspecting calls to intrinsic functions.
An instruction for reading from memory.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
void preserve()
Mark an analysis as preserved.
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
StringRef - Represent a constant reference to a string, i.e.
Analysis pass providing the TargetTransformInfo.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
The instances of the Type class are immutable: once they are created, they are never changed.
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
bool isVoidTy() const
Return true if this is 'void'.
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
void setName(const Twine &Name)
Change the name of the value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
LLVMContext & getContext() const
All values hold a context through their type.
StringRef getName() const
Return a constant reference to the value's name.
void takeName(Value *V)
Transfer the name from V to this value.
const ParentTy * getParent() const
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ C
The default llvm calling convention, compatible with C.
This is an optimization pass for GlobalISel generic memory operations.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
FunctionPass * createScalarizeMaskedMemIntrinLegacyPass()
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
bool isSplatValue(const Value *V, int Index=-1, unsigned Depth=0)
Return true if each element of the vector value V is poisoned or equal to every other non-poisoned el...
void initializeScalarizeMaskedMemIntrinLegacyPassPass(PassRegistry &)
constexpr int PoisonMaskElem
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Instruction * SplitBlockAndInsertIfThen(Value *Cond, BasicBlock::iterator SplitBefore, bool Unreachable, MDNode *BranchWeights=nullptr, DomTreeUpdater *DTU=nullptr, LoopInfo *LI=nullptr, BasicBlock *ThenBlock=nullptr)
Split the containing block at the specified instruction - everything before SplitBefore stays in the ...
This struct is a compact representation of a valid (non-zero power of two) alignment.
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Align valueOrOne() const
For convenience, returns a valid alignment or 1 if undefined.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)