42#define DEBUG_TYPE "scalarize-masked-mem-intrin"
46class ScalarizeMaskedMemIntrinLegacyPass :
public FunctionPass {
58 return "Scalarize Masked Memory Intrinsics";
76char ScalarizeMaskedMemIntrinLegacyPass::ID = 0;
79 "Scalarize unsupported masked memory intrinsics",
false,
88 return new ScalarizeMaskedMemIntrinLegacyPass();
96 unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements();
97 for (
unsigned i = 0; i != NumElts; ++i) {
98 Constant *CElt =
C->getAggregateElement(i);
99 if (!CElt || !isa<ConstantInt>(CElt))
108 return DL.isBigEndian() ? VectorWidth - 1 -
Idx :
Idx;
150 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
153 Type *EltTy = VecType->getElementType();
163 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
171 const Align AdjustedAlignVal =
173 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
176 Value *VResult = Src0;
179 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
180 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
194 if (VectorWidth != 1) {
196 SclrMask = Builder.
CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
199 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
208 if (VectorWidth != 1) {
212 Builder.
getIntN(VectorWidth, 0));
228 CondBlock->
setName(
"cond.load");
239 IfBlock = NewIfBlock;
244 Phi->addIncoming(NewVResult, CondBlock);
245 Phi->addIncoming(VResult, PrevIfBlock);
288 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
289 auto *VecType = cast<VectorType>(Src->getType());
291 Type *EltTy = VecType->getElementType();
299 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
306 const Align AdjustedAlignVal =
308 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
311 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
312 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
325 if (VectorWidth != 1) {
327 SclrMask = Builder.
CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
330 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
338 if (VectorWidth != 1) {
342 Builder.
getIntN(VectorWidth, 0));
358 CondBlock->
setName(
"cond.store");
412 auto *VecType = cast<FixedVectorType>(CI->
getType());
413 Type *EltTy = VecType->getElementType();
419 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
424 Value *VResult = Src0;
425 unsigned VectorWidth = VecType->getNumElements();
429 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
430 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
446 if (VectorWidth != 1) {
448 SclrMask = Builder.
CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
451 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
460 if (VectorWidth != 1) {
464 Builder.
getIntN(VectorWidth, 0));
480 CondBlock->
setName(
"cond.load");
493 IfBlock = NewIfBlock;
498 Phi->addIncoming(NewVResult, CondBlock);
499 Phi->addIncoming(VResult, PrevIfBlock);
542 auto *SrcFVTy = cast<FixedVectorType>(Src->getType());
545 isa<VectorType>(Ptrs->
getType()) &&
546 isa<PointerType>(cast<VectorType>(Ptrs->
getType())->getElementType()) &&
547 "Vector of pointers is expected in masked scatter intrinsic");
554 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
555 unsigned VectorWidth = SrcFVTy->getNumElements();
559 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
560 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
574 if (VectorWidth != 1) {
576 SclrMask = Builder.
CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
579 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
587 if (VectorWidth != 1) {
591 Builder.
getIntN(VectorWidth, 0));
607 CondBlock->
setName(
"cond.store");
632 auto *VecType = cast<FixedVectorType>(CI->
getType());
634 Type *EltTy = VecType->getElementType();
643 unsigned VectorWidth = VecType->getNumElements();
646 Value *VResult = PassThru;
649 const Align AdjustedAlignment =
656 unsigned MemIndex = 0;
659 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
661 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue()) {
663 ShuffleMask[
Idx] =
Idx + VectorWidth;
684 if (VectorWidth != 1) {
686 SclrMask = Builder.
CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
689 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
698 if (VectorWidth != 1) {
702 Builder.
getIntN(VectorWidth, 0));
718 CondBlock->
setName(
"cond.load");
726 if ((
Idx + 1) != VectorWidth)
733 IfBlock = NewIfBlock;
743 if ((
Idx + 1) != VectorWidth) {
765 auto *VecType = cast<FixedVectorType>(Src->getType());
774 Type *EltTy = VecType->getElementType();
777 const Align AdjustedAlignment =
780 unsigned VectorWidth = VecType->getNumElements();
784 unsigned MemIndex = 0;
785 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
786 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
801 if (VectorWidth != 1) {
803 SclrMask = Builder.
CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
806 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
813 if (VectorWidth != 1) {
817 Builder.
getIntN(VectorWidth, 0));
833 CondBlock->
setName(
"cond.store");
841 if ((
Idx + 1) != VectorWidth)
848 IfBlock = NewIfBlock;
853 if ((
Idx + 1) != VectorWidth) {
875 auto *AddrType = cast<FixedVectorType>(Ptrs->
getType());
885 unsigned VectorWidth = AddrType->getNumElements();
889 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
890 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
901 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
910 CondBlock->
setName(
"cond.histogram.update");
930 std::optional<DomTreeUpdater> DTU;
932 DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy);
934 bool EverMadeChange =
false;
935 bool MadeChange =
true;
936 auto &
DL =
F.getDataLayout();
940 bool ModifiedDTOnIteration =
false;
942 DTU ? &*DTU :
nullptr);
945 if (ModifiedDTOnIteration)
949 EverMadeChange |= MadeChange;
951 return EverMadeChange;
954bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(
Function &
F) {
955 auto &
TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F);
957 if (
auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>())
958 DT = &DTWP->getDomTree();
977 bool MadeChange =
false;
980 while (CurInstIterator != BB.
end()) {
981 if (
CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
996 if (isa<ScalableVectorType>(
II->getType()) ||
998 [](
Value *V) { return isa<ScalableVectorType>(V->getType()); }))
1001 switch (
II->getIntrinsicID()) {
1004 case Intrinsic::experimental_vector_histogram_add:
1010 case Intrinsic::masked_load:
1018 case Intrinsic::masked_store:
1025 case Intrinsic::masked_gather: {
1027 cast<ConstantInt>(CI->
getArgOperand(1))->getMaybeAlignValue();
1029 Align Alignment =
DL.getValueOrABITypeAlignment(MA,
1037 case Intrinsic::masked_scatter: {
1039 cast<ConstantInt>(CI->
getArgOperand(2))->getMaybeAlignValue();
1041 Align Alignment =
DL.getValueOrABITypeAlignment(MA,
1050 case Intrinsic::masked_expandload:
1057 case Intrinsic::masked_compressstore:
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
static bool runImpl(Function &F, const TargetLowering &TLI)
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static void scalarizeMaskedVectorHistogram(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth, unsigned Idx)
static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, const TargetTransformInfo &TTI, const DataLayout &DL, DomTreeUpdater *DTU)
static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, const TargetTransformInfo &TTI, const DataLayout &DL, DomTreeUpdater *DTU)
static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static bool runImpl(Function &F, const TargetTransformInfo &TTI, DominatorTree *DT)
static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static bool isConstantIntVector(Value *Mask)
Scalarize unsupported masked memory intrinsics
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
A container for analyses that lazily runs them and caches their results.
PassT::Result * getCachedResult(IRUnitT &IR) const
Get the cached result of an analysis pass for a given IR unit.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
AttributeSet getParamAttrs(unsigned ArgNo) const
The attributes for the argument or parameter at the given index are returned.
MaybeAlign getAlignment() const
LLVM Basic Block Representation.
iterator begin()
Instruction iterator methods.
InstListType::iterator iterator
Instruction iterators...
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Value * getArgOperand(unsigned i) const
AttributeList getAttributes() const
Return the parameter attributes for this call.
This class represents a function call, abstracting a target machine's calling convention.
This is an important base class in LLVM.
A parsed version of the target data layout string in and methods for querying it.
Analysis pass which computes a DominatorTree.
Legacy analysis pass which computes a DominatorTree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
Value * CreateInsertElement(Type *VecTy, Value *NewElt, Value *Idx, const Twine &Name="")
Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")
IntegerType * getIntNTy(unsigned N)
Fetch the type representing an N-bit integer.
LoadInst * CreateAlignedLoad(Type *Ty, Value *Ptr, MaybeAlign Align, const char *Name)
Value * CreateConstInBoundsGEP1_32(Type *Ty, Value *Ptr, unsigned Idx0, const Twine &Name="")
void SetCurrentDebugLocation(DebugLoc L)
Set location information used by debugging information.
Value * CreateICmpNE(Value *LHS, Value *RHS, const Twine &Name="")
PHINode * CreatePHI(Type *Ty, unsigned NumReservedValues, const Twine &Name="")
Value * CreateBitCast(Value *V, Type *DestTy, const Twine &Name="")
ConstantInt * getIntN(unsigned N, uint64_t C)
Get a constant N-bit value, zero extended or truncated from a 64-bit value.
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
Value * CreateShuffleVector(Value *V1, Value *V2, Value *Mask, const Twine &Name="")
Value * CreateAnd(Value *LHS, Value *RHS, const Twine &Name="")
StoreInst * CreateStore(Value *Val, Value *Ptr, bool isVolatile=false)
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
StoreInst * CreateAlignedStore(Value *Val, Value *Ptr, MaybeAlign Align, bool isVolatile=false)
ConstantInt * getInt(const APInt &AI)
Get a constant integer value.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
BasicBlock * getSuccessor(unsigned Idx) const LLVM_READONLY
Return the specified successor. This instruction must be a terminator.
A wrapper class for inspecting calls to intrinsic functions.
An instruction for reading from memory.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
void preserve()
Mark an analysis as preserved.
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
Analysis pass providing the TargetTransformInfo.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
The instances of the Type class are immutable: once they are created, they are never changed.
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
bool isVoidTy() const
Return true if this is 'void'.
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
void setName(const Twine &Name)
Change the name of the value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
LLVMContext & getContext() const
All values hold a context through their type.
const ParentTy * getParent() const
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ C
The default llvm calling convention, compatible with C.
This is an optimization pass for GlobalISel generic memory operations.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
FunctionPass * createScalarizeMaskedMemIntrinLegacyPass()
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
void initializeScalarizeMaskedMemIntrinLegacyPassPass(PassRegistry &)
constexpr int PoisonMaskElem
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Instruction * SplitBlockAndInsertIfThen(Value *Cond, BasicBlock::iterator SplitBefore, bool Unreachable, MDNode *BranchWeights=nullptr, DomTreeUpdater *DTU=nullptr, LoopInfo *LI=nullptr, BasicBlock *ThenBlock=nullptr)
Split the containing block at the specified instruction - everything before SplitBefore stays in the ...
This struct is a compact representation of a valid (non-zero power of two) alignment.
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Align valueOrOne() const
For convenience, returns a valid alignment or 1 if undefined.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)