35 #include "llvm/IR/IntrinsicsARM.h"
48 #define DEBUG_TYPE "arm-mve-gather-scatter-lowering"
52 cl::desc(
"Enable the generation of masked gathers and scatters"));
67 return "MVE gather/scatter lowering";
82 bool isLegalTypeAndAlignment(
unsigned NumElements,
unsigned ElemSize,
85 void lookThroughBitcast(
Value *&Ptr);
98 int computeScale(
unsigned GEPElemSize,
unsigned MemoryElemSize);
105 std::pair<Value *, int64_t> getVarAndConst(
Value *Inst,
int TypeScale);
115 int64_t Increment = 0);
119 int64_t Increment = 0);
128 int64_t Increment = 0);
132 int64_t Increment = 0);
142 Value *Ptr,
unsigned TypeScale,
153 void pushOutAdd(
PHINode *&Phi,
Value *OffsSecondOperand,
unsigned StartIndex);
155 void pushOutMulShl(
unsigned Opc,
PHINode *&Phi,
Value *IncrementPerRound,
156 Value *OffsSecondOperand,
unsigned LoopIncrement,
165 "MVE gather/scattering lowering pass",
false,
false)
168 return new MVEGatherScatterLowering();
171 bool MVEGatherScatterLowering::isLegalTypeAndAlignment(
unsigned NumElements,
174 if (((NumElements == 4 &&
175 (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) ||
176 (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) ||
177 (NumElements == 16 && ElemSize == 8)) &&
178 Alignment >= ElemSize / 8)
180 LLVM_DEBUG(
dbgs() <<
"masked gathers/scatters: instruction does not have "
181 <<
"valid alignment or vector type \n");
196 unsigned TargetElemSize = 128 / TargetElemCount;
197 unsigned OffsetElemSize = cast<FixedVectorType>(
Offsets->getType())
199 ->getScalarSizeInBits();
200 if (OffsetElemSize != TargetElemSize || OffsetElemSize != 32) {
204 int64_t TargetElemMaxSize = (1ULL << TargetElemSize);
205 auto CheckValueSize = [TargetElemMaxSize](
Value *OffsetElem) {
206 ConstantInt *OConst = dyn_cast<ConstantInt>(OffsetElem);
210 if (SExtValue >= TargetElemMaxSize || SExtValue < 0)
214 if (isa<FixedVectorType>(ConstOff->
getType())) {
215 for (
unsigned i = 0;
i < TargetElemCount;
i++) {
220 if (!CheckValueSize(ConstOff))
231 if (
auto *
GEP = dyn_cast<GetElementPtrInst>(Ptr)) {
234 computeScale(
GEP->getSourceElementType()->getPrimitiveSizeInBits(),
236 return Scale == -1 ? nullptr : V;
259 LLVM_DEBUG(
dbgs() <<
"masked gathers/scatters: no getelementpointer "
263 LLVM_DEBUG(
dbgs() <<
"masked gathers/scatters: getelementpointer found."
264 <<
" Looking at intrinsic for base + vector of offsets\n");
265 Value *GEPPtr =
GEP->getPointerOperand();
268 !isa<FixedVectorType>(
Offsets->getType()))
271 if (
GEP->getNumOperands() != 2) {
272 LLVM_DEBUG(
dbgs() <<
"masked gathers/scatters: getelementptr with too many"
273 <<
" operands. Expanding.\n");
277 unsigned OffsetsElemCount =
278 cast<FixedVectorType>(
Offsets->getType())->getNumElements();
289 if (!ZextOffs || cast<FixedVectorType>(ZextOffs->
getDestTy())
291 ->getScalarSizeInBits() != 32)
297 if (Ty !=
Offsets->getType()) {
306 LLVM_DEBUG(
dbgs() <<
"masked gathers/scatters: found correct offsets\n");
310 void MVEGatherScatterLowering::lookThroughBitcast(
Value *&Ptr) {
312 if (
auto *BitCast = dyn_cast<BitCastInst>(Ptr)) {
313 auto *BCTy = cast<FixedVectorType>(BitCast->getType());
314 auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType());
315 if (BCTy->getNumElements() == BCSrcTy->getNumElements()) {
318 Ptr = BitCast->getOperand(0);
323 int MVEGatherScatterLowering::computeScale(
unsigned GEPElemSize,
324 unsigned MemoryElemSize) {
327 if (GEPElemSize == 32 && MemoryElemSize == 32)
329 else if (GEPElemSize == 16 && MemoryElemSize == 16)
331 else if (GEPElemSize == 8)
333 LLVM_DEBUG(
dbgs() <<
"masked gathers/scatters: incorrect scale. Can't "
334 <<
"create intrinsic\n");
339 const Constant *
C = dyn_cast<Constant>(V);
340 if (
C &&
C->getSplatValue())
342 if (!isa<Instruction>(V))
348 I->getOpcode() == Instruction::Shl) {
357 if (
I->getOpcode() == Instruction::Shl)
359 if (
I->getOpcode() == Instruction::Or)
368 return I->getOpcode() == Instruction::Or &&
372 std::pair<Value *, int64_t>
373 MVEGatherScatterLowering::getVarAndConst(
Value *Inst,
int TypeScale) {
374 std::pair<Value *, int64_t> ReturnFalse =
375 std::pair<Value *, int64_t>(
nullptr, 0);
379 if (Add ==
nullptr ||
386 if ((Const = getIfConst(
Add->getOperand(0))))
387 Summand =
Add->getOperand(1);
388 else if ((Const = getIfConst(
Add->getOperand(1))))
389 Summand =
Add->getOperand(0);
394 int64_t Immediate = *
Const << TypeScale;
395 if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0)
398 return std::pair<Value *, int64_t>(Summand, Immediate);
402 using namespace PatternMatch;
403 LLVM_DEBUG(
dbgs() <<
"masked gathers: checking transform preconditions\n"
409 auto *Ty = cast<FixedVectorType>(
I->getType());
410 Value *Ptr =
I->getArgOperand(0);
411 Align Alignment = cast<ConstantInt>(
I->getArgOperand(1))->getAlignValue();
413 Value *PassThru =
I->getArgOperand(3);
418 lookThroughBitcast(Ptr);
423 Builder.SetCurrentDebugLocation(
I->getDebugLoc());
429 Load = tryCreateMaskedGatherOffset(
I, Ptr, Root,
Builder);
435 if (!isa<UndefValue>(PassThru) && !
match(PassThru,
m_Zero())) {
436 LLVM_DEBUG(
dbgs() <<
"masked gathers: found non-trivial passthru - "
437 <<
"creating select\n");
447 I->eraseFromParent();
449 LLVM_DEBUG(
dbgs() <<
"masked gathers: successfully built masked gather\n"
454 Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBase(
456 using namespace PatternMatch;
457 auto *Ty = cast<FixedVectorType>(
I->getType());
458 LLVM_DEBUG(
dbgs() <<
"masked gathers: loading from vector of pointers\n");
464 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base,
466 {Ptr,
Builder.getInt32(Increment)});
468 return Builder.CreateIntrinsic(
469 Intrinsic::arm_mve_vldr_gather_base_predicated,
474 Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB(
476 using namespace PatternMatch;
477 auto *Ty = cast<FixedVectorType>(
I->getType());
478 LLVM_DEBUG(
dbgs() <<
"masked gathers: loading from vector of pointers with "
485 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb,
487 {Ptr,
Builder.getInt32(Increment)});
489 return Builder.CreateIntrinsic(
490 Intrinsic::arm_mve_vldr_gather_base_wb_predicated,
495 Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
497 using namespace PatternMatch;
499 Type *MemoryTy =
I->getType();
500 Type *ResultTy = MemoryTy;
506 bool TruncResult =
false;
508 if (
I->hasOneUse()) {
513 if (isa<SExtInst>(
User) &&
520 }
else if (isa<ZExtInst>(
User) &&
523 << *ResultTy <<
"\n");
534 128 / cast<FixedVectorType>(ResultTy)->getNumElements());
536 LLVM_DEBUG(
dbgs() <<
"masked gathers: Small input type, truncing to: "
537 << *ResultTy <<
"\n");
542 LLVM_DEBUG(
dbgs() <<
"masked gathers: Extend needed but not provided "
543 "from the correct type. Expanding\n");
551 Ptr,
Offsets, Scale, cast<FixedVectorType>(ResultTy), MemoryTy,
Builder);
560 Intrinsic::arm_mve_vldr_gather_offset_predicated,
566 Intrinsic::arm_mve_vldr_gather_offset,
579 using namespace PatternMatch;
580 LLVM_DEBUG(
dbgs() <<
"masked scatters: checking transform preconditions\n"
587 Value *Ptr =
I->getArgOperand(1);
588 Align Alignment = cast<ConstantInt>(
I->getArgOperand(2))->getAlignValue();
589 auto *Ty = cast<FixedVectorType>(
Input->getType());
595 lookThroughBitcast(Ptr);
600 Builder.SetCurrentDebugLocation(
I->getDebugLoc());
610 LLVM_DEBUG(
dbgs() <<
"masked scatters: successfully built masked scatter\n"
612 I->eraseFromParent();
616 Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBase(
618 using namespace PatternMatch;
620 auto *Ty = cast<FixedVectorType>(
Input->getType());
628 LLVM_DEBUG(
dbgs() <<
"masked scatters: storing to a vector of pointers\n");
630 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base,
632 {Ptr,
Builder.getInt32(Increment), Input});
634 return Builder.CreateIntrinsic(
635 Intrinsic::arm_mve_vstr_scatter_base_predicated,
640 Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB(
642 using namespace PatternMatch;
644 auto *Ty = cast<FixedVectorType>(
Input->getType());
645 LLVM_DEBUG(
dbgs() <<
"masked scatters: storing to a vector of pointers "
646 <<
"with writeback\n");
652 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb,
654 {Ptr,
Builder.getInt32(Increment), Input});
656 return Builder.CreateIntrinsic(
657 Intrinsic::arm_mve_vstr_scatter_base_wb_predicated,
662 Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterOffset(
664 using namespace PatternMatch;
668 Type *MemoryTy = InputTy;
670 LLVM_DEBUG(
dbgs() <<
"masked scatters: getelementpointer found. Storing"
671 <<
" to base + vector of offsets\n");
674 if (
TruncInst *Trunc = dyn_cast<TruncInst>(Input)) {
675 Value *PreTrunc = Trunc->getOperand(0);
679 InputTy = PreTruncTy;
682 bool ExtendInput =
false;
690 128 / cast<FixedVectorType>(InputTy)->getNumElements());
692 LLVM_DEBUG(
dbgs() <<
"masked scatters: Small input type, will extend:\n"
696 LLVM_DEBUG(
dbgs() <<
"masked scatters: cannot create scatters for "
697 "non-standard input types. Expanding.\n");
704 Ptr,
Offsets, Scale, cast<FixedVectorType>(InputTy), MemoryTy,
Builder);
711 return Builder.CreateIntrinsic(
712 Intrinsic::arm_mve_vstr_scatter_offset_predicated,
719 return Builder.CreateIntrinsic(
720 Intrinsic::arm_mve_vstr_scatter_offset,
727 Instruction *MVEGatherScatterLowering::tryCreateIncrementingGatScat(
730 if (
I->getIntrinsicID() == Intrinsic::masked_gather)
731 Ty = cast<FixedVectorType>(
I->getType());
733 Ty = cast<FixedVectorType>(
I->getArgOperand(0)->getType());
739 Loop *L = LI->getLoopFor(
I->getParent());
750 LLVM_DEBUG(
dbgs() <<
"masked gathers/scatters: trying to build incrementing "
751 "wb gather/scatter\n");
756 computeScale(
DL->getTypeSizeInBits(
GEP->getOperand(0)->getType()),
757 DL->getTypeSizeInBits(
GEP->getType()) /
758 cast<FixedVectorType>(
GEP->getType())->getNumElements());
762 if (
GEP->hasOneUse()) {
766 if (
auto *
Load = tryCreateIncrementingWBGatScat(
I, BasePtr,
Offsets,
771 LLVM_DEBUG(
dbgs() <<
"masked gathers/scatters: trying to build incrementing "
772 "non-wb gather/scatter\n");
774 std::pair<Value *, int64_t>
Add = getVarAndConst(
Offsets, TypeScale);
775 if (
Add.first ==
nullptr)
778 int64_t Immediate =
Add.second;
782 Instruction::Shl, OffsetsIncoming,
792 cast<VectorType>(ScaledOffsets->
getType())->getElementType())),
795 if (
I->getIntrinsicID() == Intrinsic::masked_gather)
796 return tryCreateMaskedGatherBase(
I, OffsetsIncoming,
Builder, Immediate);
798 return tryCreateMaskedScatterBase(
I, OffsetsIncoming,
Builder, Immediate);
801 Instruction *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat(
806 Loop *L = LI->getLoopFor(
I->getParent());
818 unsigned IncrementIndex =
823 std::pair<Value *, int64_t>
Add = getVarAndConst(
Offsets, TypeScale);
824 if (
Add.first ==
nullptr)
827 int64_t Immediate =
Add.second;
828 if (OffsetsIncoming != Phi)
835 cast<FixedVectorType>(OffsetsIncoming->
getType())->getNumElements();
849 cast<VectorType>(ScaledOffsets->
getType())->getElementType())),
853 Instruction::Sub, OffsetsIncoming,
855 "PreIncrementStartIndex",
863 if (
I->getIntrinsicID() == Intrinsic::masked_gather) {
874 EndResult = NewInduction =
875 tryCreateMaskedScatterBaseWB(
I, Phi,
Builder, Immediate);
885 void MVEGatherScatterLowering::pushOutAdd(
PHINode *&Phi,
886 Value *OffsSecondOperand,
887 unsigned StartIndex) {
888 LLVM_DEBUG(
dbgs() <<
"masked gathers/scatters: optimising add instruction\n");
894 "PushedOutAdd", InsertionPoint);
895 unsigned IncrementIndex = StartIndex == 0 ? 1 : 0;
905 void MVEGatherScatterLowering::pushOutMulShl(
unsigned Opcode,
PHINode *&Phi,
906 Value *IncrementPerRound,
907 Value *OffsSecondOperand,
908 unsigned LoopIncrement,
910 LLVM_DEBUG(
dbgs() <<
"masked gathers/scatters: optimising mul instruction\n");
921 OffsSecondOperand,
"PushedOutMul", InsertionPoint);
925 OffsSecondOperand,
"Product", InsertionPoint);
942 if (
I->hasNUses(0)) {
946 for (
User *U :
I->users()) {
947 if (!isa<Instruction>(U))
949 if (isa<GetElementPtrInst>(U) ||
953 unsigned OpCode = cast<Instruction>(U)->getOpcode();
955 OpCode == Instruction::Shl ||
968 LLVM_DEBUG(
dbgs() <<
"masked gathers/scatters: trying to optimize: "
972 if (!isa<Instruction>(
Offsets))
994 }
else if (isa<PHINode>(Offs->
getOperand(1))) {
998 bool Changed =
false;
1001 Changed |= optimiseOffsets(Offs->
getOperand(0),
BB, LI);
1004 Changed |= optimiseOffsets(Offs->
getOperand(1),
BB, LI);
1010 }
else if (isa<PHINode>(Offs->
getOperand(1))) {
1024 Value *Start, *IncrementPerRound;
1029 int IncrementingBlock = Phi->
getIncomingValue(0) == IncInstruction ? 0 : 1;
1034 if (IncrementPerRound->
getType() != OffsSecondOperand->
getType() ||
1041 if (!isa<Constant>(IncrementPerRound) &&
1042 !(isa<Instruction>(IncrementPerRound) &&
1043 !L->
contains(cast<Instruction>(IncrementPerRound))))
1057 IncrementPerRound,
"LoopIncrement", IncInstruction);
1069 IncrementPerRound,
"LoopIncrement", IncInstruction);
1072 IncrementingBlock = 1;
1081 case Instruction::Or:
1082 pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1);
1085 case Instruction::Shl:
1086 pushOutMulShl(Offs->
getOpcode(), NewPhi, IncrementPerRound,
1087 OffsSecondOperand, IncrementingBlock,
Builder);
1092 LLVM_DEBUG(
dbgs() <<
"masked gathers/scatters: simplified loop variable "
1114 if ((Const = dyn_cast<ConstantInt>(NonVectorVal)) &&
1118 if (
N < (
unsigned)(1 << (TargetElemSize - 1))) {
1119 NonVectorVal =
Builder.CreateVectorSplat(
1132 if (XElType && !YElType) {
1133 FixSummands(XElType,
Y);
1134 YElType = cast<FixedVectorType>(
Y->getType());
1135 }
else if (YElType && !XElType) {
1136 FixSummands(YElType,
X);
1137 XElType = cast<FixedVectorType>(
X->getType());
1139 assert(XElType && YElType &&
"Unknown vector types");
1141 if (XElType != YElType) {
1142 LLVM_DEBUG(
dbgs() <<
"masked gathers/scatters: incompatible gep offsets\n");
1149 Constant *ConstX = dyn_cast<Constant>(
X);
1150 Constant *ConstY = dyn_cast<Constant>(
Y);
1151 if (!ConstX || !ConstY)
1159 if (!ConstXEl || !ConstYEl ||
1162 (
unsigned)(1 << (TargetElemSize - 1)))
1185 Value *GEPPtr =
GEP->getPointerOperand();
1187 Scale =
DL->getTypeAllocSize(
GEP->getSourceElementType());
1190 if (
GEP->getNumIndices() != 1 || !isa<Constant>(
Offsets))
1199 DL->getTypeAllocSize(
GEP->getSourceElementType()),
Builder);
1213 bool Changed =
false;
1214 if (
GEP->hasOneUse() && isa<GetElementPtrInst>(
GEP->getPointerOperand())) {
1217 Builder.SetCurrentDebugLocation(
GEP->getDebugLoc());
1226 assert(Scale == 1 &&
"Expected to fold GEP to a scale of 1");
1228 if (
auto *VecTy = dyn_cast<FixedVectorType>(
Base->getType()))
1234 <<
"\n new : " << *NewAddress <<
"\n");
1235 GEP->replaceAllUsesWith(
1236 Builder.CreateBitCast(NewAddress,
GEP->getType()));
1241 Changed |= optimiseOffsets(
GEP->getOperand(1),
GEP->getParent(), LI);
1248 auto &TPC = getAnalysis<TargetPassConfig>();
1251 if (!
ST->hasMVEIntegerOps())
1253 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1254 DL = &
F.getParent()->getDataLayout();
1258 bool Changed =
false;
1266 isa<FixedVectorType>(II->
getType())) {
1267 Gathers.push_back(II);
1269 }
else if (II && II->
getIntrinsicID() == Intrinsic::masked_scatter &&
1271 Scatters.push_back(II);
1276 for (
unsigned i = 0;
i < Gathers.size();
i++) {
1287 for (
unsigned i = 0;
i < Scatters.size();
i++) {