23#include "llvm/IR/IntrinsicsRISCV.h"
29using namespace PatternMatch;
31#define DEBUG_TYPE "riscv-gather-scatter-lowering"
62 return "RISC-V gather/scatter lowering";
66 bool isLegalTypeAndAlignment(
Type *DataType,
Value *AlignOp);
81char RISCVGatherScatterLowering::ID = 0;
84 "RISC-V gather/scatter lowering pass",
false,
false)
87 return new RISCVGatherScatterLowering();
90bool RISCVGatherScatterLowering::isLegalTypeAndAlignment(
Type *DataType,
93 if (!TLI->isLegalElementTypeForRVV(ScalarType))
96 MaybeAlign MA = cast<ConstantInt>(AlignOp)->getMaybeAlignValue();
97 if (MA && MA->value() <
DL->getTypeStoreSize(ScalarType).getFixedValue())
101 EVT DataVT = TLI->getValueType(*
DL, DataType);
102 if (!TLI->isTypeLegal(DataVT))
110 unsigned NumElts = cast<FixedVectorType>(StartC->
getType())->getNumElements();
116 return std::make_pair(
nullptr,
nullptr);
117 APInt StrideVal(StartVal->getValue().getBitWidth(), 0);
119 for (
unsigned i = 1; i != NumElts; ++i) {
122 return std::make_pair(
nullptr,
nullptr);
126 StrideVal = LocalStride;
127 else if (StrideVal != LocalStride)
128 return std::make_pair(
nullptr,
nullptr);
135 return std::make_pair(StartVal, Stride);
141 auto *StartC = dyn_cast<Constant>(Start);
146 if (
match(Start, m_Intrinsic<Intrinsic::experimental_stepvector>())) {
147 auto *Ty = Start->getType()->getScalarType();
153 auto *BO = dyn_cast<BinaryOperator>(Start);
154 if (!BO || (BO->getOpcode() != Instruction::Add &&
155 BO->getOpcode() != Instruction::Mul))
156 return std::make_pair(
nullptr,
nullptr);
159 unsigned OtherIndex = 1;
166 return std::make_pair(
nullptr,
nullptr);
172 return std::make_pair(
nullptr,
nullptr);
178 if (BO->getOpcode() == Instruction::Add) {
179 Start =
Builder.CreateAdd(Start, Splat);
181 assert(BO->getOpcode() == Instruction::Mul &&
"Unexpected opcode");
182 Start =
Builder.CreateMul(Start, Splat);
183 Stride =
Builder.CreateMul(Stride, Splat);
185 return std::make_pair(Start, Stride);
192bool RISCVGatherScatterLowering::matchStridedRecurrence(
Value *
Index,
Loop *L,
198 if (
auto *Phi = dyn_cast<PHINode>(
Index)) {
201 if (Phi->getParent() !=
L->getHeader())
208 assert(Phi->getNumIncomingValues() == 2 &&
"Expected 2 operand phi.");
209 unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1;
210 assert(Phi->getIncomingValue(IncrementingBlock) == Inc &&
211 "Expected one operand of phi to be Inc");
214 if (!
L->isLoopInvariant(Step))
225 assert(Stride !=
nullptr);
230 Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->
getName() +
".scalar",
232 BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock));
233 BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock));
236 MaybeDeadPHIs.push_back(Phi);
241 auto *BO = dyn_cast<BinaryOperator>(
Index);
245 if (BO->getOpcode() != Instruction::Add &&
246 BO->getOpcode() != Instruction::Or &&
247 BO->getOpcode() != Instruction::Mul &&
248 BO->getOpcode() != Instruction::Shl)
252 if (BO->getOpcode() == Instruction::Shl && !isa<Constant>(BO->getOperand(1)))
256 if (BO->getOpcode() == Instruction::Or &&
262 if (isa<Instruction>(BO->getOperand(0)) &&
263 L->contains(cast<Instruction>(BO->getOperand(0)))) {
264 Index = cast<Instruction>(BO->getOperand(0));
265 OtherOp = BO->getOperand(1);
266 }
else if (isa<Instruction>(BO->getOperand(1)) &&
267 L->contains(cast<Instruction>(BO->getOperand(1)))) {
268 Index = cast<Instruction>(BO->getOperand(1));
269 OtherOp = BO->getOperand(0);
275 if (!
L->isLoopInvariant(OtherOp))
284 if (!matchStridedRecurrence(
Index, L, Stride, BasePtr, Inc, Builder))
289 unsigned StartBlock =
BasePtr->getOperand(0) == Inc ? 1 : 0;
295 BasePtr->getIncomingBlock(StartBlock)->getTerminator());
298 switch (BO->getOpcode()) {
301 case Instruction::Add:
302 case Instruction::Or: {
307 if (isa<ConstantInt>(Start) && cast<ConstantInt>(Start)->
isZero())
310 Start =
Builder.CreateAdd(Start, SplatOp,
"start");
311 BasePtr->setIncomingValue(StartBlock, Start);
314 case Instruction::Mul: {
316 if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->
isZero())
317 Start =
Builder.CreateMul(Start, SplatOp,
"start");
319 Step =
Builder.CreateMul(Step, SplatOp,
"step");
322 if (isa<ConstantInt>(Stride) && cast<ConstantInt>(Stride)->isOne())
325 Stride =
Builder.CreateMul(Stride, SplatOp,
"stride");
327 BasePtr->setIncomingValue(StartBlock, Start);
330 case Instruction::Shl: {
332 if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->
isZero())
333 Start =
Builder.CreateShl(Start, SplatOp,
"start");
334 Step =
Builder.CreateShl(Step, SplatOp,
"step");
335 Stride =
Builder.CreateShl(Stride, SplatOp,
"stride");
337 BasePtr->setIncomingValue(StartBlock, Start);
345std::pair<Value *, Value *>
349 auto I = StridedAddrs.find(
GEP);
350 if (
I != StridedAddrs.end())
356 if (Ops[0]->
getType()->isVectorTy())
357 return std::make_pair(
nullptr,
nullptr);
359 std::optional<unsigned> VecOperand;
360 unsigned TypeScale = 0;
364 for (
unsigned i = 1, e =
GEP->getNumOperands(); i != e; ++i, ++GTI) {
365 if (!Ops[i]->
getType()->isVectorTy())
369 return std::make_pair(
nullptr,
nullptr);
375 return std::make_pair(
nullptr,
nullptr);
382 return std::make_pair(
nullptr,
nullptr);
388 Value *VecIndex = Ops[*VecOperand];
389 Type *VecIntPtrTy =
DL->getIntPtrType(
GEP->getType());
390 if (VecIndex->
getType() != VecIntPtrTy)
391 return std::make_pair(
nullptr,
nullptr);
401 Ops[*VecOperand] = Start;
402 Type *SourceTy =
GEP->getSourceElementType();
414 auto P = std::make_pair(BasePtr, Stride);
415 StridedAddrs[
GEP] =
P;
420 Loop *
L = LI->getLoopFor(
GEP->getParent());
421 if (!L || !
L->getLoopPreheader() || !
L->getLoopLatch())
422 return std::make_pair(
nullptr,
nullptr);
426 if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder))
427 return std::make_pair(
nullptr,
nullptr);
430 unsigned IncrementingBlock = BasePhi->
getOperand(0) == Inc ? 0 : 1;
432 "Expected one operand of phi to be Inc");
437 Ops[*VecOperand] = BasePhi;
438 Type *SourceTy =
GEP->getSourceElementType();
454 auto P = std::make_pair(BasePtr, Stride);
455 StridedAddrs[
GEP] =
P;
459bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(
IntrinsicInst *II,
464 if (!isLegalTypeAndAlignment(DataType, AlignOp))
468 auto *
GEP = dyn_cast<GetElementPtrInst>(
Ptr);
475 std::tie(BasePtr, Stride) = determineBaseAndStride(
GEP, Builder);
478 assert(Stride !=
nullptr);
485 Intrinsic::riscv_masked_strided_load,
490 Intrinsic::riscv_masked_strided_store,
498 if (
GEP->use_empty())
504bool RISCVGatherScatterLowering::runOnFunction(
Function &
F) {
508 auto &TPC = getAnalysis<TargetPassConfig>();
511 if (!
ST->hasVInstructions() || !
ST->useRVVForFixedLengthVectors())
514 TLI =
ST->getTargetLowering();
515 DL = &
F.getParent()->getDataLayout();
516 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
518 StridedAddrs.clear();
523 bool Changed =
false;
530 }
else if (II && II->
getIntrinsicID() == Intrinsic::masked_scatter) {
537 for (
auto *II : Gathers)
538 Changed |= tryCreateStridedLoadStore(
540 for (
auto *II : Scatters)
546 while (!MaybeDeadPHIs.empty()) {
547 if (
auto *Phi = dyn_cast_or_null<PHINode>(MaybeDeadPHIs.pop_back_val()))
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
const char LLVMTargetMachineRef TM
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
static std::pair< Value *, Value * > matchStridedStart(Value *Start, IRBuilder<> &Builder)
static std::pair< Value *, Value * > matchStridedConstant(Constant *StartC)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static SymbolRef::Type getType(const Symbol *Sym)
Target-Independent Code Generator Pass Configuration Options pass.
Class for arbitrary precision integers.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesCFG()
This function should be called by the pass, iff they do not:
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
LLVM Basic Block Representation.
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
BinaryOps getOpcode() const
Value * getArgOperand(unsigned i) const
This class represents a function call, abstracting a target machine's calling convention.
This is the shared class of boolean and integer constants.
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
const APInt & getValue() const
Return the constant as an APInt value reference.
This is an important base class in LLVM.
Constant * getAggregateElement(unsigned Elt) const
For aggregates (struct/array/vector) return the constant that corresponds to the specified element if...
A parsed version of the target data layout string in and methods for querying it.
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
an instruction for type-safe pointer arithmetic to access elements of arrays and structs
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
A wrapper class for inspecting calls to intrinsic functions.
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
The legacy pass manager's analysis pass to compute loop information.
Represents a single loop in the control flow graph.
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", Instruction *InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
void setOperand(unsigned i, Value *Val)
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
StringRef getName() const
Return a constant reference to the value's name.
constexpr ScalarTy getFixedValue() const
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Type * getIndexedType() const
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ C
The default llvm calling convention, compatible with C.
bool match(Val *V, const Pattern &P)
This is an optimization pass for GlobalISel generic memory operations.
bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
Value * getSplatValue(const Value *V)
Get splat value if the input is a splat vector or return nullptr.
FunctionPass * createRISCVGatherScatterLoweringPass()
bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
bool haveNoCommonBitsSet(const Value *LHS, const Value *RHS, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)
Return true if LHS and RHS have no common bits set.
gep_type_iterator gep_type_begin(const User *GEP)
bool RecursivelyDeleteDeadPHINode(PHINode *PN, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr)
If the specified value is an effectively dead PHI node, due to being a def-use chain of single-use no...
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.