31#define DEBUG_TYPE "riscv-gather-scatter-lowering"
62 return "RISC-V gather/scatter lowering";
68 std::pair<Value *, Value *> determineBaseAndStride(
Instruction *
Ptr,
78char RISCVGatherScatterLowering::ID = 0;
81 "RISC-V gather/scatter lowering pass",
false,
false)
84 return new RISCVGatherScatterLowering();
90 return std::make_pair(
nullptr,
nullptr);
98 return std::make_pair(
nullptr,
nullptr);
99 APInt StrideVal(StartVal->getValue().getBitWidth(), 0);
101 for (
unsigned i = 1; i != NumElts; ++i) {
104 return std::make_pair(
nullptr,
nullptr);
108 StrideVal = LocalStride;
109 else if (StrideVal != LocalStride)
110 return std::make_pair(
nullptr,
nullptr);
115 Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal);
117 return std::make_pair(StartVal, Stride);
129 auto *Ty = Start->getType()->getScalarType();
130 return std::make_pair(ConstantInt::get(Ty, 0), ConstantInt::get(Ty, 1));
136 if (!BO || (BO->getOpcode() != Instruction::Add &&
137 BO->getOpcode() != Instruction::Or &&
138 BO->getOpcode() != Instruction::Shl &&
139 BO->getOpcode() != Instruction::Mul))
140 return std::make_pair(
nullptr,
nullptr);
142 if (BO->getOpcode() == Instruction::Or &&
144 return std::make_pair(
nullptr,
nullptr);
147 unsigned OtherIndex = 0;
154 return std::make_pair(
nullptr,
nullptr);
160 return std::make_pair(
nullptr,
nullptr);
162 Builder.SetInsertPoint(BO);
163 Builder.SetCurrentDebugLocation(
DebugLoc());
166 switch (BO->getOpcode()) {
169 case Instruction::Or:
170 Start = Builder.CreateOr(Start,
Splat,
"",
true);
172 case Instruction::Add:
173 Start = Builder.CreateAdd(Start,
Splat);
175 case Instruction::Mul:
176 Start = Builder.CreateMul(Start,
Splat);
177 Stride = Builder.CreateMul(Stride,
Splat);
179 case Instruction::Shl:
180 Start = Builder.CreateShl(Start,
Splat);
181 Stride = Builder.CreateShl(Stride,
Splat);
185 return std::make_pair(Start, Stride);
192bool RISCVGatherScatterLowering::matchStridedRecurrence(
Value *Index,
Loop *L,
201 if (
Phi->getParent() !=
L->getHeader())
208 assert(
Phi->getNumIncomingValues() == 2 &&
"Expected 2 operand phi.");
209 unsigned IncrementingBlock =
Phi->getIncomingValue(0) == Inc ? 0 : 1;
210 assert(
Phi->getIncomingValue(IncrementingBlock) == Inc &&
211 "Expected one operand of phi to be Inc");
221 assert(Stride !=
nullptr);
226 Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->
getName() +
".scalar",
228 BasePtr->addIncoming(Start,
Phi->getIncomingBlock(1 - IncrementingBlock));
229 BasePtr->addIncoming(Inc,
Phi->getIncomingBlock(IncrementingBlock));
232 MaybeDeadPHIs.push_back(Phi);
241 switch (BO->getOpcode()) {
244 case Instruction::Or:
249 case Instruction::Add:
251 case Instruction::Shl:
253 case Instruction::Mul:
262 OtherOp = BO->getOperand(1);
267 OtherOp = BO->getOperand(0);
273 if (!
L->isLoopInvariant(OtherOp))
282 if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder))
287 unsigned StartBlock =
BasePtr->getOperand(0) == Inc ? 1 : 0;
293 BasePtr->getIncomingBlock(StartBlock)->getTerminator());
297 switch (BO->getOpcode()) {
300 case Instruction::Add:
301 case Instruction::Or: {
307 case Instruction::Mul: {
309 Stride = Builder.
CreateMul(Stride, SplatOp,
"stride");
312 case Instruction::Shl: {
314 Stride = Builder.
CreateShl(Stride, SplatOp,
"stride");
324 switch (BO->getOpcode()) {
327 case Instruction::Mul:
328 Step = Builder.
CreateMul(Step, SplatOp,
"step");
330 case Instruction::Shl:
331 Step = Builder.
CreateShl(Step, SplatOp,
"step");
336 BasePtr->setIncomingValue(StartBlock, Start);
340std::pair<Value *, Value *>
341RISCVGatherScatterLowering::determineBaseAndStride(Instruction *
Ptr,
342 IRBuilderBase &Builder) {
347 return std::make_pair(BasePtr, ConstantInt::get(IntPtrTy, 0));
352 return std::make_pair(
nullptr,
nullptr);
354 auto I = StridedAddrs.find(
GEP);
355 if (
I != StridedAddrs.end())
358 SmallVector<Value *, 2>
Ops(
GEP->operands());
363 BaseInst && BaseInst->getType()->isVectorTy()) {
365 auto IsScalar = [](
Value *Idx) {
return !Idx->getType()->isVectorTy(); };
367 auto [BaseBase, Stride] = determineBaseAndStride(BaseInst, Builder);
372 Builder.
CreateGEP(
GEP->getSourceElementType(), BaseBase, Indices,
373 GEP->getName() +
"offset",
GEP->isInBounds());
374 return {OffsetBase, Stride};
384 return std::make_pair(
nullptr,
nullptr);
387 std::optional<unsigned> VecOperand;
388 unsigned TypeScale = 0;
392 for (
unsigned i = 1, e =
GEP->getNumOperands(); i != e; ++i, ++GTI) {
397 return std::make_pair(
nullptr,
nullptr);
403 return std::make_pair(
nullptr,
nullptr);
410 return std::make_pair(
nullptr,
nullptr);
418 Type *VecIntPtrTy =
DL->getIntPtrType(
GEP->getType());
419 if (VecIndex->
getType() != VecIntPtrTy) {
422 return std::make_pair(
nullptr,
nullptr);
438 Type *SourceTy =
GEP->getSourceElementType();
448 Stride = Builder.
CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
450 auto P = std::make_pair(BasePtr, Stride);
451 StridedAddrs[
GEP] =
P;
457 if (!L || !
L->getLoopPreheader() || !
L->getLoopLatch())
458 return std::make_pair(
nullptr,
nullptr);
462 if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder))
463 return std::make_pair(
nullptr,
nullptr);
466 unsigned IncrementingBlock = BasePhi->
getOperand(0) == Inc ? 0 : 1;
468 "Expected one operand of phi to be Inc");
473 Ops[*VecOperand] = BasePhi;
474 Type *SourceTy =
GEP->getSourceElementType();
488 Stride = Builder.
CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
490 auto P = std::make_pair(BasePtr, Stride);
491 StridedAddrs[
GEP] =
P;
495bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *
II) {
499 switch (
II->getIntrinsicID()) {
500 case Intrinsic::masked_gather:
502 Ptr =
II->getArgOperand(0);
504 Mask =
II->getArgOperand(2);
506 case Intrinsic::vp_gather:
508 Ptr =
II->getArgOperand(0);
509 MA =
II->getParamAlign(0).value_or(
510 DL->getABITypeAlign(DataType->getElementType()));
511 Mask =
II->getArgOperand(1);
512 EVL =
II->getArgOperand(2);
514 case Intrinsic::masked_scatter:
516 StoreVal =
II->getArgOperand(0);
517 Ptr =
II->getArgOperand(1);
519 Mask =
II->getArgOperand(3);
521 case Intrinsic::vp_scatter:
523 StoreVal =
II->getArgOperand(0);
524 Ptr =
II->getArgOperand(1);
525 MA =
II->getParamAlign(1).value_or(
526 DL->getABITypeAlign(DataType->getElementType()));
527 Mask =
II->getArgOperand(2);
528 EVL =
II->getArgOperand(3);
548 LLVMContext &Ctx = PtrI->getContext();
553 std::tie(BasePtr, Stride) = determineBaseAndStride(PtrI, Builder);
556 assert(Stride !=
nullptr);
568 Intrinsic::experimental_vp_strided_load,
573 if (
II->getIntrinsicID() == Intrinsic::masked_gather)
577 Intrinsic::experimental_vp_strided_store,
582 II->replaceAllUsesWith(
Call);
583 II->eraseFromParent();
585 if (PtrI->use_empty())
591bool RISCVGatherScatterLowering::runOnFunction(Function &
F) {
595 auto &TPC = getAnalysis<TargetPassConfig>();
596 auto &
TM = TPC.getTM<RISCVTargetMachine>();
597 ST = &
TM.getSubtarget<RISCVSubtarget>(
F);
598 if (!
ST->hasVInstructions() || !
ST->useRVVForFixedLengthVectors())
601 TLI =
ST->getTargetLowering();
602 DL = &
F.getDataLayout();
603 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
605 StridedAddrs.clear();
611 for (BasicBlock &BB :
F) {
612 for (Instruction &
I : BB) {
616 switch (
II->getIntrinsicID()) {
617 case Intrinsic::masked_gather:
618 case Intrinsic::masked_scatter:
619 case Intrinsic::vp_gather:
620 case Intrinsic::vp_scatter:
630 for (
auto *
II : Worklist)
631 Changed |= tryCreateStridedLoadStore(
II);
634 while (!MaybeDeadPHIs.empty()) {
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static bool runOnFunction(Function &F, bool PostInlining)
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
static std::pair< Value *, Value * > matchStridedStart(Value *Start, IRBuilderBase &Builder)
static std::pair< Value *, Value * > matchStridedConstant(Constant *StartC)
static SymbolRef::Type getType(const Symbol *Sym)
Target-Independent Code Generator Pass Configuration Options pass.
Class for arbitrary precision integers.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
BinaryOps getOpcode() const
This is the shared class of boolean and integer constants.
const APInt & getValue() const
Return the constant as an APInt value reference.
This is an important base class in LLVM.
LLVM_ABI Constant * getAggregateElement(unsigned Elt) const
For aggregates (struct/array/vector) return the constant that corresponds to the specified element if...
A parsed version of the target data layout string in and methods for querying it.
FunctionPass class - This class is used to implement most global optimizations.
Common base class shared among various IRBuilders.
LLVM_ABI Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
IntegerType * getInt32Ty()
Fetch the type representing a 32-bit integer.
void SetCurrentDebugLocation(DebugLoc L)
Set location information used by debugging information.
Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
Value * CreateShl(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
LLVM_ABI Value * CreateElementCount(Type *Ty, ElementCount EC)
Create an expression which evaluates to the number of elements in EC at runtime.
LLVM_ABI bool isCommutative() const LLVM_READONLY
Return true if the instruction is commutative:
A wrapper class for inspecting calls to intrinsic functions.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Represents a single loop in the control flow graph.
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
bool isLegalStridedLoadStore(EVT DataType, Align Alignment) const
Return true if a stride load store of the given result type and alignment is legal.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
EVT getValueType(const DataLayout &DL, Type *Ty, bool AllowUnknown=false) const
Return the EVT corresponding to this LLVM type.
bool isTypeLegal(EVT VT) const
Return true if the target has native support for the specified value type.
Target-Independent Code Generator Pass Configuration Options.
bool isVectorTy() const
True if this is an instance of VectorType.
LLVM_ABI unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
void setOperand(unsigned i, Value *Val)
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
LLVM_ABI void takeName(Value *V)
Transfer the name from V to this value.
constexpr ScalarTy getFixedValue() const
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
TypeSize getSequentialElementStride(const DataLayout &DL) const
self_iterator getIterator()
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ C
The default llvm calling convention, compatible with C.
bool match(Val *V, const Pattern &P)
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
NodeAddr< PhiNode * > Phi
This is an optimization pass for GlobalISel generic memory operations.
FunctionAddr VTableAddr Value
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
LLVM_ABI Value * getSplatValue(const Value *V)
Get splat value if the input is a splat vector or return nullptr.
FunctionPass * createRISCVGatherScatterLoweringPass()
LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
auto dyn_cast_or_null(const Y &Val)
generic_gep_type_iterator<> gep_type_iterator
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
ArrayRef(const T &OneElt) -> ArrayRef< T >
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
gep_type_iterator gep_type_begin(const User *GEP)
LLVM_ABI bool RecursivelyDeleteDeadPHINode(PHINode *PN, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr)
If the specified value is an effectively dead PHI node, due to being a def-use chain of single-use no...
LLVM_ABI Constant * ConstantFoldCastInstruction(unsigned opcode, Constant *V, Type *DestTy)