LLVM 18.0.0git
MVEGatherScatterLowering.cpp
Go to the documentation of this file.
1//===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// This pass custom lowers llvm.gather and llvm.scatter instructions to
10/// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to
11/// produce a better final result as we go.
12//
13//===----------------------------------------------------------------------===//
14
15#include "ARM.h"
16#include "ARMBaseInstrInfo.h"
17#include "ARMSubtarget.h"
25#include "llvm/IR/BasicBlock.h"
26#include "llvm/IR/Constant.h"
27#include "llvm/IR/Constants.h"
29#include "llvm/IR/Function.h"
30#include "llvm/IR/InstrTypes.h"
31#include "llvm/IR/Instruction.h"
34#include "llvm/IR/Intrinsics.h"
35#include "llvm/IR/IntrinsicsARM.h"
36#include "llvm/IR/IRBuilder.h"
38#include "llvm/IR/Type.h"
39#include "llvm/IR/Value.h"
40#include "llvm/Pass.h"
43#include <algorithm>
44#include <cassert>
45
46using namespace llvm;
47
48#define DEBUG_TYPE "arm-mve-gather-scatter-lowering"
49
51 "enable-arm-maskedgatscat", cl::Hidden, cl::init(true),
52 cl::desc("Enable the generation of masked gathers and scatters"));
53
54namespace {
55
56class MVEGatherScatterLowering : public FunctionPass {
57public:
58 static char ID; // Pass identification, replacement for typeid
59
60 explicit MVEGatherScatterLowering() : FunctionPass(ID) {
62 }
63
64 bool runOnFunction(Function &F) override;
65
66 StringRef getPassName() const override {
67 return "MVE gather/scatter lowering";
68 }
69
70 void getAnalysisUsage(AnalysisUsage &AU) const override {
71 AU.setPreservesCFG();
75 }
76
77private:
78 LoopInfo *LI = nullptr;
79 const DataLayout *DL;
80
81 // Check this is a valid gather with correct alignment
82 bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize,
83 Align Alignment);
84 // Check whether Ptr is hidden behind a bitcast and look through it
85 void lookThroughBitcast(Value *&Ptr);
86 // Decompose a ptr into Base and Offsets, potentially using a GEP to return a
87 // scalar base and vector offsets, or else fallback to using a base of 0 and
88 // offset of Ptr where possible.
89 Value *decomposePtr(Value *Ptr, Value *&Offsets, int &Scale,
90 FixedVectorType *Ty, Type *MemoryTy,
91 IRBuilder<> &Builder);
92 // Check for a getelementptr and deduce base and offsets from it, on success
93 // returning the base directly and the offsets indirectly using the Offsets
94 // argument
97 // Compute the scale of this gather/scatter instruction
98 int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize);
99 // If the value is a constant, or derived from constants via additions
100 // and multilications, return its numeric value
101 std::optional<int64_t> getIfConst(const Value *V);
102 // If Inst is an add instruction, check whether one summand is a
103 // constant. If so, scale this constant and return it together with
104 // the other summand.
105 std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale);
106
107 Instruction *lowerGather(IntrinsicInst *I);
108 // Create a gather from a base + vector of offsets
109 Instruction *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr,
110 Instruction *&Root,
111 IRBuilder<> &Builder);
112 // Create a gather from a vector of pointers
113 Instruction *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr,
114 IRBuilder<> &Builder,
115 int64_t Increment = 0);
116 // Create an incrementing gather from a vector of pointers
117 Instruction *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr,
118 IRBuilder<> &Builder,
119 int64_t Increment = 0);
120
121 Instruction *lowerScatter(IntrinsicInst *I);
122 // Create a scatter to a base + vector of offsets
123 Instruction *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets,
124 IRBuilder<> &Builder);
125 // Create a scatter to a vector of pointers
126 Instruction *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr,
127 IRBuilder<> &Builder,
128 int64_t Increment = 0);
129 // Create an incrementing scatter from a vector of pointers
130 Instruction *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr,
131 IRBuilder<> &Builder,
132 int64_t Increment = 0);
133
134 // QI gathers and scatters can increment their offsets on their own if
135 // the increment is a constant value (digit)
136 Instruction *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *Ptr,
137 IRBuilder<> &Builder);
138 // QI gathers/scatters can increment their offsets on their own if the
139 // increment is a constant value (digit) - this creates a writeback QI
140 // gather/scatter
141 Instruction *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr,
142 Value *Ptr, unsigned TypeScale,
143 IRBuilder<> &Builder);
144
145 // Optimise the base and offsets of the given address
146 bool optimiseAddress(Value *Address, BasicBlock *BB, LoopInfo *LI);
147 // Try to fold consecutive geps together into one
148 Value *foldGEP(GetElementPtrInst *GEP, Value *&Offsets, unsigned &Scale,
149 IRBuilder<> &Builder);
150 // Check whether these offsets could be moved out of the loop they're in
151 bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI);
152 // Pushes the given add out of the loop
153 void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex);
154 // Pushes the given mul or shl out of the loop
155 void pushOutMulShl(unsigned Opc, PHINode *&Phi, Value *IncrementPerRound,
156 Value *OffsSecondOperand, unsigned LoopIncrement,
157 IRBuilder<> &Builder);
158};
159
160} // end anonymous namespace
161
162char MVEGatherScatterLowering::ID = 0;
163
164INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE,
165 "MVE gather/scattering lowering pass", false, false)
166
168 return new MVEGatherScatterLowering();
169}
170
171bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements,
172 unsigned ElemSize,
173 Align Alignment) {
174 if (((NumElements == 4 &&
175 (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) ||
176 (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) ||
177 (NumElements == 16 && ElemSize == 8)) &&
178 Alignment >= ElemSize / 8)
179 return true;
180 LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have "
181 << "valid alignment or vector type \n");
182 return false;
183}
184
185static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount) {
186 // Offsets that are not of type <N x i32> are sign extended by the
187 // getelementptr instruction, and MVE gathers/scatters treat the offset as
188 // unsigned. Thus, if the element size is smaller than 32, we can only allow
189 // positive offsets - i.e., the offsets are not allowed to be variables we
190 // can't look into.
191 // Additionally, <N x i32> offsets have to either originate from a zext of a
192 // vector with element types smaller or equal the type of the gather we're
193 // looking at, or consist of constants that we can check are small enough
194 // to fit into the gather type.
195 // Thus we check that 0 < value < 2^TargetElemSize.
196 unsigned TargetElemSize = 128 / TargetElemCount;
197 unsigned OffsetElemSize = cast<FixedVectorType>(Offsets->getType())
198 ->getElementType()
199 ->getScalarSizeInBits();
200 if (OffsetElemSize != TargetElemSize || OffsetElemSize != 32) {
201 Constant *ConstOff = dyn_cast<Constant>(Offsets);
202 if (!ConstOff)
203 return false;
204 int64_t TargetElemMaxSize = (1ULL << TargetElemSize);
205 auto CheckValueSize = [TargetElemMaxSize](Value *OffsetElem) {
206 ConstantInt *OConst = dyn_cast<ConstantInt>(OffsetElem);
207 if (!OConst)
208 return false;
209 int SExtValue = OConst->getSExtValue();
210 if (SExtValue >= TargetElemMaxSize || SExtValue < 0)
211 return false;
212 return true;
213 };
214 if (isa<FixedVectorType>(ConstOff->getType())) {
215 for (unsigned i = 0; i < TargetElemCount; i++) {
216 if (!CheckValueSize(ConstOff->getAggregateElement(i)))
217 return false;
218 }
219 } else {
220 if (!CheckValueSize(ConstOff))
221 return false;
222 }
223 }
224 return true;
225}
226
227Value *MVEGatherScatterLowering::decomposePtr(Value *Ptr, Value *&Offsets,
228 int &Scale, FixedVectorType *Ty,
229 Type *MemoryTy,
230 IRBuilder<> &Builder) {
231 if (auto *GEP = dyn_cast<GetElementPtrInst>(Ptr)) {
232 if (Value *V = decomposeGEP(Offsets, Ty, GEP, Builder)) {
233 Scale =
234 computeScale(GEP->getSourceElementType()->getPrimitiveSizeInBits(),
235 MemoryTy->getScalarSizeInBits());
236 return Scale == -1 ? nullptr : V;
237 }
238 }
239
240 // If we couldn't use the GEP (or it doesn't exist), attempt to use a
241 // BasePtr of 0 with Ptr as the Offsets, so long as there are only 4
242 // elements.
243 FixedVectorType *PtrTy = cast<FixedVectorType>(Ptr->getType());
244 if (PtrTy->getNumElements() != 4 || MemoryTy->getScalarSizeInBits() == 32)
245 return nullptr;
246 Value *Zero = ConstantInt::get(Builder.getInt32Ty(), 0);
247 Value *BasePtr = Builder.CreateIntToPtr(Zero, Builder.getPtrTy());
248 Offsets = Builder.CreatePtrToInt(
249 Ptr, FixedVectorType::get(Builder.getInt32Ty(), 4));
250 Scale = 0;
251 return BasePtr;
252}
253
254Value *MVEGatherScatterLowering::decomposeGEP(Value *&Offsets,
255 FixedVectorType *Ty,
257 IRBuilder<> &Builder) {
258 if (!GEP) {
259 LLVM_DEBUG(dbgs() << "masked gathers/scatters: no getelementpointer "
260 << "found\n");
261 return nullptr;
262 }
263 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found."
264 << " Looking at intrinsic for base + vector of offsets\n");
265 Value *GEPPtr = GEP->getPointerOperand();
266 Offsets = GEP->getOperand(1);
267 if (GEPPtr->getType()->isVectorTy() ||
268 !isa<FixedVectorType>(Offsets->getType()))
269 return nullptr;
270
271 if (GEP->getNumOperands() != 2) {
272 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many"
273 << " operands. Expanding.\n");
274 return nullptr;
275 }
276 Offsets = GEP->getOperand(1);
277 unsigned OffsetsElemCount =
278 cast<FixedVectorType>(Offsets->getType())->getNumElements();
279 // Paranoid check whether the number of parallel lanes is the same
280 assert(Ty->getNumElements() == OffsetsElemCount);
281
282 ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets);
283 if (ZextOffs)
284 Offsets = ZextOffs->getOperand(0);
285 FixedVectorType *OffsetType = cast<FixedVectorType>(Offsets->getType());
286
287 // If the offsets are already being zext-ed to <N x i32>, that relieves us of
288 // having to make sure that they won't overflow.
289 if (!ZextOffs || cast<FixedVectorType>(ZextOffs->getDestTy())
290 ->getElementType()
291 ->getScalarSizeInBits() != 32)
292 if (!checkOffsetSize(Offsets, OffsetsElemCount))
293 return nullptr;
294
295 // The offset sizes have been checked; if any truncating or zext-ing is
296 // required to fix them, do that now
297 if (Ty != Offsets->getType()) {
298 if ((Ty->getElementType()->getScalarSizeInBits() <
299 OffsetType->getElementType()->getScalarSizeInBits())) {
300 Offsets = Builder.CreateTrunc(Offsets, Ty);
301 } else {
302 Offsets = Builder.CreateZExt(Offsets, VectorType::getInteger(Ty));
303 }
304 }
305 // If none of the checks failed, return the gep's base pointer
306 LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n");
307 return GEPPtr;
308}
309
310void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) {
311 // Look through bitcast instruction if #elements is the same
312 if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) {
313 auto *BCTy = cast<FixedVectorType>(BitCast->getType());
314 auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType());
315 if (BCTy->getNumElements() == BCSrcTy->getNumElements()) {
316 LLVM_DEBUG(dbgs() << "masked gathers/scatters: looking through "
317 << "bitcast\n");
318 Ptr = BitCast->getOperand(0);
319 }
320 }
321}
322
323int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize,
324 unsigned MemoryElemSize) {
325 // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2,
326 // or a 8bit, 16bit or 32bit load/store scaled by 1
327 if (GEPElemSize == 32 && MemoryElemSize == 32)
328 return 2;
329 else if (GEPElemSize == 16 && MemoryElemSize == 16)
330 return 1;
331 else if (GEPElemSize == 8)
332 return 0;
333 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't "
334 << "create intrinsic\n");
335 return -1;
336}
337
338std::optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {
339 const Constant *C = dyn_cast<Constant>(V);
340 if (C && C->getSplatValue())
341 return std::optional<int64_t>{C->getUniqueInteger().getSExtValue()};
342 if (!isa<Instruction>(V))
343 return std::optional<int64_t>{};
344
345 const Instruction *I = cast<Instruction>(V);
346 if (I->getOpcode() == Instruction::Add || I->getOpcode() == Instruction::Or ||
347 I->getOpcode() == Instruction::Mul ||
348 I->getOpcode() == Instruction::Shl) {
349 std::optional<int64_t> Op0 = getIfConst(I->getOperand(0));
350 std::optional<int64_t> Op1 = getIfConst(I->getOperand(1));
351 if (!Op0 || !Op1)
352 return std::optional<int64_t>{};
353 if (I->getOpcode() == Instruction::Add)
354 return std::optional<int64_t>{*Op0 + *Op1};
355 if (I->getOpcode() == Instruction::Mul)
356 return std::optional<int64_t>{*Op0 * *Op1};
357 if (I->getOpcode() == Instruction::Shl)
358 return std::optional<int64_t>{*Op0 << *Op1};
359 if (I->getOpcode() == Instruction::Or)
360 return std::optional<int64_t>{*Op0 | *Op1};
361 }
362 return std::optional<int64_t>{};
363}
364
365// Return true if I is an Or instruction that is equivalent to an add, due to
366// the operands having no common bits set.
367static bool isAddLikeOr(Instruction *I, const DataLayout &DL) {
368 return I->getOpcode() == Instruction::Or &&
369 haveNoCommonBitsSet(I->getOperand(0), I->getOperand(1), DL);
370}
371
372std::pair<Value *, int64_t>
373MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) {
374 std::pair<Value *, int64_t> ReturnFalse =
375 std::pair<Value *, int64_t>(nullptr, 0);
376 // At this point, the instruction we're looking at must be an add or an
377 // add-like-or.
378 Instruction *Add = dyn_cast<Instruction>(Inst);
379 if (Add == nullptr ||
380 (Add->getOpcode() != Instruction::Add && !isAddLikeOr(Add, *DL)))
381 return ReturnFalse;
382
383 Value *Summand;
384 std::optional<int64_t> Const;
385 // Find out which operand the value that is increased is
386 if ((Const = getIfConst(Add->getOperand(0))))
387 Summand = Add->getOperand(1);
388 else if ((Const = getIfConst(Add->getOperand(1))))
389 Summand = Add->getOperand(0);
390 else
391 return ReturnFalse;
392
393 // Check that the constant is small enough for an incrementing gather
394 int64_t Immediate = *Const << TypeScale;
395 if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0)
396 return ReturnFalse;
397
398 return std::pair<Value *, int64_t>(Summand, Immediate);
399}
400
401Instruction *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) {
402 using namespace PatternMatch;
403 LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n"
404 << *I << "\n");
405
406 // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
407 // Attempt to turn the masked gather in I into a MVE intrinsic
408 // Potentially optimising the addressing modes as we do so.
409 auto *Ty = cast<FixedVectorType>(I->getType());
410 Value *Ptr = I->getArgOperand(0);
411 Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue();
412 Value *Mask = I->getArgOperand(2);
413 Value *PassThru = I->getArgOperand(3);
414
415 if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
416 Alignment))
417 return nullptr;
418 lookThroughBitcast(Ptr);
419 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
420
421 IRBuilder<> Builder(I->getContext());
422 Builder.SetInsertPoint(I);
423 Builder.SetCurrentDebugLocation(I->getDebugLoc());
424
425 Instruction *Root = I;
426
427 Instruction *Load = tryCreateIncrementingGatScat(I, Ptr, Builder);
428 if (!Load)
429 Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder);
430 if (!Load)
431 Load = tryCreateMaskedGatherBase(I, Ptr, Builder);
432 if (!Load)
433 return nullptr;
434
435 if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) {
436 LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - "
437 << "creating select\n");
438 Load = SelectInst::Create(Mask, Load, PassThru);
439 Builder.Insert(Load);
440 }
441
442 Root->replaceAllUsesWith(Load);
443 Root->eraseFromParent();
444 if (Root != I)
445 // If this was an extending gather, we need to get rid of the sext/zext
446 // sext/zext as well as of the gather itself
447 I->eraseFromParent();
448
449 LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n"
450 << *Load << "\n");
451 return Load;
452}
453
454Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBase(
455 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
456 using namespace PatternMatch;
457 auto *Ty = cast<FixedVectorType>(I->getType());
458 LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");
459 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
460 // Can't build an intrinsic for this
461 return nullptr;
462 Value *Mask = I->getArgOperand(2);
463 if (match(Mask, m_One()))
464 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base,
465 {Ty, Ptr->getType()},
466 {Ptr, Builder.getInt32(Increment)});
467 else
468 return Builder.CreateIntrinsic(
469 Intrinsic::arm_mve_vldr_gather_base_predicated,
470 {Ty, Ptr->getType(), Mask->getType()},
471 {Ptr, Builder.getInt32(Increment), Mask});
472}
473
474Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB(
475 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
476 using namespace PatternMatch;
477 auto *Ty = cast<FixedVectorType>(I->getType());
478 LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers with "
479 << "writeback\n");
480 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
481 // Can't build an intrinsic for this
482 return nullptr;
483 Value *Mask = I->getArgOperand(2);
484 if (match(Mask, m_One()))
485 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb,
486 {Ty, Ptr->getType()},
487 {Ptr, Builder.getInt32(Increment)});
488 else
489 return Builder.CreateIntrinsic(
490 Intrinsic::arm_mve_vldr_gather_base_wb_predicated,
491 {Ty, Ptr->getType(), Mask->getType()},
492 {Ptr, Builder.getInt32(Increment), Mask});
493}
494
495Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
496 IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) {
497 using namespace PatternMatch;
498
499 Type *MemoryTy = I->getType();
500 Type *ResultTy = MemoryTy;
501
502 unsigned Unsigned = 1;
503 // The size of the gather was already checked in isLegalTypeAndAlignment;
504 // if it was not a full vector width an appropriate extend should follow.
505 auto *Extend = Root;
506 bool TruncResult = false;
507 if (MemoryTy->getPrimitiveSizeInBits() < 128) {
508 if (I->hasOneUse()) {
509 // If the gather has a single extend of the correct type, use an extending
510 // gather and replace the ext. In which case the correct root to replace
511 // is not the CallInst itself, but the instruction which extends it.
512 Instruction* User = cast<Instruction>(*I->users().begin());
513 if (isa<SExtInst>(User) &&
514 User->getType()->getPrimitiveSizeInBits() == 128) {
515 LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: "
516 << *User << "\n");
517 Extend = User;
518 ResultTy = User->getType();
519 Unsigned = 0;
520 } else if (isa<ZExtInst>(User) &&
521 User->getType()->getPrimitiveSizeInBits() == 128) {
522 LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: "
523 << *ResultTy << "\n");
524 Extend = User;
525 ResultTy = User->getType();
526 }
527 }
528
529 // If an extend hasn't been found and the type is an integer, create an
530 // extending gather and truncate back to the original type.
531 if (ResultTy->getPrimitiveSizeInBits() < 128 &&
532 ResultTy->isIntOrIntVectorTy()) {
533 ResultTy = ResultTy->getWithNewBitWidth(
534 128 / cast<FixedVectorType>(ResultTy)->getNumElements());
535 TruncResult = true;
536 LLVM_DEBUG(dbgs() << "masked gathers: Small input type, truncing to: "
537 << *ResultTy << "\n");
538 }
539
540 // The final size of the gather must be a full vector width
541 if (ResultTy->getPrimitiveSizeInBits() != 128) {
542 LLVM_DEBUG(dbgs() << "masked gathers: Extend needed but not provided "
543 "from the correct type. Expanding\n");
544 return nullptr;
545 }
546 }
547
548 Value *Offsets;
549 int Scale;
550 Value *BasePtr = decomposePtr(
551 Ptr, Offsets, Scale, cast<FixedVectorType>(ResultTy), MemoryTy, Builder);
552 if (!BasePtr)
553 return nullptr;
554
555 Root = Extend;
556 Value *Mask = I->getArgOperand(2);
557 Instruction *Load = nullptr;
558 if (!match(Mask, m_One()))
559 Load = Builder.CreateIntrinsic(
560 Intrinsic::arm_mve_vldr_gather_offset_predicated,
561 {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()},
562 {BasePtr, Offsets, Builder.getInt32(MemoryTy->getScalarSizeInBits()),
563 Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask});
564 else
565 Load = Builder.CreateIntrinsic(
566 Intrinsic::arm_mve_vldr_gather_offset,
567 {ResultTy, BasePtr->getType(), Offsets->getType()},
568 {BasePtr, Offsets, Builder.getInt32(MemoryTy->getScalarSizeInBits()),
569 Builder.getInt32(Scale), Builder.getInt32(Unsigned)});
570
571 if (TruncResult) {
572 Load = TruncInst::Create(Instruction::Trunc, Load, MemoryTy);
573 Builder.Insert(Load);
574 }
575 return Load;
576}
577
578Instruction *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) {
579 using namespace PatternMatch;
580 LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n"
581 << *I << "\n");
582
583 // @llvm.masked.scatter.*(data, ptrs, alignment, mask)
584 // Attempt to turn the masked scatter in I into a MVE intrinsic
585 // Potentially optimising the addressing modes as we do so.
586 Value *Input = I->getArgOperand(0);
587 Value *Ptr = I->getArgOperand(1);
588 Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue();
589 auto *Ty = cast<FixedVectorType>(Input->getType());
590
591 if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
592 Alignment))
593 return nullptr;
594
595 lookThroughBitcast(Ptr);
596 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
597
598 IRBuilder<> Builder(I->getContext());
599 Builder.SetInsertPoint(I);
600 Builder.SetCurrentDebugLocation(I->getDebugLoc());
601
602 Instruction *Store = tryCreateIncrementingGatScat(I, Ptr, Builder);
603 if (!Store)
604 Store = tryCreateMaskedScatterOffset(I, Ptr, Builder);
605 if (!Store)
606 Store = tryCreateMaskedScatterBase(I, Ptr, Builder);
607 if (!Store)
608 return nullptr;
609
610 LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n"
611 << *Store << "\n");
612 I->eraseFromParent();
613 return Store;
614}
615
616Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBase(
617 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
618 using namespace PatternMatch;
619 Value *Input = I->getArgOperand(0);
620 auto *Ty = cast<FixedVectorType>(Input->getType());
621 // Only QR variants allow truncating
622 if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) {
623 // Can't build an intrinsic for this
624 return nullptr;
625 }
626 Value *Mask = I->getArgOperand(3);
627 // int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask)
628 LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n");
629 if (match(Mask, m_One()))
630 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base,
631 {Ptr->getType(), Input->getType()},
632 {Ptr, Builder.getInt32(Increment), Input});
633 else
634 return Builder.CreateIntrinsic(
635 Intrinsic::arm_mve_vstr_scatter_base_predicated,
636 {Ptr->getType(), Input->getType(), Mask->getType()},
637 {Ptr, Builder.getInt32(Increment), Input, Mask});
638}
639
640Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB(
641 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
642 using namespace PatternMatch;
643 Value *Input = I->getArgOperand(0);
644 auto *Ty = cast<FixedVectorType>(Input->getType());
645 LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers "
646 << "with writeback\n");
647 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
648 // Can't build an intrinsic for this
649 return nullptr;
650 Value *Mask = I->getArgOperand(3);
651 if (match(Mask, m_One()))
652 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb,
653 {Ptr->getType(), Input->getType()},
654 {Ptr, Builder.getInt32(Increment), Input});
655 else
656 return Builder.CreateIntrinsic(
657 Intrinsic::arm_mve_vstr_scatter_base_wb_predicated,
658 {Ptr->getType(), Input->getType(), Mask->getType()},
659 {Ptr, Builder.getInt32(Increment), Input, Mask});
660}
661
662Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterOffset(
663 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
664 using namespace PatternMatch;
665 Value *Input = I->getArgOperand(0);
666 Value *Mask = I->getArgOperand(3);
667 Type *InputTy = Input->getType();
668 Type *MemoryTy = InputTy;
669
670 LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing"
671 << " to base + vector of offsets\n");
672 // If the input has been truncated, try to integrate that trunc into the
673 // scatter instruction (we don't care about alignment here)
674 if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) {
675 Value *PreTrunc = Trunc->getOperand(0);
676 Type *PreTruncTy = PreTrunc->getType();
677 if (PreTruncTy->getPrimitiveSizeInBits() == 128) {
678 Input = PreTrunc;
679 InputTy = PreTruncTy;
680 }
681 }
682 bool ExtendInput = false;
683 if (InputTy->getPrimitiveSizeInBits() < 128 &&
684 InputTy->isIntOrIntVectorTy()) {
685 // If we can't find a trunc to incorporate into the instruction, create an
686 // implicit one with a zext, so that we can still create a scatter. We know
687 // that the input type is 4x/8x/16x and of type i8/i16/i32, so any type
688 // smaller than 128 bits will divide evenly into a 128bit vector.
689 InputTy = InputTy->getWithNewBitWidth(
690 128 / cast<FixedVectorType>(InputTy)->getNumElements());
691 ExtendInput = true;
692 LLVM_DEBUG(dbgs() << "masked scatters: Small input type, will extend:\n"
693 << *Input << "\n");
694 }
695 if (InputTy->getPrimitiveSizeInBits() != 128) {
696 LLVM_DEBUG(dbgs() << "masked scatters: cannot create scatters for "
697 "non-standard input types. Expanding.\n");
698 return nullptr;
699 }
700
701 Value *Offsets;
702 int Scale;
703 Value *BasePtr = decomposePtr(
704 Ptr, Offsets, Scale, cast<FixedVectorType>(InputTy), MemoryTy, Builder);
705 if (!BasePtr)
706 return nullptr;
707
708 if (ExtendInput)
709 Input = Builder.CreateZExt(Input, InputTy);
710 if (!match(Mask, m_One()))
711 return Builder.CreateIntrinsic(
712 Intrinsic::arm_mve_vstr_scatter_offset_predicated,
713 {BasePtr->getType(), Offsets->getType(), Input->getType(),
714 Mask->getType()},
715 {BasePtr, Offsets, Input,
716 Builder.getInt32(MemoryTy->getScalarSizeInBits()),
717 Builder.getInt32(Scale), Mask});
718 else
719 return Builder.CreateIntrinsic(
720 Intrinsic::arm_mve_vstr_scatter_offset,
721 {BasePtr->getType(), Offsets->getType(), Input->getType()},
722 {BasePtr, Offsets, Input,
723 Builder.getInt32(MemoryTy->getScalarSizeInBits()),
724 Builder.getInt32(Scale)});
725}
726
727Instruction *MVEGatherScatterLowering::tryCreateIncrementingGatScat(
728 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
729 FixedVectorType *Ty;
730 if (I->getIntrinsicID() == Intrinsic::masked_gather)
731 Ty = cast<FixedVectorType>(I->getType());
732 else
733 Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType());
734
735 // Incrementing gathers only exist for v4i32
736 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
737 return nullptr;
738 // Incrementing gathers are not beneficial outside of a loop
739 Loop *L = LI->getLoopFor(I->getParent());
740 if (L == nullptr)
741 return nullptr;
742
743 // Decompose the GEP into Base and Offsets
744 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
745 Value *Offsets;
746 Value *BasePtr = decomposeGEP(Offsets, Ty, GEP, Builder);
747 if (!BasePtr)
748 return nullptr;
749
750 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
751 "wb gather/scatter\n");
752
753 // The gep was in charge of making sure the offsets are scaled correctly
754 // - calculate that factor so it can be applied by hand
755 int TypeScale =
756 computeScale(DL->getTypeSizeInBits(GEP->getOperand(0)->getType()),
757 DL->getTypeSizeInBits(GEP->getType()) /
758 cast<FixedVectorType>(GEP->getType())->getNumElements());
759 if (TypeScale == -1)
760 return nullptr;
761
762 if (GEP->hasOneUse()) {
763 // Only in this case do we want to build a wb gather, because the wb will
764 // change the phi which does affect other users of the gep (which will still
765 // be using the phi in the old way)
766 if (auto *Load = tryCreateIncrementingWBGatScat(I, BasePtr, Offsets,
767 TypeScale, Builder))
768 return Load;
769 }
770
771 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
772 "non-wb gather/scatter\n");
773
774 std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
775 if (Add.first == nullptr)
776 return nullptr;
777 Value *OffsetsIncoming = Add.first;
778 int64_t Immediate = Add.second;
779
780 // Make sure the offsets are scaled correctly
781 Instruction *ScaledOffsets = BinaryOperator::Create(
782 Instruction::Shl, OffsetsIncoming,
783 Builder.CreateVectorSplat(Ty->getNumElements(), Builder.getInt32(TypeScale)),
784 "ScaledIndex", I);
785 // Add the base to the offsets
786 OffsetsIncoming = BinaryOperator::Create(
787 Instruction::Add, ScaledOffsets,
788 Builder.CreateVectorSplat(
789 Ty->getNumElements(),
790 Builder.CreatePtrToInt(
791 BasePtr,
792 cast<VectorType>(ScaledOffsets->getType())->getElementType())),
793 "StartIndex", I);
794
795 if (I->getIntrinsicID() == Intrinsic::masked_gather)
796 return tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate);
797 else
798 return tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate);
799}
800
801Instruction *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat(
802 IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale,
803 IRBuilder<> &Builder) {
804 // Check whether this gather's offset is incremented by a constant - if so,
805 // and the load is of the right type, we can merge this into a QI gather
806 Loop *L = LI->getLoopFor(I->getParent());
807 // Offsets that are worth merging into this instruction will be incremented
808 // by a constant, thus we're looking for an add of a phi and a constant
809 PHINode *Phi = dyn_cast<PHINode>(Offsets);
810 if (Phi == nullptr || Phi->getNumIncomingValues() != 2 ||
811 Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2)
812 // No phi means no IV to write back to; if there is a phi, we expect it
813 // to have exactly two incoming values; the only phis we are interested in
814 // will be loop IV's and have exactly two uses, one in their increment and
815 // one in the gather's gep
816 return nullptr;
817
818 unsigned IncrementIndex =
819 Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1;
820 // Look through the phi to the phi increment
821 Offsets = Phi->getIncomingValue(IncrementIndex);
822
823 std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
824 if (Add.first == nullptr)
825 return nullptr;
826 Value *OffsetsIncoming = Add.first;
827 int64_t Immediate = Add.second;
828 if (OffsetsIncoming != Phi)
829 // Then the increment we are looking at is not an increment of the
830 // induction variable, and we don't want to do a writeback
831 return nullptr;
832
833 Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back());
834 unsigned NumElems =
835 cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements();
836
837 // Make sure the offsets are scaled correctly
838 Instruction *ScaledOffsets = BinaryOperator::Create(
839 Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex),
840 Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)),
841 "ScaledIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
842 // Add the base to the offsets
843 OffsetsIncoming = BinaryOperator::Create(
844 Instruction::Add, ScaledOffsets,
845 Builder.CreateVectorSplat(
846 NumElems,
847 Builder.CreatePtrToInt(
848 BasePtr,
849 cast<VectorType>(ScaledOffsets->getType())->getElementType())),
850 "StartIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
851 // The gather is pre-incrementing
852 OffsetsIncoming = BinaryOperator::Create(
853 Instruction::Sub, OffsetsIncoming,
854 Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)),
855 "PreIncrementStartIndex",
856 &Phi->getIncomingBlock(1 - IncrementIndex)->back());
857 Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming);
858
859 Builder.SetInsertPoint(I);
860
861 Instruction *EndResult;
862 Instruction *NewInduction;
863 if (I->getIntrinsicID() == Intrinsic::masked_gather) {
864 // Build the incrementing gather
865 Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate);
866 // One value to be handed to whoever uses the gather, one is the loop
867 // increment
868 EndResult = ExtractValueInst::Create(Load, 0, "Gather");
869 NewInduction = ExtractValueInst::Create(Load, 1, "GatherIncrement");
870 Builder.Insert(EndResult);
871 Builder.Insert(NewInduction);
872 } else {
873 // Build the incrementing scatter
874 EndResult = NewInduction =
875 tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate);
876 }
877 Instruction *AddInst = cast<Instruction>(Offsets);
878 AddInst->replaceAllUsesWith(NewInduction);
879 AddInst->eraseFromParent();
880 Phi->setIncomingValue(IncrementIndex, NewInduction);
881
882 return EndResult;
883}
884
885void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi,
886 Value *OffsSecondOperand,
887 unsigned StartIndex) {
888 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n");
889 Instruction *InsertionPoint =
890 &cast<Instruction>(Phi->getIncomingBlock(StartIndex)->back());
891 // Initialize the phi with a vector that contains a sum of the constants
893 Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand,
894 "PushedOutAdd", InsertionPoint);
895 unsigned IncrementIndex = StartIndex == 0 ? 1 : 0;
896
897 // Order such that start index comes first (this reduces mov's)
898 Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex));
899 Phi->addIncoming(Phi->getIncomingValue(IncrementIndex),
900 Phi->getIncomingBlock(IncrementIndex));
901 Phi->removeIncomingValue(IncrementIndex);
902 Phi->removeIncomingValue(StartIndex);
903}
904
905void MVEGatherScatterLowering::pushOutMulShl(unsigned Opcode, PHINode *&Phi,
906 Value *IncrementPerRound,
907 Value *OffsSecondOperand,
908 unsigned LoopIncrement,
909 IRBuilder<> &Builder) {
910 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n");
911
912 // Create a new scalar add outside of the loop and transform it to a splat
913 // by which loop variable can be incremented
914 Instruction *InsertionPoint = &cast<Instruction>(
915 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back());
916
917 // Create a new index
918 Value *StartIndex =
920 Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1),
921 OffsSecondOperand, "PushedOutMul", InsertionPoint);
922
923 Instruction *Product =
924 BinaryOperator::Create((Instruction::BinaryOps)Opcode, IncrementPerRound,
925 OffsSecondOperand, "Product", InsertionPoint);
926 // Increment NewIndex by Product instead of the multiplication
927 Instruction *NewIncrement = BinaryOperator::Create(
928 Instruction::Add, Phi, Product, "IncrementPushedOutMul",
929 cast<Instruction>(Phi->getIncomingBlock(LoopIncrement)->back())
930 .getPrevNode());
931
932 Phi->addIncoming(StartIndex,
933 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1));
934 Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement));
935 Phi->removeIncomingValue((unsigned)0);
936 Phi->removeIncomingValue((unsigned)0);
937}
938
939// Check whether all usages of this instruction are as offsets of
940// gathers/scatters or simple arithmetics only used by gathers/scatters
942 if (I->hasNUses(0)) {
943 return false;
944 }
945 bool Gatscat = true;
946 for (User *U : I->users()) {
947 if (!isa<Instruction>(U))
948 return false;
949 if (isa<GetElementPtrInst>(U) ||
950 isGatherScatter(dyn_cast<IntrinsicInst>(U))) {
951 return Gatscat;
952 } else {
953 unsigned OpCode = cast<Instruction>(U)->getOpcode();
954 if ((OpCode == Instruction::Add || OpCode == Instruction::Mul ||
955 OpCode == Instruction::Shl ||
956 isAddLikeOr(cast<Instruction>(U), DL)) &&
957 hasAllGatScatUsers(cast<Instruction>(U), DL)) {
958 continue;
959 }
960 return false;
961 }
962 }
963 return Gatscat;
964}
965
966bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
967 LoopInfo *LI) {
968 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize: "
969 << *Offsets << "\n");
970 // Optimise the addresses of gathers/scatters by moving invariant
971 // calculations out of the loop
972 if (!isa<Instruction>(Offsets))
973 return false;
974 Instruction *Offs = cast<Instruction>(Offsets);
975 if (Offs->getOpcode() != Instruction::Add && !isAddLikeOr(Offs, *DL) &&
976 Offs->getOpcode() != Instruction::Mul &&
977 Offs->getOpcode() != Instruction::Shl)
978 return false;
979 Loop *L = LI->getLoopFor(BB);
980 if (L == nullptr)
981 return false;
982 if (!Offs->hasOneUse()) {
983 if (!hasAllGatScatUsers(Offs, *DL))
984 return false;
985 }
986
987 // Find out which, if any, operand of the instruction
988 // is a phi node
989 PHINode *Phi;
990 int OffsSecondOp;
991 if (isa<PHINode>(Offs->getOperand(0))) {
992 Phi = cast<PHINode>(Offs->getOperand(0));
993 OffsSecondOp = 1;
994 } else if (isa<PHINode>(Offs->getOperand(1))) {
995 Phi = cast<PHINode>(Offs->getOperand(1));
996 OffsSecondOp = 0;
997 } else {
998 bool Changed = false;
999 if (isa<Instruction>(Offs->getOperand(0)) &&
1000 L->contains(cast<Instruction>(Offs->getOperand(0))))
1001 Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI);
1002 if (isa<Instruction>(Offs->getOperand(1)) &&
1003 L->contains(cast<Instruction>(Offs->getOperand(1))))
1004 Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI);
1005 if (!Changed)
1006 return false;
1007 if (isa<PHINode>(Offs->getOperand(0))) {
1008 Phi = cast<PHINode>(Offs->getOperand(0));
1009 OffsSecondOp = 1;
1010 } else if (isa<PHINode>(Offs->getOperand(1))) {
1011 Phi = cast<PHINode>(Offs->getOperand(1));
1012 OffsSecondOp = 0;
1013 } else {
1014 return false;
1015 }
1016 }
1017 // A phi node we want to perform this function on should be from the
1018 // loop header.
1019 if (Phi->getParent() != L->getHeader())
1020 return false;
1021
1022 // We're looking for a simple add recurrence.
1023 BinaryOperator *IncInstruction;
1024 Value *Start, *IncrementPerRound;
1025 if (!matchSimpleRecurrence(Phi, IncInstruction, Start, IncrementPerRound) ||
1026 IncInstruction->getOpcode() != Instruction::Add)
1027 return false;
1028
1029 int IncrementingBlock = Phi->getIncomingValue(0) == IncInstruction ? 0 : 1;
1030
1031 // Get the value that is added to/multiplied with the phi
1032 Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp);
1033
1034 if (IncrementPerRound->getType() != OffsSecondOperand->getType() ||
1035 !L->isLoopInvariant(OffsSecondOperand))
1036 // Something has gone wrong, abort
1037 return false;
1038
1039 // Only proceed if the increment per round is a constant or an instruction
1040 // which does not originate from within the loop
1041 if (!isa<Constant>(IncrementPerRound) &&
1042 !(isa<Instruction>(IncrementPerRound) &&
1043 !L->contains(cast<Instruction>(IncrementPerRound))))
1044 return false;
1045
1046 // If the phi is not used by anything else, we can just adapt it when
1047 // replacing the instruction; if it is, we'll have to duplicate it
1048 PHINode *NewPhi;
1049 if (Phi->getNumUses() == 2) {
1050 // No other users -> reuse existing phi (One user is the instruction
1051 // we're looking at, the other is the phi increment)
1052 if (IncInstruction->getNumUses() != 1) {
1053 // If the incrementing instruction does have more users than
1054 // our phi, we need to copy it
1055 IncInstruction = BinaryOperator::Create(
1056 Instruction::BinaryOps(IncInstruction->getOpcode()), Phi,
1057 IncrementPerRound, "LoopIncrement", IncInstruction);
1058 Phi->setIncomingValue(IncrementingBlock, IncInstruction);
1059 }
1060 NewPhi = Phi;
1061 } else {
1062 // There are other users -> create a new phi
1063 NewPhi = PHINode::Create(Phi->getType(), 2, "NewPhi", Phi);
1064 // Copy the incoming values of the old phi
1065 NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1),
1066 Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1));
1067 IncInstruction = BinaryOperator::Create(
1068 Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi,
1069 IncrementPerRound, "LoopIncrement", IncInstruction);
1070 NewPhi->addIncoming(IncInstruction,
1071 Phi->getIncomingBlock(IncrementingBlock));
1072 IncrementingBlock = 1;
1073 }
1074
1076 Builder.SetInsertPoint(Phi);
1077 Builder.SetCurrentDebugLocation(Offs->getDebugLoc());
1078
1079 switch (Offs->getOpcode()) {
1080 case Instruction::Add:
1081 case Instruction::Or:
1082 pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1);
1083 break;
1084 case Instruction::Mul:
1085 case Instruction::Shl:
1086 pushOutMulShl(Offs->getOpcode(), NewPhi, IncrementPerRound,
1087 OffsSecondOperand, IncrementingBlock, Builder);
1088 break;
1089 default:
1090 return false;
1091 }
1092 LLVM_DEBUG(dbgs() << "masked gathers/scatters: simplified loop variable "
1093 << "add/mul\n");
1094
1095 // The instruction has now been "absorbed" into the phi value
1096 Offs->replaceAllUsesWith(NewPhi);
1097 if (Offs->hasNUses(0))
1098 Offs->eraseFromParent();
1099 // Clean up the old increment in case it's unused because we built a new
1100 // one
1101 if (IncInstruction->hasNUses(0))
1102 IncInstruction->eraseFromParent();
1103
1104 return true;
1105}
1106
1107static Value *CheckAndCreateOffsetAdd(Value *X, unsigned ScaleX, Value *Y,
1108 unsigned ScaleY, IRBuilder<> &Builder) {
1109 // Splat the non-vector value to a vector of the given type - if the value is
1110 // a constant (and its value isn't too big), we can even use this opportunity
1111 // to scale it to the size of the vector elements
1112 auto FixSummands = [&Builder](FixedVectorType *&VT, Value *&NonVectorVal) {
1113 ConstantInt *Const;
1114 if ((Const = dyn_cast<ConstantInt>(NonVectorVal)) &&
1115 VT->getElementType() != NonVectorVal->getType()) {
1116 unsigned TargetElemSize = VT->getElementType()->getPrimitiveSizeInBits();
1117 uint64_t N = Const->getZExtValue();
1118 if (N < (unsigned)(1 << (TargetElemSize - 1))) {
1119 NonVectorVal = Builder.CreateVectorSplat(
1120 VT->getNumElements(), Builder.getIntN(TargetElemSize, N));
1121 return;
1122 }
1123 }
1124 NonVectorVal =
1125 Builder.CreateVectorSplat(VT->getNumElements(), NonVectorVal);
1126 };
1127
1128 FixedVectorType *XElType = dyn_cast<FixedVectorType>(X->getType());
1129 FixedVectorType *YElType = dyn_cast<FixedVectorType>(Y->getType());
1130 // If one of X, Y is not a vector, we have to splat it in order
1131 // to add the two of them.
1132 if (XElType && !YElType) {
1133 FixSummands(XElType, Y);
1134 YElType = cast<FixedVectorType>(Y->getType());
1135 } else if (YElType && !XElType) {
1136 FixSummands(YElType, X);
1137 XElType = cast<FixedVectorType>(X->getType());
1138 }
1139 assert(XElType && YElType && "Unknown vector types");
1140 // Check that the summands are of compatible types
1141 if (XElType != YElType) {
1142 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incompatible gep offsets\n");
1143 return nullptr;
1144 }
1145
1146 if (XElType->getElementType()->getScalarSizeInBits() != 32) {
1147 // Check that by adding the vectors we do not accidentally
1148 // create an overflow
1149 Constant *ConstX = dyn_cast<Constant>(X);
1150 Constant *ConstY = dyn_cast<Constant>(Y);
1151 if (!ConstX || !ConstY)
1152 return nullptr;
1153 unsigned TargetElemSize = 128 / XElType->getNumElements();
1154 for (unsigned i = 0; i < XElType->getNumElements(); i++) {
1155 ConstantInt *ConstXEl =
1156 dyn_cast<ConstantInt>(ConstX->getAggregateElement(i));
1157 ConstantInt *ConstYEl =
1158 dyn_cast<ConstantInt>(ConstY->getAggregateElement(i));
1159 if (!ConstXEl || !ConstYEl ||
1160 ConstXEl->getZExtValue() * ScaleX +
1161 ConstYEl->getZExtValue() * ScaleY >=
1162 (unsigned)(1 << (TargetElemSize - 1)))
1163 return nullptr;
1164 }
1165 }
1166
1167 Value *XScale = Builder.CreateVectorSplat(
1168 XElType->getNumElements(),
1169 Builder.getIntN(XElType->getScalarSizeInBits(), ScaleX));
1170 Value *YScale = Builder.CreateVectorSplat(
1171 YElType->getNumElements(),
1172 Builder.getIntN(YElType->getScalarSizeInBits(), ScaleY));
1173 Value *Add = Builder.CreateAdd(Builder.CreateMul(X, XScale),
1174 Builder.CreateMul(Y, YScale));
1175
1176 if (checkOffsetSize(Add, XElType->getNumElements()))
1177 return Add;
1178 else
1179 return nullptr;
1180}
1181
1182Value *MVEGatherScatterLowering::foldGEP(GetElementPtrInst *GEP,
1183 Value *&Offsets, unsigned &Scale,
1184 IRBuilder<> &Builder) {
1185 Value *GEPPtr = GEP->getPointerOperand();
1186 Offsets = GEP->getOperand(1);
1187 Scale = DL->getTypeAllocSize(GEP->getSourceElementType());
1188 // We only merge geps with constant offsets, because only for those
1189 // we can make sure that we do not cause an overflow
1190 if (GEP->getNumIndices() != 1 || !isa<Constant>(Offsets))
1191 return nullptr;
1192 if (GetElementPtrInst *BaseGEP = dyn_cast<GetElementPtrInst>(GEPPtr)) {
1193 // Merge the two geps into one
1194 Value *BaseBasePtr = foldGEP(BaseGEP, Offsets, Scale, Builder);
1195 if (!BaseBasePtr)
1196 return nullptr;
1198 Offsets, Scale, GEP->getOperand(1),
1199 DL->getTypeAllocSize(GEP->getSourceElementType()), Builder);
1200 if (Offsets == nullptr)
1201 return nullptr;
1202 Scale = 1; // Scale is always an i8 at this point.
1203 return BaseBasePtr;
1204 }
1205 return GEPPtr;
1206}
1207
1208bool MVEGatherScatterLowering::optimiseAddress(Value *Address, BasicBlock *BB,
1209 LoopInfo *LI) {
1210 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Address);
1211 if (!GEP)
1212 return false;
1213 bool Changed = false;
1214 if (GEP->hasOneUse() && isa<GetElementPtrInst>(GEP->getPointerOperand())) {
1215 IRBuilder<> Builder(GEP->getContext());
1216 Builder.SetInsertPoint(GEP);
1217 Builder.SetCurrentDebugLocation(GEP->getDebugLoc());
1218 Value *Offsets;
1219 unsigned Scale;
1220 Value *Base = foldGEP(GEP, Offsets, Scale, Builder);
1221 // We only want to merge the geps if there is a real chance that they can be
1222 // used by an MVE gather; thus the offset has to have the correct size
1223 // (always i32 if it is not of vector type) and the base has to be a
1224 // pointer.
1225 if (Offsets && Base && Base != GEP) {
1226 assert(Scale == 1 && "Expected to fold GEP to a scale of 1");
1227 Type *BaseTy = Builder.getPtrTy();
1228 if (auto *VecTy = dyn_cast<FixedVectorType>(Base->getType()))
1231 Builder.getInt8Ty(), Builder.CreateBitCast(Base, BaseTy), Offsets,
1232 "gep.merged", GEP);
1233 LLVM_DEBUG(dbgs() << "Folded GEP: " << *GEP
1234 << "\n new : " << *NewAddress << "\n");
1235 GEP->replaceAllUsesWith(
1236 Builder.CreateBitCast(NewAddress, GEP->getType()));
1237 GEP = NewAddress;
1238 Changed = true;
1239 }
1240 }
1241 Changed |= optimiseOffsets(GEP->getOperand(1), GEP->getParent(), LI);
1242 return Changed;
1243}
1244
1245bool MVEGatherScatterLowering::runOnFunction(Function &F) {
1247 return false;
1248 auto &TPC = getAnalysis<TargetPassConfig>();
1249 auto &TM = TPC.getTM<TargetMachine>();
1250 auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
1251 if (!ST->hasMVEIntegerOps())
1252 return false;
1253 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1254 DL = &F.getParent()->getDataLayout();
1257
1258 bool Changed = false;
1259
1260 for (BasicBlock &BB : F) {
1261 Changed |= SimplifyInstructionsInBlock(&BB);
1262
1263 for (Instruction &I : BB) {
1264 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
1265 if (II && II->getIntrinsicID() == Intrinsic::masked_gather &&
1266 isa<FixedVectorType>(II->getType())) {
1267 Gathers.push_back(II);
1268 Changed |= optimiseAddress(II->getArgOperand(0), II->getParent(), LI);
1269 } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter &&
1270 isa<FixedVectorType>(II->getArgOperand(0)->getType())) {
1271 Scatters.push_back(II);
1272 Changed |= optimiseAddress(II->getArgOperand(1), II->getParent(), LI);
1273 }
1274 }
1275 }
1276 for (unsigned i = 0; i < Gathers.size(); i++) {
1277 IntrinsicInst *I = Gathers[i];
1278 Instruction *L = lowerGather(I);
1279 if (L == nullptr)
1280 continue;
1281
1282 // Get rid of any now dead instructions
1283 SimplifyInstructionsInBlock(L->getParent());
1284 Changed = true;
1285 }
1286
1287 for (unsigned i = 0; i < Scatters.size(); i++) {
1288 IntrinsicInst *I = Scatters[i];
1289 Instruction *S = lowerScatter(I);
1290 if (S == nullptr)
1291 continue;
1292
1293 // Get rid of any now dead instructions
1295 Changed = true;
1296 }
1297 return Changed;
1298}
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
assume Assume Builder
This file contains the declarations for the subclasses of Constant, which represent the different fla...
static Decomposition decomposeGEP(GEPOperator &GEP, SmallVectorImpl< ConditionTy > &Preconditions, bool IsSigned, const DataLayout &DL)
#define LLVM_DEBUG(X)
Definition: Debug.h:101
static GCMetadataPrinterRegistry::Add< ErlangGCPrinter > X("erlang", "erlang-compatible garbage collector")
Hexagon Common GEP
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
static bool isAddLikeOr(Instruction *I, const DataLayout &DL)
static bool hasAllGatScatUsers(Instruction *I, const DataLayout &DL)
static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount)
static Value * CheckAndCreateOffsetAdd(Value *X, unsigned ScaleX, Value *Y, unsigned ScaleY, IRBuilder<> &Builder)
#define DEBUG_TYPE
cl::opt< bool > EnableMaskedGatherScatters("enable-arm-maskedgatscat", cl::Hidden, cl::init(true), cl::desc("Enable the generation of masked gathers and scatters"))
static GCMetadataPrinterRegistry::Add< OcamlGCMetadataPrinter > Y("ocaml", "ocaml 3.10-compatible collector")
const char LLVMTargetMachineRef TM
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:38
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file describes how to lower LLVM code to machine code.
Target-Independent Code Generator Pass Configuration Options pass.
This pass exposes codegen information to IR-level passes.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:269
LLVM Basic Block Representation.
Definition: BasicBlock.h:56
LLVMContext & getContext() const
Get the context in which this basic block lives.
Definition: BasicBlock.cpp:35
static BinaryOperator * Create(BinaryOps Op, Value *S1, Value *S2, const Twine &Name=Twine(), Instruction *InsertBefore=nullptr)
Construct a binary instruction, given the opcode and the two operands.
BinaryOps getOpcode() const
Definition: InstrTypes.h:391
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1357
Type * getDestTy() const
Return the destination type, as a convenience.
Definition: InstrTypes.h:675
This is the shared class of boolean and integer constants.
Definition: Constants.h:78
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
Definition: Constants.cpp:888
int64_t getSExtValue() const
Return the constant as a 64-bit integer value after it has been sign extended as appropriate for the ...
Definition: Constants.h:151
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition: Constants.h:145
This is an important base class in LLVM.
Definition: Constant.h:41
Constant * getAggregateElement(unsigned Elt) const
For aggregates (struct/array/vector) return the constant that corresponds to the specified element if...
Definition: Constants.cpp:418
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:110
static ExtractValueInst * Create(Value *Agg, ArrayRef< unsigned > Idxs, const Twine &NameStr="", Instruction *InsertBefore=nullptr)
Class to represent fixed width SIMD vectors.
Definition: DerivedTypes.h:536
unsigned getNumElements() const
Definition: DerivedTypes.h:579
static FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition: Type.cpp:693
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:311
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
an instruction for type-safe pointer arithmetic to access elements of arrays and structs
Definition: Instructions.h:940
static GetElementPtrInst * Create(Type *PointeeType, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &NameStr="", Instruction *InsertBefore=nullptr)
Definition: Instructions.h:966
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2628
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
Definition: Instruction.h:392
const BasicBlock * getParent() const
Definition: Instruction.h:90
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
Definition: Instruction.h:195
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:83
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:47
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
Definition: IntrinsicInst.h:54
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition: LoopInfo.h:594
Represents a single loop in the control flow graph.
Definition: LoopInfo.h:47
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", Instruction *InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Pass interface - Implemented by all 'passes'.
Definition: Pass.h:94
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:98
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
Definition: Pass.cpp:81
static SelectInst * Create(Value *C, Value *S1, Value *S2, const Twine &NameStr="", Instruction *InsertBefore=nullptr, Instruction *MDFrom=nullptr)
size_t size() const
Definition: SmallVector.h:91
void push_back(const T &Elt)
Definition: SmallVector.h:416
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1200
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
Primary interface to the complete machine description for the target machine.
Definition: TargetMachine.h:78
Target-Independent Code Generator Pass Configuration Options.
This class represents a truncation of integer types.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
bool isVectorTy() const
True if this is an instance of VectorType.
Definition: Type.h:265
bool isIntOrIntVectorTy() const
Return true if this is an integer type or a vector of integer types.
Definition: Type.h:234
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
Type * getWithNewBitWidth(unsigned NewBitWidth) const
Given an integer or vector type, change the lane bitwidth to NewBitwidth, whilst keeping the old numb...
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Value * getOperand(unsigned i) const
Definition: User.h:169
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:535
bool hasNUses(unsigned N) const
Return true if this Value has exactly N uses.
Definition: Value.cpp:149
unsigned getNumUses() const
This method computes the number of uses of this Value.
Definition: Value.cpp:255
Type * getElementType() const
Definition: DerivedTypes.h:433
This class represents zero extension of integer types.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
Definition: BitmaskEnum.h:119
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
cst_pred_ty< is_one > m_One()
Match an integer 1 or a vector with all elements equal to 1.
Definition: PatternMatch.h:525
is_zero m_Zero()
Match any null constant or a vector with all elements equal to 0.
Definition: PatternMatch.h:545
Offsets
Offsets in bytes from the start of the input buffer.
Definition: SIInstrInfo.h:1404
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:445
NodeAddr< PhiNode * > Phi
Definition: RDFGraph.h:390
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
Pass * createMVEGatherScatterLoweringPass()
bool SimplifyInstructionsInBlock(BasicBlock *BB, const TargetLibraryInfo *TLI=nullptr)
Scan the specified basic block and try to simplify any instructions in it and recursively delete dead...
Definition: Local.cpp:717
bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
bool isGatherScatter(IntrinsicInst *IntInst)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
bool haveNoCommonBitsSet(const Value *LHS, const Value *RHS, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)
Return true if LHS and RHS have no common bits set.
@ Add
Sum of integers.
void initializeMVEGatherScatterLoweringPass(PassRegistry &)
#define N
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39