LLVM  13.0.0git
MVEGatherScatterLowering.cpp
Go to the documentation of this file.
1 //===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// This pass custom lowers llvm.gather and llvm.scatter instructions to
10 /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to
11 /// produce a better final result as we go.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "ARM.h"
16 #include "ARMBaseInstrInfo.h"
17 #include "ARMSubtarget.h"
18 #include "llvm/Analysis/LoopInfo.h"
23 #include "llvm/InitializePasses.h"
24 #include "llvm/IR/BasicBlock.h"
25 #include "llvm/IR/Constant.h"
26 #include "llvm/IR/Constants.h"
27 #include "llvm/IR/DerivedTypes.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/InstrTypes.h"
30 #include "llvm/IR/Instruction.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/IntrinsicInst.h"
33 #include "llvm/IR/Intrinsics.h"
34 #include "llvm/IR/IntrinsicsARM.h"
35 #include "llvm/IR/IRBuilder.h"
36 #include "llvm/IR/PatternMatch.h"
37 #include "llvm/IR/Type.h"
38 #include "llvm/IR/Value.h"
39 #include "llvm/Pass.h"
40 #include "llvm/Support/Casting.h"
42 #include <algorithm>
43 #include <cassert>
44 
45 using namespace llvm;
46 
47 #define DEBUG_TYPE "arm-mve-gather-scatter-lowering"
48 
50  "enable-arm-maskedgatscat", cl::Hidden, cl::init(true),
51  cl::desc("Enable the generation of masked gathers and scatters"));
52 
53 namespace {
54 
55 class MVEGatherScatterLowering : public FunctionPass {
56 public:
57  static char ID; // Pass identification, replacement for typeid
58 
59  explicit MVEGatherScatterLowering() : FunctionPass(ID) {
61  }
62 
63  bool runOnFunction(Function &F) override;
64 
65  StringRef getPassName() const override {
66  return "MVE gather/scatter lowering";
67  }
68 
69  void getAnalysisUsage(AnalysisUsage &AU) const override {
70  AU.setPreservesCFG();
74  }
75 
76 private:
77  LoopInfo *LI = nullptr;
78 
79  // Check this is a valid gather with correct alignment
80  bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize,
81  Align Alignment);
82  // Check whether Ptr is hidden behind a bitcast and look through it
83  void lookThroughBitcast(Value *&Ptr);
84  // Check for a getelementptr and deduce base and offsets from it, on success
85  // returning the base directly and the offsets indirectly using the Offsets
86  // argument
89  // Compute the scale of this gather/scatter instruction
90  int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize);
91  // If the value is a constant, or derived from constants via additions
92  // and multilications, return its numeric value
93  Optional<int64_t> getIfConst(const Value *V);
94  // If Inst is an add instruction, check whether one summand is a
95  // constant. If so, scale this constant and return it together with
96  // the other summand.
97  std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale);
98 
99  Value *lowerGather(IntrinsicInst *I);
100  // Create a gather from a base + vector of offsets
101  Value *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr,
102  Instruction *&Root, IRBuilder<> &Builder);
103  // Create a gather from a vector of pointers
104  Value *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr,
105  IRBuilder<> &Builder, int64_t Increment = 0);
106  // Create an incrementing gather from a vector of pointers
107  Value *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr,
109  int64_t Increment = 0);
110 
111  Value *lowerScatter(IntrinsicInst *I);
112  // Create a scatter to a base + vector of offsets
113  Value *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets,
115  // Create a scatter to a vector of pointers
116  Value *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr,
118  int64_t Increment = 0);
119  // Create an incrementing scatter from a vector of pointers
120  Value *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr,
122  int64_t Increment = 0);
123 
124  // QI gathers and scatters can increment their offsets on their own if
125  // the increment is a constant value (digit)
126  Value *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *BasePtr,
127  Value *Ptr, GetElementPtrInst *GEP,
129  // QI gathers/scatters can increment their offsets on their own if the
130  // increment is a constant value (digit) - this creates a writeback QI
131  // gather/scatter
132  Value *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr,
133  Value *Ptr, unsigned TypeScale,
135 
136  // Optimise the base and offsets of the given address
137  bool optimiseAddress(Value *Address, BasicBlock *BB, LoopInfo *LI);
138  // Try to fold consecutive geps together into one
140  // Check whether these offsets could be moved out of the loop they're in
141  bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI);
142  // Pushes the given add out of the loop
143  void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex);
144  // Pushes the given mul out of the loop
145  void pushOutMul(PHINode *&Phi, Value *IncrementPerRound,
146  Value *OffsSecondOperand, unsigned LoopIncrement,
148 };
149 
150 } // end anonymous namespace
151 
153 
154 INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE,
155  "MVE gather/scattering lowering pass", false, false)
156 
158  return new MVEGatherScatterLowering();
159 }
160 
161 bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements,
162  unsigned ElemSize,
163  Align Alignment) {
164  if (((NumElements == 4 &&
165  (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) ||
166  (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) ||
167  (NumElements == 16 && ElemSize == 8)) &&
168  Alignment >= ElemSize / 8)
169  return true;
170  LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have "
171  << "valid alignment or vector type \n");
172  return false;
173 }
174 
175 static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount) {
176  // Offsets that are not of type <N x i32> are sign extended by the
177  // getelementptr instruction, and MVE gathers/scatters treat the offset as
178  // unsigned. Thus, if the element size is smaller than 32, we can only allow
179  // positive offsets - i.e., the offsets are not allowed to be variables we
180  // can't look into.
181  // Additionally, <N x i32> offsets have to either originate from a zext of a
182  // vector with element types smaller or equal the type of the gather we're
183  // looking at, or consist of constants that we can check are small enough
184  // to fit into the gather type.
185  // Thus we check that 0 < value < 2^TargetElemSize.
186  unsigned TargetElemSize = 128 / TargetElemCount;
187  unsigned OffsetElemSize = cast<FixedVectorType>(Offsets->getType())
188  ->getElementType()
189  ->getScalarSizeInBits();
190  if (OffsetElemSize != TargetElemSize || OffsetElemSize != 32) {
191  Constant *ConstOff = dyn_cast<Constant>(Offsets);
192  if (!ConstOff)
193  return false;
194  int64_t TargetElemMaxSize = (1ULL << TargetElemSize);
195  auto CheckValueSize = [TargetElemMaxSize](Value *OffsetElem) {
196  ConstantInt *OConst = dyn_cast<ConstantInt>(OffsetElem);
197  if (!OConst)
198  return false;
199  int SExtValue = OConst->getSExtValue();
200  if (SExtValue >= TargetElemMaxSize || SExtValue < 0)
201  return false;
202  return true;
203  };
204  if (isa<FixedVectorType>(ConstOff->getType())) {
205  for (unsigned i = 0; i < TargetElemCount; i++) {
206  if (!CheckValueSize(ConstOff->getAggregateElement(i)))
207  return false;
208  }
209  } else {
210  if (!CheckValueSize(ConstOff))
211  return false;
212  }
213  }
214  return true;
215 }
216 
217 Value *MVEGatherScatterLowering::checkGEP(Value *&Offsets, FixedVectorType *Ty,
219  IRBuilder<> &Builder) {
220  if (!GEP) {
221  LLVM_DEBUG(dbgs() << "masked gathers/scatters: no getelementpointer "
222  << "found\n");
223  return nullptr;
224  }
225  LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found."
226  << " Looking at intrinsic for base + vector of offsets\n");
227  Value *GEPPtr = GEP->getPointerOperand();
228  Offsets = GEP->getOperand(1);
229  if (GEPPtr->getType()->isVectorTy() ||
230  !isa<FixedVectorType>(Offsets->getType()))
231  return nullptr;
232 
233  if (GEP->getNumOperands() != 2) {
234  LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many"
235  << " operands. Expanding.\n");
236  return nullptr;
237  }
238  Offsets = GEP->getOperand(1);
239  unsigned OffsetsElemCount =
240  cast<FixedVectorType>(Offsets->getType())->getNumElements();
241  // Paranoid check whether the number of parallel lanes is the same
242  assert(Ty->getNumElements() == OffsetsElemCount);
243 
244  ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets);
245  if (ZextOffs)
246  Offsets = ZextOffs->getOperand(0);
247  FixedVectorType *OffsetType = cast<FixedVectorType>(Offsets->getType());
248 
249  // If the offsets are already being zext-ed to <N x i32>, that relieves us of
250  // having to make sure that they won't overflow.
251  if (!ZextOffs || cast<FixedVectorType>(ZextOffs->getDestTy())
252  ->getElementType()
253  ->getScalarSizeInBits() != 32)
254  if (!checkOffsetSize(Offsets, OffsetsElemCount))
255  return nullptr;
256 
257  // The offset sizes have been checked; if any truncating or zext-ing is
258  // required to fix them, do that now
259  if (Ty != Offsets->getType()) {
260  if ((Ty->getElementType()->getScalarSizeInBits() <
261  OffsetType->getElementType()->getScalarSizeInBits())) {
262  Offsets = Builder.CreateTrunc(Offsets, Ty);
263  } else {
264  Offsets = Builder.CreateZExt(Offsets, VectorType::getInteger(Ty));
265  }
266  }
267  // If none of the checks failed, return the gep's base pointer
268  LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n");
269  return GEPPtr;
270 }
271 
272 void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) {
273  // Look through bitcast instruction if #elements is the same
274  if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) {
275  auto *BCTy = cast<FixedVectorType>(BitCast->getType());
276  auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType());
277  if (BCTy->getNumElements() == BCSrcTy->getNumElements()) {
278  LLVM_DEBUG(dbgs() << "masked gathers/scatters: looking through "
279  << "bitcast\n");
280  Ptr = BitCast->getOperand(0);
281  }
282  }
283 }
284 
285 int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize,
286  unsigned MemoryElemSize) {
287  // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2,
288  // or a 8bit, 16bit or 32bit load/store scaled by 1
289  if (GEPElemSize == 32 && MemoryElemSize == 32)
290  return 2;
291  else if (GEPElemSize == 16 && MemoryElemSize == 16)
292  return 1;
293  else if (GEPElemSize == 8)
294  return 0;
295  LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't "
296  << "create intrinsic\n");
297  return -1;
298 }
299 
300 Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {
301  const Constant *C = dyn_cast<Constant>(V);
302  if (C != nullptr)
303  return Optional<int64_t>{C->getUniqueInteger().getSExtValue()};
304  if (!isa<Instruction>(V))
305  return Optional<int64_t>{};
306 
307  const Instruction *I = cast<Instruction>(V);
308  if (I->getOpcode() == Instruction::Add ||
309  I->getOpcode() == Instruction::Mul) {
310  Optional<int64_t> Op0 = getIfConst(I->getOperand(0));
311  Optional<int64_t> Op1 = getIfConst(I->getOperand(1));
312  if (!Op0 || !Op1)
313  return Optional<int64_t>{};
314  if (I->getOpcode() == Instruction::Add)
315  return Optional<int64_t>{Op0.getValue() + Op1.getValue()};
316  if (I->getOpcode() == Instruction::Mul)
317  return Optional<int64_t>{Op0.getValue() * Op1.getValue()};
318  }
319  return Optional<int64_t>{};
320 }
321 
322 std::pair<Value *, int64_t>
323 MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) {
324  std::pair<Value *, int64_t> ReturnFalse =
325  std::pair<Value *, int64_t>(nullptr, 0);
326  // At this point, the instruction we're looking at must be an add or we
327  // bail out
328  Instruction *Add = dyn_cast<Instruction>(Inst);
329  if (Add == nullptr || Add->getOpcode() != Instruction::Add)
330  return ReturnFalse;
331 
332  Value *Summand;
334  // Find out which operand the value that is increased is
335  if ((Const = getIfConst(Add->getOperand(0))))
336  Summand = Add->getOperand(1);
337  else if ((Const = getIfConst(Add->getOperand(1))))
338  Summand = Add->getOperand(0);
339  else
340  return ReturnFalse;
341 
342  // Check that the constant is small enough for an incrementing gather
343  int64_t Immediate = Const.getValue() << TypeScale;
344  if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0)
345  return ReturnFalse;
346 
347  return std::pair<Value *, int64_t>(Summand, Immediate);
348 }
349 
350 Value *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) {
351  using namespace PatternMatch;
352  LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n"
353  << *I << "\n");
354 
355  // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
356  // Attempt to turn the masked gather in I into a MVE intrinsic
357  // Potentially optimising the addressing modes as we do so.
358  auto *Ty = cast<FixedVectorType>(I->getType());
359  Value *Ptr = I->getArgOperand(0);
360  Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue();
361  Value *Mask = I->getArgOperand(2);
362  Value *PassThru = I->getArgOperand(3);
363 
364  if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
365  Alignment))
366  return nullptr;
367  lookThroughBitcast(Ptr);
368  assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
369 
370  IRBuilder<> Builder(I->getContext());
371  Builder.SetInsertPoint(I);
372  Builder.SetCurrentDebugLocation(I->getDebugLoc());
373 
374  Instruction *Root = I;
375  Value *Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder);
376  if (!Load)
377  Load = tryCreateMaskedGatherBase(I, Ptr, Builder);
378  if (!Load)
379  return nullptr;
380 
381  if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) {
382  LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - "
383  << "creating select\n");
384  Load = Builder.CreateSelect(Mask, Load, PassThru);
385  }
386 
387  Root->replaceAllUsesWith(Load);
388  Root->eraseFromParent();
389  if (Root != I)
390  // If this was an extending gather, we need to get rid of the sext/zext
391  // sext/zext as well as of the gather itself
392  I->eraseFromParent();
393 
394  LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n"
395  << *Load << "\n");
396  return Load;
397 }
398 
399 Value *MVEGatherScatterLowering::tryCreateMaskedGatherBase(IntrinsicInst *I,
400  Value *Ptr,
402  int64_t Increment) {
403  using namespace PatternMatch;
404  auto *Ty = cast<FixedVectorType>(I->getType());
405  LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");
406  if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
407  // Can't build an intrinsic for this
408  return nullptr;
409  Value *Mask = I->getArgOperand(2);
410  if (match(Mask, m_One()))
411  return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base,
412  {Ty, Ptr->getType()},
413  {Ptr, Builder.getInt32(Increment)});
414  else
415  return Builder.CreateIntrinsic(
416  Intrinsic::arm_mve_vldr_gather_base_predicated,
417  {Ty, Ptr->getType(), Mask->getType()},
418  {Ptr, Builder.getInt32(Increment), Mask});
419 }
420 
421 Value *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB(
422  IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
423  using namespace PatternMatch;
424  auto *Ty = cast<FixedVectorType>(I->getType());
425  LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers with "
426  << "writeback\n");
427  if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
428  // Can't build an intrinsic for this
429  return nullptr;
430  Value *Mask = I->getArgOperand(2);
431  if (match(Mask, m_One()))
432  return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb,
433  {Ty, Ptr->getType()},
434  {Ptr, Builder.getInt32(Increment)});
435  else
436  return Builder.CreateIntrinsic(
437  Intrinsic::arm_mve_vldr_gather_base_wb_predicated,
438  {Ty, Ptr->getType(), Mask->getType()},
439  {Ptr, Builder.getInt32(Increment), Mask});
440 }
441 
442 Value *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
443  IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) {
444  using namespace PatternMatch;
445 
446  Type *OriginalTy = I->getType();
447  Type *ResultTy = OriginalTy;
448 
449  unsigned Unsigned = 1;
450  // The size of the gather was already checked in isLegalTypeAndAlignment;
451  // if it was not a full vector width an appropriate extend should follow.
452  auto *Extend = Root;
453  if (OriginalTy->getPrimitiveSizeInBits() < 128) {
454  // Only transform gathers with exactly one use
455  if (!I->hasOneUse())
456  return nullptr;
457 
458  // The correct root to replace is not the CallInst itself, but the
459  // instruction which extends it
460  Extend = cast<Instruction>(*I->users().begin());
461  if (isa<SExtInst>(Extend)) {
462  Unsigned = 0;
463  } else if (!isa<ZExtInst>(Extend)) {
464  LLVM_DEBUG(dbgs() << "masked gathers: extend needed but not provided. "
465  << "Expanding\n");
466  return nullptr;
467  }
468  LLVM_DEBUG(dbgs() << "masked gathers: found an extending gather\n");
469  ResultTy = Extend->getType();
470  // The final size of the gather must be a full vector width
471  if (ResultTy->getPrimitiveSizeInBits() != 128) {
472  LLVM_DEBUG(dbgs() << "masked gathers: extending from the wrong type. "
473  << "Expanding\n");
474  return nullptr;
475  }
476  }
477 
478  GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
479  Value *Offsets;
480  Value *BasePtr =
481  checkGEP(Offsets, cast<FixedVectorType>(ResultTy), GEP, Builder);
482  if (!BasePtr)
483  return nullptr;
484  // Check whether the offset is a constant increment that could be merged into
485  // a QI gather
486  Value *Load = tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder);
487  if (Load)
488  return Load;
489 
490  int Scale =
491  computeScale(GEP->getSourceElementType()->getPrimitiveSizeInBits(),
492  OriginalTy->getScalarSizeInBits());
493  if (Scale == -1)
494  return nullptr;
495  Root = Extend;
496 
497  Value *Mask = I->getArgOperand(2);
498  if (!match(Mask, m_One()))
499  return Builder.CreateIntrinsic(
500  Intrinsic::arm_mve_vldr_gather_offset_predicated,
501  {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()},
502  {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()),
503  Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask});
504  else
505  return Builder.CreateIntrinsic(
506  Intrinsic::arm_mve_vldr_gather_offset,
507  {ResultTy, BasePtr->getType(), Offsets->getType()},
508  {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()),
509  Builder.getInt32(Scale), Builder.getInt32(Unsigned)});
510 }
511 
512 Value *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) {
513  using namespace PatternMatch;
514  LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n"
515  << *I << "\n");
516 
517  // @llvm.masked.scatter.*(data, ptrs, alignment, mask)
518  // Attempt to turn the masked scatter in I into a MVE intrinsic
519  // Potentially optimising the addressing modes as we do so.
520  Value *Input = I->getArgOperand(0);
521  Value *Ptr = I->getArgOperand(1);
522  Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue();
523  auto *Ty = cast<FixedVectorType>(Input->getType());
524 
525  if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
526  Alignment))
527  return nullptr;
528 
529  lookThroughBitcast(Ptr);
530  assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
531 
532  IRBuilder<> Builder(I->getContext());
533  Builder.SetInsertPoint(I);
534  Builder.SetCurrentDebugLocation(I->getDebugLoc());
535 
536  Value *Store = tryCreateMaskedScatterOffset(I, Ptr, Builder);
537  if (!Store)
538  Store = tryCreateMaskedScatterBase(I, Ptr, Builder);
539  if (!Store)
540  return nullptr;
541 
542  LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n"
543  << *Store << "\n");
544  I->eraseFromParent();
545  return Store;
546 }
547 
548 Value *MVEGatherScatterLowering::tryCreateMaskedScatterBase(
549  IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
550  using namespace PatternMatch;
551  Value *Input = I->getArgOperand(0);
552  auto *Ty = cast<FixedVectorType>(Input->getType());
553  // Only QR variants allow truncating
554  if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) {
555  // Can't build an intrinsic for this
556  return nullptr;
557  }
558  Value *Mask = I->getArgOperand(3);
559  // int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask)
560  LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n");
561  if (match(Mask, m_One()))
562  return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base,
563  {Ptr->getType(), Input->getType()},
564  {Ptr, Builder.getInt32(Increment), Input});
565  else
566  return Builder.CreateIntrinsic(
567  Intrinsic::arm_mve_vstr_scatter_base_predicated,
568  {Ptr->getType(), Input->getType(), Mask->getType()},
569  {Ptr, Builder.getInt32(Increment), Input, Mask});
570 }
571 
572 Value *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB(
573  IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
574  using namespace PatternMatch;
575  Value *Input = I->getArgOperand(0);
576  auto *Ty = cast<FixedVectorType>(Input->getType());
577  LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers "
578  << "with writeback\n");
579  if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
580  // Can't build an intrinsic for this
581  return nullptr;
582  Value *Mask = I->getArgOperand(3);
583  if (match(Mask, m_One()))
584  return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb,
585  {Ptr->getType(), Input->getType()},
586  {Ptr, Builder.getInt32(Increment), Input});
587  else
588  return Builder.CreateIntrinsic(
589  Intrinsic::arm_mve_vstr_scatter_base_wb_predicated,
590  {Ptr->getType(), Input->getType(), Mask->getType()},
591  {Ptr, Builder.getInt32(Increment), Input, Mask});
592 }
593 
594 Value *MVEGatherScatterLowering::tryCreateMaskedScatterOffset(
596  using namespace PatternMatch;
597  Value *Input = I->getArgOperand(0);
598  Value *Mask = I->getArgOperand(3);
599  Type *InputTy = Input->getType();
600  Type *MemoryTy = InputTy;
601  LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing"
602  << " to base + vector of offsets\n");
603  // If the input has been truncated, try to integrate that trunc into the
604  // scatter instruction (we don't care about alignment here)
605  if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) {
606  Value *PreTrunc = Trunc->getOperand(0);
607  Type *PreTruncTy = PreTrunc->getType();
608  if (PreTruncTy->getPrimitiveSizeInBits() == 128) {
609  Input = PreTrunc;
610  InputTy = PreTruncTy;
611  }
612  }
613  if (InputTy->getPrimitiveSizeInBits() != 128) {
614  LLVM_DEBUG(dbgs() << "masked scatters: cannot create scatters for "
615  "non-standard input types. Expanding.\n");
616  return nullptr;
617  }
618 
619  GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
620  Value *Offsets;
621  Value *BasePtr =
622  checkGEP(Offsets, cast<FixedVectorType>(InputTy), GEP, Builder);
623  if (!BasePtr)
624  return nullptr;
625  // Check whether the offset is a constant increment that could be merged into
626  // a QI gather
627  Value *Store =
628  tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder);
629  if (Store)
630  return Store;
631  int Scale =
632  computeScale(GEP->getSourceElementType()->getPrimitiveSizeInBits(),
633  MemoryTy->getScalarSizeInBits());
634  if (Scale == -1)
635  return nullptr;
636 
637  if (!match(Mask, m_One()))
638  return Builder.CreateIntrinsic(
639  Intrinsic::arm_mve_vstr_scatter_offset_predicated,
640  {BasePtr->getType(), Offsets->getType(), Input->getType(),
641  Mask->getType()},
642  {BasePtr, Offsets, Input,
643  Builder.getInt32(MemoryTy->getScalarSizeInBits()),
644  Builder.getInt32(Scale), Mask});
645  else
646  return Builder.CreateIntrinsic(
647  Intrinsic::arm_mve_vstr_scatter_offset,
648  {BasePtr->getType(), Offsets->getType(), Input->getType()},
649  {BasePtr, Offsets, Input,
650  Builder.getInt32(MemoryTy->getScalarSizeInBits()),
651  Builder.getInt32(Scale)});
652 }
653 
654 Value *MVEGatherScatterLowering::tryCreateIncrementingGatScat(
656  IRBuilder<> &Builder) {
657  FixedVectorType *Ty;
658  if (I->getIntrinsicID() == Intrinsic::masked_gather)
659  Ty = cast<FixedVectorType>(I->getType());
660  else
661  Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType());
662  // Incrementing gathers only exist for v4i32
663  if (Ty->getNumElements() != 4 ||
664  Ty->getScalarSizeInBits() != 32)
665  return nullptr;
666  Loop *L = LI->getLoopFor(I->getParent());
667  if (L == nullptr)
668  // Incrementing gathers are not beneficial outside of a loop
669  return nullptr;
670  LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
671  "wb gather/scatter\n");
672 
673  // The gep was in charge of making sure the offsets are scaled correctly
674  // - calculate that factor so it can be applied by hand
675  DataLayout DT = I->getParent()->getParent()->getParent()->getDataLayout();
676  int TypeScale =
677  computeScale(DT.getTypeSizeInBits(GEP->getOperand(0)->getType()),
678  DT.getTypeSizeInBits(GEP->getType()) /
679  cast<FixedVectorType>(GEP->getType())->getNumElements());
680  if (TypeScale == -1)
681  return nullptr;
682 
683  if (GEP->hasOneUse()) {
684  // Only in this case do we want to build a wb gather, because the wb will
685  // change the phi which does affect other users of the gep (which will still
686  // be using the phi in the old way)
687  Value *Load =
688  tryCreateIncrementingWBGatScat(I, BasePtr, Offsets, TypeScale, Builder);
689  if (Load != nullptr)
690  return Load;
691  }
692  LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
693  "non-wb gather/scatter\n");
694 
695  std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
696  if (Add.first == nullptr)
697  return nullptr;
698  Value *OffsetsIncoming = Add.first;
699  int64_t Immediate = Add.second;
700 
701  // Make sure the offsets are scaled correctly
702  Instruction *ScaledOffsets = BinaryOperator::Create(
703  Instruction::Shl, OffsetsIncoming,
704  Builder.CreateVectorSplat(Ty->getNumElements(), Builder.getInt32(TypeScale)),
705  "ScaledIndex", I);
706  // Add the base to the offsets
707  OffsetsIncoming = BinaryOperator::Create(
708  Instruction::Add, ScaledOffsets,
709  Builder.CreateVectorSplat(
710  Ty->getNumElements(),
711  Builder.CreatePtrToInt(
712  BasePtr,
713  cast<VectorType>(ScaledOffsets->getType())->getElementType())),
714  "StartIndex", I);
715 
716  if (I->getIntrinsicID() == Intrinsic::masked_gather)
717  return cast<IntrinsicInst>(
718  tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate));
719  else
720  return cast<IntrinsicInst>(
721  tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate));
722 }
723 
724 Value *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat(
725  IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale,
726  IRBuilder<> &Builder) {
727  // Check whether this gather's offset is incremented by a constant - if so,
728  // and the load is of the right type, we can merge this into a QI gather
729  Loop *L = LI->getLoopFor(I->getParent());
730  // Offsets that are worth merging into this instruction will be incremented
731  // by a constant, thus we're looking for an add of a phi and a constant
732  PHINode *Phi = dyn_cast<PHINode>(Offsets);
733  if (Phi == nullptr || Phi->getNumIncomingValues() != 2 ||
734  Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2)
735  // No phi means no IV to write back to; if there is a phi, we expect it
736  // to have exactly two incoming values; the only phis we are interested in
737  // will be loop IV's and have exactly two uses, one in their increment and
738  // one in the gather's gep
739  return nullptr;
740 
741  unsigned IncrementIndex =
742  Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1;
743  // Look through the phi to the phi increment
744  Offsets = Phi->getIncomingValue(IncrementIndex);
745 
746  std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
747  if (Add.first == nullptr)
748  return nullptr;
749  Value *OffsetsIncoming = Add.first;
750  int64_t Immediate = Add.second;
751  if (OffsetsIncoming != Phi)
752  // Then the increment we are looking at is not an increment of the
753  // induction variable, and we don't want to do a writeback
754  return nullptr;
755 
756  Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back());
757  unsigned NumElems =
758  cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements();
759 
760  // Make sure the offsets are scaled correctly
761  Instruction *ScaledOffsets = BinaryOperator::Create(
762  Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex),
763  Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)),
764  "ScaledIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
765  // Add the base to the offsets
766  OffsetsIncoming = BinaryOperator::Create(
767  Instruction::Add, ScaledOffsets,
768  Builder.CreateVectorSplat(
769  NumElems,
770  Builder.CreatePtrToInt(
771  BasePtr,
772  cast<VectorType>(ScaledOffsets->getType())->getElementType())),
773  "StartIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
774  // The gather is pre-incrementing
775  OffsetsIncoming = BinaryOperator::Create(
776  Instruction::Sub, OffsetsIncoming,
777  Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)),
778  "PreIncrementStartIndex",
779  &Phi->getIncomingBlock(1 - IncrementIndex)->back());
780  Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming);
781 
782  Builder.SetInsertPoint(I);
783 
784  Value *EndResult;
785  Value *NewInduction;
786  if (I->getIntrinsicID() == Intrinsic::masked_gather) {
787  // Build the incrementing gather
788  Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate);
789  // One value to be handed to whoever uses the gather, one is the loop
790  // increment
791  EndResult = Builder.CreateExtractValue(Load, 0, "Gather");
792  NewInduction = Builder.CreateExtractValue(Load, 1, "GatherIncrement");
793  } else {
794  // Build the incrementing scatter
795  NewInduction = tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate);
796  EndResult = NewInduction;
797  }
798  Instruction *AddInst = cast<Instruction>(Offsets);
799  AddInst->replaceAllUsesWith(NewInduction);
800  AddInst->eraseFromParent();
801  Phi->setIncomingValue(IncrementIndex, NewInduction);
802 
803  return EndResult;
804 }
805 
806 void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi,
807  Value *OffsSecondOperand,
808  unsigned StartIndex) {
809  LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n");
810  Instruction *InsertionPoint =
811  &cast<Instruction>(Phi->getIncomingBlock(StartIndex)->back());
812  // Initialize the phi with a vector that contains a sum of the constants
814  Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand,
815  "PushedOutAdd", InsertionPoint);
816  unsigned IncrementIndex = StartIndex == 0 ? 1 : 0;
817 
818  // Order such that start index comes first (this reduces mov's)
819  Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex));
820  Phi->addIncoming(Phi->getIncomingValue(IncrementIndex),
821  Phi->getIncomingBlock(IncrementIndex));
822  Phi->removeIncomingValue(IncrementIndex);
823  Phi->removeIncomingValue(StartIndex);
824 }
825 
826 void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi,
827  Value *IncrementPerRound,
828  Value *OffsSecondOperand,
829  unsigned LoopIncrement,
830  IRBuilder<> &Builder) {
831  LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n");
832 
833  // Create a new scalar add outside of the loop and transform it to a splat
834  // by which loop variable can be incremented
835  Instruction *InsertionPoint = &cast<Instruction>(
836  Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back());
837 
838  // Create a new index
839  Value *StartIndex = BinaryOperator::Create(
840  Instruction::Mul, Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1),
841  OffsSecondOperand, "PushedOutMul", InsertionPoint);
842 
843  Instruction *Product =
844  BinaryOperator::Create(Instruction::Mul, IncrementPerRound,
845  OffsSecondOperand, "Product", InsertionPoint);
846  // Increment NewIndex by Product instead of the multiplication
847  Instruction *NewIncrement = BinaryOperator::Create(
848  Instruction::Add, Phi, Product, "IncrementPushedOutMul",
849  cast<Instruction>(Phi->getIncomingBlock(LoopIncrement)->back())
850  .getPrevNode());
851 
852  Phi->addIncoming(StartIndex,
853  Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1));
854  Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement));
855  Phi->removeIncomingValue((unsigned)0);
856  Phi->removeIncomingValue((unsigned)0);
857 }
858 
859 // Check whether all usages of this instruction are as offsets of
860 // gathers/scatters or simple arithmetics only used by gathers/scatters
862  if (I->hasNUses(0)) {
863  return false;
864  }
865  bool Gatscat = true;
866  for (User *U : I->users()) {
867  if (!isa<Instruction>(U))
868  return false;
869  if (isa<GetElementPtrInst>(U) ||
870  isGatherScatter(dyn_cast<IntrinsicInst>(U))) {
871  return Gatscat;
872  } else {
873  unsigned OpCode = cast<Instruction>(U)->getOpcode();
874  if ((OpCode == Instruction::Add || OpCode == Instruction::Mul) &&
875  hasAllGatScatUsers(cast<Instruction>(U))) {
876  continue;
877  }
878  return false;
879  }
880  }
881  return Gatscat;
882 }
883 
884 bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
885  LoopInfo *LI) {
886  LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize\n"
887  << *Offsets << "\n");
888  // Optimise the addresses of gathers/scatters by moving invariant
889  // calculations out of the loop
890  if (!isa<Instruction>(Offsets))
891  return false;
892  Instruction *Offs = cast<Instruction>(Offsets);
893  if (Offs->getOpcode() != Instruction::Add &&
894  Offs->getOpcode() != Instruction::Mul)
895  return false;
896  Loop *L = LI->getLoopFor(BB);
897  if (L == nullptr)
898  return false;
899  if (!Offs->hasOneUse()) {
900  if (!hasAllGatScatUsers(Offs))
901  return false;
902  }
903 
904  // Find out which, if any, operand of the instruction
905  // is a phi node
906  PHINode *Phi;
907  int OffsSecondOp;
908  if (isa<PHINode>(Offs->getOperand(0))) {
909  Phi = cast<PHINode>(Offs->getOperand(0));
910  OffsSecondOp = 1;
911  } else if (isa<PHINode>(Offs->getOperand(1))) {
912  Phi = cast<PHINode>(Offs->getOperand(1));
913  OffsSecondOp = 0;
914  } else {
915  bool Changed = true;
916  if (isa<Instruction>(Offs->getOperand(0)) &&
917  L->contains(cast<Instruction>(Offs->getOperand(0))))
918  Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI);
919  if (isa<Instruction>(Offs->getOperand(1)) &&
920  L->contains(cast<Instruction>(Offs->getOperand(1))))
921  Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI);
922  if (!Changed) {
923  return false;
924  } else {
925  if (isa<PHINode>(Offs->getOperand(0))) {
926  Phi = cast<PHINode>(Offs->getOperand(0));
927  OffsSecondOp = 1;
928  } else if (isa<PHINode>(Offs->getOperand(1))) {
929  Phi = cast<PHINode>(Offs->getOperand(1));
930  OffsSecondOp = 0;
931  } else {
932  return false;
933  }
934  }
935  }
936  // A phi node we want to perform this function on should be from the
937  // loop header, and shouldn't have more than 2 incoming values
938  if (Phi->getParent() != L->getHeader() ||
939  Phi->getNumIncomingValues() != 2)
940  return false;
941 
942  // The phi must be an induction variable
943  int IncrementingBlock = -1;
944 
945  for (int i = 0; i < 2; i++)
946  if (auto *Op = dyn_cast<Instruction>(Phi->getIncomingValue(i)))
947  if (Op->getOpcode() == Instruction::Add &&
948  (Op->getOperand(0) == Phi || Op->getOperand(1) == Phi))
949  IncrementingBlock = i;
950  if (IncrementingBlock == -1)
951  return false;
952 
953  Instruction *IncInstruction =
954  cast<Instruction>(Phi->getIncomingValue(IncrementingBlock));
955 
956  // If the phi is not used by anything else, we can just adapt it when
957  // replacing the instruction; if it is, we'll have to duplicate it
958  PHINode *NewPhi;
959  Value *IncrementPerRound = IncInstruction->getOperand(
960  (IncInstruction->getOperand(0) == Phi) ? 1 : 0);
961 
962  // Get the value that is added to/multiplied with the phi
963  Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp);
964 
965  if (IncrementPerRound->getType() != OffsSecondOperand->getType() ||
966  !L->isLoopInvariant(OffsSecondOperand))
967  // Something has gone wrong, abort
968  return false;
969 
970  // Only proceed if the increment per round is a constant or an instruction
971  // which does not originate from within the loop
972  if (!isa<Constant>(IncrementPerRound) &&
973  !(isa<Instruction>(IncrementPerRound) &&
974  !L->contains(cast<Instruction>(IncrementPerRound))))
975  return false;
976 
977  if (Phi->getNumUses() == 2) {
978  // No other users -> reuse existing phi (One user is the instruction
979  // we're looking at, the other is the phi increment)
980  if (IncInstruction->getNumUses() != 1) {
981  // If the incrementing instruction does have more users than
982  // our phi, we need to copy it
983  IncInstruction = BinaryOperator::Create(
984  Instruction::BinaryOps(IncInstruction->getOpcode()), Phi,
985  IncrementPerRound, "LoopIncrement", IncInstruction);
986  Phi->setIncomingValue(IncrementingBlock, IncInstruction);
987  }
988  NewPhi = Phi;
989  } else {
990  // There are other users -> create a new phi
991  NewPhi = PHINode::Create(Phi->getType(), 0, "NewPhi", Phi);
992  std::vector<Value *> Increases;
993  // Copy the incoming values of the old phi
994  NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1),
995  Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1));
996  IncInstruction = BinaryOperator::Create(
997  Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi,
998  IncrementPerRound, "LoopIncrement", IncInstruction);
999  NewPhi->addIncoming(IncInstruction,
1000  Phi->getIncomingBlock(IncrementingBlock));
1001  IncrementingBlock = 1;
1002  }
1003 
1004  IRBuilder<> Builder(BB->getContext());
1005  Builder.SetInsertPoint(Phi);
1006  Builder.SetCurrentDebugLocation(Offs->getDebugLoc());
1007 
1008  switch (Offs->getOpcode()) {
1009  case Instruction::Add:
1010  pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1);
1011  break;
1012  case Instruction::Mul:
1013  pushOutMul(NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock,
1014  Builder);
1015  break;
1016  default:
1017  return false;
1018  }
1019  LLVM_DEBUG(dbgs() << "masked gathers/scatters: simplified loop variable "
1020  << "add/mul\n");
1021 
1022  // The instruction has now been "absorbed" into the phi value
1023  Offs->replaceAllUsesWith(NewPhi);
1024  if (Offs->hasNUses(0))
1025  Offs->eraseFromParent();
1026  // Clean up the old increment in case it's unused because we built a new
1027  // one
1028  if (IncInstruction->hasNUses(0))
1029  IncInstruction->eraseFromParent();
1030 
1031  return true;
1032 }
1033 
1035  IRBuilder<> &Builder) {
1036  // Splat the non-vector value to a vector of the given type - if the value is
1037  // a constant (and its value isn't too big), we can even use this opportunity
1038  // to scale it to the size of the vector elements
1039  auto FixSummands = [&Builder](FixedVectorType *&VT, Value *&NonVectorVal) {
1040  ConstantInt *Const;
1041  if ((Const = dyn_cast<ConstantInt>(NonVectorVal)) &&
1042  VT->getElementType() != NonVectorVal->getType()) {
1043  unsigned TargetElemSize = VT->getElementType()->getPrimitiveSizeInBits();
1044  uint64_t N = Const->getZExtValue();
1045  if (N < (unsigned)(1 << (TargetElemSize - 1))) {
1046  NonVectorVal = Builder.CreateVectorSplat(
1047  VT->getNumElements(), Builder.getIntN(TargetElemSize, N));
1048  return;
1049  }
1050  }
1051  NonVectorVal =
1052  Builder.CreateVectorSplat(VT->getNumElements(), NonVectorVal);
1053  };
1054 
1055  FixedVectorType *XElType = dyn_cast<FixedVectorType>(X->getType());
1056  FixedVectorType *YElType = dyn_cast<FixedVectorType>(Y->getType());
1057  // If one of X, Y is not a vector, we have to splat it in order
1058  // to add the two of them.
1059  if (XElType && !YElType) {
1060  FixSummands(XElType, Y);
1061  YElType = cast<FixedVectorType>(Y->getType());
1062  } else if (YElType && !XElType) {
1063  FixSummands(YElType, X);
1064  XElType = cast<FixedVectorType>(X->getType());
1065  }
1066  assert(XElType && YElType && "Unknown vector types");
1067  // Check that the summands are of compatible types
1068  if (XElType != YElType) {
1069  LLVM_DEBUG(dbgs() << "masked gathers/scatters: incompatible gep offsets\n");
1070  return nullptr;
1071  }
1072 
1073  if (XElType->getElementType()->getScalarSizeInBits() != 32) {
1074  // Check that by adding the vectors we do not accidentally
1075  // create an overflow
1076  Constant *ConstX = dyn_cast<Constant>(X);
1077  Constant *ConstY = dyn_cast<Constant>(Y);
1078  if (!ConstX || !ConstY)
1079  return nullptr;
1080  unsigned TargetElemSize = 128 / XElType->getNumElements();
1081  for (unsigned i = 0; i < XElType->getNumElements(); i++) {
1082  ConstantInt *ConstXEl =
1083  dyn_cast<ConstantInt>(ConstX->getAggregateElement(i));
1084  ConstantInt *ConstYEl =
1085  dyn_cast<ConstantInt>(ConstY->getAggregateElement(i));
1086  if (!ConstXEl || !ConstYEl ||
1087  ConstXEl->getZExtValue() + ConstYEl->getZExtValue() >=
1088  (unsigned)(1 << (TargetElemSize - 1)))
1089  return nullptr;
1090  }
1091  }
1092 
1093  Value *Add = Builder.CreateAdd(X, Y);
1094 
1095  FixedVectorType *GEPType = cast<FixedVectorType>(GEP->getType());
1096  if (checkOffsetSize(Add, GEPType->getNumElements()))
1097  return Add;
1098  else
1099  return nullptr;
1100 }
1101 
1102 Value *MVEGatherScatterLowering::foldGEP(GetElementPtrInst *GEP,
1103  Value *&Offsets,
1104  IRBuilder<> &Builder) {
1105  Value *GEPPtr = GEP->getPointerOperand();
1106  Offsets = GEP->getOperand(1);
1107  // We only merge geps with constant offsets, because only for those
1108  // we can make sure that we do not cause an overflow
1109  if (!isa<Constant>(Offsets))
1110  return nullptr;
1111  GetElementPtrInst *BaseGEP;
1112  if ((BaseGEP = dyn_cast<GetElementPtrInst>(GEPPtr))) {
1113  // Merge the two geps into one
1114  Value *BaseBasePtr = foldGEP(BaseGEP, Offsets, Builder);
1115  if (!BaseBasePtr)
1116  return nullptr;
1117  Offsets =
1118  CheckAndCreateOffsetAdd(Offsets, GEP->getOperand(1), GEP, Builder);
1119  if (Offsets == nullptr)
1120  return nullptr;
1121  return BaseBasePtr;
1122  }
1123  return GEPPtr;
1124 }
1125 
1126 bool MVEGatherScatterLowering::optimiseAddress(Value *Address, BasicBlock *BB,
1127  LoopInfo *LI) {
1128  GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Address);
1129  if (!GEP)
1130  return false;
1131  bool Changed = false;
1132  if (GEP->hasOneUse() &&
1133  dyn_cast<GetElementPtrInst>(GEP->getPointerOperand())) {
1134  IRBuilder<> Builder(GEP->getContext());
1135  Builder.SetInsertPoint(GEP);
1136  Builder.SetCurrentDebugLocation(GEP->getDebugLoc());
1137  Value *Offsets;
1138  Value *Base = foldGEP(GEP, Offsets, Builder);
1139  // We only want to merge the geps if there is a real chance that they can be
1140  // used by an MVE gather; thus the offset has to have the correct size
1141  // (always i32 if it is not of vector type) and the base has to be a
1142  // pointer.
1143  if (Offsets && Base && Base != GEP) {
1145  GEP->getSourceElementType(), Base, Offsets, "gep.merged", GEP);
1146  GEP->replaceAllUsesWith(NewAddress);
1147  GEP = NewAddress;
1148  Changed = true;
1149  }
1150  }
1151  Changed |= optimiseOffsets(GEP->getOperand(1), GEP->getParent(), LI);
1152  return Changed;
1153 }
1154 
1157  return false;
1158  auto &TPC = getAnalysis<TargetPassConfig>();
1159  auto &TM = TPC.getTM<TargetMachine>();
1160  auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
1161  if (!ST->hasMVEIntegerOps())
1162  return false;
1163  LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1166 
1167  bool Changed = false;
1168 
1169  for (BasicBlock &BB : F) {
1170  for (Instruction &I : BB) {
1171  IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
1172  if (II && II->getIntrinsicID() == Intrinsic::masked_gather &&
1173  isa<FixedVectorType>(II->getType())) {
1174  Gathers.push_back(II);
1175  Changed |= optimiseAddress(II->getArgOperand(0), II->getParent(), LI);
1176  } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter &&
1177  isa<FixedVectorType>(II->getArgOperand(0)->getType())) {
1178  Scatters.push_back(II);
1179  Changed |= optimiseAddress(II->getArgOperand(1), II->getParent(), LI);
1180  }
1181  }
1182  }
1183  for (unsigned i = 0; i < Gathers.size(); i++) {
1184  IntrinsicInst *I = Gathers[i];
1185  Value *L = lowerGather(I);
1186  if (L == nullptr)
1187  continue;
1188 
1189  // Get rid of any now dead instructions
1190  SimplifyInstructionsInBlock(cast<Instruction>(L)->getParent());
1191  Changed = true;
1192  }
1193 
1194  for (unsigned i = 0; i < Scatters.size(); i++) {
1195  IntrinsicInst *I = Scatters[i];
1196  Value *S = lowerScatter(I);
1197  if (S == nullptr)
1198  continue;
1199 
1200  // Get rid of any now dead instructions
1201  SimplifyInstructionsInBlock(cast<Instruction>(S)->getParent());
1202  Changed = true;
1203  }
1204  return Changed;
1205 }
i
i
Definition: README.txt:29
ARMSubtarget.h
llvm::Loop::isLoopInvariant
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition: LoopInfo.cpp:64
hasAllGatScatUsers
static bool hasAllGatScatUsers(Instruction *I)
Definition: MVEGatherScatterLowering.cpp:861
llvm
Definition: AllocatorList.h:23
CheckAndCreateOffsetAdd
static Value * CheckAndCreateOffsetAdd(Value *X, Value *Y, Value *GEP, IRBuilder<> &Builder)
Definition: MVEGatherScatterLowering.cpp:1034
llvm::DataLayout
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:112
llvm::Value::hasOneUse
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:435
IntrinsicInst.h
llvm::ARMSubtarget
Definition: ARMSubtarget.h:46
llvm::Function
Definition: Function.h:61
llvm::Loop
Represents a single loop in the control flow graph.
Definition: LoopInfo.h:530
Pass.h
llvm::LoopBase::contains
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
Definition: LoopInfo.h:122
llvm::IntrinsicInst::getIntrinsicID
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
Definition: IntrinsicInst.h:52
llvm::DataLayout::getTypeSizeInBits
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition: DataLayout.h:655
llvm::SmallVector
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1167
llvm::IRBuilder<>
llvm::PHINode::removeIncomingValue
Value * removeIncomingValue(unsigned Idx, bool DeletePHIIfEmpty=true)
Remove an incoming value.
Definition: Instructions.cpp:109
Local.h
llvm::cl::Hidden
@ Hidden
Definition: CommandLine.h:143
llvm::SPII::Load
@ Load
Definition: SparcInstrInfo.h:32
llvm::Type
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:46
llvm::LoopInfoWrapperPass
The legacy pass manager's analysis pass to compute loop information.
Definition: LoopInfo.h:1258
llvm::Optional< int64_t >
llvm::VectorType::getElementType
Type * getElementType() const
Definition: DerivedTypes.h:424
llvm::FixedVectorType
Class to represent fixed width SIMD vectors.
Definition: DerivedTypes.h:527
llvm::BitmaskEnumDetail::Mask
std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
Definition: BitmaskEnum.h:80
llvm::PHINode::setIncomingValue
void setIncomingValue(unsigned i, Value *V)
Definition: Instructions.h:2699
LLVM_DEBUG
#define LLVM_DEBUG(X)
Definition: Debug.h:122
llvm::CastInst::getDestTy
Type * getDestTy() const
Return the destination type, as a convenience.
Definition: InstrTypes.h:686
F
#define F(x, y, z)
Definition: MD5.cpp:56
llvm::BasicBlock
LLVM Basic Block Representation.
Definition: BasicBlock.h:58
llvm::dbgs
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:132
Instruction.h
llvm::FixedVectorType::getNumElements
unsigned getNumElements() const
Definition: DerivedTypes.h:570
TargetLowering.h
llvm::ConstantInt
This is the shared class of boolean and integer constants.
Definition: Constants.h:77
llvm::Instruction::getOpcode
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
Definition: Instruction.h:160
llvm::SimplifyInstructionsInBlock
bool SimplifyInstructionsInBlock(BasicBlock *BB, const TargetLibraryInfo *TLI=nullptr)
Scan the specified basic block and try to simplify any instructions in it and recursively delete dead...
Definition: Local.cpp:686
llvm::PassRegistry::getPassRegistry
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Definition: PassRegistry.cpp:31
Constants.h
llvm::PatternMatch::match
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
llvm::PHINode::getIncomingValue
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
Definition: Instructions.h:2696
INITIALIZE_PASS
INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE, "MVE gather/scattering lowering pass", false, false) Pass *llvm
Definition: MVEGatherScatterLowering.cpp:154
llvm::isGatherScatter
bool isGatherScatter(IntrinsicInst *IntInst)
Definition: ARMBaseInstrInfo.h:933
llvm::User
Definition: User.h:44
Intrinsics.h
C
(vector float) vec_cmpeq(*A, *B) C
Definition: README_ALTIVEC.txt:86
InstrTypes.h
llvm::VectorType::getInteger
static VectorType * getInteger(VectorType *VTy)
This static method gets a VectorType with the same number of elements as the input type,...
Definition: DerivedTypes.h:442
Y
static GCMetadataPrinterRegistry::Add< OcamlGCMetadataPrinter > Y("ocaml", "ocaml 3.10-compatible collector")
llvm::AnalysisUsage
Represent the analysis usage information of a pass.
Definition: PassAnalysisSupport.h:47
llvm::Type::isVectorTy
bool isVectorTy() const
True if this is an instance of VectorType.
Definition: Type.h:235
llvm::Instruction
Definition: Instruction.h:45
llvm::Type::getScalarSizeInBits
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
Definition: Type.cpp:147
DEBUG_TYPE
#define DEBUG_TYPE
Definition: MVEGatherScatterLowering.cpp:47
llvm::codeview::EncodedFramePtrReg::BasePtr
@ BasePtr
PatternMatch.h
llvm::Align
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
llvm::PHINode::getNumIncomingValues
unsigned getNumIncomingValues() const
Return the number of incoming edges.
Definition: Instructions.h:2692
Type.h
X
static GCMetadataPrinterRegistry::Add< ErlangGCPrinter > X("erlang", "erlang-compatible garbage collector")
llvm::PatternMatch::m_One
cst_pred_ty< is_one > m_One()
Match an integer 1 or a vector with all elements equal to 1.
Definition: PatternMatch.h:513
LoopInfo.h
llvm::TargetPassConfig
Target-Independent Code Generator Pass Configuration Options.
Definition: TargetPassConfig.h:84
BasicBlock.h
llvm::cl::opt< bool >
llvm::PatternMatch::m_Zero
is_zero m_Zero()
Match any null constant or a vector with all elements equal to 0.
Definition: PatternMatch.h:535
llvm::Constant
This is an important base class in LLVM.
Definition: Constant.h:41
llvm::Instruction::eraseFromParent
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:78
llvm::ARM_MB::ST
@ ST
Definition: ARMBaseInfo.h:73
llvm::TruncInst
This class represents a truncation of integer types.
Definition: Instructions.h:4721
llvm::PHINode::addIncoming
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
Definition: Instructions.h:2750
llvm::LoopInfoBase::getLoopFor
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
Definition: LoopInfo.h:964
I
#define I(x, y, z)
Definition: MD5.cpp:59
llvm::GetElementPtrInst
an instruction for type-safe pointer arithmetic to access elements of arrays and structs
Definition: Instructions.h:931
llvm::createMVEGatherScatterLoweringPass
Pass * createMVEGatherScatterLoweringPass()
llvm::cl::init
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:443
TargetPassConfig.h
llvm::LoopBase::getLoopLatch
BlockT * getLoopLatch() const
If there is a single latch block for this loop, return it.
Definition: LoopInfoImpl.h:216
IRBuilder.h
assert
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
llvm::TargetMachine
Primary interface to the complete machine description for the target machine.
Definition: TargetMachine.h:77
ARMBaseInstrInfo.h
ARM.h
Builder
assume Assume Builder
Definition: AssumeBundleBuilder.cpp:651
llvm::ZExtInst
This class represents zero extension of integer types.
Definition: Instructions.h:4760
llvm::GetElementPtrInst::Create
static GetElementPtrInst * Create(Type *PointeeType, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &NameStr="", Instruction *InsertBefore=nullptr)
Definition: Instructions.h:957
llvm::LoopInfo
Definition: LoopInfo.h:1080
llvm::AnalysisUsage::setPreservesCFG
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:253
llvm::StringRef
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:58
llvm::Constant::getAggregateElement
Constant * getAggregateElement(unsigned Elt) const
For aggregates (struct/array/vector) return the constant that corresponds to the specified element if...
Definition: Constants.cpp:420
llvm::Value::getNumUses
unsigned getNumUses() const
This method computes the number of uses of this Value.
Definition: Value.cpp:240
llvm::Value::getType
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:256
llvm::Value::replaceAllUsesWith
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:517
getParent
static const Function * getParent(const Value *V)
Definition: BasicAliasAnalysis.cpp:767
llvm::initializeMVEGatherScatterLoweringPass
void initializeMVEGatherScatterLoweringPass(PassRegistry &)
TargetSubtargetInfo.h
Unsigned
@ Unsigned
Definition: NVPTXISelLowering.cpp:4545
S
add sub stmia L5 ldr r0 bl L_printf $stub Instead of a and a wouldn t it be better to do three moves *Return an aggregate type is even return S
Definition: README.txt:210
llvm::ConstantInt::getSExtValue
int64_t getSExtValue() const
Return the constant as a 64-bit integer value after it has been sign extended as appropriate for the ...
Definition: Constants.h:146
llvm::ConstantInt::getZExtValue
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition: Constants.h:140
runOnFunction
static bool runOnFunction(Function &F, bool PostInlining)
Definition: EntryExitInstrumenter.cpp:69
Constant.h
llvm::AMDGPU::SendMsg::Op
Op
Definition: SIDefines.h:314
llvm::PHINode::Create
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", Instruction *InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
Definition: Instructions.h:2642
Casting.h
Function.h
llvm::Value::hasNUses
bool hasNUses(unsigned N) const
Return true if this Value has exactly N uses.
Definition: Value.cpp:146
llvm::LoopBase::getHeader
BlockT * getHeader() const
Definition: LoopInfo.h:104
checkOffsetSize
static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount)
Definition: MVEGatherScatterLowering.cpp:175
llvm::MCID::Add
@ Add
Definition: MCInstrDesc.h:183
llvm::codeview::ModifierOptions::Const
@ Const
llvm::IntrinsicInst
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:45
llvm::SPII::Store
@ Store
Definition: SparcInstrInfo.h:33
llvm::Instruction::BinaryOps
BinaryOps
Definition: Instruction.h:773
llvm::BasicBlock::back
const Instruction & back() const
Definition: BasicBlock.h:310
llvm::Pass
Pass interface - Implemented by all 'passes'.
Definition: Pass.h:91
EnableMaskedGatherScatters
cl::opt< bool > EnableMaskedGatherScatters("enable-arm-maskedgatscat", cl::Hidden, cl::init(true), cl::desc("Enable the generation of masked gathers and scatters"))
Instructions.h
llvm::Instruction::getDebugLoc
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
Definition: Instruction.h:370
llvm::CallBase::getArgOperand
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1341
N
#define N
llvm::Instruction::getParent
const BasicBlock * getParent() const
Definition: Instruction.h:94
llvm::PHINode::getIncomingBlock
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Definition: Instructions.h:2716
TargetTransformInfo.h
llvm::PHINode
Definition: Instructions.h:2600
llvm::Pass::getAnalysisUsage
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:93
DerivedTypes.h
TM
const char LLVMTargetMachineRef TM
Definition: PassBuilderBindings.cpp:47
llvm::FunctionPass
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:298
BB
Common register allocation spilling lr str ldr sxth r3 ldr mla r4 can lr mov lr str ldr sxth r3 mla r4 and then merge mul and lr str ldr sxth r3 mla r4 It also increase the likelihood the store may become dead bb27 Successors according to LLVM BB
Definition: README.txt:39
GEP
Hexagon Common GEP
Definition: HexagonCommonGEP.cpp:171
llvm::AnalysisUsage::addRequired
AnalysisUsage & addRequired()
Definition: PassAnalysisSupport.h:75
llvm::User::getOperand
Value * getOperand(unsigned i) const
Definition: User.h:169
llvm::cl::desc
Definition: CommandLine.h:414
llvm::SI::KernelInputOffsets::Offsets
Offsets
Offsets in bytes from the start of the input buffer.
Definition: SIInstrInfo.h:1246
llvm::BinaryOperator::Create
static BinaryOperator * Create(BinaryOps Op, Value *S1, Value *S2, const Twine &Name=Twine(), Instruction *InsertBefore=nullptr)
Construct a binary instruction, given the opcode and the two operands.
Definition: Instructions.cpp:2550
Value.h
InitializePasses.h
llvm::Value
LLVM Value Representation.
Definition: Value.h:75
llvm::Optional::getValue
constexpr const T & getValue() const LLVM_LVALUE_FUNCTION
Definition: Optional.h:282
llvm::sampleprof::Base
@ Base
Definition: Discriminator.h:58
llvm::Type::getPrimitiveSizeInBits
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition: Type.cpp:122
llvm::Intrinsic::ID
unsigned ID
Definition: TargetTransformInfo.h:38