43#define DEBUG_TYPE "scalarize-masked-mem-intrin"
47class ScalarizeMaskedMemIntrinLegacyPass :
public FunctionPass {
59 return "Scalarize Masked Memory Intrinsics";
78char ScalarizeMaskedMemIntrinLegacyPass::ID = 0;
81 "Scalarize unsupported masked memory intrinsics",
false,
90 return new ScalarizeMaskedMemIntrinLegacyPass();
98 unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements();
99 for (
unsigned i = 0; i != NumElts; ++i) {
100 Constant *CElt =
C->getAggregateElement(i);
101 if (!CElt || !isa<ConstantInt>(CElt))
110 return DL.isBigEndian() ? VectorWidth - 1 -
Idx :
Idx;
153 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
156 Type *EltTy = VecType->getElementType();
166 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
176 const Align AdjustedAlignVal =
178 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
181 Value *VResult = Src0;
184 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
185 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
201 Mask->getName() +
".first");
207 CondBlock->
setName(
"cond.load");
211 Load->copyMetadata(*CI);
216 Phi->addIncoming(Load, CondBlock);
217 Phi->addIncoming(Src0, IfBlock);
228 Value *SclrMask =
nullptr;
229 if (VectorWidth != 1 && !HasBranchDivergence) {
231 SclrMask = Builder.
CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
234 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
245 if (SclrMask !=
nullptr) {
249 Builder.
getIntN(VectorWidth, 0));
265 CondBlock->
setName(
"cond.load");
276 IfBlock = NewIfBlock;
281 Phi->addIncoming(NewVResult, CondBlock);
282 Phi->addIncoming(VResult, PrevIfBlock);
326 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
327 auto *VecType = cast<VectorType>(Src->getType());
329 Type *EltTy = VecType->getElementType();
337 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
340 Store->copyMetadata(*CI);
346 const Align AdjustedAlignVal =
348 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
351 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
352 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
367 Mask->getName() +
".first");
372 CondBlock->
setName(
"cond.store");
377 Store->copyMetadata(*CI);
387 Value *SclrMask =
nullptr;
388 if (VectorWidth != 1 && !HasBranchDivergence) {
390 SclrMask = Builder.
CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
393 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
404 if (SclrMask !=
nullptr) {
408 Builder.
getIntN(VectorWidth, 0));
424 CondBlock->
setName(
"cond.store");
472 bool HasBranchDivergence,
CallInst *CI,
479 auto *VecType = cast<FixedVectorType>(CI->
getType());
480 Type *EltTy = VecType->getElementType();
486 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
491 Value *VResult = Src0;
492 unsigned VectorWidth = VecType->getNumElements();
496 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
497 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
513 Value *SclrMask =
nullptr;
514 if (VectorWidth != 1 && !HasBranchDivergence) {
516 SclrMask = Builder.
CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
519 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
531 if (SclrMask !=
nullptr) {
535 Builder.
getIntN(VectorWidth, 0));
551 CondBlock->
setName(
"cond.load");
564 IfBlock = NewIfBlock;
569 Phi->addIncoming(NewVResult, CondBlock);
570 Phi->addIncoming(VResult, PrevIfBlock);
607 bool HasBranchDivergence,
CallInst *CI,
614 auto *SrcFVTy = cast<FixedVectorType>(Src->getType());
617 isa<VectorType>(Ptrs->
getType()) &&
618 isa<PointerType>(cast<VectorType>(Ptrs->
getType())->getElementType()) &&
619 "Vector of pointers is expected in masked scatter intrinsic");
626 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
627 unsigned VectorWidth = SrcFVTy->getNumElements();
631 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
632 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
645 Value *SclrMask =
nullptr;
646 if (VectorWidth != 1 && !HasBranchDivergence) {
648 SclrMask = Builder.
CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
651 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
662 if (SclrMask !=
nullptr) {
666 Builder.
getIntN(VectorWidth, 0));
682 CondBlock->
setName(
"cond.store");
701 bool HasBranchDivergence,
CallInst *CI,
708 auto *VecType = cast<FixedVectorType>(CI->
getType());
710 Type *EltTy = VecType->getElementType();
719 unsigned VectorWidth = VecType->getNumElements();
722 Value *VResult = PassThru;
725 const Align AdjustedAlignment =
732 unsigned MemIndex = 0;
735 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
737 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue()) {
739 ShuffleMask[
Idx] =
Idx + VectorWidth;
760 Value *SclrMask =
nullptr;
761 if (VectorWidth != 1 && !HasBranchDivergence) {
763 SclrMask = Builder.
CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
766 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
778 if (SclrMask !=
nullptr) {
782 Builder.
getIntN(VectorWidth, 0));
798 CondBlock->
setName(
"cond.load");
806 if ((
Idx + 1) != VectorWidth)
813 IfBlock = NewIfBlock;
823 if ((
Idx + 1) != VectorWidth) {
838 bool HasBranchDivergence,
CallInst *CI,
846 auto *VecType = cast<FixedVectorType>(Src->getType());
855 Type *EltTy = VecType->getElementType();
858 const Align AdjustedAlignment =
861 unsigned VectorWidth = VecType->getNumElements();
865 unsigned MemIndex = 0;
866 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
867 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
882 Value *SclrMask =
nullptr;
883 if (VectorWidth != 1 && !HasBranchDivergence) {
885 SclrMask = Builder.
CreateBitCast(Mask, SclrMaskTy,
"scalar_mask");
888 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
898 if (SclrMask !=
nullptr) {
902 Builder.
getIntN(VectorWidth, 0));
918 CondBlock->
setName(
"cond.store");
926 if ((
Idx + 1) != VectorWidth)
933 IfBlock = NewIfBlock;
938 if ((
Idx + 1) != VectorWidth) {
960 auto *AddrType = cast<FixedVectorType>(Ptrs->
getType());
970 unsigned VectorWidth = AddrType->getNumElements();
974 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
975 if (cast<Constant>(Mask)->getAggregateElement(
Idx)->isNullValue())
986 for (
unsigned Idx = 0;
Idx < VectorWidth; ++
Idx) {
995 CondBlock->
setName(
"cond.histogram.update");
1015 std::optional<DomTreeUpdater> DTU;
1017 DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy);
1019 bool EverMadeChange =
false;
1020 bool MadeChange =
true;
1021 auto &
DL =
F.getDataLayout();
1023 while (MadeChange) {
1026 bool ModifiedDTOnIteration =
false;
1028 HasBranchDivergence, DTU ? &*DTU :
nullptr);
1031 if (ModifiedDTOnIteration)
1035 EverMadeChange |= MadeChange;
1037 return EverMadeChange;
1040bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(
Function &
F) {
1041 auto &
TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F);
1043 if (
auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>())
1044 DT = &DTWP->getDomTree();
1063 bool MadeChange =
false;
1066 while (CurInstIterator != BB.
end()) {
1067 if (
CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
1084 if (isa<ScalableVectorType>(
II->getType()) ||
1086 [](
Value *V) { return isa<ScalableVectorType>(V->getType()); }))
1088 switch (
II->getIntrinsicID()) {
1091 case Intrinsic::experimental_vector_histogram_add:
1097 case Intrinsic::masked_load:
1105 case Intrinsic::masked_store:
1112 case Intrinsic::masked_gather: {
1114 cast<ConstantInt>(CI->
getArgOperand(1))->getMaybeAlignValue();
1116 Align Alignment =
DL.getValueOrABITypeAlignment(MA,
1124 case Intrinsic::masked_scatter: {
1126 cast<ConstantInt>(CI->
getArgOperand(2))->getMaybeAlignValue();
1128 Align Alignment =
DL.getValueOrABITypeAlignment(MA,
1137 case Intrinsic::masked_expandload:
1144 case Intrinsic::masked_compressstore:
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
static bool runImpl(Function &F, const TargetLowering &TLI)
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static void scalarizeMaskedExpandLoad(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static void scalarizeMaskedVectorHistogram(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, const TargetTransformInfo &TTI, const DataLayout &DL, bool HasBranchDivergence, DomTreeUpdater *DTU)
static void scalarizeMaskedScatter(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth, unsigned Idx)
static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, const TargetTransformInfo &TTI, const DataLayout &DL, bool HasBranchDivergence, DomTreeUpdater *DTU)
static void scalarizeMaskedStore(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static void scalarizeMaskedCompressStore(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static void scalarizeMaskedGather(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static bool runImpl(Function &F, const TargetTransformInfo &TTI, DominatorTree *DT)
static bool isConstantIntVector(Value *Mask)
Scalarize unsupported masked memory intrinsics
static void scalarizeMaskedLoad(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
A container for analyses that lazily runs them and caches their results.
PassT::Result * getCachedResult(IRUnitT &IR) const
Get the cached result of an analysis pass for a given IR unit.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
AttributeSet getParamAttrs(unsigned ArgNo) const
The attributes for the argument or parameter at the given index are returned.
MaybeAlign getAlignment() const
LLVM Basic Block Representation.
iterator begin()
Instruction iterator methods.
InstListType::iterator iterator
Instruction iterators...
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Value * getArgOperand(unsigned i) const
AttributeList getAttributes() const
Return the attributes for this call.
This class represents a function call, abstracting a target machine's calling convention.
This is an important base class in LLVM.
A parsed version of the target data layout string in and methods for querying it.
Analysis pass which computes a DominatorTree.
Legacy analysis pass which computes a DominatorTree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
Value * CreateInsertElement(Type *VecTy, Value *NewElt, Value *Idx, const Twine &Name="")
Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")
IntegerType * getIntNTy(unsigned N)
Fetch the type representing an N-bit integer.
LoadInst * CreateAlignedLoad(Type *Ty, Value *Ptr, MaybeAlign Align, const char *Name)
Value * CreateConstInBoundsGEP1_32(Type *Ty, Value *Ptr, unsigned Idx0, const Twine &Name="")
void SetCurrentDebugLocation(DebugLoc L)
Set location information used by debugging information.
Value * CreateICmpNE(Value *LHS, Value *RHS, const Twine &Name="")
PHINode * CreatePHI(Type *Ty, unsigned NumReservedValues, const Twine &Name="")
Value * CreateBitCast(Value *V, Type *DestTy, const Twine &Name="")
ConstantInt * getIntN(unsigned N, uint64_t C)
Get a constant N-bit value, zero extended or truncated from a 64-bit value.
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
Value * CreateShuffleVector(Value *V1, Value *V2, Value *Mask, const Twine &Name="")
Value * CreateAnd(Value *LHS, Value *RHS, const Twine &Name="")
StoreInst * CreateStore(Value *Val, Value *Ptr, bool isVolatile=false)
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
StoreInst * CreateAlignedStore(Value *Val, Value *Ptr, MaybeAlign Align, bool isVolatile=false)
ConstantInt * getInt(const APInt &AI)
Get a constant integer value.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
BasicBlock * getSuccessor(unsigned Idx) const LLVM_READONLY
Return the specified successor. This instruction must be a terminator.
void copyMetadata(const Instruction &SrcInst, ArrayRef< unsigned > WL=ArrayRef< unsigned >())
Copy metadata from SrcInst to this instruction.
A wrapper class for inspecting calls to intrinsic functions.
An instruction for reading from memory.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
void preserve()
Mark an analysis as preserved.
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
StringRef - Represent a constant reference to a string, i.e.
Analysis pass providing the TargetTransformInfo.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
The instances of the Type class are immutable: once they are created, they are never changed.
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
bool isVoidTy() const
Return true if this is 'void'.
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
void setName(const Twine &Name)
Change the name of the value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
LLVMContext & getContext() const
All values hold a context through their type.
StringRef getName() const
Return a constant reference to the value's name.
void takeName(Value *V)
Transfer the name from V to this value.
const ParentTy * getParent() const
@ C
The default llvm calling convention, compatible with C.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
This is an optimization pass for GlobalISel generic memory operations.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
FunctionPass * createScalarizeMaskedMemIntrinLegacyPass()
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
bool isSplatValue(const Value *V, int Index=-1, unsigned Depth=0)
Return true if each element of the vector value V is poisoned or equal to every other non-poisoned el...
void initializeScalarizeMaskedMemIntrinLegacyPassPass(PassRegistry &)
constexpr int PoisonMaskElem
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Instruction * SplitBlockAndInsertIfThen(Value *Cond, BasicBlock::iterator SplitBefore, bool Unreachable, MDNode *BranchWeights=nullptr, DomTreeUpdater *DTU=nullptr, LoopInfo *LI=nullptr, BasicBlock *ThenBlock=nullptr)
Split the containing block at the specified instruction - everything before SplitBefore stays in the ...
This struct is a compact representation of a valid (non-zero power of two) alignment.
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Align valueOrOne() const
For convenience, returns a valid alignment or 1 if undefined.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)