29using namespace PatternMatch;
31#define DEBUG_TYPE "riscv-gather-scatter-lowering"
62 return "RISC-V gather/scatter lowering";
68 std::pair<Value *, Value *> determineBaseAndStride(
Instruction *
Ptr,
78char RISCVGatherScatterLowering::ID = 0;
81 "RISC-V gather/scatter lowering pass",
false,
false)
84 return new RISCVGatherScatterLowering();
89 if (!isa<FixedVectorType>(StartC->
getType()))
90 return std::make_pair(
nullptr,
nullptr);
92 unsigned NumElts = cast<FixedVectorType>(StartC->
getType())->getNumElements();
98 return std::make_pair(
nullptr,
nullptr);
99 APInt StrideVal(StartVal->getValue().getBitWidth(), 0);
101 for (
unsigned i = 1; i != NumElts; ++i) {
104 return std::make_pair(
nullptr,
nullptr);
108 StrideVal = LocalStride;
109 else if (StrideVal != LocalStride)
110 return std::make_pair(
nullptr,
nullptr);
115 Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal);
117 return std::make_pair(StartVal, Stride);
123 auto *StartC = dyn_cast<Constant>(Start);
128 if (
match(Start, m_Intrinsic<Intrinsic::stepvector>())) {
129 auto *Ty = Start->getType()->getScalarType();
130 return std::make_pair(ConstantInt::get(Ty, 0), ConstantInt::get(Ty, 1));
135 auto *BO = dyn_cast<BinaryOperator>(Start);
136 if (!BO || (BO->getOpcode() != Instruction::Add &&
137 BO->getOpcode() != Instruction::Or &&
138 BO->getOpcode() != Instruction::Shl &&
139 BO->getOpcode() != Instruction::Mul))
140 return std::make_pair(
nullptr,
nullptr);
142 if (BO->getOpcode() == Instruction::Or &&
143 !cast<PossiblyDisjointInst>(BO)->isDisjoint())
144 return std::make_pair(
nullptr,
nullptr);
147 unsigned OtherIndex = 0;
154 return std::make_pair(
nullptr,
nullptr);
160 return std::make_pair(
nullptr,
nullptr);
166 switch (BO->getOpcode()) {
169 case Instruction::Or:
173 case Instruction::Add:
176 case Instruction::Mul:
180 case Instruction::Shl:
186 return std::make_pair(Start, Stride);
193bool RISCVGatherScatterLowering::matchStridedRecurrence(
Value *Index,
Loop *L,
199 if (
auto *Phi = dyn_cast<PHINode>(Index)) {
202 if (
Phi->getParent() !=
L->getHeader())
209 assert(
Phi->getNumIncomingValues() == 2 &&
"Expected 2 operand phi.");
210 unsigned IncrementingBlock =
Phi->getIncomingValue(0) == Inc ? 0 : 1;
211 assert(
Phi->getIncomingValue(IncrementingBlock) == Inc &&
212 "Expected one operand of phi to be Inc");
215 if (!
L->isLoopInvariant(Step))
226 assert(Stride !=
nullptr);
231 Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->
getName() +
".scalar",
233 BasePtr->addIncoming(Start,
Phi->getIncomingBlock(1 - IncrementingBlock));
234 BasePtr->addIncoming(Inc,
Phi->getIncomingBlock(IncrementingBlock));
237 MaybeDeadPHIs.push_back(Phi);
242 auto *BO = dyn_cast<BinaryOperator>(Index);
246 switch (BO->getOpcode()) {
249 case Instruction::Or:
251 if (!cast<PossiblyDisjointInst>(BO)->isDisjoint())
254 case Instruction::Add:
256 case Instruction::Shl:
258 case Instruction::Mul:
264 if (isa<Instruction>(BO->getOperand(0)) &&
265 L->contains(cast<Instruction>(BO->getOperand(0)))) {
266 Index = cast<Instruction>(BO->getOperand(0));
267 OtherOp = BO->getOperand(1);
268 }
else if (isa<Instruction>(BO->getOperand(1)) &&
269 L->contains(cast<Instruction>(BO->getOperand(1))) &&
271 Index = cast<Instruction>(BO->getOperand(1));
272 OtherOp = BO->getOperand(0);
278 if (!
L->isLoopInvariant(OtherOp))
287 if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder))
292 unsigned StartBlock =
BasePtr->getOperand(0) == Inc ? 1 : 0;
298 BasePtr->getIncomingBlock(StartBlock)->getTerminator());
301 switch (BO->getOpcode()) {
304 case Instruction::Add:
305 case Instruction::Or: {
308 Start = Builder.
CreateAdd(Start, SplatOp,
"start");
311 case Instruction::Mul: {
312 Start = Builder.
CreateMul(Start, SplatOp,
"start");
313 Step = Builder.
CreateMul(Step, SplatOp,
"step");
314 Stride = Builder.
CreateMul(Stride, SplatOp,
"stride");
317 case Instruction::Shl: {
318 Start = Builder.
CreateShl(Start, SplatOp,
"start");
319 Step = Builder.
CreateShl(Step, SplatOp,
"step");
320 Stride = Builder.
CreateShl(Stride, SplatOp,
"stride");
326 BasePtr->setIncomingValue(StartBlock, Start);
330std::pair<Value *, Value *>
331RISCVGatherScatterLowering::determineBaseAndStride(
Instruction *
Ptr,
337 return std::make_pair(BasePtr, ConstantInt::get(IntPtrTy, 0));
340 auto *
GEP = dyn_cast<GetElementPtrInst>(
Ptr);
342 return std::make_pair(
nullptr,
nullptr);
344 auto I = StridedAddrs.find(
GEP);
345 if (
I != StridedAddrs.end())
352 if (
auto *BaseInst = dyn_cast<Instruction>(
Base);
353 BaseInst && BaseInst->getType()->isVectorTy()) {
355 auto IsScalar = [](
Value *
Idx) {
return !
Idx->getType()->isVectorTy(); };
357 auto [BaseBase, Stride] = determineBaseAndStride(BaseInst, Builder);
362 Builder.
CreateGEP(
GEP->getSourceElementType(), BaseBase, Indices,
363 GEP->getName() +
"offset",
GEP->isInBounds());
364 return {OffsetBase, Stride};
374 return std::make_pair(
nullptr,
nullptr);
377 std::optional<unsigned> VecOperand;
378 unsigned TypeScale = 0;
382 for (
unsigned i = 1, e =
GEP->getNumOperands(); i != e; ++i, ++GTI) {
383 if (!Ops[i]->
getType()->isVectorTy())
387 return std::make_pair(
nullptr,
nullptr);
393 return std::make_pair(
nullptr,
nullptr);
400 return std::make_pair(
nullptr,
nullptr);
407 Value *VecIndex = Ops[*VecOperand];
408 Type *VecIntPtrTy =
DL->getIntPtrType(
GEP->getType());
409 if (VecIndex->
getType() != VecIntPtrTy) {
410 auto *VecIndexC = dyn_cast<Constant>(VecIndex);
412 return std::make_pair(
nullptr,
nullptr);
427 Ops[*VecOperand] = Start;
428 Type *SourceTy =
GEP->getSourceElementType();
438 Stride = Builder.
CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
440 auto P = std::make_pair(BasePtr, Stride);
441 StridedAddrs[
GEP] =
P;
446 Loop *
L = LI->getLoopFor(
GEP->getParent());
447 if (!L || !
L->getLoopPreheader() || !
L->getLoopLatch())
448 return std::make_pair(
nullptr,
nullptr);
452 if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder))
453 return std::make_pair(
nullptr,
nullptr);
456 unsigned IncrementingBlock = BasePhi->
getOperand(0) == Inc ? 0 : 1;
458 "Expected one operand of phi to be Inc");
463 Ops[*VecOperand] = BasePhi;
464 Type *SourceTy =
GEP->getSourceElementType();
478 Stride = Builder.
CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
480 auto P = std::make_pair(BasePtr, Stride);
481 StridedAddrs[
GEP] =
P;
485bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(
IntrinsicInst *
II) {
489 switch (
II->getIntrinsicID()) {
490 case Intrinsic::masked_gather:
491 DataType = cast<VectorType>(
II->getType());
492 Ptr =
II->getArgOperand(0);
493 MA = cast<ConstantInt>(
II->getArgOperand(1))->getMaybeAlignValue();
494 Mask =
II->getArgOperand(2);
496 case Intrinsic::vp_gather:
497 DataType = cast<VectorType>(
II->getType());
498 Ptr =
II->getArgOperand(0);
499 MA =
II->getParamAlign(0).value_or(
500 DL->getABITypeAlign(DataType->getElementType()));
501 Mask =
II->getArgOperand(1);
502 EVL =
II->getArgOperand(2);
504 case Intrinsic::masked_scatter:
505 DataType = cast<VectorType>(
II->getArgOperand(0)->getType());
506 StoreVal =
II->getArgOperand(0);
507 Ptr =
II->getArgOperand(1);
508 MA = cast<ConstantInt>(
II->getArgOperand(2))->getMaybeAlignValue();
509 Mask =
II->getArgOperand(3);
511 case Intrinsic::vp_scatter:
512 DataType = cast<VectorType>(
II->getArgOperand(0)->getType());
513 StoreVal =
II->getArgOperand(0);
514 Ptr =
II->getArgOperand(1);
515 MA =
II->getParamAlign(1).value_or(
516 DL->getABITypeAlign(DataType->getElementType()));
517 Mask =
II->getArgOperand(2);
518 EVL =
II->getArgOperand(3);
525 EVT DataTypeVT = TLI->getValueType(*
DL, DataType);
526 if (!MA || !TLI->isLegalStridedLoadStore(DataTypeVT, *MA))
530 if (!TLI->isTypeLegal(DataTypeVT))
534 auto *PtrI = dyn_cast<Instruction>(
Ptr);
543 std::tie(BasePtr, Stride) = determineBaseAndStride(PtrI, Builder);
546 assert(Stride !=
nullptr);
552 Builder.
getInt32Ty(), cast<VectorType>(DataType)->getElementCount());
558 Intrinsic::experimental_vp_strided_load,
563 if (
II->getIntrinsicID() == Intrinsic::masked_gather)
568 Intrinsic::experimental_vp_strided_store,
573 II->replaceAllUsesWith(Call);
574 II->eraseFromParent();
576 if (PtrI->use_empty())
582bool RISCVGatherScatterLowering::runOnFunction(
Function &
F) {
586 auto &TPC = getAnalysis<TargetPassConfig>();
589 if (!
ST->hasVInstructions() || !
ST->useRVVForFixedLengthVectors())
592 TLI =
ST->getTargetLowering();
593 DL = &
F.getDataLayout();
594 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
596 StridedAddrs.clear();
600 bool Changed =
false;
607 switch (
II->getIntrinsicID()) {
608 case Intrinsic::masked_gather:
609 case Intrinsic::masked_scatter:
610 case Intrinsic::vp_gather:
611 case Intrinsic::vp_scatter:
621 for (
auto *
II : Worklist)
622 Changed |= tryCreateStridedLoadStore(
II);
625 while (!MaybeDeadPHIs.empty()) {
626 if (
auto *Phi = dyn_cast_or_null<PHINode>(MaybeDeadPHIs.pop_back_val()))
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
static std::pair< Value *, Value * > matchStridedStart(Value *Start, IRBuilderBase &Builder)
static std::pair< Value *, Value * > matchStridedConstant(Constant *StartC)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static SymbolRef::Type getType(const Symbol *Sym)
Target-Independent Code Generator Pass Configuration Options pass.
Class for arbitrary precision integers.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesCFG()
This function should be called by the pass, iff they do not:
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
LLVM Basic Block Representation.
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
BinaryOps getOpcode() const
This class represents a function call, abstracting a target machine's calling convention.
This is the shared class of boolean and integer constants.
const APInt & getValue() const
Return the constant as an APInt value reference.
This is an important base class in LLVM.
Constant * getAggregateElement(unsigned Elt) const
For aggregates (struct/array/vector) return the constant that corresponds to the specified element if...
A parsed version of the target data layout string in and methods for querying it.
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
Common base class shared among various IRBuilders.
IntegerType * getInt32Ty()
Fetch the type representing a 32-bit integer.
void SetCurrentDebugLocation(DebugLoc L)
Set location information used by debugging information.
Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())
CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
Value * CreateShl(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreateElementCount(Type *DstType, ElementCount EC)
Create an expression which evaluates to the number of elements in EC at runtime.
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
InstSimplifyFolder - Use InstructionSimplify to fold operations to existing values.
bool isCommutative() const LLVM_READONLY
Return true if the instruction is commutative:
A wrapper class for inspecting calls to intrinsic functions.
This is an important class for using LLVM in a threaded context.
The legacy pass manager's analysis pass to compute loop information.
Represents a single loop in the control flow graph.
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
bool isVectorTy() const
True if this is an instance of VectorType.
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
void setOperand(unsigned i, Value *Val)
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
StringRef getName() const
Return a constant reference to the value's name.
constexpr ScalarTy getFixedValue() const
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
TypeSize getSequentialElementStride(const DataLayout &DL) const
self_iterator getIterator()
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
bool match(Val *V, const Pattern &P)
NodeAddr< PhiNode * > Phi
This is an optimization pass for GlobalISel generic memory operations.
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
Value * getSplatValue(const Value *V)
Get splat value if the input is a splat vector or return nullptr.
FunctionPass * createRISCVGatherScatterLoweringPass()
bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
gep_type_iterator gep_type_begin(const User *GEP)
bool RecursivelyDeleteDeadPHINode(PHINode *PN, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr)
If the specified value is an effectively dead PHI node, due to being a def-use chain of single-use no...
Constant * ConstantFoldCastInstruction(unsigned opcode, Constant *V, Type *DestTy)
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.