34 #include "llvm/IR/IntrinsicsARM.h"
47 #define DEBUG_TYPE "arm-mve-gather-scatter-lowering"
51 cl::desc(
"Enable the generation of masked gathers and scatters"));
66 return "MVE gather/scatter lowering";
80 bool isLegalTypeAndAlignment(
unsigned NumElements,
unsigned ElemSize,
83 void lookThroughBitcast(
Value *&Ptr);
90 int computeScale(
unsigned GEPElemSize,
unsigned MemoryElemSize);
97 std::pair<Value *, int64_t> getVarAndConst(
Value *Inst,
int TypeScale);
109 int64_t Increment = 0);
118 int64_t Increment = 0);
122 int64_t Increment = 0);
133 Value *Ptr,
unsigned TypeScale,
143 void pushOutAdd(
PHINode *&Phi,
Value *OffsSecondOperand,
unsigned StartIndex);
145 void pushOutMul(
PHINode *&Phi,
Value *IncrementPerRound,
146 Value *OffsSecondOperand,
unsigned LoopIncrement,
155 "MVE gather/scattering lowering pass",
false,
false)
158 return new MVEGatherScatterLowering();
161 bool MVEGatherScatterLowering::isLegalTypeAndAlignment(
unsigned NumElements,
164 if (((NumElements == 4 &&
165 (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) ||
166 (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) ||
167 (NumElements == 16 && ElemSize == 8)) &&
168 Alignment >= ElemSize / 8)
170 LLVM_DEBUG(
dbgs() <<
"masked gathers/scatters: instruction does not have "
171 <<
"valid alignment or vector type \n");
186 unsigned TargetElemSize = 128 / TargetElemCount;
187 unsigned OffsetElemSize = cast<FixedVectorType>(
Offsets->getType())
189 ->getScalarSizeInBits();
190 if (OffsetElemSize != TargetElemSize || OffsetElemSize != 32) {
194 int64_t TargetElemMaxSize = (1ULL << TargetElemSize);
195 auto CheckValueSize = [TargetElemMaxSize](
Value *OffsetElem) {
196 ConstantInt *OConst = dyn_cast<ConstantInt>(OffsetElem);
200 if (SExtValue >= TargetElemMaxSize || SExtValue < 0)
204 if (isa<FixedVectorType>(ConstOff->
getType())) {
205 for (
unsigned i = 0;
i < TargetElemCount;
i++) {
210 if (!CheckValueSize(ConstOff))
222 dbgs() <<
"masked gathers/scatters: no getelementpointer found\n");
225 LLVM_DEBUG(
dbgs() <<
"masked gathers/scatters: getelementpointer found."
226 <<
" Looking at intrinsic for base + vector of offsets\n");
227 Value *GEPPtr =
GEP->getPointerOperand();
230 !isa<FixedVectorType>(
Offsets->getType()))
233 if (
GEP->getNumOperands() != 2) {
234 LLVM_DEBUG(
dbgs() <<
"masked gathers/scatters: getelementptr with too many"
235 <<
" operands. Expanding.\n");
239 unsigned OffsetsElemCount =
240 cast<FixedVectorType>(
Offsets->getType())->getNumElements();
251 if (!ZextOffs || cast<FixedVectorType>(ZextOffs->
getDestTy())
253 ->getScalarSizeInBits() != 32)
259 if (Ty !=
Offsets->getType()) {
268 LLVM_DEBUG(
dbgs() <<
"masked gathers/scatters: found correct offsets\n");
272 void MVEGatherScatterLowering::lookThroughBitcast(
Value *&Ptr) {
274 if (
auto *BitCast = dyn_cast<BitCastInst>(Ptr)) {
275 auto *BCTy = cast<FixedVectorType>(BitCast->getType());
276 auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType());
277 if (BCTy->getNumElements() == BCSrcTy->getNumElements()) {
279 dbgs() <<
"masked gathers/scatters: looking through bitcast\n");
280 Ptr = BitCast->getOperand(0);
285 int MVEGatherScatterLowering::computeScale(
unsigned GEPElemSize,
286 unsigned MemoryElemSize) {
289 if (GEPElemSize == 32 && MemoryElemSize == 32)
291 else if (GEPElemSize == 16 && MemoryElemSize == 16)
293 else if (GEPElemSize == 8)
295 LLVM_DEBUG(
dbgs() <<
"masked gathers/scatters: incorrect scale. Can't "
296 <<
"create intrinsic\n");
301 const Constant *
C = dyn_cast<Constant>(V);
304 if (!isa<Instruction>(V))
309 I->getOpcode() == Instruction::Mul) {
316 if (
I->getOpcode() == Instruction::Mul)
322 std::pair<Value *, int64_t>
323 MVEGatherScatterLowering::getVarAndConst(
Value *Inst,
int TypeScale) {
324 std::pair<Value *, int64_t> ReturnFalse =
325 std::pair<Value *, int64_t>(
nullptr, 0);
335 if ((Const = getIfConst(
Add->getOperand(0))))
336 Summand =
Add->getOperand(1);
337 else if ((Const = getIfConst(
Add->getOperand(1))))
338 Summand =
Add->getOperand(0);
343 int64_t Immediate =
Const.getValue() << TypeScale;
344 if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0)
347 return std::pair<Value *, int64_t>(Summand, Immediate);
351 using namespace PatternMatch;
352 LLVM_DEBUG(
dbgs() <<
"masked gathers: checking transform preconditions\n");
357 auto *Ty = cast<FixedVectorType>(
I->getType());
358 Value *Ptr =
I->getArgOperand(0);
359 Align Alignment = cast<ConstantInt>(
I->getArgOperand(1))->getAlignValue();
361 Value *PassThru =
I->getArgOperand(3);
366 lookThroughBitcast(Ptr);
371 Builder.SetCurrentDebugLocation(
I->getDebugLoc());
380 if (!isa<UndefValue>(PassThru) && !
match(PassThru,
m_Zero())) {
381 LLVM_DEBUG(
dbgs() <<
"masked gathers: found non-trivial passthru - "
382 <<
"creating select\n");
391 I->eraseFromParent();
393 LLVM_DEBUG(
dbgs() <<
"masked gathers: successfully built masked gather\n");
401 using namespace PatternMatch;
402 auto *Ty = cast<FixedVectorType>(
I->getType());
403 LLVM_DEBUG(
dbgs() <<
"masked gathers: loading from vector of pointers\n");
409 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base,
411 {Ptr,
Builder.getInt32(Increment)});
413 return Builder.CreateIntrinsic(
414 Intrinsic::arm_mve_vldr_gather_base_predicated,
419 Value *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB(
421 using namespace PatternMatch;
422 auto *Ty = cast<FixedVectorType>(
I->getType());
425 <<
"masked gathers: loading from vector of pointers with writeback\n");
431 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb,
433 {Ptr,
Builder.getInt32(Increment)});
435 return Builder.CreateIntrinsic(
436 Intrinsic::arm_mve_vldr_gather_base_wb_predicated,
441 Value *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
443 using namespace PatternMatch;
445 Type *OriginalTy =
I->getType();
446 Type *ResultTy = OriginalTy;
459 Extend = cast<Instruction>(*
I->users().begin());
460 if (isa<SExtInst>(Extend)) {
462 }
else if (!isa<ZExtInst>(Extend)) {
463 LLVM_DEBUG(
dbgs() <<
"masked gathers: extend needed but not provided. "
467 LLVM_DEBUG(
dbgs() <<
"masked gathers: found an extending gather\n");
468 ResultTy = Extend->getType();
471 LLVM_DEBUG(
dbgs() <<
"masked gathers: extending from the wrong type. "
489 int Scale = computeScale(
490 BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(),
498 return Builder.CreateIntrinsic(
499 Intrinsic::arm_mve_vldr_gather_offset_predicated,
504 return Builder.CreateIntrinsic(
505 Intrinsic::arm_mve_vldr_gather_offset,
512 using namespace PatternMatch;
513 LLVM_DEBUG(
dbgs() <<
"masked scatters: checking transform preconditions\n");
518 Value *Input =
I->getArgOperand(0);
519 Value *Ptr =
I->getArgOperand(1);
520 Align Alignment = cast<ConstantInt>(
I->getArgOperand(2))->getAlignValue();
521 auto *Ty = cast<FixedVectorType>(Input->
getType());
527 lookThroughBitcast(Ptr);
532 Builder.SetCurrentDebugLocation(
I->getDebugLoc());
540 LLVM_DEBUG(
dbgs() <<
"masked scatters: successfully built masked scatter\n");
541 I->eraseFromParent();
545 Value *MVEGatherScatterLowering::tryCreateMaskedScatterBase(
547 using namespace PatternMatch;
548 Value *Input =
I->getArgOperand(0);
549 auto *Ty = cast<FixedVectorType>(Input->
getType());
557 LLVM_DEBUG(
dbgs() <<
"masked scatters: storing to a vector of pointers\n");
559 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base,
561 {Ptr,
Builder.getInt32(Increment), Input});
563 return Builder.CreateIntrinsic(
564 Intrinsic::arm_mve_vstr_scatter_base_predicated,
569 Value *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB(
571 using namespace PatternMatch;
572 Value *Input =
I->getArgOperand(0);
573 auto *Ty = cast<FixedVectorType>(Input->
getType());
576 <<
"masked scatters: storing to a vector of pointers with writeback\n");
582 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb,
584 {Ptr,
Builder.getInt32(Increment), Input});
586 return Builder.CreateIntrinsic(
587 Intrinsic::arm_mve_vstr_scatter_base_wb_predicated,
592 Value *MVEGatherScatterLowering::tryCreateMaskedScatterOffset(
594 using namespace PatternMatch;
595 Value *Input =
I->getArgOperand(0);
598 Type *MemoryTy = InputTy;
599 LLVM_DEBUG(
dbgs() <<
"masked scatters: getelementpointer found. Storing"
600 <<
" to base + vector of offsets\n");
603 if (
TruncInst *Trunc = dyn_cast<TruncInst>(Input)) {
604 Value *PreTrunc = Trunc->getOperand(0);
608 InputTy = PreTruncTy;
613 dbgs() <<
"masked scatters: cannot create scatters for non-standard"
614 <<
" input types. Expanding.\n");
630 int Scale = computeScale(
631 BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(),
637 return Builder.CreateIntrinsic(
638 Intrinsic::arm_mve_vstr_scatter_offset_predicated,
645 return Builder.CreateIntrinsic(
646 Intrinsic::arm_mve_vstr_scatter_offset,
653 Value *MVEGatherScatterLowering::tryCreateIncrementingGatScat(
657 if (
I->getIntrinsicID() == Intrinsic::masked_gather)
658 Ty = cast<FixedVectorType>(
I->getType());
660 Ty = cast<FixedVectorType>(
I->getArgOperand(0)->getType());
665 Loop *L = LI->getLoopFor(
I->getParent());
669 LLVM_DEBUG(
dbgs() <<
"masked gathers/scatters: trying to build incrementing "
670 "wb gather/scatter\n");
674 DataLayout DT =
I->getParent()->getParent()->getParent()->getDataLayout();
678 cast<FixedVectorType>(
GEP->getType())->getNumElements());
682 if (
GEP->hasOneUse()) {
687 tryCreateIncrementingWBGatScat(
I, BasePtr,
Offsets, TypeScale,
Builder);
691 LLVM_DEBUG(
dbgs() <<
"masked gathers/scatters: trying to build incrementing "
692 "non-wb gather/scatter\n");
694 std::pair<Value *, int64_t>
Add = getVarAndConst(
Offsets, TypeScale);
695 if (
Add.first ==
nullptr)
698 int64_t Immediate =
Add.second;
702 Instruction::Shl, OffsetsIncoming,
712 cast<VectorType>(ScaledOffsets->
getType())->getElementType())),
715 if (
I->getIntrinsicID() == Intrinsic::masked_gather)
716 return cast<IntrinsicInst>(
717 tryCreateMaskedGatherBase(
I, OffsetsIncoming,
Builder, Immediate));
719 return cast<IntrinsicInst>(
720 tryCreateMaskedScatterBase(
I, OffsetsIncoming,
Builder, Immediate));
723 Value *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat(
728 Loop *L = LI->getLoopFor(
I->getParent());
740 unsigned IncrementIndex =
745 std::pair<Value *, int64_t>
Add = getVarAndConst(
Offsets, TypeScale);
746 if (
Add.first ==
nullptr)
749 int64_t Immediate =
Add.second;
750 if (OffsetsIncoming != Phi)
757 cast<FixedVectorType>(OffsetsIncoming->
getType())->getNumElements();
771 cast<VectorType>(ScaledOffsets->
getType())->getElementType())),
775 Instruction::Sub, OffsetsIncoming,
777 "PreIncrementStartIndex",
785 if (
I->getIntrinsicID() == Intrinsic::masked_gather) {
790 EndResult =
Builder.CreateExtractValue(
Load, 0,
"Gather");
791 NewInduction =
Builder.CreateExtractValue(
Load, 1,
"GatherIncrement");
794 NewInduction = tryCreateMaskedScatterBaseWB(
I, Phi,
Builder, Immediate);
795 EndResult = NewInduction;
805 void MVEGatherScatterLowering::pushOutAdd(
PHINode *&Phi,
806 Value *OffsSecondOperand,
807 unsigned StartIndex) {
808 LLVM_DEBUG(
dbgs() <<
"masked gathers/scatters: optimising add instruction\n");
814 "PushedOutAdd", InsertionPoint);
815 unsigned IncrementIndex = StartIndex == 0 ? 1 : 0;
825 void MVEGatherScatterLowering::pushOutMul(
PHINode *&Phi,
826 Value *IncrementPerRound,
827 Value *OffsSecondOperand,
828 unsigned LoopIncrement,
830 LLVM_DEBUG(
dbgs() <<
"masked gathers/scatters: optimising mul instruction\n");
840 OffsSecondOperand,
"PushedOutMul", InsertionPoint);
844 OffsSecondOperand,
"Product", InsertionPoint);
861 if (
I->hasNUses(0)) {
865 for (
User *U :
I->users()) {
866 if (!isa<Instruction>(U))
868 if (isa<GetElementPtrInst>(U) ||
872 unsigned OpCode = cast<Instruction>(U)->getOpcode();
885 LLVM_DEBUG(
dbgs() <<
"masked gathers/scatters: trying to optimize\n");
888 if (!isa<Instruction>(
Offsets))
909 }
else if (isa<PHINode>(Offs->
getOperand(1))) {
916 Changed |= optimiseOffsets(Offs->
getOperand(0),
BB, LI);
919 Changed |= optimiseOffsets(Offs->
getOperand(1),
BB, LI);
926 }
else if (isa<PHINode>(Offs->
getOperand(1))) {
941 int IncrementingBlock = -1;
943 for (
int i = 0;
i < 2;
i++)
946 (
Op->getOperand(0) == Phi ||
Op->getOperand(1) == Phi))
947 IncrementingBlock =
i;
948 if (IncrementingBlock == -1)
958 (IncInstruction->
getOperand(0) == Phi) ? 1 : 0);
969 if (!isa<Constant>(IncrementPerRound) &&
970 !(isa<Instruction>(IncrementPerRound) &&
971 !L->
contains(cast<Instruction>(IncrementPerRound))))
982 IncrementPerRound,
"LoopIncrement", IncInstruction);
989 std::vector<Value *> Increases;
995 IncrementPerRound,
"LoopIncrement", IncInstruction);
998 IncrementingBlock = 1;
1007 pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1);
1009 case Instruction::Mul:
1010 pushOutMul(NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock,
1017 dbgs() <<
"masked gathers/scatters: simplified loop variable add/mul\n");
1038 if ((Const = dyn_cast<ConstantInt>(NonVectorVal)) &&
1041 uint64_t
N = Const->getZExtValue();
1042 if (
N < (
unsigned)(1 << (TargetElemSize - 1))) {
1043 NonVectorVal =
Builder.CreateVectorSplat(
1056 if (XElType && !YElType) {
1057 FixSummands(XElType,
Y);
1058 YElType = cast<FixedVectorType>(
Y->getType());
1059 }
else if (YElType && !XElType) {
1060 FixSummands(YElType,
X);
1061 XElType = cast<FixedVectorType>(
X->getType());
1063 assert(XElType && YElType &&
"Unknown vector types");
1065 if (XElType != YElType) {
1066 LLVM_DEBUG(
dbgs() <<
"masked gathers/scatters: incompatible gep offsets\n");
1073 Constant *ConstX = dyn_cast<Constant>(
X);
1074 Constant *ConstY = dyn_cast<Constant>(
Y);
1075 if (!ConstX || !ConstY)
1083 if (!ConstXEl || !ConstYEl ||
1085 (
unsigned)(1 << (TargetElemSize - 1)))
1102 Value *GEPPtr =
GEP->getPointerOperand();
1109 if ((BaseGEP = dyn_cast<GetElementPtrInst>(GEPPtr))) {
1128 bool Changed =
false;
1129 if (
GEP->hasOneUse() &&
1130 dyn_cast<GetElementPtrInst>(
GEP->getPointerOperand())) {
1133 Builder.SetCurrentDebugLocation(
GEP->getDebugLoc());
1144 GEP->replaceAllUsesWith(NewAddress);
1149 Changed |= optimiseOffsets(
GEP->getOperand(1),
GEP->getParent(), LI);
1156 auto &TPC = getAnalysis<TargetPassConfig>();
1159 if (!
ST->hasMVEIntegerOps())
1161 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1165 bool Changed =
false;
1171 isa<FixedVectorType>(II->
getType())) {
1172 Gathers.push_back(II);
1174 }
else if (II && II->
getIntrinsicID() == Intrinsic::masked_scatter &&
1176 Scatters.push_back(II);
1181 for (
unsigned i = 0;
i < Gathers.size();
i++) {
1183 Value *L = lowerGather(
I);
1192 for (
unsigned i = 0;
i < Scatters.size();
i++) {