LLVM 23.0.0git
VectorCombine.cpp
Go to the documentation of this file.
1//===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass optimizes scalar/vector interactions using target cost models. The
10// transforms implemented here may not fit in traditional loop-based or SLP
11// vectorization passes.
12//
13//===----------------------------------------------------------------------===//
14
16#include "llvm/ADT/DenseMap.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/ScopeExit.h"
20#include "llvm/ADT/Statistic.h"
25#include "llvm/Analysis/Loads.h"
30#include "llvm/IR/Dominators.h"
31#include "llvm/IR/Function.h"
32#include "llvm/IR/IRBuilder.h"
40#include <numeric>
41#include <optional>
42#include <queue>
43#include <set>
44
45#define DEBUG_TYPE "vector-combine"
47
48using namespace llvm;
49using namespace llvm::PatternMatch;
50
51STATISTIC(NumVecLoad, "Number of vector loads formed");
52STATISTIC(NumVecCmp, "Number of vector compares formed");
53STATISTIC(NumVecBO, "Number of vector binops formed");
54STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed");
55STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast");
56STATISTIC(NumScalarOps, "Number of scalar unary + binary ops formed");
57STATISTIC(NumScalarCmp, "Number of scalar compares formed");
58STATISTIC(NumScalarIntrinsic, "Number of scalar intrinsic calls formed");
59
61 "disable-vector-combine", cl::init(false), cl::Hidden,
62 cl::desc("Disable all vector combine transforms"));
63
65 "disable-binop-extract-shuffle", cl::init(false), cl::Hidden,
66 cl::desc("Disable binop extract to shuffle transforms"));
67
69 "vector-combine-max-scan-instrs", cl::init(30), cl::Hidden,
70 cl::desc("Max number of instructions to scan for vector combining."));
71
72static const unsigned InvalidIndex = std::numeric_limits<unsigned>::max();
73
74namespace {
75class VectorCombine {
76public:
77 VectorCombine(Function &F, const TargetTransformInfo &TTI,
80 bool TryEarlyFoldsOnly)
81 : F(F), Builder(F.getContext(), InstSimplifyFolder(*DL)), TTI(TTI),
82 DT(DT), AA(AA), DL(DL), CostKind(CostKind),
83 SQ(*DL, /*TLI=*/nullptr, &DT, &AC),
84 TryEarlyFoldsOnly(TryEarlyFoldsOnly) {}
85
86 bool run();
87
88private:
89 Function &F;
91 const TargetTransformInfo &TTI;
92 const DominatorTree &DT;
93 AAResults &AA;
94 const DataLayout *DL;
95 TTI::TargetCostKind CostKind;
96 const SimplifyQuery SQ;
97
98 /// If true, only perform beneficial early IR transforms. Do not introduce new
99 /// vector operations.
100 bool TryEarlyFoldsOnly;
101
102 InstructionWorklist Worklist;
103
104 /// Next instruction to iterate. It will be updated when it is erased by
105 /// RecursivelyDeleteTriviallyDeadInstructions.
106 Instruction *NextInst;
107
108 // TODO: Direct calls from the top-level "run" loop use a plain "Instruction"
109 // parameter. That should be updated to specific sub-classes because the
110 // run loop was changed to dispatch on opcode.
111 bool vectorizeLoadInsert(Instruction &I);
112 bool widenSubvectorLoad(Instruction &I);
113 ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0,
114 ExtractElementInst *Ext1,
115 unsigned PreferredExtractIndex) const;
116 bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
117 const Instruction &I,
118 ExtractElementInst *&ConvertToShuffle,
119 unsigned PreferredExtractIndex);
120 Value *foldExtExtCmp(Value *V0, Value *V1, Value *ExtIndex, Instruction &I);
121 Value *foldExtExtBinop(Value *V0, Value *V1, Value *ExtIndex, Instruction &I);
122 bool foldExtractExtract(Instruction &I);
123 bool foldInsExtFNeg(Instruction &I);
124 bool foldInsExtBinop(Instruction &I);
125 bool foldInsExtVectorToShuffle(Instruction &I);
126 bool foldBitOpOfCastops(Instruction &I);
127 bool foldBitOpOfCastConstant(Instruction &I);
128 bool foldBitcastShuffle(Instruction &I);
129 bool scalarizeOpOrCmp(Instruction &I);
130 bool scalarizeVPIntrinsic(Instruction &I);
131 bool foldExtractedCmps(Instruction &I);
132 bool foldSelectsFromBitcast(Instruction &I);
133 bool foldBinopOfReductions(Instruction &I);
134 bool foldSingleElementStore(Instruction &I);
135 bool scalarizeLoad(Instruction &I);
136 bool scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy, Value *Ptr);
137 bool scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy, Value *Ptr);
138 bool scalarizeExtExtract(Instruction &I);
139 bool foldConcatOfBoolMasks(Instruction &I);
140 bool foldPermuteOfBinops(Instruction &I);
141 bool foldShuffleOfBinops(Instruction &I);
142 bool foldShuffleOfSelects(Instruction &I);
143 bool foldShuffleOfCastops(Instruction &I);
144 bool foldShuffleOfShuffles(Instruction &I);
145 bool foldPermuteOfIntrinsic(Instruction &I);
146 bool foldShufflesOfLengthChangingShuffles(Instruction &I);
147 bool foldShuffleOfIntrinsics(Instruction &I);
148 bool foldShuffleToIdentity(Instruction &I);
149 bool foldShuffleFromReductions(Instruction &I);
150 bool foldShuffleChainsToReduce(Instruction &I);
151 bool foldCastFromReductions(Instruction &I);
152 bool foldSignBitReductionCmp(Instruction &I);
153 bool foldICmpEqZeroVectorReduce(Instruction &I);
154 bool foldEquivalentReductionCmp(Instruction &I);
155 bool foldReduceAddCmpZero(Instruction &I);
156 bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
157 bool foldInterleaveIntrinsics(Instruction &I);
158 bool foldDeinterleaveIntrinsics(Instruction &I);
159 bool foldBitcastOfVPLoad(Instruction &I);
160 bool shrinkType(Instruction &I);
161 bool shrinkLoadForShuffles(Instruction &I);
162 bool shrinkPhiOfShuffles(Instruction &I);
163
164 void replaceValue(Instruction &Old, Value &New, bool Erase = true) {
165 LLVM_DEBUG(dbgs() << "VC: Replacing: " << Old << '\n');
166 LLVM_DEBUG(dbgs() << " With: " << New << '\n');
167 Old.replaceAllUsesWith(&New);
168 if (auto *NewI = dyn_cast<Instruction>(&New)) {
169 New.takeName(&Old);
170 Worklist.pushUsersToWorkList(*NewI);
171 Worklist.pushValue(NewI);
172 }
173 if (Erase && isInstructionTriviallyDead(&Old)) {
174 eraseInstruction(Old);
175 } else {
176 Worklist.push(&Old);
177 }
178 }
179
180 void eraseInstruction(Instruction &I) {
181 LLVM_DEBUG(dbgs() << "VC: Erasing: " << I << '\n');
182 SmallVector<Value *> Ops(I.operands());
183 Worklist.remove(&I);
184 I.eraseFromParent();
185
186 // Push remaining users of the operands and then the operand itself - allows
187 // further folds that were hindered by OneUse limits.
188 SmallPtrSet<Value *, 4> Visited;
189 for (Value *Op : Ops) {
190 if (!Visited.contains(Op)) {
191 if (auto *OpI = dyn_cast<Instruction>(Op)) {
193 OpI, nullptr, nullptr, [&](Value *V) {
194 if (auto *I = dyn_cast<Instruction>(V)) {
195 LLVM_DEBUG(dbgs() << "VC: Erased: " << *I << '\n');
196 Worklist.remove(I);
197 if (I == NextInst)
198 NextInst = NextInst->getNextNode();
199 Visited.insert(I);
200 }
201 }))
202 continue;
203 Worklist.pushUsersToWorkList(*OpI);
204 Worklist.pushValue(OpI);
205 }
206 }
207 }
208 }
209};
210} // namespace
211
212/// Return the source operand of a potentially bitcasted value. If there is no
213/// bitcast, return the input value itself.
215 while (auto *BitCast = dyn_cast<BitCastInst>(V))
216 V = BitCast->getOperand(0);
217 return V;
218}
219
220static bool canWidenLoad(LoadInst *Load, const TargetTransformInfo &TTI) {
221 // Do not widen load if atomic/volatile or under asan/hwasan/memtag/tsan.
222 // The widened load may load data from dirty regions or create data races
223 // non-existent in the source.
224 if (!Load || !Load->isSimple() || !Load->hasOneUse() ||
225 Load->getFunction()->hasFnAttribute(Attribute::SanitizeMemTag) ||
227 return false;
228
229 // We are potentially transforming byte-sized (8-bit) memory accesses, so make
230 // sure we have all of our type-based constraints in place for this target.
231 Type *ScalarTy = Load->getType()->getScalarType();
232 uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
233 unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
234 if (!ScalarSize || !MinVectorSize || MinVectorSize % ScalarSize != 0 ||
235 ScalarSize % 8 != 0)
236 return false;
237
238 return true;
239}
240
241bool VectorCombine::vectorizeLoadInsert(Instruction &I) {
242 // Match insert into fixed vector of scalar value.
243 // TODO: Handle non-zero insert index.
244 Value *Scalar;
245 if (!match(&I,
247 return false;
248
249 // Optionally match an extract from another vector.
250 Value *X;
251 bool HasExtract = match(Scalar, m_ExtractElt(m_Value(X), m_ZeroInt()));
252 if (!HasExtract)
253 X = Scalar;
254
255 auto *Load = dyn_cast<LoadInst>(X);
256 if (!canWidenLoad(Load, TTI))
257 return false;
258
259 Type *ScalarTy = Scalar->getType();
260 uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
261 unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
262
263 // Check safety of replacing the scalar load with a larger vector load.
264 // We use minimal alignment (maximum flexibility) because we only care about
265 // the dereferenceable region. When calculating cost and creating a new op,
266 // we may use a larger value based on alignment attributes.
267 Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
268 assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
269
270 unsigned MinVecNumElts = MinVectorSize / ScalarSize;
271 auto *MinVecTy = VectorType::get(ScalarTy, MinVecNumElts, false);
272 unsigned OffsetEltIndex = 0;
273 Align Alignment = Load->getAlign();
274 if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), *DL, Load, SQ.AC,
275 SQ.DT)) {
276 // It is not safe to load directly from the pointer, but we can still peek
277 // through gep offsets and check if it safe to load from a base address with
278 // updated alignment. If it is, we can shuffle the element(s) into place
279 // after loading.
280 unsigned OffsetBitWidth = DL->getIndexTypeSizeInBits(SrcPtr->getType());
281 APInt Offset(OffsetBitWidth, 0);
283
284 // We want to shuffle the result down from a high element of a vector, so
285 // the offset must be positive.
286 if (Offset.isNegative())
287 return false;
288
289 // The offset must be a multiple of the scalar element to shuffle cleanly
290 // in the element's size.
291 uint64_t ScalarSizeInBytes = ScalarSize / 8;
292 if (Offset.urem(ScalarSizeInBytes) != 0)
293 return false;
294
295 // If we load MinVecNumElts, will our target element still be loaded?
296 OffsetEltIndex = Offset.udiv(ScalarSizeInBytes).getZExtValue();
297 if (OffsetEltIndex >= MinVecNumElts)
298 return false;
299
300 if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), *DL, Load,
301 SQ.AC, SQ.DT))
302 return false;
303
304 // Update alignment with offset value. Note that the offset could be negated
305 // to more accurately represent "(new) SrcPtr - Offset = (old) SrcPtr", but
306 // negation does not change the result of the alignment calculation.
307 Alignment = commonAlignment(Alignment, Offset.getZExtValue());
308 }
309
310 // Original pattern: insertelt undef, load [free casts of] PtrOp, 0
311 // Use the greater of the alignment on the load or its source pointer.
312 Alignment = std::max(SrcPtr->getPointerAlignment(*DL), Alignment);
313 Type *LoadTy = Load->getType();
314 unsigned AS = Load->getPointerAddressSpace();
315 InstructionCost OldCost =
316 TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS, CostKind);
317 APInt DemandedElts = APInt::getOneBitSet(MinVecNumElts, 0);
318 OldCost +=
319 TTI.getScalarizationOverhead(MinVecTy, DemandedElts,
320 /* Insert */ true, HasExtract, CostKind);
321
322 // New pattern: load VecPtr
323 InstructionCost NewCost =
324 TTI.getMemoryOpCost(Instruction::Load, MinVecTy, Alignment, AS, CostKind);
325 // Optionally, we are shuffling the loaded vector element(s) into place.
326 // For the mask set everything but element 0 to undef to prevent poison from
327 // propagating from the extra loaded memory. This will also optionally
328 // shrink/grow the vector from the loaded size to the output size.
329 // We assume this operation has no cost in codegen if there was no offset.
330 // Note that we could use freeze to avoid poison problems, but then we might
331 // still need a shuffle to change the vector size.
332 auto *Ty = cast<FixedVectorType>(I.getType());
333 unsigned OutputNumElts = Ty->getNumElements();
334 SmallVector<int, 16> Mask(OutputNumElts, PoisonMaskElem);
335 assert(OffsetEltIndex < MinVecNumElts && "Address offset too big");
336 Mask[0] = OffsetEltIndex;
337 if (OffsetEltIndex)
338 NewCost += TTI.getShuffleCost(TTI::SK_PermuteSingleSrc, Ty, MinVecTy, Mask,
339 CostKind);
340
341 // We can aggressively convert to the vector form because the backend can
342 // invert this transform if it does not result in a performance win.
343 if (OldCost < NewCost || !NewCost.isValid())
344 return false;
345
346 // It is safe and potentially profitable to load a vector directly:
347 // inselt undef, load Scalar, 0 --> load VecPtr
348 IRBuilder<> Builder(Load);
349 Value *CastedPtr =
350 Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Builder.getPtrTy(AS));
351 Value *VecLd = Builder.CreateAlignedLoad(MinVecTy, CastedPtr, Alignment);
352 VecLd = Builder.CreateShuffleVector(VecLd, Mask);
353
354 replaceValue(I, *VecLd);
355 ++NumVecLoad;
356 return true;
357}
358
359/// If we are loading a vector and then inserting it into a larger vector with
360/// undefined elements, try to load the larger vector and eliminate the insert.
361/// This removes a shuffle in IR and may allow combining of other loaded values.
362bool VectorCombine::widenSubvectorLoad(Instruction &I) {
363 // Match subvector insert of fixed vector.
364 auto *Shuf = cast<ShuffleVectorInst>(&I);
365 if (!Shuf->isIdentityWithPadding())
366 return false;
367
368 // Allow a non-canonical shuffle mask that is choosing elements from op1.
369 unsigned NumOpElts =
370 cast<FixedVectorType>(Shuf->getOperand(0)->getType())->getNumElements();
371 unsigned OpIndex = any_of(Shuf->getShuffleMask(), [&NumOpElts](int M) {
372 return M >= (int)(NumOpElts);
373 });
374
375 auto *Load = dyn_cast<LoadInst>(Shuf->getOperand(OpIndex));
376 if (!canWidenLoad(Load, TTI))
377 return false;
378
379 // We use minimal alignment (maximum flexibility) because we only care about
380 // the dereferenceable region. When calculating cost and creating a new op,
381 // we may use a larger value based on alignment attributes.
382 auto *Ty = cast<FixedVectorType>(I.getType());
383 Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
384 assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
385 Align Alignment = Load->getAlign();
386 if (!isSafeToLoadUnconditionally(SrcPtr, Ty, Align(1), *DL, Load, SQ.AC,
387 SQ.DT))
388 return false;
389
390 Alignment = std::max(SrcPtr->getPointerAlignment(*DL), Alignment);
391 Type *LoadTy = Load->getType();
392 unsigned AS = Load->getPointerAddressSpace();
393
394 // Original pattern: insert_subvector (load PtrOp)
395 // This conservatively assumes that the cost of a subvector insert into an
396 // undef value is 0. We could add that cost if the cost model accurately
397 // reflects the real cost of that operation.
398 InstructionCost OldCost =
399 TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS, CostKind);
400
401 // New pattern: load PtrOp
402 InstructionCost NewCost =
403 TTI.getMemoryOpCost(Instruction::Load, Ty, Alignment, AS, CostKind);
404
405 // We can aggressively convert to the vector form because the backend can
406 // invert this transform if it does not result in a performance win.
407 if (OldCost < NewCost || !NewCost.isValid())
408 return false;
409
410 IRBuilder<> Builder(Load);
411 Value *CastedPtr =
412 Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Builder.getPtrTy(AS));
413 Value *VecLd = Builder.CreateAlignedLoad(Ty, CastedPtr, Alignment);
414 replaceValue(I, *VecLd);
415 ++NumVecLoad;
416 return true;
417}
418
419/// Determine which, if any, of the inputs should be replaced by a shuffle
420/// followed by extract from a different index.
421ExtractElementInst *VectorCombine::getShuffleExtract(
422 ExtractElementInst *Ext0, ExtractElementInst *Ext1,
423 unsigned PreferredExtractIndex = InvalidIndex) const {
424 auto *Index0C = dyn_cast<ConstantInt>(Ext0->getIndexOperand());
425 auto *Index1C = dyn_cast<ConstantInt>(Ext1->getIndexOperand());
426 assert(Index0C && Index1C && "Expected constant extract indexes");
427
428 unsigned Index0 = Index0C->getZExtValue();
429 unsigned Index1 = Index1C->getZExtValue();
430
431 // If the extract indexes are identical, no shuffle is needed.
432 if (Index0 == Index1)
433 return nullptr;
434
435 Type *VecTy = Ext0->getVectorOperand()->getType();
436 assert(VecTy == Ext1->getVectorOperand()->getType() && "Need matching types");
437 InstructionCost Cost0 =
438 TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0);
439 InstructionCost Cost1 =
440 TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1);
441
442 // If both costs are invalid no shuffle is needed
443 if (!Cost0.isValid() && !Cost1.isValid())
444 return nullptr;
445
446 // We are extracting from 2 different indexes, so one operand must be shuffled
447 // before performing a vector operation and/or extract. The more expensive
448 // extract will be replaced by a shuffle.
449 if (Cost0 > Cost1)
450 return Ext0;
451 if (Cost1 > Cost0)
452 return Ext1;
453
454 // If the costs are equal and there is a preferred extract index, shuffle the
455 // opposite operand.
456 if (PreferredExtractIndex == Index0)
457 return Ext1;
458 if (PreferredExtractIndex == Index1)
459 return Ext0;
460
461 // Otherwise, replace the extract with the higher index.
462 return Index0 > Index1 ? Ext0 : Ext1;
463}
464
465/// Compare the relative costs of 2 extracts followed by scalar operation vs.
466/// vector operation(s) followed by extract. Return true if the existing
467/// instructions are cheaper than a vector alternative. Otherwise, return false
468/// and if one of the extracts should be transformed to a shufflevector, set
469/// \p ConvertToShuffle to that extract instruction.
470bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0,
471 ExtractElementInst *Ext1,
472 const Instruction &I,
473 ExtractElementInst *&ConvertToShuffle,
474 unsigned PreferredExtractIndex) {
475 auto *Ext0IndexC = dyn_cast<ConstantInt>(Ext0->getIndexOperand());
476 auto *Ext1IndexC = dyn_cast<ConstantInt>(Ext1->getIndexOperand());
477 assert(Ext0IndexC && Ext1IndexC && "Expected constant extract indexes");
478
479 unsigned Opcode = I.getOpcode();
480 Value *Ext0Src = Ext0->getVectorOperand();
481 Value *Ext1Src = Ext1->getVectorOperand();
482 Type *ScalarTy = Ext0->getType();
483 auto *VecTy = cast<VectorType>(Ext0Src->getType());
484 InstructionCost ScalarOpCost, VectorOpCost;
485
486 // Get cost estimates for scalar and vector versions of the operation.
487 bool IsBinOp = Instruction::isBinaryOp(Opcode);
488 if (IsBinOp) {
489 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy, CostKind);
490 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy, CostKind);
491 } else {
492 assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
493 "Expected a compare");
494 CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate();
495 ScalarOpCost = TTI.getCmpSelInstrCost(
496 Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred, CostKind);
497 VectorOpCost = TTI.getCmpSelInstrCost(
498 Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind);
499 }
500
501 // Get cost estimates for the extract elements. These costs will factor into
502 // both sequences.
503 unsigned Ext0Index = Ext0IndexC->getZExtValue();
504 unsigned Ext1Index = Ext1IndexC->getZExtValue();
505
506 InstructionCost Extract0Cost =
507 TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Ext0Index);
508 InstructionCost Extract1Cost =
509 TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Ext1Index);
510
511 // A more expensive extract will always be replaced by a splat shuffle.
512 // For example, if Ext0 is more expensive:
513 // opcode (extelt V0, Ext0), (ext V1, Ext1) -->
514 // extelt (opcode (splat V0, Ext0), V1), Ext1
515 // TODO: Evaluate whether that always results in lowest cost. Alternatively,
516 // check the cost of creating a broadcast shuffle and shuffling both
517 // operands to element 0.
518 unsigned BestExtIndex = Extract0Cost > Extract1Cost ? Ext0Index : Ext1Index;
519 unsigned BestInsIndex = Extract0Cost > Extract1Cost ? Ext1Index : Ext0Index;
520 InstructionCost CheapExtractCost = std::min(Extract0Cost, Extract1Cost);
521
522 // Extra uses of the extracts mean that we include those costs in the
523 // vector total because those instructions will not be eliminated.
524 InstructionCost OldCost, NewCost;
525 if (Ext0Src == Ext1Src && Ext0Index == Ext1Index) {
526 // Handle a special case. If the 2 extracts are identical, adjust the
527 // formulas to account for that. The extra use charge allows for either the
528 // CSE'd pattern or an unoptimized form with identical values:
529 // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C
530 bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2)
531 : !Ext0->hasOneUse() || !Ext1->hasOneUse();
532 OldCost = CheapExtractCost + ScalarOpCost;
533 NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost;
534 } else {
535 // Handle the general case. Each extract is actually a different value:
536 // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C
537 OldCost = Extract0Cost + Extract1Cost + ScalarOpCost;
538 NewCost = VectorOpCost + CheapExtractCost +
539 !Ext0->hasOneUse() * Extract0Cost +
540 !Ext1->hasOneUse() * Extract1Cost;
541 }
542
543 ConvertToShuffle = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex);
544 if (ConvertToShuffle) {
545 if (IsBinOp && DisableBinopExtractShuffle)
546 return true;
547
548 // If we are extracting from 2 different indexes, then one operand must be
549 // shuffled before performing the vector operation. The shuffle mask is
550 // poison except for 1 lane that is being translated to the remaining
551 // extraction lane. Therefore, it is a splat shuffle. Ex:
552 // ShufMask = { poison, poison, 0, poison }
553 // TODO: The cost model has an option for a "broadcast" shuffle
554 // (splat-from-element-0), but no option for a more general splat.
555 if (auto *FixedVecTy = dyn_cast<FixedVectorType>(VecTy)) {
556 SmallVector<int> ShuffleMask(FixedVecTy->getNumElements(),
558 ShuffleMask[BestInsIndex] = BestExtIndex;
560 VecTy, VecTy, ShuffleMask, CostKind, 0,
561 nullptr, {ConvertToShuffle});
562 } else {
564 VecTy, VecTy, {}, CostKind, 0, nullptr,
565 {ConvertToShuffle});
566 }
567 }
568
569 // Aggressively form a vector op if the cost is equal because the transform
570 // may enable further optimization.
571 // Codegen can reverse this transform (scalarize) if it was not profitable.
572 return OldCost < NewCost;
573}
574
575/// Create a shuffle that translates (shifts) 1 element from the input vector
576/// to a new element location.
577static Value *createShiftShuffle(Value *Vec, unsigned OldIndex,
578 unsigned NewIndex, IRBuilderBase &Builder) {
579 // The shuffle mask is poison except for 1 lane that is being translated
580 // to the new element index. Example for OldIndex == 2 and NewIndex == 0:
581 // ShufMask = { 2, poison, poison, poison }
582 auto *VecTy = cast<FixedVectorType>(Vec->getType());
583 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
584 ShufMask[NewIndex] = OldIndex;
585 return Builder.CreateShuffleVector(Vec, ShufMask, "shift");
586}
587
588/// Given an extract element instruction with constant index operand, shuffle
589/// the source vector (shift the scalar element) to a NewIndex for extraction.
590/// Return null if the input can be constant folded, so that we are not creating
591/// unnecessary instructions.
592static Value *translateExtract(ExtractElementInst *ExtElt, unsigned NewIndex,
593 IRBuilderBase &Builder) {
594 // Shufflevectors can only be created for fixed-width vectors.
595 Value *X = ExtElt->getVectorOperand();
596 if (!isa<FixedVectorType>(X->getType()))
597 return nullptr;
598
599 // If the extract can be constant-folded, this code is unsimplified. Defer
600 // to other passes to handle that.
601 Value *C = ExtElt->getIndexOperand();
602 assert(isa<ConstantInt>(C) && "Expected a constant index operand");
603 if (isa<Constant>(X))
604 return nullptr;
605
606 Value *Shuf = createShiftShuffle(X, cast<ConstantInt>(C)->getZExtValue(),
607 NewIndex, Builder);
608 return Shuf;
609}
610
611/// Try to reduce extract element costs by converting scalar compares to vector
612/// compares followed by extract.
613/// cmp (ext0 V0, ExtIndex), (ext1 V1, ExtIndex)
614Value *VectorCombine::foldExtExtCmp(Value *V0, Value *V1, Value *ExtIndex,
615 Instruction &I) {
616 assert(isa<CmpInst>(&I) && "Expected a compare");
617
618 // cmp Pred (extelt V0, ExtIndex), (extelt V1, ExtIndex)
619 // --> extelt (cmp Pred V0, V1), ExtIndex
620 ++NumVecCmp;
621 CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate();
622 Value *VecCmp = Builder.CreateCmp(Pred, V0, V1);
623 return Builder.CreateExtractElement(VecCmp, ExtIndex, "foldExtExtCmp");
624}
625
626/// Try to reduce extract element costs by converting scalar binops to vector
627/// binops followed by extract.
628/// bo (ext0 V0, ExtIndex), (ext1 V1, ExtIndex)
629Value *VectorCombine::foldExtExtBinop(Value *V0, Value *V1, Value *ExtIndex,
630 Instruction &I) {
631 assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
632
633 // bo (extelt V0, ExtIndex), (extelt V1, ExtIndex)
634 // --> extelt (bo V0, V1), ExtIndex
635 ++NumVecBO;
636 Value *VecBO = Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0,
637 V1, "foldExtExtBinop");
638
639 // All IR flags are safe to back-propagate because any potential poison
640 // created in unused vector elements is discarded by the extract.
641 if (auto *VecBOInst = dyn_cast<Instruction>(VecBO))
642 VecBOInst->copyIRFlags(&I);
643
644 return Builder.CreateExtractElement(VecBO, ExtIndex, "foldExtExtBinop");
645}
646
647/// Match an instruction with extracted vector operands.
648bool VectorCombine::foldExtractExtract(Instruction &I) {
649 // It is not safe to transform things like div, urem, etc. because we may
650 // create undefined behavior when executing those on unknown vector elements.
652 return false;
653
654 Instruction *I0, *I1;
655 CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
656 if (!match(&I, m_Cmp(Pred, m_Instruction(I0), m_Instruction(I1))) &&
658 return false;
659
660 Value *V0, *V1;
661 uint64_t C0, C1;
662 if (!match(I0, m_ExtractElt(m_Value(V0), m_ConstantInt(C0))) ||
664 V0->getType() != V1->getType())
665 return false;
666
667 // For fixed-width vectors, reject out-of-bounds extract indexes
668 if (auto *FixedVecTy = dyn_cast<FixedVectorType>(V0->getType())) {
669 unsigned NumElts = FixedVecTy->getNumElements();
670 if (C0 >= NumElts || C1 >= NumElts)
671 return false;
672 }
673
674 // If the scalar value 'I' is going to be re-inserted into a vector, then try
675 // to create an extract to that same element. The extract/insert can be
676 // reduced to a "select shuffle".
677 // TODO: If we add a larger pattern match that starts from an insert, this
678 // probably becomes unnecessary.
679 auto *Ext0 = cast<ExtractElementInst>(I0);
680 auto *Ext1 = cast<ExtractElementInst>(I1);
681 uint64_t InsertIndex = InvalidIndex;
682 if (I.hasOneUse())
683 match(I.user_back(),
684 m_InsertElt(m_Value(), m_Value(), m_ConstantInt(InsertIndex)));
685
686 ExtractElementInst *ExtractToChange;
687 if (isExtractExtractCheap(Ext0, Ext1, I, ExtractToChange, InsertIndex))
688 return false;
689
690 Value *ExtOp0 = Ext0->getVectorOperand();
691 Value *ExtOp1 = Ext1->getVectorOperand();
692
693 if (ExtractToChange) {
694 unsigned CheapExtractIdx = ExtractToChange == Ext0 ? C1 : C0;
695 Value *NewExtOp =
696 translateExtract(ExtractToChange, CheapExtractIdx, Builder);
697 if (!NewExtOp)
698 return false;
699 if (ExtractToChange == Ext0)
700 ExtOp0 = NewExtOp;
701 else
702 ExtOp1 = NewExtOp;
703 }
704
705 Value *ExtIndex = ExtractToChange == Ext0 ? Ext1->getIndexOperand()
706 : Ext0->getIndexOperand();
707 Value *NewExt = Pred != CmpInst::BAD_ICMP_PREDICATE
708 ? foldExtExtCmp(ExtOp0, ExtOp1, ExtIndex, I)
709 : foldExtExtBinop(ExtOp0, ExtOp1, ExtIndex, I);
710 Worklist.push(Ext0);
711 Worklist.push(Ext1);
712 replaceValue(I, *NewExt);
713 return true;
714}
715
716/// Try to replace an extract + scalar fneg + insert with a vector fneg +
717/// shuffle.
718bool VectorCombine::foldInsExtFNeg(Instruction &I) {
719 // Match an insert (op (extract)) pattern.
720 Value *DstVec;
721 uint64_t ExtIdx, InsIdx;
722 Instruction *FNeg;
723 if (!match(&I, m_InsertElt(m_Value(DstVec), m_OneUse(m_Instruction(FNeg)),
724 m_ConstantInt(InsIdx))))
725 return false;
726
727 // Note: This handles the canonical fneg instruction and "fsub -0.0, X".
728 Value *SrcVec;
729 Instruction *Extract;
730 if (!match(FNeg, m_FNeg(m_CombineAnd(
731 m_Instruction(Extract),
732 m_ExtractElt(m_Value(SrcVec), m_ConstantInt(ExtIdx))))))
733 return false;
734
735 auto *DstVecTy = cast<FixedVectorType>(DstVec->getType());
736 auto *DstVecScalarTy = DstVecTy->getScalarType();
737 auto *SrcVecTy = dyn_cast<FixedVectorType>(SrcVec->getType());
738 if (!SrcVecTy || DstVecScalarTy != SrcVecTy->getScalarType())
739 return false;
740
741 // Ignore if insert/extract index is out of bounds or destination vector has
742 // one element
743 unsigned NumDstElts = DstVecTy->getNumElements();
744 unsigned NumSrcElts = SrcVecTy->getNumElements();
745 if (ExtIdx > NumSrcElts || InsIdx >= NumDstElts || NumDstElts == 1)
746 return false;
747
748 // We are inserting the negated element into the same lane that we extracted
749 // from. This is equivalent to a select-shuffle that chooses all but the
750 // negated element from the destination vector.
751 SmallVector<int> Mask(NumDstElts);
752 std::iota(Mask.begin(), Mask.end(), 0);
753 Mask[InsIdx] = (ExtIdx % NumDstElts) + NumDstElts;
754 InstructionCost OldCost =
755 TTI.getArithmeticInstrCost(Instruction::FNeg, DstVecScalarTy, CostKind) +
756 TTI.getVectorInstrCost(I, DstVecTy, CostKind, InsIdx);
757
758 // If the extract has one use, it will be eliminated, so count it in the
759 // original cost. If it has more than one use, ignore the cost because it will
760 // be the same before/after.
761 if (Extract->hasOneUse())
762 OldCost += TTI.getVectorInstrCost(*Extract, SrcVecTy, CostKind, ExtIdx);
763
764 InstructionCost NewCost =
765 TTI.getArithmeticInstrCost(Instruction::FNeg, SrcVecTy, CostKind) +
767 DstVecTy, Mask, CostKind);
768
769 bool NeedLenChg = SrcVecTy->getNumElements() != NumDstElts;
770 // If the lengths of the two vectors are not equal,
771 // we need to add a length-change vector. Add this cost.
772 SmallVector<int> SrcMask;
773 if (NeedLenChg) {
774 SrcMask.assign(NumDstElts, PoisonMaskElem);
775 SrcMask[ExtIdx % NumDstElts] = ExtIdx;
777 DstVecTy, SrcVecTy, SrcMask, CostKind);
778 }
779
780 LLVM_DEBUG(dbgs() << "Found an insertion of (extract)fneg : " << I
781 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
782 << "\n");
783 if (NewCost > OldCost)
784 return false;
785
786 Value *NewShuf, *LenChgShuf = nullptr;
787 // insertelt DstVec, (fneg (extractelt SrcVec, Index)), Index
788 Value *VecFNeg = Builder.CreateFNegFMF(SrcVec, FNeg);
789 if (NeedLenChg) {
790 // shuffle DstVec, (shuffle (fneg SrcVec), poison, SrcMask), Mask
791 LenChgShuf = Builder.CreateShuffleVector(VecFNeg, SrcMask);
792 NewShuf = Builder.CreateShuffleVector(DstVec, LenChgShuf, Mask);
793 Worklist.pushValue(LenChgShuf);
794 } else {
795 // shuffle DstVec, (fneg SrcVec), Mask
796 NewShuf = Builder.CreateShuffleVector(DstVec, VecFNeg, Mask);
797 }
798
799 Worklist.pushValue(VecFNeg);
800 replaceValue(I, *NewShuf);
801 return true;
802}
803
804/// Try to fold insert(binop(x,y),binop(a,b),idx)
805/// --> binop(insert(x,a,idx),insert(y,b,idx))
806bool VectorCombine::foldInsExtBinop(Instruction &I) {
807 BinaryOperator *VecBinOp, *SclBinOp;
808 uint64_t Index;
809 if (!match(&I,
810 m_InsertElt(m_OneUse(m_BinOp(VecBinOp)),
811 m_OneUse(m_BinOp(SclBinOp)), m_ConstantInt(Index))))
812 return false;
813
814 // TODO: Add support for addlike etc.
815 Instruction::BinaryOps BinOpcode = VecBinOp->getOpcode();
816 if (BinOpcode != SclBinOp->getOpcode())
817 return false;
818
819 auto *ResultTy = dyn_cast<FixedVectorType>(I.getType());
820 if (!ResultTy)
821 return false;
822
823 // TODO: Attempt to detect m_ExtractElt for scalar operands and convert to
824 // shuffle?
825
827 TTI.getInstructionCost(VecBinOp, CostKind) +
829 InstructionCost NewCost =
830 TTI.getArithmeticInstrCost(BinOpcode, ResultTy, CostKind) +
831 TTI.getVectorInstrCost(Instruction::InsertElement, ResultTy, CostKind,
832 Index, VecBinOp->getOperand(0),
833 SclBinOp->getOperand(0)) +
834 TTI.getVectorInstrCost(Instruction::InsertElement, ResultTy, CostKind,
835 Index, VecBinOp->getOperand(1),
836 SclBinOp->getOperand(1));
837
838 LLVM_DEBUG(dbgs() << "Found an insertion of two binops: " << I
839 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
840 << "\n");
841 if (NewCost > OldCost)
842 return false;
843
844 Value *NewIns0 = Builder.CreateInsertElement(VecBinOp->getOperand(0),
845 SclBinOp->getOperand(0), Index);
846 Value *NewIns1 = Builder.CreateInsertElement(VecBinOp->getOperand(1),
847 SclBinOp->getOperand(1), Index);
848 Value *NewBO = Builder.CreateBinOp(BinOpcode, NewIns0, NewIns1);
849
850 // Intersect flags from the old binops.
851 if (auto *NewInst = dyn_cast<Instruction>(NewBO)) {
852 NewInst->copyIRFlags(VecBinOp);
853 NewInst->andIRFlags(SclBinOp);
854 }
855
856 Worklist.pushValue(NewIns0);
857 Worklist.pushValue(NewIns1);
858 replaceValue(I, *NewBO);
859 return true;
860}
861
862/// Match: bitop(castop(x), castop(y)) -> castop(bitop(x, y))
863/// Supports: bitcast, trunc, sext, zext
864bool VectorCombine::foldBitOpOfCastops(Instruction &I) {
865 // Check if this is a bitwise logic operation
866 auto *BinOp = dyn_cast<BinaryOperator>(&I);
867 if (!BinOp || !BinOp->isBitwiseLogicOp())
868 return false;
869
870 // Get the cast instructions
871 auto *LHSCast = dyn_cast<CastInst>(BinOp->getOperand(0));
872 auto *RHSCast = dyn_cast<CastInst>(BinOp->getOperand(1));
873 if (!LHSCast || !RHSCast) {
874 LLVM_DEBUG(dbgs() << " One or both operands are not cast instructions\n");
875 return false;
876 }
877
878 // Both casts must be the same type
879 Instruction::CastOps CastOpcode = LHSCast->getOpcode();
880 if (CastOpcode != RHSCast->getOpcode())
881 return false;
882
883 // Only handle supported cast operations
884 switch (CastOpcode) {
885 case Instruction::BitCast:
886 case Instruction::Trunc:
887 case Instruction::SExt:
888 case Instruction::ZExt:
889 break;
890 default:
891 return false;
892 }
893
894 Value *LHSSrc = LHSCast->getOperand(0);
895 Value *RHSSrc = RHSCast->getOperand(0);
896
897 // Source types must match
898 if (LHSSrc->getType() != RHSSrc->getType())
899 return false;
900
901 auto *SrcTy = LHSSrc->getType();
902 auto *DstTy = I.getType();
903 // Bitcasts can handle scalar/vector mixes, such as i16 -> <16 x i1>.
904 // Other casts only handle vector types with integer elements.
905 if (CastOpcode != Instruction::BitCast &&
906 (!isa<FixedVectorType>(SrcTy) || !isa<FixedVectorType>(DstTy)))
907 return false;
908
909 // Only integer scalar/vector values are legal for bitwise logic operations.
910 if (!SrcTy->getScalarType()->isIntegerTy() ||
911 !DstTy->getScalarType()->isIntegerTy())
912 return false;
913
914 // Cost Check :
915 // OldCost = bitlogic + 2*casts
916 // NewCost = bitlogic + cast
917
918 // Calculate specific costs for each cast with instruction context
920 CastOpcode, DstTy, SrcTy, TTI::CastContextHint::None, CostKind, LHSCast);
922 CastOpcode, DstTy, SrcTy, TTI::CastContextHint::None, CostKind, RHSCast);
923
924 InstructionCost OldCost =
925 TTI.getArithmeticInstrCost(BinOp->getOpcode(), DstTy, CostKind) +
926 LHSCastCost + RHSCastCost;
927
928 // For new cost, we can't provide an instruction (it doesn't exist yet)
929 InstructionCost GenericCastCost = TTI.getCastInstrCost(
930 CastOpcode, DstTy, SrcTy, TTI::CastContextHint::None, CostKind);
931
932 InstructionCost NewCost =
933 TTI.getArithmeticInstrCost(BinOp->getOpcode(), SrcTy, CostKind) +
934 GenericCastCost;
935
936 // Account for multi-use casts using specific costs
937 if (!LHSCast->hasOneUse())
938 NewCost += LHSCastCost;
939 if (!RHSCast->hasOneUse())
940 NewCost += RHSCastCost;
941
942 LLVM_DEBUG(dbgs() << "foldBitOpOfCastops: OldCost=" << OldCost
943 << " NewCost=" << NewCost << "\n");
944
945 if (NewCost > OldCost)
946 return false;
947
948 // Create the operation on the source type
949 Value *NewOp = Builder.CreateBinOp(BinOp->getOpcode(), LHSSrc, RHSSrc,
950 BinOp->getName() + ".inner");
951 if (auto *NewBinOp = dyn_cast<BinaryOperator>(NewOp))
952 NewBinOp->copyIRFlags(BinOp);
953
954 Worklist.pushValue(NewOp);
955
956 // Create the cast operation directly to ensure we get a new instruction
957 Instruction *NewCast = CastInst::Create(CastOpcode, NewOp, I.getType());
958
959 // Preserve cast instruction flags
960 NewCast->copyIRFlags(LHSCast);
961 NewCast->andIRFlags(RHSCast);
962
963 // Insert the new instruction
964 Value *Result = Builder.Insert(NewCast);
965
966 replaceValue(I, *Result);
967 return true;
968}
969
970/// Match:
971// bitop(castop(x), C) ->
972// bitop(castop(x), castop(InvC)) ->
973// castop(bitop(x, InvC))
974// Supports: bitcast
975bool VectorCombine::foldBitOpOfCastConstant(Instruction &I) {
977 Constant *C;
978
979 // Check if this is a bitwise logic operation
981 return false;
982
983 // Get the cast instructions
984 auto *LHSCast = dyn_cast<CastInst>(LHS);
985 if (!LHSCast)
986 return false;
987
988 Instruction::CastOps CastOpcode = LHSCast->getOpcode();
989
990 // Only handle supported cast operations
991 switch (CastOpcode) {
992 case Instruction::BitCast:
993 case Instruction::ZExt:
994 case Instruction::SExt:
995 case Instruction::Trunc:
996 break;
997 default:
998 return false;
999 }
1000
1001 Value *LHSSrc = LHSCast->getOperand(0);
1002
1003 auto *SrcTy = LHSSrc->getType();
1004 auto *DstTy = I.getType();
1005 // Bitcasts can handle scalar/vector mixes, such as i16 -> <16 x i1>.
1006 // Other casts only handle vector types with integer elements.
1007 if (CastOpcode != Instruction::BitCast &&
1008 (!isa<FixedVectorType>(SrcTy) || !isa<FixedVectorType>(DstTy)))
1009 return false;
1010
1011 // Only integer scalar/vector values are legal for bitwise logic operations.
1012 if (!SrcTy->getScalarType()->isIntegerTy() ||
1013 !DstTy->getScalarType()->isIntegerTy())
1014 return false;
1015
1016 // Find the constant InvC, such that castop(InvC) equals to C.
1017 PreservedCastFlags RHSFlags;
1018 Constant *InvC = getLosslessInvCast(C, SrcTy, CastOpcode, *DL, &RHSFlags);
1019 if (!InvC)
1020 return false;
1021
1022 // Cost Check :
1023 // OldCost = bitlogic + cast
1024 // NewCost = bitlogic + cast
1025
1026 // Calculate specific costs for each cast with instruction context
1027 InstructionCost LHSCastCost = TTI.getCastInstrCost(
1028 CastOpcode, DstTy, SrcTy, TTI::CastContextHint::None, CostKind, LHSCast);
1029
1030 InstructionCost OldCost =
1031 TTI.getArithmeticInstrCost(I.getOpcode(), DstTy, CostKind) + LHSCastCost;
1032
1033 // For new cost, we can't provide an instruction (it doesn't exist yet)
1034 InstructionCost GenericCastCost = TTI.getCastInstrCost(
1035 CastOpcode, DstTy, SrcTy, TTI::CastContextHint::None, CostKind);
1036
1037 InstructionCost NewCost =
1038 TTI.getArithmeticInstrCost(I.getOpcode(), SrcTy, CostKind) +
1039 GenericCastCost;
1040
1041 // Account for multi-use casts using specific costs
1042 if (!LHSCast->hasOneUse())
1043 NewCost += LHSCastCost;
1044
1045 LLVM_DEBUG(dbgs() << "foldBitOpOfCastConstant: OldCost=" << OldCost
1046 << " NewCost=" << NewCost << "\n");
1047
1048 if (NewCost > OldCost)
1049 return false;
1050
1051 // Create the operation on the source type
1052 Value *NewOp = Builder.CreateBinOp((Instruction::BinaryOps)I.getOpcode(),
1053 LHSSrc, InvC, I.getName() + ".inner");
1054 if (auto *NewBinOp = dyn_cast<BinaryOperator>(NewOp))
1055 NewBinOp->copyIRFlags(&I);
1056
1057 Worklist.pushValue(NewOp);
1058
1059 // Create the cast operation directly to ensure we get a new instruction
1060 Instruction *NewCast = CastInst::Create(CastOpcode, NewOp, I.getType());
1061
1062 // Preserve cast instruction flags
1063 if (RHSFlags.NNeg)
1064 NewCast->setNonNeg();
1065 if (RHSFlags.NUW)
1066 NewCast->setHasNoUnsignedWrap();
1067 if (RHSFlags.NSW)
1068 NewCast->setHasNoSignedWrap();
1069
1070 NewCast->andIRFlags(LHSCast);
1071
1072 // Insert the new instruction
1073 Value *Result = Builder.Insert(NewCast);
1074
1075 replaceValue(I, *Result);
1076 return true;
1077}
1078
1079/// If this is a bitcast of a shuffle, try to bitcast the source vector to the
1080/// destination type followed by shuffle. This can enable further transforms by
1081/// moving bitcasts or shuffles together.
1082bool VectorCombine::foldBitcastShuffle(Instruction &I) {
1083 Value *V0, *V1;
1084 ArrayRef<int> Mask;
1085 if (!match(&I, m_BitCast(m_OneUse(
1086 m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(Mask))))))
1087 return false;
1088
1089 // 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for
1090 // scalable type is unknown; Second, we cannot reason if the narrowed shuffle
1091 // mask for scalable type is a splat or not.
1092 // 2) Disallow non-vector casts.
1093 // TODO: We could allow any shuffle.
1094 auto *DestTy = dyn_cast<FixedVectorType>(I.getType());
1095 auto *SrcTy = dyn_cast<FixedVectorType>(V0->getType());
1096 if (!DestTy || !SrcTy)
1097 return false;
1098
1099 unsigned DestEltSize = DestTy->getScalarSizeInBits();
1100 unsigned SrcEltSize = SrcTy->getScalarSizeInBits();
1101 if (SrcTy->getPrimitiveSizeInBits() % DestEltSize != 0)
1102 return false;
1103
1104 bool IsUnary = isa<UndefValue>(V1);
1105
1106 // For binary shuffles, only fold bitcast(shuffle(X,Y))
1107 // if it won't increase the number of bitcasts.
1108 if (!IsUnary) {
1111 if (!(BCTy0 && BCTy0->getElementType() == DestTy->getElementType()) &&
1112 !(BCTy1 && BCTy1->getElementType() == DestTy->getElementType()))
1113 return false;
1114 }
1115
1116 SmallVector<int, 16> NewMask;
1117 if (DestEltSize <= SrcEltSize) {
1118 // The bitcast is from wide to narrow/equal elements. The shuffle mask can
1119 // always be expanded to the equivalent form choosing narrower elements.
1120 if (SrcEltSize % DestEltSize != 0)
1121 return false;
1122 unsigned ScaleFactor = SrcEltSize / DestEltSize;
1123 narrowShuffleMaskElts(ScaleFactor, Mask, NewMask);
1124 } else {
1125 // The bitcast is from narrow elements to wide elements. The shuffle mask
1126 // must choose consecutive elements to allow casting first.
1127 if (DestEltSize % SrcEltSize != 0)
1128 return false;
1129 unsigned ScaleFactor = DestEltSize / SrcEltSize;
1130 if (!widenShuffleMaskElts(ScaleFactor, Mask, NewMask))
1131 return false;
1132 }
1133
1134 // Bitcast the shuffle src - keep its original width but using the destination
1135 // scalar type.
1136 unsigned NumSrcElts = SrcTy->getPrimitiveSizeInBits() / DestEltSize;
1137 auto *NewShuffleTy =
1138 FixedVectorType::get(DestTy->getScalarType(), NumSrcElts);
1139 auto *OldShuffleTy =
1140 FixedVectorType::get(SrcTy->getScalarType(), Mask.size());
1141 unsigned NumOps = IsUnary ? 1 : 2;
1142
1143 // The new shuffle must not cost more than the old shuffle.
1147
1148 InstructionCost NewCost =
1149 TTI.getShuffleCost(SK, DestTy, NewShuffleTy, NewMask, CostKind) +
1150 (NumOps * TTI.getCastInstrCost(Instruction::BitCast, NewShuffleTy, SrcTy,
1151 TargetTransformInfo::CastContextHint::None,
1152 CostKind));
1153 InstructionCost OldCost =
1154 TTI.getShuffleCost(SK, OldShuffleTy, SrcTy, Mask, CostKind) +
1155 TTI.getCastInstrCost(Instruction::BitCast, DestTy, OldShuffleTy,
1156 TargetTransformInfo::CastContextHint::None,
1157 CostKind);
1158
1159 LLVM_DEBUG(dbgs() << "Found a bitcasted shuffle: " << I << "\n OldCost: "
1160 << OldCost << " vs NewCost: " << NewCost << "\n");
1161
1162 if (NewCost > OldCost || !NewCost.isValid())
1163 return false;
1164
1165 // bitcast (shuf V0, V1, MaskC) --> shuf (bitcast V0), (bitcast V1), MaskC'
1166 ++NumShufOfBitcast;
1167 Value *CastV0 = Builder.CreateBitCast(peekThroughBitcasts(V0), NewShuffleTy);
1168 Value *CastV1 = Builder.CreateBitCast(peekThroughBitcasts(V1), NewShuffleTy);
1169 Value *Shuf = Builder.CreateShuffleVector(CastV0, CastV1, NewMask);
1170 replaceValue(I, *Shuf);
1171 return true;
1172}
1173
1174/// VP Intrinsics whose vector operands are both splat values may be simplified
1175/// into the scalar version of the operation and the result splatted. This
1176/// can lead to scalarization down the line.
1177bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
1178 if (!isa<VPIntrinsic>(I))
1179 return false;
1180 VPIntrinsic &VPI = cast<VPIntrinsic>(I);
1181 Value *Op0 = VPI.getArgOperand(0);
1182 Value *Op1 = VPI.getArgOperand(1);
1183
1184 if (!isSplatValue(Op0) || !isSplatValue(Op1))
1185 return false;
1186
1187 // Check getSplatValue early in this function, to avoid doing unnecessary
1188 // work.
1189 Value *ScalarOp0 = getSplatValue(Op0);
1190 Value *ScalarOp1 = getSplatValue(Op1);
1191 if (!ScalarOp0 || !ScalarOp1)
1192 return false;
1193
1194 // For the binary VP intrinsics supported here, the result on disabled lanes
1195 // is a poison value. For now, only do this simplification if all lanes
1196 // are active.
1197 // TODO: Relax the condition that all lanes are active by using insertelement
1198 // on inactive lanes.
1199 auto IsAllTrueMask = [](Value *MaskVal) {
1200 if (Value *SplattedVal = getSplatValue(MaskVal))
1201 if (auto *ConstValue = dyn_cast<Constant>(SplattedVal))
1202 return ConstValue->isAllOnesValue();
1203 return false;
1204 };
1205 if (!IsAllTrueMask(VPI.getArgOperand(2)))
1206 return false;
1207
1208 // Check to make sure we support scalarization of the intrinsic
1209 Intrinsic::ID IntrID = VPI.getIntrinsicID();
1210 if (!VPBinOpIntrinsic::isVPBinOp(IntrID))
1211 return false;
1212
1213 // Calculate cost of splatting both operands into vectors and the vector
1214 // intrinsic
1215 VectorType *VecTy = cast<VectorType>(VPI.getType());
1216 SmallVector<int> Mask;
1217 if (auto *FVTy = dyn_cast<FixedVectorType>(VecTy))
1218 Mask.resize(FVTy->getNumElements(), 0);
1219 InstructionCost SplatCost =
1220 TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, CostKind, 0) +
1222 CostKind);
1223
1224 // Calculate the cost of the VP Intrinsic
1226 for (Value *V : VPI.args())
1227 Args.push_back(V->getType());
1228 IntrinsicCostAttributes Attrs(IntrID, VecTy, Args);
1229 InstructionCost VectorOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
1230 InstructionCost OldCost = 2 * SplatCost + VectorOpCost;
1231
1232 // Determine scalar opcode
1233 std::optional<unsigned> FunctionalOpcode =
1234 VPI.getFunctionalOpcode();
1235 std::optional<Intrinsic::ID> ScalarIntrID = std::nullopt;
1236 if (!FunctionalOpcode) {
1237 ScalarIntrID = VPI.getFunctionalIntrinsicID();
1238 if (!ScalarIntrID)
1239 return false;
1240 }
1241
1242 // Calculate cost of scalarizing
1243 InstructionCost ScalarOpCost = 0;
1244 if (ScalarIntrID) {
1245 IntrinsicCostAttributes Attrs(*ScalarIntrID, VecTy->getScalarType(), Args);
1246 ScalarOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
1247 } else {
1248 ScalarOpCost = TTI.getArithmeticInstrCost(*FunctionalOpcode,
1249 VecTy->getScalarType(), CostKind);
1250 }
1251
1252 // The existing splats may be kept around if other instructions use them.
1253 InstructionCost CostToKeepSplats =
1254 (SplatCost * !Op0->hasOneUse()) + (SplatCost * !Op1->hasOneUse());
1255 InstructionCost NewCost = ScalarOpCost + SplatCost + CostToKeepSplats;
1256
1257 LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI
1258 << "\n");
1259 LLVM_DEBUG(dbgs() << "Cost of Intrinsic: " << OldCost
1260 << ", Cost of scalarizing:" << NewCost << "\n");
1261
1262 // We want to scalarize unless the vector variant actually has lower cost.
1263 if (OldCost < NewCost || !NewCost.isValid())
1264 return false;
1265
1266 // Scalarize the intrinsic
1267 ElementCount EC = cast<VectorType>(Op0->getType())->getElementCount();
1268 Value *EVL = VPI.getArgOperand(3);
1269
1270 // If the VP op might introduce UB or poison, we can scalarize it provided
1271 // that we know the EVL > 0: If the EVL is zero, then the original VP op
1272 // becomes a no-op and thus won't be UB, so make sure we don't introduce UB by
1273 // scalarizing it.
1274 bool SafeToSpeculate;
1275 if (ScalarIntrID)
1276 SafeToSpeculate = Intrinsic::getFnAttributes(I.getContext(), *ScalarIntrID)
1277 .hasAttribute(Attribute::AttrKind::Speculatable);
1278 else
1280 *FunctionalOpcode, &VPI, nullptr, SQ.AC, SQ.DT);
1281 if (!SafeToSpeculate &&
1282 !isKnownNonZero(EVL, SimplifyQuery(*DL, SQ.DT, SQ.AC, &VPI)))
1283 return false;
1284
1285 Value *ScalarVal =
1286 ScalarIntrID
1287 ? Builder.CreateIntrinsic(VecTy->getScalarType(), *ScalarIntrID,
1288 {ScalarOp0, ScalarOp1})
1289 : Builder.CreateBinOp((Instruction::BinaryOps)(*FunctionalOpcode),
1290 ScalarOp0, ScalarOp1);
1291
1292 replaceValue(VPI, *Builder.CreateVectorSplat(EC, ScalarVal));
1293 return true;
1294}
1295
1296/// Match a vector op/compare/intrinsic with at least one
1297/// inserted scalar operand and convert to scalar op/cmp/intrinsic followed
1298/// by insertelement.
1299bool VectorCombine::scalarizeOpOrCmp(Instruction &I) {
1300 auto *UO = dyn_cast<UnaryOperator>(&I);
1301 auto *BO = dyn_cast<BinaryOperator>(&I);
1302 auto *CI = dyn_cast<CmpInst>(&I);
1303 auto *II = dyn_cast<IntrinsicInst>(&I);
1304 if (!UO && !BO && !CI && !II)
1305 return false;
1306
1307 // TODO: Allow intrinsics with different argument types
1308 if (II) {
1309 if (!isTriviallyVectorizable(II->getIntrinsicID()))
1310 return false;
1311 for (auto [Idx, Arg] : enumerate(II->args()))
1312 if (Arg->getType() != II->getType() &&
1313 !isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Idx, &TTI))
1314 return false;
1315 }
1316
1317 // Do not convert the vector condition of a vector select into a scalar
1318 // condition. That may cause problems for codegen because of differences in
1319 // boolean formats and register-file transfers.
1320 // TODO: Can we account for that in the cost model?
1321 if (CI)
1322 for (User *U : I.users())
1323 if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value())))
1324 return false;
1325
1326 // Match constant vectors or scalars being inserted into constant vectors:
1327 // vec_op [VecC0 | (inselt VecC0, V0, Index)], ...
1328 SmallVector<Value *> VecCs, ScalarOps;
1329 std::optional<uint64_t> Index;
1330
1331 auto Ops = II ? II->args() : I.operands();
1332 for (auto [OpNum, Op] : enumerate(Ops)) {
1333 Constant *VecC;
1334 Value *V;
1335 uint64_t InsIdx = 0;
1336 if (match(Op.get(), m_InsertElt(m_Constant(VecC), m_Value(V),
1337 m_ConstantInt(InsIdx)))) {
1338 // Bail if any inserts are out of bounds.
1339 VectorType *OpTy = cast<VectorType>(Op->getType());
1340 if (OpTy->getElementCount().getKnownMinValue() <= InsIdx)
1341 return false;
1342 // All inserts must have the same index.
1343 // TODO: Deal with mismatched index constants and variable indexes?
1344 if (!Index)
1345 Index = InsIdx;
1346 else if (InsIdx != *Index)
1347 return false;
1348 VecCs.push_back(VecC);
1349 ScalarOps.push_back(V);
1350 } else if (II && isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(),
1351 OpNum, &TTI)) {
1352 VecCs.push_back(Op.get());
1353 ScalarOps.push_back(Op.get());
1354 } else if (match(Op.get(), m_Constant(VecC))) {
1355 VecCs.push_back(VecC);
1356 ScalarOps.push_back(nullptr);
1357 } else {
1358 return false;
1359 }
1360 }
1361
1362 // Bail if all operands are constant.
1363 if (!Index.has_value())
1364 return false;
1365
1366 VectorType *VecTy = cast<VectorType>(I.getType());
1367 Type *ScalarTy = VecTy->getScalarType();
1368 assert(VecTy->isVectorTy() &&
1369 (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() ||
1370 ScalarTy->isPointerTy()) &&
1371 "Unexpected types for insert element into binop or cmp");
1372
1373 unsigned Opcode = I.getOpcode();
1374 InstructionCost ScalarOpCost, VectorOpCost;
1375 if (CI) {
1376 CmpInst::Predicate Pred = CI->getPredicate();
1377 ScalarOpCost = TTI.getCmpSelInstrCost(
1378 Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred, CostKind);
1379 VectorOpCost = TTI.getCmpSelInstrCost(
1380 Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind);
1381 } else if (UO || BO) {
1382 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy, CostKind);
1383 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy, CostKind);
1384 } else {
1385 IntrinsicCostAttributes ScalarICA(
1386 II->getIntrinsicID(), ScalarTy,
1387 SmallVector<Type *>(II->arg_size(), ScalarTy));
1388 ScalarOpCost = TTI.getIntrinsicInstrCost(ScalarICA, CostKind);
1389 IntrinsicCostAttributes VectorICA(
1390 II->getIntrinsicID(), VecTy,
1391 SmallVector<Type *>(II->arg_size(), VecTy));
1392 VectorOpCost = TTI.getIntrinsicInstrCost(VectorICA, CostKind);
1393 }
1394
1395 // Fold the vector constants in the original vectors into a new base vector to
1396 // get more accurate cost modelling.
1397 Value *NewVecC = nullptr;
1398 if (CI)
1399 NewVecC = simplifyCmpInst(CI->getPredicate(), VecCs[0], VecCs[1], SQ);
1400 else if (UO)
1401 NewVecC =
1402 simplifyUnOp(UO->getOpcode(), VecCs[0], UO->getFastMathFlags(), SQ);
1403 else if (BO)
1404 NewVecC = simplifyBinOp(BO->getOpcode(), VecCs[0], VecCs[1], SQ);
1405 else if (II)
1406 NewVecC = simplifyCall(II, II->getCalledOperand(), VecCs, SQ);
1407
1408 if (!NewVecC)
1409 return false;
1410
1411 // Get cost estimate for the insert element. This cost will factor into
1412 // both sequences.
1413 InstructionCost OldCost = VectorOpCost;
1414 InstructionCost NewCost =
1415 ScalarOpCost + TTI.getVectorInstrCost(Instruction::InsertElement, VecTy,
1416 CostKind, *Index, NewVecC);
1417
1418 for (auto [Idx, Op, VecC, Scalar] : enumerate(Ops, VecCs, ScalarOps)) {
1419 if (!Scalar || (II && isVectorIntrinsicWithScalarOpAtArg(
1420 II->getIntrinsicID(), Idx, &TTI)))
1421 continue;
1423 Instruction::InsertElement, VecTy, CostKind, *Index, VecC, Scalar);
1424 OldCost += InsertCost;
1425 NewCost += !Op->hasOneUse() * InsertCost;
1426 }
1427
1428 // We want to scalarize unless the vector variant actually has lower cost.
1429 if (OldCost < NewCost || !NewCost.isValid())
1430 return false;
1431
1432 // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
1433 // inselt NewVecC, (scalar_op V0, V1), Index
1434 if (CI)
1435 ++NumScalarCmp;
1436 else if (UO || BO)
1437 ++NumScalarOps;
1438 else
1439 ++NumScalarIntrinsic;
1440
1441 // For constant cases, extract the scalar element, this should constant fold.
1442 for (auto [OpIdx, Scalar, VecC] : enumerate(ScalarOps, VecCs))
1443 if (!Scalar)
1445 cast<Constant>(VecC), Builder.getInt64(*Index));
1446
1447 Value *Scalar;
1448 if (CI)
1449 Scalar = Builder.CreateCmp(CI->getPredicate(), ScalarOps[0], ScalarOps[1]);
1450 else if (UO || BO)
1451 Scalar = Builder.CreateNAryOp(Opcode, ScalarOps);
1452 else
1453 Scalar = Builder.CreateIntrinsic(ScalarTy, II->getIntrinsicID(), ScalarOps);
1454
1455 Scalar->setName(I.getName() + ".scalar");
1456
1457 // All IR flags are safe to back-propagate. There is no potential for extra
1458 // poison to be created by the scalar instruction.
1459 if (auto *ScalarInst = dyn_cast<Instruction>(Scalar))
1460 ScalarInst->copyIRFlags(&I);
1461
1462 Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, *Index);
1463 replaceValue(I, *Insert);
1464 return true;
1465}
1466
1467/// Try to combine a scalar binop + 2 scalar compares of extracted elements of
1468/// a vector into vector operations followed by extract. Note: The SLP pass
1469/// may miss this pattern because of implementation problems.
1470bool VectorCombine::foldExtractedCmps(Instruction &I) {
1471 auto *BI = dyn_cast<BinaryOperator>(&I);
1472
1473 // We are looking for a scalar binop of booleans.
1474 // binop i1 (cmp Pred I0, C0), (cmp Pred I1, C1)
1475 if (!BI || !I.getType()->isIntegerTy(1))
1476 return false;
1477
1478 // The compare predicates should match, and each compare should have a
1479 // constant operand.
1480 Value *B0 = I.getOperand(0), *B1 = I.getOperand(1);
1481 Instruction *I0, *I1;
1482 Constant *C0, *C1;
1483 CmpPredicate P0, P1;
1484 if (!match(B0, m_Cmp(P0, m_Instruction(I0), m_Constant(C0))) ||
1485 !match(B1, m_Cmp(P1, m_Instruction(I1), m_Constant(C1))))
1486 return false;
1487
1488 auto MatchingPred = CmpPredicate::getMatching(P0, P1);
1489 if (!MatchingPred)
1490 return false;
1491
1492 // The compare operands must be extracts of the same vector with constant
1493 // extract indexes.
1494 Value *X;
1495 uint64_t Index0, Index1;
1496 if (!match(I0, m_ExtractElt(m_Value(X), m_ConstantInt(Index0))) ||
1497 !match(I1, m_ExtractElt(m_Specific(X), m_ConstantInt(Index1))))
1498 return false;
1499
1500 auto *Ext0 = cast<ExtractElementInst>(I0);
1501 auto *Ext1 = cast<ExtractElementInst>(I1);
1502 ExtractElementInst *ConvertToShuf = getShuffleExtract(Ext0, Ext1, CostKind);
1503 if (!ConvertToShuf)
1504 return false;
1505 assert((ConvertToShuf == Ext0 || ConvertToShuf == Ext1) &&
1506 "Unknown ExtractElementInst");
1507
1508 // The original scalar pattern is:
1509 // binop i1 (cmp Pred (ext X, Index0), C0), (cmp Pred (ext X, Index1), C1)
1510 CmpInst::Predicate Pred = *MatchingPred;
1511 unsigned CmpOpcode =
1512 CmpInst::isFPPredicate(Pred) ? Instruction::FCmp : Instruction::ICmp;
1513 auto *VecTy = dyn_cast<FixedVectorType>(X->getType());
1514 if (!VecTy)
1515 return false;
1516
1517 InstructionCost Ext0Cost =
1518 TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0);
1519 InstructionCost Ext1Cost =
1520 TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1);
1522 CmpOpcode, I0->getType(), CmpInst::makeCmpResultType(I0->getType()), Pred,
1523 CostKind);
1524
1525 InstructionCost OldCost =
1526 Ext0Cost + Ext1Cost + CmpCost * 2 +
1527 TTI.getArithmeticInstrCost(I.getOpcode(), I.getType(), CostKind);
1528
1529 // The proposed vector pattern is:
1530 // vcmp = cmp Pred X, VecC
1531 // ext (binop vNi1 vcmp, (shuffle vcmp, Index1)), Index0
1532 int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0;
1533 int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1;
1536 CmpOpcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind);
1537 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
1538 ShufMask[CheapIndex] = ExpensiveIndex;
1540 CmpTy, ShufMask, CostKind);
1541 NewCost += TTI.getArithmeticInstrCost(I.getOpcode(), CmpTy, CostKind);
1542 NewCost += TTI.getVectorInstrCost(*Ext0, CmpTy, CostKind, CheapIndex);
1543 NewCost += Ext0->hasOneUse() ? 0 : Ext0Cost;
1544 NewCost += Ext1->hasOneUse() ? 0 : Ext1Cost;
1545
1546 // Aggressively form vector ops if the cost is equal because the transform
1547 // may enable further optimization.
1548 // Codegen can reverse this transform (scalarize) if it was not profitable.
1549 if (OldCost < NewCost || !NewCost.isValid())
1550 return false;
1551
1552 // Create a vector constant from the 2 scalar constants.
1553 SmallVector<Constant *, 32> CmpC(VecTy->getNumElements(),
1554 PoisonValue::get(VecTy->getElementType()));
1555 CmpC[Index0] = C0;
1556 CmpC[Index1] = C1;
1557 Value *VCmp = Builder.CreateCmp(Pred, X, ConstantVector::get(CmpC));
1558 Value *Shuf = createShiftShuffle(VCmp, ExpensiveIndex, CheapIndex, Builder);
1559 Value *LHS = ConvertToShuf == Ext0 ? Shuf : VCmp;
1560 Value *RHS = ConvertToShuf == Ext0 ? VCmp : Shuf;
1561 Value *VecLogic = Builder.CreateBinOp(BI->getOpcode(), LHS, RHS);
1562 Value *NewExt = Builder.CreateExtractElement(VecLogic, CheapIndex);
1563 replaceValue(I, *NewExt);
1564 ++NumVecCmpBO;
1565 return true;
1566}
1567
1568/// Try to fold scalar selects that select between extracted elements and zero
1569/// into extracting from a vector select. This is rooted at the bitcast.
1570///
1571/// This pattern arises when a vector is bitcast to a smaller element type,
1572/// elements are extracted, and then conditionally selected with zero:
1573///
1574/// %bc = bitcast <4 x i32> %src to <16 x i8>
1575/// %e0 = extractelement <16 x i8> %bc, i32 0
1576/// %s0 = select i1 %cond, i8 %e0, i8 0
1577/// %e1 = extractelement <16 x i8> %bc, i32 1
1578/// %s1 = select i1 %cond, i8 %e1, i8 0
1579/// ...
1580///
1581/// Transforms to:
1582/// %sel = select i1 %cond, <4 x i32> %src, <4 x i32> zeroinitializer
1583/// %bc = bitcast <4 x i32> %sel to <16 x i8>
1584/// %e0 = extractelement <16 x i8> %bc, i32 0
1585/// %e1 = extractelement <16 x i8> %bc, i32 1
1586/// ...
1587///
1588/// This is profitable because vector select on wider types produces fewer
1589/// select/cndmask instructions than scalar selects on each element.
1590bool VectorCombine::foldSelectsFromBitcast(Instruction &I) {
1591 auto *BC = dyn_cast<BitCastInst>(&I);
1592 if (!BC)
1593 return false;
1594
1595 FixedVectorType *SrcVecTy = dyn_cast<FixedVectorType>(BC->getSrcTy());
1596 FixedVectorType *DstVecTy = dyn_cast<FixedVectorType>(BC->getDestTy());
1597 if (!SrcVecTy || !DstVecTy)
1598 return false;
1599
1600 // Source must be 32-bit or 64-bit elements, destination must be smaller
1601 // integer elements. Zero in all these types is all-bits-zero.
1602 Type *SrcEltTy = SrcVecTy->getElementType();
1603 Type *DstEltTy = DstVecTy->getElementType();
1604 unsigned SrcEltBits = SrcEltTy->getPrimitiveSizeInBits();
1605 unsigned DstEltBits = DstEltTy->getPrimitiveSizeInBits();
1606
1607 if (SrcEltBits != 32 && SrcEltBits != 64)
1608 return false;
1609
1610 if (!DstEltTy->isIntegerTy() || DstEltBits >= SrcEltBits)
1611 return false;
1612
1613 // Check profitability using TTI before collecting users.
1614 Type *CondTy = CmpInst::makeCmpResultType(DstEltTy);
1615 Type *VecCondTy = CmpInst::makeCmpResultType(SrcVecTy);
1616
1617 InstructionCost ScalarSelCost =
1618 TTI.getCmpSelInstrCost(Instruction::Select, DstEltTy, CondTy,
1620 InstructionCost VecSelCost =
1621 TTI.getCmpSelInstrCost(Instruction::Select, SrcVecTy, VecCondTy,
1623
1624 // We need at least this many selects for vectorization to be profitable.
1625 // VecSelCost < ScalarSelCost * NumSelects => NumSelects > VecSelCost /
1626 // ScalarSelCost
1627 if (!ScalarSelCost.isValid() || ScalarSelCost == 0)
1628 return false;
1629
1630 unsigned MinSelects = (VecSelCost.getValue() / ScalarSelCost.getValue()) + 1;
1631
1632 // Quick check: if bitcast doesn't have enough users, bail early.
1633 if (!BC->hasNUsesOrMore(MinSelects))
1634 return false;
1635
1636 // Collect all select users that match the pattern, grouped by condition.
1637 // Pattern: select i1 %cond, (extractelement %bc, idx), 0
1638 DenseMap<Value *, SmallVector<SelectInst *, 8>> CondToSelects;
1639
1640 for (User *U : BC->users()) {
1641 auto *Ext = dyn_cast<ExtractElementInst>(U);
1642 if (!Ext)
1643 continue;
1644
1645 for (User *ExtUser : Ext->users()) {
1646 Value *Cond;
1647 // Match: select i1 %cond, %ext, 0
1648 if (match(ExtUser, m_Select(m_Value(Cond), m_Specific(Ext), m_Zero())) &&
1649 Cond->getType()->isIntegerTy(1))
1650 CondToSelects[Cond].push_back(cast<SelectInst>(ExtUser));
1651 }
1652 }
1653
1654 if (CondToSelects.empty())
1655 return false;
1656
1657 bool MadeChange = false;
1658 Value *SrcVec = BC->getOperand(0);
1659
1660 // Process each group of selects with the same condition.
1661 for (auto [Cond, Selects] : CondToSelects) {
1662 // Only profitable if vector select cost < total scalar select cost.
1663 if (Selects.size() < MinSelects) {
1664 LLVM_DEBUG(dbgs() << "VectorCombine: foldSelectsFromBitcast not "
1665 << "profitable (VecCost=" << VecSelCost
1666 << ", ScalarCost=" << ScalarSelCost
1667 << ", NumSelects=" << Selects.size() << ")\n");
1668 continue;
1669 }
1670
1671 // Create the vector select and bitcast once for this condition.
1672 auto InsertPt = std::next(BC->getIterator());
1673
1674 if (auto *CondInst = dyn_cast<Instruction>(Cond))
1675 if (DT.dominates(BC, CondInst))
1676 InsertPt = std::next(CondInst->getIterator());
1677
1678 Builder.SetInsertPoint(InsertPt);
1679 Value *VecSel =
1680 Builder.CreateSelect(Cond, SrcVec, Constant::getNullValue(SrcVecTy));
1681 Value *NewBC = Builder.CreateBitCast(VecSel, DstVecTy);
1682
1683 // Replace each scalar select with an extract from the new bitcast.
1684 for (SelectInst *Sel : Selects) {
1685 auto *Ext = cast<ExtractElementInst>(Sel->getTrueValue());
1686 Value *Idx = Ext->getIndexOperand();
1687
1688 Builder.SetInsertPoint(Sel);
1689 Value *NewExt = Builder.CreateExtractElement(NewBC, Idx);
1690 replaceValue(*Sel, *NewExt);
1691 MadeChange = true;
1692 }
1693
1694 LLVM_DEBUG(dbgs() << "VectorCombine: folded " << Selects.size()
1695 << " selects into vector select\n");
1696 }
1697
1698 return MadeChange;
1699}
1700
1703 const TargetTransformInfo &TTI,
1704 InstructionCost &CostBeforeReduction,
1705 InstructionCost &CostAfterReduction) {
1706 Instruction *Op0, *Op1;
1707 auto *RedOp = dyn_cast<Instruction>(II.getOperand(0));
1708 auto *VecRedTy = cast<VectorType>(II.getOperand(0)->getType());
1709 unsigned ReductionOpc =
1710 getArithmeticReductionInstruction(II.getIntrinsicID());
1711 if (RedOp && match(RedOp, m_ZExtOrSExt(m_Value()))) {
1712 bool IsUnsigned = isa<ZExtInst>(RedOp);
1713 auto *ExtType = cast<VectorType>(RedOp->getOperand(0)->getType());
1714
1715 CostBeforeReduction =
1716 TTI.getCastInstrCost(RedOp->getOpcode(), VecRedTy, ExtType,
1718 CostAfterReduction =
1719 TTI.getExtendedReductionCost(ReductionOpc, IsUnsigned, II.getType(),
1720 ExtType, FastMathFlags(), CostKind);
1721 return;
1722 }
1723 if (RedOp && II.getIntrinsicID() == Intrinsic::vector_reduce_add &&
1724 match(RedOp,
1726 match(Op0, m_ZExtOrSExt(m_Value())) &&
1727 Op0->getOpcode() == Op1->getOpcode() &&
1728 Op0->getOperand(0)->getType() == Op1->getOperand(0)->getType() &&
1729 (Op0->getOpcode() == RedOp->getOpcode() || Op0 == Op1)) {
1730 // Matched reduce.add(ext(mul(ext(A), ext(B)))
1731 bool IsUnsigned = isa<ZExtInst>(Op0);
1732 auto *ExtType = cast<VectorType>(Op0->getOperand(0)->getType());
1733 VectorType *MulType = VectorType::get(Op0->getType(), VecRedTy);
1734
1735 InstructionCost ExtCost =
1736 TTI.getCastInstrCost(Op0->getOpcode(), MulType, ExtType,
1738 InstructionCost MulCost =
1739 TTI.getArithmeticInstrCost(Instruction::Mul, MulType, CostKind);
1740 InstructionCost Ext2Cost =
1741 TTI.getCastInstrCost(RedOp->getOpcode(), VecRedTy, MulType,
1743
1744 CostBeforeReduction = ExtCost * 2 + MulCost + Ext2Cost;
1745 CostAfterReduction = TTI.getMulAccReductionCost(
1746 IsUnsigned, ReductionOpc, II.getType(), ExtType, CostKind);
1747 return;
1748 }
1749 CostAfterReduction = TTI.getArithmeticReductionCost(ReductionOpc, VecRedTy,
1750 std::nullopt, CostKind);
1751}
1752
1753bool VectorCombine::foldBinopOfReductions(Instruction &I) {
1754 Instruction::BinaryOps BinOpOpc = cast<BinaryOperator>(&I)->getOpcode();
1755 Intrinsic::ID ReductionIID = getReductionForBinop(BinOpOpc);
1756 if (BinOpOpc == Instruction::Sub)
1757 ReductionIID = Intrinsic::vector_reduce_add;
1758 if (ReductionIID == Intrinsic::not_intrinsic)
1759 return false;
1760 // FP reductions have a start-value operand that this fold doesn't handle.
1761 if (ReductionIID == Intrinsic::vector_reduce_fadd ||
1762 ReductionIID == Intrinsic::vector_reduce_fmul)
1763 return false;
1764
1765 auto checkIntrinsicAndGetItsArgument = [](Value *V,
1766 Intrinsic::ID IID) -> Value * {
1767 auto *II = dyn_cast<IntrinsicInst>(V);
1768 if (!II)
1769 return nullptr;
1770 if (II->getIntrinsicID() == IID && II->hasOneUse())
1771 return II->getArgOperand(0);
1772 return nullptr;
1773 };
1774
1775 Value *V0 = checkIntrinsicAndGetItsArgument(I.getOperand(0), ReductionIID);
1776 if (!V0)
1777 return false;
1778 Value *V1 = checkIntrinsicAndGetItsArgument(I.getOperand(1), ReductionIID);
1779 if (!V1)
1780 return false;
1781
1782 auto *VTy = cast<VectorType>(V0->getType());
1783 if (V1->getType() != VTy)
1784 return false;
1785 const auto &II0 = *cast<IntrinsicInst>(I.getOperand(0));
1786 const auto &II1 = *cast<IntrinsicInst>(I.getOperand(1));
1787 unsigned ReductionOpc =
1788 getArithmeticReductionInstruction(II0.getIntrinsicID());
1789
1790 InstructionCost OldCost = 0;
1791 InstructionCost NewCost = 0;
1792 InstructionCost CostOfRedOperand0 = 0;
1793 InstructionCost CostOfRed0 = 0;
1794 InstructionCost CostOfRedOperand1 = 0;
1795 InstructionCost CostOfRed1 = 0;
1796 analyzeCostOfVecReduction(II0, CostKind, TTI, CostOfRedOperand0, CostOfRed0);
1797 analyzeCostOfVecReduction(II1, CostKind, TTI, CostOfRedOperand1, CostOfRed1);
1798 OldCost = CostOfRed0 + CostOfRed1 + TTI.getInstructionCost(&I, CostKind);
1799 NewCost =
1800 CostOfRedOperand0 + CostOfRedOperand1 +
1801 TTI.getArithmeticInstrCost(BinOpOpc, VTy, CostKind) +
1802 TTI.getArithmeticReductionCost(ReductionOpc, VTy, std::nullopt, CostKind);
1803 if (NewCost >= OldCost || !NewCost.isValid())
1804 return false;
1805
1806 LLVM_DEBUG(dbgs() << "Found two mergeable reductions: " << I
1807 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
1808 << "\n");
1809 Value *VectorBO;
1810 if (BinOpOpc == Instruction::Or)
1811 VectorBO = Builder.CreateOr(V0, V1, "",
1812 cast<PossiblyDisjointInst>(I).isDisjoint());
1813 else
1814 VectorBO = Builder.CreateBinOp(BinOpOpc, V0, V1);
1815
1816 Instruction *Rdx = Builder.CreateIntrinsic(ReductionIID, {VTy}, {VectorBO});
1817 replaceValue(I, *Rdx);
1818 return true;
1819}
1820
1821// Check if memory loc modified between two instrs in the same BB
1824 const MemoryLocation &Loc, AAResults &AA) {
1825 unsigned NumScanned = 0;
1826 return std::any_of(Begin, End, [&](const Instruction &Instr) {
1827 return isModSet(AA.getModRefInfo(&Instr, Loc)) ||
1828 ++NumScanned > MaxInstrsToScan;
1829 });
1830}
1831
1832namespace {
1833/// Helper class to indicate whether a vector index can be safely scalarized and
1834/// if a freeze needs to be inserted.
1835class ScalarizationResult {
1836 enum class StatusTy { Unsafe, Safe, SafeWithFreeze };
1837
1838 StatusTy Status;
1839 Value *ToFreeze;
1840
1841 ScalarizationResult(StatusTy Status, Value *ToFreeze = nullptr)
1842 : Status(Status), ToFreeze(ToFreeze) {}
1843
1844public:
1845 ScalarizationResult(const ScalarizationResult &Other) = default;
1846 ~ScalarizationResult() {
1847 assert(!ToFreeze && "freeze() not called with ToFreeze being set");
1848 }
1849
1850 static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; }
1851 static ScalarizationResult safe() { return {StatusTy::Safe}; }
1852 static ScalarizationResult safeWithFreeze(Value *ToFreeze) {
1853 return {StatusTy::SafeWithFreeze, ToFreeze};
1854 }
1855
1856 /// Returns true if the index can be scalarize without requiring a freeze.
1857 bool isSafe() const { return Status == StatusTy::Safe; }
1858 /// Returns true if the index cannot be scalarized.
1859 bool isUnsafe() const { return Status == StatusTy::Unsafe; }
1860 /// Returns true if the index can be scalarize, but requires inserting a
1861 /// freeze.
1862 bool isSafeWithFreeze() const { return Status == StatusTy::SafeWithFreeze; }
1863
1864 /// Reset the state of Unsafe and clear ToFreze if set.
1865 void discard() {
1866 ToFreeze = nullptr;
1867 Status = StatusTy::Unsafe;
1868 }
1869
1870 /// Freeze the ToFreeze and update the use in \p User to use it.
1871 void freeze(IRBuilderBase &Builder, Instruction &UserI) {
1872 assert(isSafeWithFreeze() &&
1873 "should only be used when freezing is required");
1874 assert(is_contained(ToFreeze->users(), &UserI) &&
1875 "UserI must be a user of ToFreeze");
1876 IRBuilder<>::InsertPointGuard Guard(Builder);
1877 Builder.SetInsertPoint(cast<Instruction>(&UserI));
1878 Value *Frozen =
1879 Builder.CreateFreeze(ToFreeze, ToFreeze->getName() + ".frozen");
1880 for (Use &U : make_early_inc_range((UserI.operands())))
1881 if (U.get() == ToFreeze)
1882 U.set(Frozen);
1883
1884 ToFreeze = nullptr;
1885 }
1886};
1887} // namespace
1888
1889/// Check if it is legal to scalarize a memory access to \p VecTy at index \p
1890/// Idx. \p Idx must access a valid vector element.
1891static ScalarizationResult canScalarizeAccess(VectorType *VecTy, Value *Idx,
1892 const SimplifyQuery &SQ) {
1893 // We do checks for both fixed vector types and scalable vector types.
1894 // This is the number of elements of fixed vector types,
1895 // or the minimum number of elements of scalable vector types.
1896 uint64_t NumElements = VecTy->getElementCount().getKnownMinValue();
1897 unsigned IntWidth = Idx->getType()->getScalarSizeInBits();
1898
1899 if (auto *C = dyn_cast<ConstantInt>(Idx)) {
1900 if (C->getValue().ult(NumElements))
1901 return ScalarizationResult::safe();
1902 return ScalarizationResult::unsafe();
1903 }
1904
1905 // Always unsafe if the index type can't handle all inbound values.
1906 if (!llvm::isUIntN(IntWidth, NumElements))
1907 return ScalarizationResult::unsafe();
1908
1909 APInt Zero(IntWidth, 0);
1910 APInt MaxElts(IntWidth, NumElements);
1911 ConstantRange ValidIndices(Zero, MaxElts);
1912 ConstantRange IdxRange(IntWidth, true);
1913
1914 if (isGuaranteedNotToBePoison(Idx, SQ.AC, SQ.CxtI, SQ.DT)) {
1915 if (ValidIndices.contains(
1916 computeConstantRange(Idx, /*ForSigned=*/false, SQ)))
1917 return ScalarizationResult::safe();
1918 return ScalarizationResult::unsafe();
1919 }
1920
1921 // If the index may be poison, check if we can insert a freeze before the
1922 // range of the index is restricted.
1923 Value *IdxBase;
1924 ConstantInt *CI;
1925 if (match(Idx, m_And(m_Value(IdxBase), m_ConstantInt(CI)))) {
1926 IdxRange = IdxRange.binaryAnd(CI->getValue());
1927 } else if (match(Idx, m_URem(m_Value(IdxBase), m_ConstantInt(CI)))) {
1928 IdxRange = IdxRange.urem(CI->getValue());
1929 }
1930
1931 if (ValidIndices.contains(IdxRange))
1932 return ScalarizationResult::safeWithFreeze(IdxBase);
1933 return ScalarizationResult::unsafe();
1934}
1935
1936/// The memory operation on a vector of \p ScalarType had alignment of
1937/// \p VectorAlignment. Compute the maximal, but conservatively correct,
1938/// alignment that will be valid for the memory operation on a single scalar
1939/// element of the same type with index \p Idx.
1941 Type *ScalarType, Value *Idx,
1942 const DataLayout &DL) {
1943 if (auto *C = dyn_cast<ConstantInt>(Idx))
1944 return commonAlignment(VectorAlignment,
1945 C->getZExtValue() * DL.getTypeStoreSize(ScalarType));
1946 return commonAlignment(VectorAlignment, DL.getTypeStoreSize(ScalarType));
1947}
1948
1949// Combine patterns like:
1950// %0 = load <4 x i32>, <4 x i32>* %a
1951// %1 = insertelement <4 x i32> %0, i32 %b, i32 1
1952// store <4 x i32> %1, <4 x i32>* %a
1953// to:
1954// %0 = bitcast <4 x i32>* %a to i32*
1955// %1 = getelementptr inbounds i32, i32* %0, i64 0, i64 1
1956// store i32 %b, i32* %1
1957bool VectorCombine::foldSingleElementStore(Instruction &I) {
1959 return false;
1960 auto *SI = cast<StoreInst>(&I);
1961 if (!SI->isSimple() || !isa<VectorType>(SI->getValueOperand()->getType()))
1962 return false;
1963
1964 // TODO: Combine more complicated patterns (multiple insert) by referencing
1965 // TargetTransformInfo.
1967 Value *NewElement;
1968 Value *Idx;
1969 if (!match(SI->getValueOperand(),
1970 m_InsertElt(m_Instruction(Source), m_Value(NewElement),
1971 m_Value(Idx))))
1972 return false;
1973
1974 if (auto *Load = dyn_cast<LoadInst>(Source)) {
1975 auto VecTy = cast<VectorType>(SI->getValueOperand()->getType());
1976 Value *SrcAddr = Load->getPointerOperand()->stripPointerCasts();
1977 // Don't optimize for atomic/volatile load or store. Ensure memory is not
1978 // modified between, vector type matches store size, and index is inbounds.
1979 if (!Load->isSimple() || Load->getParent() != SI->getParent() ||
1980 !DL->typeSizeEqualsStoreSize(Load->getType()->getScalarType()) ||
1981 SrcAddr != SI->getPointerOperand()->stripPointerCasts())
1982 return false;
1983
1984 auto ScalarizableIdx =
1985 canScalarizeAccess(VecTy, Idx, SQ.getWithInstruction(Load));
1986 if (ScalarizableIdx.isUnsafe() ||
1987 isMemModifiedBetween(Load->getIterator(), SI->getIterator(),
1988 MemoryLocation::get(SI), AA))
1989 return false;
1990
1991 // Ensure we add the load back to the worklist BEFORE its users so they can
1992 // erased in the correct order.
1993 Worklist.push(Load);
1994
1995 if (ScalarizableIdx.isSafeWithFreeze())
1996 ScalarizableIdx.freeze(Builder, *cast<Instruction>(Idx));
1997 Value *GEP = Builder.CreateInBoundsGEP(
1998 SI->getValueOperand()->getType(), SI->getPointerOperand(),
1999 {ConstantInt::get(Idx->getType(), 0), Idx});
2000 StoreInst *NSI = Builder.CreateStore(NewElement, GEP);
2001 NSI->copyMetadata(*SI);
2002 Align ScalarOpAlignment = computeAlignmentAfterScalarization(
2003 std::max(SI->getAlign(), Load->getAlign()), NewElement->getType(), Idx,
2004 *DL);
2005 NSI->setAlignment(ScalarOpAlignment);
2006 replaceValue(I, *NSI);
2008 return true;
2009 }
2010
2011 return false;
2012}
2013
2014/// Try to scalarize vector loads feeding extractelement or bitcast
2015/// instructions.
2016bool VectorCombine::scalarizeLoad(Instruction &I) {
2017 Value *Ptr;
2018 if (!match(&I, m_Load(m_Value(Ptr))))
2019 return false;
2020
2021 auto *LI = cast<LoadInst>(&I);
2022 auto *VecTy = cast<VectorType>(LI->getType());
2023
2024 // The isSimple() check could be isUnordered(), but for now we cowardly
2025 // refuse to handle even unordered atomics.
2026 if (!LI->isSimple() || !DL->typeSizeEqualsStoreSize(VecTy->getScalarType()))
2027 return false;
2028
2029 bool AllExtracts = true;
2030 bool AllBitcasts = true;
2031 Instruction *LastCheckedInst = LI;
2032 unsigned NumInstChecked = 0;
2033
2034 // Check what type of users we have (must either all be extracts or
2035 // bitcasts) and ensure no memory modifications between the load and
2036 // its users.
2037 for (User *U : LI->users()) {
2038 auto *UI = dyn_cast<Instruction>(U);
2039 if (!UI || UI->getParent() != LI->getParent())
2040 return false;
2041
2042 // If any user is waiting to be erased, then bail out as this will
2043 // distort the cost calculation and possibly lead to infinite loops.
2044 if (UI->use_empty())
2045 return false;
2046
2047 if (!isa<ExtractElementInst>(UI))
2048 AllExtracts = false;
2049 if (!isa<BitCastInst>(UI))
2050 AllBitcasts = false;
2051
2052 // Check if any instruction between the load and the user may modify memory.
2053 if (LastCheckedInst->comesBefore(UI)) {
2054 for (Instruction &I :
2055 make_range(std::next(LI->getIterator()), UI->getIterator())) {
2056 // Bail out if we reached the check limit or the instruction may write
2057 // to memory.
2058 if (NumInstChecked == MaxInstrsToScan || I.mayWriteToMemory())
2059 return false;
2060 NumInstChecked++;
2061 }
2062 LastCheckedInst = UI;
2063 }
2064 }
2065
2066 if (AllExtracts)
2067 return scalarizeLoadExtract(LI, VecTy, Ptr);
2068 if (AllBitcasts)
2069 return scalarizeLoadBitcast(LI, VecTy, Ptr);
2070 return false;
2071}
2072
2073/// Try to scalarize vector loads feeding extractelement instructions.
2074bool VectorCombine::scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy,
2075 Value *Ptr) {
2077 return false;
2078
2079 DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
2080 llvm::scope_exit FailureGuard([&]() {
2081 // If the transform is aborted, discard the ScalarizationResults.
2082 for (auto &Pair : NeedFreeze)
2083 Pair.second.discard();
2084 });
2085
2086 InstructionCost OriginalCost =
2087 TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
2089 InstructionCost ScalarizedCost = 0;
2090
2091 for (User *U : LI->users()) {
2092 auto *UI = cast<ExtractElementInst>(U);
2093
2094 auto ScalarIdx = canScalarizeAccess(VecTy, UI->getIndexOperand(),
2095 SQ.getWithInstruction(LI));
2096 if (ScalarIdx.isUnsafe())
2097 return false;
2098 if (ScalarIdx.isSafeWithFreeze()) {
2099 NeedFreeze.try_emplace(UI, ScalarIdx);
2100 ScalarIdx.discard();
2101 }
2102
2103 auto *Index = dyn_cast<ConstantInt>(UI->getIndexOperand());
2104 OriginalCost +=
2105 TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind,
2106 Index ? Index->getZExtValue() : -1);
2107 ScalarizedCost +=
2108 TTI.getMemoryOpCost(Instruction::Load, VecTy->getElementType(),
2110 ScalarizedCost += TTI.getAddressComputationCost(LI->getPointerOperandType(),
2111 nullptr, nullptr, CostKind);
2112 }
2113
2114 LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << *LI
2115 << "\n LoadExtractCost: " << OriginalCost
2116 << " vs ScalarizedCost: " << ScalarizedCost << "\n");
2117
2118 if (ScalarizedCost >= OriginalCost)
2119 return false;
2120
2121 // Ensure we add the load back to the worklist BEFORE its users so they can
2122 // erased in the correct order.
2123 Worklist.push(LI);
2124
2125 Type *ElemType = VecTy->getElementType();
2126
2127 // Replace extracts with narrow scalar loads.
2128 for (User *U : LI->users()) {
2129 auto *EI = cast<ExtractElementInst>(U);
2130 Value *Idx = EI->getIndexOperand();
2131
2132 // Insert 'freeze' for poison indexes.
2133 auto It = NeedFreeze.find(EI);
2134 if (It != NeedFreeze.end())
2135 It->second.freeze(Builder, *cast<Instruction>(Idx));
2136
2137 Builder.SetInsertPoint(EI);
2138 Value *GEP =
2139 Builder.CreateInBoundsGEP(VecTy, Ptr, {Builder.getInt32(0), Idx});
2140 auto *NewLoad = cast<LoadInst>(
2141 Builder.CreateLoad(ElemType, GEP, EI->getName() + ".scalar"));
2142
2143 Align ScalarOpAlignment =
2144 computeAlignmentAfterScalarization(LI->getAlign(), ElemType, Idx, *DL);
2145 NewLoad->setAlignment(ScalarOpAlignment);
2146
2147 if (auto *ConstIdx = dyn_cast<ConstantInt>(Idx)) {
2148 size_t Offset = ConstIdx->getZExtValue() * DL->getTypeStoreSize(ElemType);
2149 AAMDNodes OldAAMD = LI->getAAMetadata();
2150 NewLoad->setAAMetadata(OldAAMD.adjustForAccess(Offset, ElemType, *DL));
2151 }
2152
2153 replaceValue(*EI, *NewLoad, false);
2154 }
2155
2156 FailureGuard.release();
2157 return true;
2158}
2159
2160/// Try to scalarize vector loads feeding bitcast instructions.
2161bool VectorCombine::scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy,
2162 Value *Ptr) {
2163 InstructionCost OriginalCost =
2164 TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
2166
2167 Type *TargetScalarType = nullptr;
2168 unsigned VecBitWidth = DL->getTypeSizeInBits(VecTy);
2169
2170 for (User *U : LI->users()) {
2171 auto *BC = cast<BitCastInst>(U);
2172
2173 Type *DestTy = BC->getDestTy();
2174 if (!DestTy->isIntegerTy() && !DestTy->isFloatingPointTy())
2175 return false;
2176
2177 unsigned DestBitWidth = DL->getTypeSizeInBits(DestTy);
2178 if (DestBitWidth != VecBitWidth)
2179 return false;
2180
2181 // All bitcasts must target the same scalar type.
2182 if (!TargetScalarType)
2183 TargetScalarType = DestTy;
2184 else if (TargetScalarType != DestTy)
2185 return false;
2186
2187 OriginalCost +=
2188 TTI.getCastInstrCost(Instruction::BitCast, TargetScalarType, VecTy,
2190 }
2191
2192 if (!TargetScalarType)
2193 return false;
2194
2195 assert(!LI->user_empty() && "Unexpected load without bitcast users");
2196 InstructionCost ScalarizedCost =
2197 TTI.getMemoryOpCost(Instruction::Load, TargetScalarType, LI->getAlign(),
2199
2200 LLVM_DEBUG(dbgs() << "Found vector load feeding only bitcasts: " << *LI
2201 << "\n OriginalCost: " << OriginalCost
2202 << " vs ScalarizedCost: " << ScalarizedCost << "\n");
2203
2204 if (ScalarizedCost >= OriginalCost)
2205 return false;
2206
2207 // Ensure we add the load back to the worklist BEFORE its users so they can
2208 // erased in the correct order.
2209 Worklist.push(LI);
2210
2211 Builder.SetInsertPoint(LI);
2212 auto *ScalarLoad =
2213 Builder.CreateLoad(TargetScalarType, Ptr, LI->getName() + ".scalar");
2214 ScalarLoad->setAlignment(LI->getAlign());
2215 ScalarLoad->copyMetadata(*LI);
2216
2217 // Replace all bitcast users with the scalar load.
2218 for (User *U : LI->users()) {
2219 auto *BC = cast<BitCastInst>(U);
2220 replaceValue(*BC, *ScalarLoad, false);
2221 }
2222
2223 return true;
2224}
2225
2226bool VectorCombine::scalarizeExtExtract(Instruction &I) {
2228 return false;
2229 auto *Ext = dyn_cast<ZExtInst>(&I);
2230 if (!Ext)
2231 return false;
2232
2233 // Try to convert a vector zext feeding only extracts to a set of scalar
2234 // (Src << ExtIdx *Size) & (Size -1)
2235 // if profitable .
2236 auto *SrcTy = dyn_cast<FixedVectorType>(Ext->getOperand(0)->getType());
2237 if (!SrcTy)
2238 return false;
2239 auto *DstTy = cast<FixedVectorType>(Ext->getType());
2240
2241 Type *ScalarDstTy = DstTy->getElementType();
2242 if (DL->getTypeSizeInBits(SrcTy) != DL->getTypeSizeInBits(ScalarDstTy))
2243 return false;
2244
2245 InstructionCost VectorCost =
2246 TTI.getCastInstrCost(Instruction::ZExt, DstTy, SrcTy,
2248 unsigned ExtCnt = 0;
2249 bool ExtLane0 = false;
2250 for (User *U : Ext->users()) {
2251 uint64_t Idx;
2252 if (!match(U, m_ExtractElt(m_Value(), m_ConstantInt(Idx))))
2253 return false;
2254 if (cast<Instruction>(U)->use_empty())
2255 continue;
2256 ExtCnt += 1;
2257 ExtLane0 |= !Idx;
2258 VectorCost += TTI.getVectorInstrCost(Instruction::ExtractElement, DstTy,
2259 CostKind, Idx, U);
2260 }
2261
2262 InstructionCost ScalarCost =
2263 ExtCnt * TTI.getArithmeticInstrCost(
2264 Instruction::And, ScalarDstTy, CostKind,
2267 (ExtCnt - ExtLane0) *
2269 Instruction::LShr, ScalarDstTy, CostKind,
2272 if (ScalarCost > VectorCost)
2273 return false;
2274
2275 Value *ScalarV = Ext->getOperand(0);
2276 if (!isGuaranteedNotToBePoison(ScalarV, SQ.AC, dyn_cast<Instruction>(ScalarV),
2277 SQ.DT)) {
2278 // Check wether all lanes are extracted, all extracts trigger UB
2279 // on poison, and the last extract (and hence all previous ones)
2280 // are guaranteed to execute if Ext executes. If so, we do not
2281 // need to insert a freeze.
2282 SmallDenseSet<ConstantInt *, 8> ExtractedLanes;
2283 bool AllExtractsTriggerUB = true;
2284 ExtractElementInst *LastExtract = nullptr;
2285 BasicBlock *ExtBB = Ext->getParent();
2286 for (User *U : Ext->users()) {
2287 auto *Extract = cast<ExtractElementInst>(U);
2288 if (Extract->getParent() != ExtBB || !programUndefinedIfPoison(Extract)) {
2289 AllExtractsTriggerUB = false;
2290 break;
2291 }
2292 ExtractedLanes.insert(cast<ConstantInt>(Extract->getIndexOperand()));
2293 if (!LastExtract || LastExtract->comesBefore(Extract))
2294 LastExtract = Extract;
2295 }
2296 if (ExtractedLanes.size() != DstTy->getNumElements() ||
2297 !AllExtractsTriggerUB ||
2299 LastExtract->getIterator()))
2300 ScalarV = Builder.CreateFreeze(ScalarV);
2301 }
2302 ScalarV = Builder.CreateBitCast(
2303 ScalarV,
2304 IntegerType::get(SrcTy->getContext(), DL->getTypeSizeInBits(SrcTy)));
2305 uint64_t SrcEltSizeInBits = DL->getTypeSizeInBits(SrcTy->getElementType());
2306 uint64_t TotalBits = DL->getTypeSizeInBits(SrcTy);
2307 APInt EltBitMask = APInt::getLowBitsSet(TotalBits, SrcEltSizeInBits);
2308 Type *PackedTy = IntegerType::get(SrcTy->getContext(), TotalBits);
2309 Value *Mask = ConstantInt::get(PackedTy, EltBitMask);
2310 for (User *U : Ext->users()) {
2311 auto *Extract = cast<ExtractElementInst>(U);
2312 uint64_t Idx =
2313 cast<ConstantInt>(Extract->getIndexOperand())->getZExtValue();
2314 uint64_t ShiftAmt =
2315 DL->isBigEndian()
2316 ? (TotalBits - SrcEltSizeInBits - Idx * SrcEltSizeInBits)
2317 : (Idx * SrcEltSizeInBits);
2318 Value *LShr = Builder.CreateLShr(ScalarV, ShiftAmt);
2319 Value *And = Builder.CreateAnd(LShr, Mask);
2320 U->replaceAllUsesWith(And);
2321 }
2322 return true;
2323}
2324
2325/// Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))"
2326/// to "(bitcast (concat X, Y))"
2327/// where X/Y are bitcasted from i1 mask vectors.
2328bool VectorCombine::foldConcatOfBoolMasks(Instruction &I) {
2329 Type *Ty = I.getType();
2330 if (!Ty->isIntegerTy())
2331 return false;
2332
2333 // TODO: Add big endian test coverage
2334 if (DL->isBigEndian())
2335 return false;
2336
2337 // Restrict to disjoint cases so the mask vectors aren't overlapping.
2338 Instruction *X, *Y;
2340 return false;
2341
2342 // Allow both sources to contain shl, to handle more generic pattern:
2343 // "(or (shl (zext (bitcast X)), C1), (shl (zext (bitcast Y)), C2))"
2344 Value *SrcX;
2345 uint64_t ShAmtX = 0;
2346 if (!match(X, m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcX)))))) &&
2347 !match(X, m_OneUse(
2349 m_ConstantInt(ShAmtX)))))
2350 return false;
2351
2352 Value *SrcY;
2353 uint64_t ShAmtY = 0;
2354 if (!match(Y, m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcY)))))) &&
2355 !match(Y, m_OneUse(
2357 m_ConstantInt(ShAmtY)))))
2358 return false;
2359
2360 // Canonicalize larger shift to the RHS.
2361 if (ShAmtX > ShAmtY) {
2362 std::swap(X, Y);
2363 std::swap(SrcX, SrcY);
2364 std::swap(ShAmtX, ShAmtY);
2365 }
2366
2367 // Ensure both sources are matching vXi1 bool mask types, and that the shift
2368 // difference is the mask width so they can be easily concatenated together.
2369 uint64_t ShAmtDiff = ShAmtY - ShAmtX;
2370 unsigned NumSHL = (ShAmtX > 0) + (ShAmtY > 0);
2371 unsigned BitWidth = Ty->getPrimitiveSizeInBits();
2372 auto *MaskTy = dyn_cast<FixedVectorType>(SrcX->getType());
2373 if (!MaskTy || SrcX->getType() != SrcY->getType() ||
2374 !MaskTy->getElementType()->isIntegerTy(1) ||
2375 MaskTy->getNumElements() != ShAmtDiff ||
2376 MaskTy->getNumElements() > (BitWidth / 2))
2377 return false;
2378
2379 auto *ConcatTy = FixedVectorType::getDoubleElementsVectorType(MaskTy);
2380 auto *ConcatIntTy =
2381 Type::getIntNTy(Ty->getContext(), ConcatTy->getNumElements());
2382 auto *MaskIntTy = Type::getIntNTy(Ty->getContext(), ShAmtDiff);
2383
2384 SmallVector<int, 32> ConcatMask(ConcatTy->getNumElements());
2385 std::iota(ConcatMask.begin(), ConcatMask.end(), 0);
2386
2387 // TODO: Is it worth supporting multi use cases?
2388 InstructionCost OldCost = 0;
2389 OldCost += TTI.getArithmeticInstrCost(Instruction::Or, Ty, CostKind);
2390 OldCost +=
2391 NumSHL * TTI.getArithmeticInstrCost(Instruction::Shl, Ty, CostKind);
2392 OldCost += 2 * TTI.getCastInstrCost(Instruction::ZExt, Ty, MaskIntTy,
2394 OldCost += 2 * TTI.getCastInstrCost(Instruction::BitCast, MaskIntTy, MaskTy,
2396
2397 InstructionCost NewCost = 0;
2399 MaskTy, ConcatMask, CostKind);
2400 NewCost += TTI.getCastInstrCost(Instruction::BitCast, ConcatIntTy, ConcatTy,
2402 if (Ty != ConcatIntTy)
2403 NewCost += TTI.getCastInstrCost(Instruction::ZExt, Ty, ConcatIntTy,
2405 if (ShAmtX > 0)
2406 NewCost += TTI.getArithmeticInstrCost(Instruction::Shl, Ty, CostKind);
2407
2408 LLVM_DEBUG(dbgs() << "Found a concatenation of bitcasted bool masks: " << I
2409 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2410 << "\n");
2411
2412 if (NewCost > OldCost)
2413 return false;
2414
2415 // Build bool mask concatenation, bitcast back to scalar integer, and perform
2416 // any residual zero-extension or shifting.
2417 Value *Concat = Builder.CreateShuffleVector(SrcX, SrcY, ConcatMask);
2418 Worklist.pushValue(Concat);
2419
2420 Value *Result = Builder.CreateBitCast(Concat, ConcatIntTy);
2421
2422 if (Ty != ConcatIntTy) {
2423 Worklist.pushValue(Result);
2424 Result = Builder.CreateZExt(Result, Ty);
2425 }
2426
2427 if (ShAmtX > 0) {
2428 Worklist.pushValue(Result);
2429 Result = Builder.CreateShl(Result, ShAmtX);
2430 }
2431
2432 replaceValue(I, *Result);
2433 return true;
2434}
2435
2436/// Try to convert "shuffle (binop (shuffle, shuffle)), undef"
2437/// --> "binop (shuffle), (shuffle)".
2438bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
2439 BinaryOperator *BinOp;
2440 ArrayRef<int> OuterMask;
2441 if (!match(&I, m_Shuffle(m_BinOp(BinOp), m_Undef(), m_Mask(OuterMask))))
2442 return false;
2443
2444 // Don't introduce poison into div/rem.
2445 if (BinOp->isIntDivRem() && llvm::is_contained(OuterMask, PoisonMaskElem))
2446 return false;
2447
2448 Value *Op00, *Op01, *Op10, *Op11;
2449 ArrayRef<int> Mask0, Mask1;
2450 bool Match0 = match(BinOp->getOperand(0),
2451 m_Shuffle(m_Value(Op00), m_Value(Op01), m_Mask(Mask0)));
2452 bool Match1 = match(BinOp->getOperand(1),
2453 m_Shuffle(m_Value(Op10), m_Value(Op11), m_Mask(Mask1)));
2454 if (!Match0 && !Match1)
2455 return false;
2456
2457 Op00 = Match0 ? Op00 : BinOp->getOperand(0);
2458 Op01 = Match0 ? Op01 : BinOp->getOperand(0);
2459 Op10 = Match1 ? Op10 : BinOp->getOperand(1);
2460 Op11 = Match1 ? Op11 : BinOp->getOperand(1);
2461
2462 Instruction::BinaryOps Opcode = BinOp->getOpcode();
2463 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
2464 auto *BinOpTy = dyn_cast<FixedVectorType>(BinOp->getType());
2465 auto *Op0Ty = dyn_cast<FixedVectorType>(Op00->getType());
2466 auto *Op1Ty = dyn_cast<FixedVectorType>(Op10->getType());
2467 if (!ShuffleDstTy || !BinOpTy || !Op0Ty || !Op1Ty)
2468 return false;
2469
2470 unsigned NumSrcElts = BinOpTy->getNumElements();
2471
2472 // Don't accept shuffles that reference the second operand in
2473 // div/rem or if its an undef arg.
2474 if ((BinOp->isIntDivRem() || !isa<PoisonValue>(I.getOperand(1))) &&
2475 any_of(OuterMask, [NumSrcElts](int M) { return M >= (int)NumSrcElts; }))
2476 return false;
2477
2478 // Merge outer / inner (or identity if no match) shuffles.
2479 SmallVector<int> NewMask0, NewMask1;
2480 for (int M : OuterMask) {
2481 if (M < 0 || M >= (int)NumSrcElts) {
2482 NewMask0.push_back(PoisonMaskElem);
2483 NewMask1.push_back(PoisonMaskElem);
2484 } else {
2485 NewMask0.push_back(Match0 ? Mask0[M] : M);
2486 NewMask1.push_back(Match1 ? Mask1[M] : M);
2487 }
2488 }
2489
2490 unsigned NumOpElts = Op0Ty->getNumElements();
2491 bool IsIdentity0 = ShuffleDstTy == Op0Ty &&
2492 all_of(NewMask0, [NumOpElts](int M) { return M < (int)NumOpElts; }) &&
2493 ShuffleVectorInst::isIdentityMask(NewMask0, NumOpElts);
2494 bool IsIdentity1 = ShuffleDstTy == Op1Ty &&
2495 all_of(NewMask1, [NumOpElts](int M) { return M < (int)NumOpElts; }) &&
2496 ShuffleVectorInst::isIdentityMask(NewMask1, NumOpElts);
2497
2498 InstructionCost NewCost = 0;
2499 // Try to merge shuffles across the binop if the new shuffles are not costly.
2500 InstructionCost BinOpCost =
2501 TTI.getArithmeticInstrCost(Opcode, BinOpTy, CostKind);
2502 InstructionCost OldCost =
2504 ShuffleDstTy, BinOpTy, OuterMask, CostKind,
2505 0, nullptr, {BinOp}, &I);
2506 if (!BinOp->hasOneUse())
2507 NewCost += BinOpCost;
2508
2509 if (Match0) {
2511 TargetTransformInfo::SK_PermuteTwoSrc, BinOpTy, Op0Ty, Mask0, CostKind,
2512 0, nullptr, {Op00, Op01}, cast<Instruction>(BinOp->getOperand(0)));
2513 OldCost += Shuf0Cost;
2514 if (!BinOp->hasOneUse() || !BinOp->getOperand(0)->hasOneUse())
2515 NewCost += Shuf0Cost;
2516 }
2517 if (Match1) {
2519 TargetTransformInfo::SK_PermuteTwoSrc, BinOpTy, Op1Ty, Mask1, CostKind,
2520 0, nullptr, {Op10, Op11}, cast<Instruction>(BinOp->getOperand(1)));
2521 OldCost += Shuf1Cost;
2522 if (!BinOp->hasOneUse() || !BinOp->getOperand(1)->hasOneUse())
2523 NewCost += Shuf1Cost;
2524 }
2525
2526 NewCost += TTI.getArithmeticInstrCost(Opcode, ShuffleDstTy, CostKind);
2527
2528 if (!IsIdentity0)
2529 NewCost +=
2531 Op0Ty, NewMask0, CostKind, 0, nullptr, {Op00, Op01});
2532 if (!IsIdentity1)
2533 NewCost +=
2535 Op1Ty, NewMask1, CostKind, 0, nullptr, {Op10, Op11});
2536
2537 LLVM_DEBUG(dbgs() << "Found a shuffle feeding a shuffled binop: " << I
2538 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2539 << "\n");
2540
2541 // If costs are equal, still fold as we reduce instruction count.
2542 if (NewCost > OldCost)
2543 return false;
2544
2545 Value *LHS =
2546 IsIdentity0 ? Op00 : Builder.CreateShuffleVector(Op00, Op01, NewMask0);
2547 Value *RHS =
2548 IsIdentity1 ? Op10 : Builder.CreateShuffleVector(Op10, Op11, NewMask1);
2549 Value *NewBO = Builder.CreateBinOp(Opcode, LHS, RHS);
2550
2551 // Intersect flags from the old binops.
2552 if (auto *NewInst = dyn_cast<Instruction>(NewBO))
2553 NewInst->copyIRFlags(BinOp);
2554
2555 Worklist.pushValue(LHS);
2556 Worklist.pushValue(RHS);
2557 replaceValue(I, *NewBO);
2558 return true;
2559}
2560
2561/// Try to convert "shuffle (binop), (binop)" into "binop (shuffle), (shuffle)".
2562/// Try to convert "shuffle (cmpop), (cmpop)" into "cmpop (shuffle), (shuffle)".
2563bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
2564 ArrayRef<int> OldMask;
2565 Instruction *LHS, *RHS;
2567 m_Mask(OldMask))))
2568 return false;
2569
2570 // TODO: Add support for addlike etc.
2571 if (LHS->getOpcode() != RHS->getOpcode())
2572 return false;
2573
2574 Value *X, *Y, *Z, *W;
2575 bool IsCommutative = false;
2576 CmpPredicate PredLHS = CmpInst::BAD_ICMP_PREDICATE;
2577 CmpPredicate PredRHS = CmpInst::BAD_ICMP_PREDICATE;
2578 if (match(LHS, m_BinOp(m_Value(X), m_Value(Y))) &&
2579 match(RHS, m_BinOp(m_Value(Z), m_Value(W)))) {
2580 auto *BO = cast<BinaryOperator>(LHS);
2581 // Don't introduce poison into div/rem.
2582 if (llvm::is_contained(OldMask, PoisonMaskElem) && BO->isIntDivRem())
2583 return false;
2584 IsCommutative = BinaryOperator::isCommutative(BO->getOpcode());
2585 } else if (match(LHS, m_Cmp(PredLHS, m_Value(X), m_Value(Y))) &&
2586 match(RHS, m_Cmp(PredRHS, m_Value(Z), m_Value(W))) &&
2587 (CmpInst::Predicate)PredLHS == (CmpInst::Predicate)PredRHS) {
2588 IsCommutative = cast<CmpInst>(LHS)->isCommutative();
2589 } else
2590 return false;
2591
2592 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
2593 auto *BinResTy = dyn_cast<FixedVectorType>(LHS->getType());
2594 auto *BinOpTy = dyn_cast<FixedVectorType>(X->getType());
2595 if (!ShuffleDstTy || !BinResTy || !BinOpTy || X->getType() != Z->getType())
2596 return false;
2597
2598 bool SameBinOp = LHS == RHS;
2599 unsigned NumSrcElts = BinOpTy->getNumElements();
2600
2601 // If we have something like "add X, Y" and "add Z, X", swap ops to match.
2602 if (IsCommutative && X != Z && Y != W && (X == W || Y == Z))
2603 std::swap(X, Y);
2604
2605 auto ConvertToUnary = [NumSrcElts](int &M) {
2606 if (M >= (int)NumSrcElts)
2607 M -= NumSrcElts;
2608 };
2609
2610 SmallVector<int> NewMask0(OldMask);
2612 TTI::OperandValueInfo Op0Info = TTI.commonOperandInfo(X, Z);
2613 if (X == Z) {
2614 llvm::for_each(NewMask0, ConvertToUnary);
2616 Z = PoisonValue::get(BinOpTy);
2617 }
2618
2619 SmallVector<int> NewMask1(OldMask);
2621 TTI::OperandValueInfo Op1Info = TTI.commonOperandInfo(Y, W);
2622 if (Y == W) {
2623 llvm::for_each(NewMask1, ConvertToUnary);
2625 W = PoisonValue::get(BinOpTy);
2626 }
2627
2628 // Try to replace a binop with a shuffle if the shuffle is not costly.
2629 // When SameBinOp, only count the binop cost once.
2632
2633 InstructionCost OldCost = LHSCost;
2634 if (!SameBinOp) {
2635 OldCost += RHSCost;
2636 }
2638 ShuffleDstTy, BinResTy, OldMask, CostKind, 0,
2639 nullptr, {LHS, RHS}, &I);
2640
2641 // Handle shuffle(binop(shuffle(x),y),binop(z,shuffle(w))) style patterns
2642 // where one use shuffles have gotten split across the binop/cmp. These
2643 // often allow a major reduction in total cost that wouldn't happen as
2644 // individual folds.
2645 auto MergeInner = [&](Value *&Op, int Offset, MutableArrayRef<int> Mask,
2646 TTI::TargetCostKind CostKind) -> bool {
2647 Value *InnerOp;
2648 ArrayRef<int> InnerMask;
2649 if (match(Op, m_OneUse(m_Shuffle(m_Value(InnerOp), m_Undef(),
2650 m_Mask(InnerMask)))) &&
2651 InnerOp->getType() == Op->getType() &&
2652 all_of(InnerMask,
2653 [NumSrcElts](int M) { return M < (int)NumSrcElts; })) {
2654 for (int &M : Mask)
2655 if (Offset <= M && M < (int)(Offset + NumSrcElts)) {
2656 M = InnerMask[M - Offset];
2657 M = 0 <= M ? M + Offset : M;
2658 }
2660 Op = InnerOp;
2661 return true;
2662 }
2663 return false;
2664 };
2665 bool ReducedInstCount = false;
2666 ReducedInstCount |= MergeInner(X, 0, NewMask0, CostKind);
2667 ReducedInstCount |= MergeInner(Y, 0, NewMask1, CostKind);
2668 ReducedInstCount |= MergeInner(Z, NumSrcElts, NewMask0, CostKind);
2669 ReducedInstCount |= MergeInner(W, NumSrcElts, NewMask1, CostKind);
2670 bool SingleSrcBinOp = (X == Y) && (Z == W) && (NewMask0 == NewMask1);
2671 // SingleSrcBinOp only reduces instruction count if we also eliminate the
2672 // original binop(s). If binops have multiple uses, they won't be eliminated.
2673 ReducedInstCount |= SingleSrcBinOp && LHS->hasOneUser() && RHS->hasOneUser();
2674
2675 auto *ShuffleCmpTy =
2676 FixedVectorType::get(BinOpTy->getElementType(), ShuffleDstTy);
2678 SK0, ShuffleCmpTy, BinOpTy, NewMask0, CostKind, 0, nullptr, {X, Z});
2679 if (!SingleSrcBinOp)
2680 NewCost += TTI.getShuffleCost(SK1, ShuffleCmpTy, BinOpTy, NewMask1,
2681 CostKind, 0, nullptr, {Y, W});
2682
2683 if (PredLHS == CmpInst::BAD_ICMP_PREDICATE) {
2684 NewCost += TTI.getArithmeticInstrCost(LHS->getOpcode(), ShuffleDstTy,
2685 CostKind, Op0Info, Op1Info);
2686 } else {
2687 NewCost +=
2688 TTI.getCmpSelInstrCost(LHS->getOpcode(), ShuffleCmpTy, ShuffleDstTy,
2689 PredLHS, CostKind, Op0Info, Op1Info);
2690 }
2691 // If LHS/RHS have other uses, we need to account for the cost of keeping
2692 // the original instructions. When SameBinOp, only add the cost once.
2693 if (!LHS->hasOneUser())
2694 NewCost += LHSCost;
2695 if (!SameBinOp && !RHS->hasOneUser())
2696 NewCost += RHSCost;
2697
2698 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two binops: " << I
2699 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2700 << "\n");
2701
2702 // If either shuffle will constant fold away, then fold for the same cost as
2703 // we will reduce the instruction count.
2704 ReducedInstCount |= (isa<Constant>(X) && isa<Constant>(Z)) ||
2705 (isa<Constant>(Y) && isa<Constant>(W));
2706 if (ReducedInstCount ? (NewCost > OldCost) : (NewCost >= OldCost))
2707 return false;
2708
2709 Value *Shuf0 = Builder.CreateShuffleVector(X, Z, NewMask0);
2710 Value *Shuf1 =
2711 SingleSrcBinOp ? Shuf0 : Builder.CreateShuffleVector(Y, W, NewMask1);
2712 Value *NewBO = PredLHS == CmpInst::BAD_ICMP_PREDICATE
2713 ? Builder.CreateBinOp(
2714 cast<BinaryOperator>(LHS)->getOpcode(), Shuf0, Shuf1)
2715 : Builder.CreateCmp(PredLHS, Shuf0, Shuf1);
2716
2717 // Intersect flags from the old binops.
2718 if (auto *NewInst = dyn_cast<Instruction>(NewBO)) {
2719 NewInst->copyIRFlags(LHS);
2720 NewInst->andIRFlags(RHS);
2721 }
2722
2723 Worklist.pushValue(Shuf0);
2724 Worklist.pushValue(Shuf1);
2725 replaceValue(I, *NewBO);
2726 return true;
2727}
2728
2729/// Try to convert,
2730/// (shuffle(select(c1,t1,f1)), (select(c2,t2,f2)), m) into
2731/// (select (shuffle c1,c2,m), (shuffle t1,t2,m), (shuffle f1,f2,m))
2732bool VectorCombine::foldShuffleOfSelects(Instruction &I) {
2733 ArrayRef<int> Mask;
2734 Value *C1, *T1, *F1, *C2, *T2, *F2;
2735 if (!match(&I, m_Shuffle(m_Select(m_Value(C1), m_Value(T1), m_Value(F1)),
2736 m_Select(m_Value(C2), m_Value(T2), m_Value(F2)),
2737 m_Mask(Mask))))
2738 return false;
2739
2740 auto *Sel1 = cast<Instruction>(I.getOperand(0));
2741 auto *Sel2 = cast<Instruction>(I.getOperand(1));
2742
2743 auto *C1VecTy = dyn_cast<FixedVectorType>(C1->getType());
2744 auto *C2VecTy = dyn_cast<FixedVectorType>(C2->getType());
2745 if (!C1VecTy || !C2VecTy || C1VecTy != C2VecTy)
2746 return false;
2747
2748 auto *SI0FOp = dyn_cast<FPMathOperator>(I.getOperand(0));
2749 auto *SI1FOp = dyn_cast<FPMathOperator>(I.getOperand(1));
2750 // SelectInsts must have the same FMF.
2751 if (((SI0FOp == nullptr) != (SI1FOp == nullptr)) ||
2752 ((SI0FOp != nullptr) &&
2753 (SI0FOp->getFastMathFlags() != SI1FOp->getFastMathFlags())))
2754 return false;
2755
2756 auto *SrcVecTy = cast<FixedVectorType>(T1->getType());
2757 auto *DstVecTy = cast<FixedVectorType>(I.getType());
2759 auto SelOp = Instruction::Select;
2760
2762 SelOp, SrcVecTy, C1VecTy, CmpInst::BAD_ICMP_PREDICATE, CostKind);
2764 SelOp, SrcVecTy, C2VecTy, CmpInst::BAD_ICMP_PREDICATE, CostKind);
2765
2766 InstructionCost OldCost =
2767 CostSel1 + CostSel2 +
2768 TTI.getShuffleCost(SK, DstVecTy, SrcVecTy, Mask, CostKind, 0, nullptr,
2769 {I.getOperand(0), I.getOperand(1)}, &I);
2770
2772 SK, FixedVectorType::get(C1VecTy->getScalarType(), Mask.size()), C1VecTy,
2773 Mask, CostKind, 0, nullptr, {C1, C2});
2774 NewCost += TTI.getShuffleCost(SK, DstVecTy, SrcVecTy, Mask, CostKind, 0,
2775 nullptr, {T1, T2});
2776 NewCost += TTI.getShuffleCost(SK, DstVecTy, SrcVecTy, Mask, CostKind, 0,
2777 nullptr, {F1, F2});
2778 auto *C1C2ShuffledVecTy = FixedVectorType::get(
2779 Type::getInt1Ty(I.getContext()), DstVecTy->getNumElements());
2780 NewCost += TTI.getCmpSelInstrCost(SelOp, DstVecTy, C1C2ShuffledVecTy,
2782
2783 if (!Sel1->hasOneUse())
2784 NewCost += CostSel1;
2785 if (!Sel2->hasOneUse())
2786 NewCost += CostSel2;
2787
2788 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two selects: " << I
2789 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2790 << "\n");
2791 if (NewCost > OldCost)
2792 return false;
2793
2794 Value *ShuffleCmp = Builder.CreateShuffleVector(C1, C2, Mask);
2795 Value *ShuffleTrue = Builder.CreateShuffleVector(T1, T2, Mask);
2796 Value *ShuffleFalse = Builder.CreateShuffleVector(F1, F2, Mask);
2797 Value *NewSel;
2798 // We presuppose that the SelectInsts have the same FMF.
2799 if (SI0FOp)
2800 NewSel = Builder.CreateSelectFMF(ShuffleCmp, ShuffleTrue, ShuffleFalse,
2801 SI0FOp->getFastMathFlags());
2802 else
2803 NewSel = Builder.CreateSelect(ShuffleCmp, ShuffleTrue, ShuffleFalse);
2804
2805 Worklist.pushValue(ShuffleCmp);
2806 Worklist.pushValue(ShuffleTrue);
2807 Worklist.pushValue(ShuffleFalse);
2808 replaceValue(I, *NewSel);
2809 return true;
2810}
2811
2812/// Try to convert "shuffle (castop), (castop)" with a shared castop operand
2813/// into "castop (shuffle)".
2814bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
2815 Value *V0, *V1;
2816 ArrayRef<int> OldMask;
2817 if (!match(&I, m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(OldMask))))
2818 return false;
2819
2820 // Check whether this is a binary shuffle.
2821 bool IsBinaryShuffle = !isa<UndefValue>(V1);
2822
2823 auto *C0 = dyn_cast<CastInst>(V0);
2824 auto *C1 = dyn_cast<CastInst>(V1);
2825 if (!C0 || (IsBinaryShuffle && !C1))
2826 return false;
2827
2828 Instruction::CastOps Opcode = C0->getOpcode();
2829
2830 // If this is allowed, foldShuffleOfCastops can get stuck in a loop
2831 // with foldBitcastOfShuffle. Reject in favor of foldBitcastOfShuffle.
2832 if (!IsBinaryShuffle && Opcode == Instruction::BitCast)
2833 return false;
2834
2835 if (IsBinaryShuffle) {
2836 if (C0->getSrcTy() != C1->getSrcTy())
2837 return false;
2838 // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
2839 if (Opcode != C1->getOpcode()) {
2840 if (match(C0, m_SExtLike(m_Value())) && match(C1, m_SExtLike(m_Value())))
2841 Opcode = Instruction::SExt;
2842 else
2843 return false;
2844 }
2845 }
2846
2847 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
2848 auto *CastDstTy = dyn_cast<FixedVectorType>(C0->getDestTy());
2849 auto *CastSrcTy = dyn_cast<FixedVectorType>(C0->getSrcTy());
2850 if (!ShuffleDstTy || !CastDstTy || !CastSrcTy)
2851 return false;
2852
2853 unsigned NumSrcElts = CastSrcTy->getNumElements();
2854 unsigned NumDstElts = CastDstTy->getNumElements();
2855 assert((NumDstElts == NumSrcElts || Opcode == Instruction::BitCast) &&
2856 "Only bitcasts expected to alter src/dst element counts");
2857
2858 // Check for bitcasting of unscalable vector types.
2859 // e.g. <32 x i40> -> <40 x i32>
2860 if (NumDstElts != NumSrcElts && (NumSrcElts % NumDstElts) != 0 &&
2861 (NumDstElts % NumSrcElts) != 0)
2862 return false;
2863
2864 SmallVector<int, 16> NewMask;
2865 if (NumSrcElts >= NumDstElts) {
2866 // The bitcast is from wide to narrow/equal elements. The shuffle mask can
2867 // always be expanded to the equivalent form choosing narrower elements.
2868 assert(NumSrcElts % NumDstElts == 0 && "Unexpected shuffle mask");
2869 unsigned ScaleFactor = NumSrcElts / NumDstElts;
2870 narrowShuffleMaskElts(ScaleFactor, OldMask, NewMask);
2871 } else {
2872 // The bitcast is from narrow elements to wide elements. The shuffle mask
2873 // must choose consecutive elements to allow casting first.
2874 assert(NumDstElts % NumSrcElts == 0 && "Unexpected shuffle mask");
2875 unsigned ScaleFactor = NumDstElts / NumSrcElts;
2876 if (!widenShuffleMaskElts(ScaleFactor, OldMask, NewMask))
2877 return false;
2878 }
2879
2880 auto *NewShuffleDstTy =
2881 FixedVectorType::get(CastSrcTy->getScalarType(), NewMask.size());
2882
2883 // Try to replace a castop with a shuffle if the shuffle is not costly.
2884 InstructionCost CostC0 =
2885 TTI.getCastInstrCost(C0->getOpcode(), CastDstTy, CastSrcTy,
2887
2889 if (IsBinaryShuffle)
2891 else
2893
2894 InstructionCost OldCost = CostC0;
2895 OldCost += TTI.getShuffleCost(ShuffleKind, ShuffleDstTy, CastDstTy, OldMask,
2896 CostKind, 0, nullptr, {}, &I);
2897
2898 InstructionCost NewCost = TTI.getShuffleCost(ShuffleKind, NewShuffleDstTy,
2899 CastSrcTy, NewMask, CostKind);
2900 NewCost += TTI.getCastInstrCost(Opcode, ShuffleDstTy, NewShuffleDstTy,
2902 if (!C0->hasOneUse())
2903 NewCost += CostC0;
2904 if (IsBinaryShuffle) {
2905 InstructionCost CostC1 =
2906 TTI.getCastInstrCost(C1->getOpcode(), CastDstTy, CastSrcTy,
2908 OldCost += CostC1;
2909 if (!C1->hasOneUse())
2910 NewCost += CostC1;
2911 }
2912
2913 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two casts: " << I
2914 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2915 << "\n");
2916 if (NewCost > OldCost)
2917 return false;
2918
2919 Value *Shuf;
2920 if (IsBinaryShuffle)
2921 Shuf = Builder.CreateShuffleVector(C0->getOperand(0), C1->getOperand(0),
2922 NewMask);
2923 else
2924 Shuf = Builder.CreateShuffleVector(C0->getOperand(0), NewMask);
2925
2926 Value *Cast = Builder.CreateCast(Opcode, Shuf, ShuffleDstTy);
2927
2928 // Intersect flags from the old casts.
2929 if (auto *NewInst = dyn_cast<Instruction>(Cast)) {
2930 NewInst->copyIRFlags(C0);
2931 if (IsBinaryShuffle)
2932 NewInst->andIRFlags(C1);
2933 }
2934
2935 Worklist.pushValue(Shuf);
2936 replaceValue(I, *Cast);
2937 return true;
2938}
2939
2940/// Try to convert any of:
2941/// "shuffle (shuffle x, y), (shuffle y, x)"
2942/// "shuffle (shuffle x, undef), (shuffle y, undef)"
2943/// "shuffle (shuffle x, undef), y"
2944/// "shuffle x, (shuffle y, undef)"
2945/// into "shuffle x, y".
2946bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
2947 ArrayRef<int> OuterMask;
2948 Value *OuterV0, *OuterV1;
2949 if (!match(&I,
2950 m_Shuffle(m_Value(OuterV0), m_Value(OuterV1), m_Mask(OuterMask))))
2951 return false;
2952
2953 ArrayRef<int> InnerMask0, InnerMask1;
2954 Value *X0, *X1, *Y0, *Y1;
2955 bool Match0 =
2956 match(OuterV0, m_Shuffle(m_Value(X0), m_Value(Y0), m_Mask(InnerMask0)));
2957 bool Match1 =
2958 match(OuterV1, m_Shuffle(m_Value(X1), m_Value(Y1), m_Mask(InnerMask1)));
2959 if (!Match0 && !Match1)
2960 return false;
2961
2962 // If the outer shuffle is a permute, then create a fake inner all-poison
2963 // shuffle. This is easier than accounting for length-changing shuffles below.
2964 SmallVector<int, 16> PoisonMask1;
2965 if (!Match1 && isa<PoisonValue>(OuterV1)) {
2966 X1 = X0;
2967 Y1 = Y0;
2968 PoisonMask1.append(InnerMask0.size(), PoisonMaskElem);
2969 InnerMask1 = PoisonMask1;
2970 Match1 = true; // fake match
2971 }
2972
2973 X0 = Match0 ? X0 : OuterV0;
2974 Y0 = Match0 ? Y0 : OuterV0;
2975 X1 = Match1 ? X1 : OuterV1;
2976 Y1 = Match1 ? Y1 : OuterV1;
2977 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
2978 auto *ShuffleSrcTy = dyn_cast<FixedVectorType>(X0->getType());
2979 auto *ShuffleImmTy = dyn_cast<FixedVectorType>(OuterV0->getType());
2980 if (!ShuffleDstTy || !ShuffleSrcTy || !ShuffleImmTy ||
2981 X0->getType() != X1->getType())
2982 return false;
2983
2984 unsigned NumSrcElts = ShuffleSrcTy->getNumElements();
2985 unsigned NumImmElts = ShuffleImmTy->getNumElements();
2986
2987 // Attempt to merge shuffles, matching upto 2 source operands.
2988 // Replace index to a poison arg with PoisonMaskElem.
2989 // Bail if either inner masks reference an undef arg.
2990 SmallVector<int, 16> NewMask(OuterMask);
2991 Value *NewX = nullptr, *NewY = nullptr;
2992 for (int &M : NewMask) {
2993 Value *Src = nullptr;
2994 if (0 <= M && M < (int)NumImmElts) {
2995 Src = OuterV0;
2996 if (Match0) {
2997 M = InnerMask0[M];
2998 Src = M >= (int)NumSrcElts ? Y0 : X0;
2999 M = M >= (int)NumSrcElts ? (M - NumSrcElts) : M;
3000 }
3001 } else if (M >= (int)NumImmElts) {
3002 Src = OuterV1;
3003 M -= NumImmElts;
3004 if (Match1) {
3005 M = InnerMask1[M];
3006 Src = M >= (int)NumSrcElts ? Y1 : X1;
3007 M = M >= (int)NumSrcElts ? (M - NumSrcElts) : M;
3008 }
3009 }
3010 if (Src && M != PoisonMaskElem) {
3011 assert(0 <= M && M < (int)NumSrcElts && "Unexpected shuffle mask index");
3012 if (isa<UndefValue>(Src)) {
3013 // We've referenced an undef element - if its poison, update the shuffle
3014 // mask, else bail.
3015 if (!isa<PoisonValue>(Src))
3016 return false;
3017 M = PoisonMaskElem;
3018 continue;
3019 }
3020 if (!NewX || NewX == Src) {
3021 NewX = Src;
3022 continue;
3023 }
3024 if (!NewY || NewY == Src) {
3025 M += NumSrcElts;
3026 NewY = Src;
3027 continue;
3028 }
3029 return false;
3030 }
3031 }
3032
3033 if (!NewX) {
3034 replaceValue(I, *PoisonValue::get(ShuffleDstTy));
3035 return true;
3036 }
3037
3038 if (!NewY)
3039 NewY = PoisonValue::get(ShuffleSrcTy);
3040
3041 // Have we folded to an Identity shuffle?
3042 if (ShuffleVectorInst::isIdentityMask(NewMask, NumSrcElts)) {
3043 replaceValue(I, *NewX);
3044 return true;
3045 }
3046
3047 // Try to merge the shuffles if the new shuffle is not costly.
3048 InstructionCost InnerCost0 = 0;
3049 if (Match0)
3050 InnerCost0 = TTI.getInstructionCost(cast<User>(OuterV0), CostKind);
3051
3052 InstructionCost InnerCost1 = 0;
3053 if (Match1)
3054 InnerCost1 = TTI.getInstructionCost(cast<User>(OuterV1), CostKind);
3055
3057
3058 InstructionCost OldCost = InnerCost0 + InnerCost1 + OuterCost;
3059
3060 bool IsUnary = all_of(NewMask, [&](int M) { return M < (int)NumSrcElts; });
3064 InstructionCost NewCost =
3065 TTI.getShuffleCost(SK, ShuffleDstTy, ShuffleSrcTy, NewMask, CostKind, 0,
3066 nullptr, {NewX, NewY});
3067 if (!OuterV0->hasOneUse())
3068 NewCost += InnerCost0;
3069 if (!OuterV1->hasOneUse())
3070 NewCost += InnerCost1;
3071
3072 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two shuffles: " << I
3073 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
3074 << "\n");
3075 if (NewCost > OldCost)
3076 return false;
3077
3078 Value *Shuf = Builder.CreateShuffleVector(NewX, NewY, NewMask);
3079 replaceValue(I, *Shuf);
3080 return true;
3081}
3082
3083/// Try to convert a chain of length-preserving shuffles that are fed by
3084/// length-changing shuffles from the same source, e.g. a chain of length 3:
3085///
3086/// "shuffle (shuffle (shuffle x, (shuffle y, undef)),
3087/// (shuffle y, undef)),
3088// (shuffle y, undef)"
3089///
3090/// into a single shuffle fed by a length-changing shuffle:
3091///
3092/// "shuffle x, (shuffle y, undef)"
3093///
3094/// Such chains arise e.g. from folding extract/insert sequences.
3095bool VectorCombine::foldShufflesOfLengthChangingShuffles(Instruction &I) {
3096 FixedVectorType *TrunkType = dyn_cast<FixedVectorType>(I.getType());
3097 if (!TrunkType)
3098 return false;
3099
3100 unsigned ChainLength = 0;
3101 SmallVector<int> Mask;
3102 SmallVector<int> YMask;
3103 InstructionCost OldCost = 0;
3104 InstructionCost NewCost = 0;
3105 Value *Trunk = &I;
3106 unsigned NumTrunkElts = TrunkType->getNumElements();
3107 Value *Y = nullptr;
3108
3109 for (;;) {
3110 // Match the current trunk against (commutations of) the pattern
3111 // "shuffle trunk', (shuffle y, undef)"
3112 ArrayRef<int> OuterMask;
3113 Value *OuterV0, *OuterV1;
3114 if (ChainLength != 0 && !Trunk->hasOneUse())
3115 break;
3116 if (!match(Trunk, m_Shuffle(m_Value(OuterV0), m_Value(OuterV1),
3117 m_Mask(OuterMask))))
3118 break;
3119 if (OuterV0->getType() != TrunkType) {
3120 // This shuffle is not length-preserving, so it cannot be part of the
3121 // chain.
3122 break;
3123 }
3124
3125 ArrayRef<int> InnerMask0, InnerMask1;
3126 Value *A0, *A1, *B0, *B1;
3127 bool Match0 =
3128 match(OuterV0, m_Shuffle(m_Value(A0), m_Value(B0), m_Mask(InnerMask0)));
3129 bool Match1 =
3130 match(OuterV1, m_Shuffle(m_Value(A1), m_Value(B1), m_Mask(InnerMask1)));
3131 bool Match0Leaf = Match0 && A0->getType() != I.getType();
3132 bool Match1Leaf = Match1 && A1->getType() != I.getType();
3133 if (Match0Leaf == Match1Leaf) {
3134 // Only handle the case of exactly one leaf in each step. The "two leaves"
3135 // case is handled by foldShuffleOfShuffles.
3136 break;
3137 }
3138
3139 SmallVector<int> CommutedOuterMask;
3140 if (Match0Leaf) {
3141 std::swap(OuterV0, OuterV1);
3142 std::swap(InnerMask0, InnerMask1);
3143 std::swap(A0, A1);
3144 std::swap(B0, B1);
3145 llvm::append_range(CommutedOuterMask, OuterMask);
3146 for (int &M : CommutedOuterMask) {
3147 if (M == PoisonMaskElem)
3148 continue;
3149 if (M < (int)NumTrunkElts)
3150 M += NumTrunkElts;
3151 else
3152 M -= NumTrunkElts;
3153 }
3154 OuterMask = CommutedOuterMask;
3155 }
3156 if (!OuterV1->hasOneUse())
3157 break;
3158
3159 if (!isa<UndefValue>(A1)) {
3160 if (!Y)
3161 Y = A1;
3162 else if (Y != A1)
3163 break;
3164 }
3165 if (!isa<UndefValue>(B1)) {
3166 if (!Y)
3167 Y = B1;
3168 else if (Y != B1)
3169 break;
3170 }
3171
3172 auto *YType = cast<FixedVectorType>(A1->getType());
3173 int NumLeafElts = YType->getNumElements();
3174 SmallVector<int> LocalYMask(InnerMask1);
3175 for (int &M : LocalYMask) {
3176 if (M >= NumLeafElts)
3177 M -= NumLeafElts;
3178 }
3179
3180 InstructionCost LocalOldCost =
3183
3184 // Handle the initial (start of chain) case.
3185 if (!ChainLength) {
3186 Mask.assign(OuterMask);
3187 YMask.assign(LocalYMask);
3188 OldCost = NewCost = LocalOldCost;
3189 Trunk = OuterV0;
3190 ChainLength++;
3191 continue;
3192 }
3193
3194 // For the non-root case, first attempt to combine masks.
3195 SmallVector<int> NewYMask(YMask);
3196 bool Valid = true;
3197 for (auto [CombinedM, LeafM] : llvm::zip(NewYMask, LocalYMask)) {
3198 if (LeafM == -1 || CombinedM == LeafM)
3199 continue;
3200 if (CombinedM == -1) {
3201 CombinedM = LeafM;
3202 } else {
3203 Valid = false;
3204 break;
3205 }
3206 }
3207 if (!Valid)
3208 break;
3209
3210 SmallVector<int> NewMask;
3211 NewMask.reserve(NumTrunkElts);
3212 for (int M : Mask) {
3213 if (M < 0 || M >= static_cast<int>(NumTrunkElts))
3214 NewMask.push_back(M);
3215 else
3216 NewMask.push_back(OuterMask[M]);
3217 }
3218
3219 // Break the chain if adding this new step complicates the shuffles such
3220 // that it would increase the new cost by more than the old cost of this
3221 // step.
3222 InstructionCost LocalNewCost =
3224 YType, NewYMask, CostKind) +
3226 TrunkType, NewMask, CostKind);
3227
3228 if (LocalNewCost >= NewCost && LocalOldCost < LocalNewCost - NewCost)
3229 break;
3230
3231 LLVM_DEBUG({
3232 if (ChainLength == 1) {
3233 dbgs() << "Found chain of shuffles fed by length-changing shuffles: "
3234 << I << '\n';
3235 }
3236 dbgs() << " next chain link: " << *Trunk << '\n'
3237 << " old cost: " << (OldCost + LocalOldCost)
3238 << " new cost: " << LocalNewCost << '\n';
3239 });
3240
3241 Mask = NewMask;
3242 YMask = NewYMask;
3243 OldCost += LocalOldCost;
3244 NewCost = LocalNewCost;
3245 Trunk = OuterV0;
3246 ChainLength++;
3247 }
3248 if (ChainLength <= 1)
3249 return false;
3250
3251 if (llvm::all_of(Mask, [&](int M) {
3252 return M < 0 || M >= static_cast<int>(NumTrunkElts);
3253 })) {
3254 // Produce a canonical simplified form if all elements are sourced from Y.
3255 for (int &M : Mask) {
3256 if (M >= static_cast<int>(NumTrunkElts))
3257 M = YMask[M - NumTrunkElts];
3258 }
3259 Value *Root =
3260 Builder.CreateShuffleVector(Y, PoisonValue::get(Y->getType()), Mask);
3261 replaceValue(I, *Root);
3262 return true;
3263 }
3264
3265 Value *Leaf =
3266 Builder.CreateShuffleVector(Y, PoisonValue::get(Y->getType()), YMask);
3267 Value *Root = Builder.CreateShuffleVector(Trunk, Leaf, Mask);
3268 replaceValue(I, *Root);
3269 return true;
3270}
3271
3272/// Try to convert
3273/// "shuffle (intrinsic), (intrinsic)" into "intrinsic (shuffle), (shuffle)".
3274bool VectorCombine::foldShuffleOfIntrinsics(Instruction &I) {
3275 Value *V0, *V1;
3276 ArrayRef<int> OldMask;
3277 if (!match(&I, m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(OldMask))))
3278 return false;
3279
3280 auto *II0 = dyn_cast<IntrinsicInst>(V0);
3281 auto *II1 = dyn_cast<IntrinsicInst>(V1);
3282 if (!II0 || !II1)
3283 return false;
3284
3285 Intrinsic::ID IID = II0->getIntrinsicID();
3286 if (IID != II1->getIntrinsicID())
3287 return false;
3288 InstructionCost CostII0 =
3289 TTI.getIntrinsicInstrCost(IntrinsicCostAttributes(IID, *II0), CostKind);
3290 InstructionCost CostII1 =
3291 TTI.getIntrinsicInstrCost(IntrinsicCostAttributes(IID, *II1), CostKind);
3292
3293 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
3294 auto *II0Ty = dyn_cast<FixedVectorType>(II0->getType());
3295 if (!ShuffleDstTy || !II0Ty)
3296 return false;
3297
3298 if (!isTriviallyVectorizable(IID))
3299 return false;
3300
3301 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
3302 Value *Arg0 = II0->getArgOperand(I);
3303 Value *Arg1 = II1->getArgOperand(I);
3305 // Scalar operands must be identical.
3306 if (Arg0 != Arg1)
3307 return false;
3308 } else if (Arg0->getType() != Arg1->getType()) {
3309 // The corresponding vector operands are shuffled together, so they must
3310 // share the same type. For intrinsics overloaded on their operand type
3311 // (e.g. llvm.fptosi.sat), two calls can produce the same result type
3312 // from different operand types; shuffling those would be invalid.
3313 return false;
3314 }
3315 }
3316
3317 InstructionCost OldCost =
3318 CostII0 + CostII1 +
3320 II0Ty, OldMask, CostKind, 0, nullptr, {II0, II1}, &I);
3321
3322 SmallVector<Type *> NewArgsTy;
3323 InstructionCost NewCost = 0;
3324 SmallDenseSet<std::pair<Value *, Value *>> SeenOperandPairs;
3325 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
3327 NewArgsTy.push_back(II0->getArgOperand(I)->getType());
3328 } else {
3329 auto *VecTy = cast<FixedVectorType>(II0->getArgOperand(I)->getType());
3330 auto *ArgTy = FixedVectorType::get(VecTy->getElementType(),
3331 ShuffleDstTy->getNumElements());
3332 NewArgsTy.push_back(ArgTy);
3333 std::pair<Value *, Value *> OperandPair =
3334 std::make_pair(II0->getArgOperand(I), II1->getArgOperand(I));
3335 if (!SeenOperandPairs.insert(OperandPair).second) {
3336 // We've already computed the cost for this operand pair.
3337 continue;
3338 }
3339 NewCost += TTI.getShuffleCost(
3340 TargetTransformInfo::SK_PermuteTwoSrc, ArgTy, VecTy, OldMask,
3341 CostKind, 0, nullptr, {II0->getArgOperand(I), II1->getArgOperand(I)});
3342 }
3343 }
3344 IntrinsicCostAttributes NewAttr(IID, ShuffleDstTy, NewArgsTy);
3345
3346 NewCost += TTI.getIntrinsicInstrCost(NewAttr, CostKind);
3347 if (!II0->hasOneUse())
3348 NewCost += CostII0;
3349 if (II1 != II0 && !II1->hasOneUse())
3350 NewCost += CostII1;
3351
3352 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two intrinsics: " << I
3353 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
3354 << "\n");
3355
3356 if (NewCost > OldCost)
3357 return false;
3358
3359 SmallVector<Value *> NewArgs;
3360 SmallDenseMap<std::pair<Value *, Value *>, Value *> ShuffleCache;
3361 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I)
3363 NewArgs.push_back(II0->getArgOperand(I));
3364 } else {
3365 std::pair<Value *, Value *> OperandPair =
3366 std::make_pair(II0->getArgOperand(I), II1->getArgOperand(I));
3367 auto It = ShuffleCache.find(OperandPair);
3368 if (It != ShuffleCache.end()) {
3369 // Reuse previously created shuffle for this operand pair.
3370 NewArgs.push_back(It->second);
3371 continue;
3372 }
3373 Value *Shuf = Builder.CreateShuffleVector(II0->getArgOperand(I),
3374 II1->getArgOperand(I), OldMask);
3375 ShuffleCache[OperandPair] = Shuf;
3376 NewArgs.push_back(Shuf);
3377 Worklist.pushValue(Shuf);
3378 }
3379 Value *NewIntrinsic = Builder.CreateIntrinsic(ShuffleDstTy, IID, NewArgs);
3380
3381 // Intersect flags from the old intrinsics.
3382 if (auto *NewInst = dyn_cast<Instruction>(NewIntrinsic)) {
3383 NewInst->copyIRFlags(II0);
3384 NewInst->andIRFlags(II1);
3385 }
3386
3387 replaceValue(I, *NewIntrinsic);
3388 return true;
3389}
3390
3391/// Try to convert
3392/// "shuffle (intrinsic), (poison/undef)" into "intrinsic (shuffle)".
3393bool VectorCombine::foldPermuteOfIntrinsic(Instruction &I) {
3394 Value *V0;
3395 ArrayRef<int> Mask;
3396 if (!match(&I, m_Shuffle(m_Value(V0), m_Undef(), m_Mask(Mask))))
3397 return false;
3398
3399 auto *II0 = dyn_cast<IntrinsicInst>(V0);
3400 if (!II0)
3401 return false;
3402
3403 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
3404 auto *IntrinsicSrcTy = dyn_cast<FixedVectorType>(II0->getType());
3405 if (!ShuffleDstTy || !IntrinsicSrcTy)
3406 return false;
3407
3408 // Validate it's a pure permute, mask should only reference the first vector
3409 unsigned NumSrcElts = IntrinsicSrcTy->getNumElements();
3410 if (any_of(Mask, [NumSrcElts](int M) { return M >= (int)NumSrcElts; }))
3411 return false;
3412
3413 Intrinsic::ID IID = II0->getIntrinsicID();
3414 if (!isTriviallyVectorizable(IID))
3415 return false;
3416
3417 // Cost analysis
3419 TTI.getIntrinsicInstrCost(IntrinsicCostAttributes(IID, *II0), CostKind);
3420 InstructionCost OldCost =
3423 IntrinsicSrcTy, Mask, CostKind, 0, nullptr, {V0}, &I);
3424
3425 SmallVector<Type *> NewArgsTy;
3426 InstructionCost NewCost = 0;
3427 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
3429 NewArgsTy.push_back(II0->getArgOperand(I)->getType());
3430 } else {
3431 auto *VecTy = cast<FixedVectorType>(II0->getArgOperand(I)->getType());
3432 auto *ArgTy = FixedVectorType::get(VecTy->getElementType(),
3433 ShuffleDstTy->getNumElements());
3434 NewArgsTy.push_back(ArgTy);
3436 ArgTy, VecTy, Mask, CostKind, 0, nullptr,
3437 {II0->getArgOperand(I)});
3438 }
3439 }
3440 IntrinsicCostAttributes NewAttr(IID, ShuffleDstTy, NewArgsTy);
3441 NewCost += TTI.getIntrinsicInstrCost(NewAttr, CostKind);
3442
3443 // If the intrinsic has multiple uses, we need to account for the cost of
3444 // keeping the original intrinsic around.
3445 if (!II0->hasOneUse())
3446 NewCost += IntrinsicCost;
3447
3448 LLVM_DEBUG(dbgs() << "Found a permute of intrinsic: " << I << "\n OldCost: "
3449 << OldCost << " vs NewCost: " << NewCost << "\n");
3450
3451 if (NewCost > OldCost)
3452 return false;
3453
3454 // Transform
3455 SmallVector<Value *> NewArgs;
3456 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
3458 NewArgs.push_back(II0->getArgOperand(I));
3459 } else {
3460 Value *Shuf = Builder.CreateShuffleVector(II0->getArgOperand(I), Mask);
3461 NewArgs.push_back(Shuf);
3462 Worklist.pushValue(Shuf);
3463 }
3464 }
3465
3466 Value *NewIntrinsic = Builder.CreateIntrinsic(ShuffleDstTy, IID, NewArgs);
3467
3468 if (auto *NewInst = dyn_cast<Instruction>(NewIntrinsic))
3469 NewInst->copyIRFlags(II0);
3470
3471 replaceValue(I, *NewIntrinsic);
3472 return true;
3473}
3474
3475using InstLane = std::pair<Value *, int>;
3476
3477static InstLane lookThroughShuffles(Value *V, int Lane) {
3478 while (auto *SV = dyn_cast<ShuffleVectorInst>(V)) {
3479 unsigned NumElts =
3480 cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements();
3481 int M = SV->getMaskValue(Lane);
3482 if (M < 0)
3483 return {nullptr, PoisonMaskElem};
3484 if (static_cast<unsigned>(M) < NumElts) {
3485 V = SV->getOperand(0);
3486 Lane = M;
3487 } else {
3488 V = SV->getOperand(1);
3489 Lane = M - NumElts;
3490 }
3491 }
3492 return InstLane{V, Lane};
3493}
3494
3498 for (InstLane IL : Item) {
3499 auto [U, Lane] = IL;
3500 InstLane OpLane =
3501 U ? lookThroughShuffles(cast<Instruction>(U)->getOperand(Op), Lane)
3502 : InstLane{nullptr, PoisonMaskElem};
3503 NItem.emplace_back(OpLane);
3504 }
3505 return NItem;
3506}
3507
3508/// Detect concat of multiple values into a vector
3510 const TargetTransformInfo &TTI) {
3511 auto *Ty = cast<FixedVectorType>(Item.front().first->getType());
3512 unsigned NumElts = Ty->getNumElements();
3513 if (Item.size() == NumElts || NumElts == 1 || Item.size() % NumElts != 0)
3514 return false;
3515
3516 // Check that the concat is free, usually meaning that the type will be split
3517 // during legalization.
3518 SmallVector<int, 16> ConcatMask(NumElts * 2);
3519 std::iota(ConcatMask.begin(), ConcatMask.end(), 0);
3520 if (TTI.getShuffleCost(TTI::SK_PermuteTwoSrc,
3521 FixedVectorType::get(Ty->getScalarType(), NumElts * 2),
3522 Ty, ConcatMask, CostKind) != 0)
3523 return false;
3524
3525 unsigned NumSlices = Item.size() / NumElts;
3526 // Currently we generate a tree of shuffles for the concats, which limits us
3527 // to a power2.
3528 if (!isPowerOf2_32(NumSlices))
3529 return false;
3530 for (unsigned Slice = 0; Slice < NumSlices; ++Slice) {
3531 Value *SliceV = Item[Slice * NumElts].first;
3532 if (!SliceV || SliceV->getType() != Ty)
3533 return false;
3534 for (unsigned Elt = 0; Elt < NumElts; ++Elt) {
3535 auto [V, Lane] = Item[Slice * NumElts + Elt];
3536 if (Lane != static_cast<int>(Elt) || SliceV != V)
3537 return false;
3538 }
3539 }
3540 return true;
3541}
3542
3543static Value *
3545 const DenseSet<std::pair<Value *, Use *>> &IdentityLeafs,
3546 const DenseSet<std::pair<Value *, Use *>> &SplatLeafs,
3547 const DenseSet<std::pair<Value *, Use *>> &ConcatLeafs,
3548 IRBuilderBase &Builder, const TargetTransformInfo *TTI) {
3549 auto [FrontV, FrontLane] = Item.front();
3550
3551 if (IdentityLeafs.contains(std::make_pair(FrontV, From))) {
3552 return FrontV;
3553 }
3554 if (SplatLeafs.contains(std::make_pair(FrontV, From))) {
3555 SmallVector<int, 16> Mask(Ty->getNumElements(), FrontLane);
3556 return Builder.CreateShuffleVector(FrontV, Mask);
3557 }
3558 if (ConcatLeafs.contains(std::make_pair(FrontV, From))) {
3559 unsigned NumElts =
3560 cast<FixedVectorType>(FrontV->getType())->getNumElements();
3561 SmallVector<Value *> Values(Item.size() / NumElts, nullptr);
3562 for (unsigned S = 0; S < Values.size(); ++S)
3563 Values[S] = Item[S * NumElts].first;
3564
3565 while (Values.size() > 1) {
3566 NumElts *= 2;
3567 SmallVector<int, 16> Mask(NumElts, 0);
3568 std::iota(Mask.begin(), Mask.end(), 0);
3569 SmallVector<Value *> NewValues(Values.size() / 2, nullptr);
3570 for (unsigned S = 0; S < NewValues.size(); ++S)
3571 NewValues[S] =
3572 Builder.CreateShuffleVector(Values[S * 2], Values[S * 2 + 1], Mask);
3573 Values = NewValues;
3574 }
3575 return Values[0];
3576 }
3577
3578 auto *I = cast<Instruction>(FrontV);
3579 auto *II = dyn_cast<IntrinsicInst>(I);
3580 unsigned NumOps = I->getNumOperands() - (II ? 1 : 0);
3582 for (unsigned Idx = 0; Idx < NumOps; Idx++) {
3583 if (II &&
3584 isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Idx, TTI)) {
3585 Ops[Idx] = II->getOperand(Idx);
3586 continue;
3587 }
3589 &I->getOperandUse(Idx), Ty, IdentityLeafs,
3590 SplatLeafs, ConcatLeafs, Builder, TTI);
3591 }
3592
3593 SmallVector<Value *, 8> ValueList;
3594 for (const auto &Lane : Item)
3595 if (Lane.first)
3596 ValueList.push_back(Lane.first);
3597
3598 Type *DstTy =
3599 FixedVectorType::get(I->getType()->getScalarType(), Ty->getNumElements());
3600 if (auto *BI = dyn_cast<BinaryOperator>(I)) {
3601 auto *Value = Builder.CreateBinOp((Instruction::BinaryOps)BI->getOpcode(),
3602 Ops[0], Ops[1]);
3603 propagateIRFlags(Value, ValueList);
3604 return Value;
3605 }
3606 if (auto *CI = dyn_cast<CmpInst>(I)) {
3607 auto *Value = Builder.CreateCmp(CI->getPredicate(), Ops[0], Ops[1]);
3608 propagateIRFlags(Value, ValueList);
3609 return Value;
3610 }
3611 if (auto *SI = dyn_cast<SelectInst>(I)) {
3612 auto *Value = Builder.CreateSelect(Ops[0], Ops[1], Ops[2], "", SI);
3613 propagateIRFlags(Value, ValueList);
3614 return Value;
3615 }
3616 if (auto *CI = dyn_cast<CastInst>(I)) {
3617 auto *Value = Builder.CreateCast(CI->getOpcode(), Ops[0], DstTy);
3618 propagateIRFlags(Value, ValueList);
3619 return Value;
3620 }
3621 if (II) {
3622 auto *Value = Builder.CreateIntrinsic(DstTy, II->getIntrinsicID(), Ops);
3623 propagateIRFlags(Value, ValueList);
3624 return Value;
3625 }
3626 assert(isa<UnaryInstruction>(I) && "Unexpected instruction type in Generate");
3627 auto *Value =
3628 Builder.CreateUnOp((Instruction::UnaryOps)I->getOpcode(), Ops[0]);
3629 propagateIRFlags(Value, ValueList);
3630 return Value;
3631}
3632
3633// Starting from a shuffle, look up through operands tracking the shuffled index
3634// of each lane. If we can simplify away the shuffles to identities then
3635// do so.
3636bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
3637 auto *Ty = dyn_cast<FixedVectorType>(I.getType());
3638 if (!Ty || I.use_empty())
3639 return false;
3640
3641 SmallVector<InstLane> Start(Ty->getNumElements());
3642 for (unsigned M = 0, E = Ty->getNumElements(); M < E; ++M)
3643 Start[M] = lookThroughShuffles(&I, M);
3644
3646 Worklist.push_back(std::make_pair(Start, &*I.use_begin()));
3647 DenseSet<std::pair<Value *, Use *>> IdentityLeafs, SplatLeafs, ConcatLeafs;
3648 unsigned NumVisited = 0;
3649
3650 while (!Worklist.empty()) {
3651 if (++NumVisited > MaxInstrsToScan)
3652 return false;
3653
3654 auto ItemFrom = Worklist.pop_back_val();
3655 auto Item = ItemFrom.first;
3656 auto From = ItemFrom.second;
3657 auto [FrontV, FrontLane] = Item.front();
3658
3659 // If we found an undef first lane then bail out to keep things simple.
3660 if (!FrontV)
3661 return false;
3662
3663 // Helper to peek through bitcasts to the same value.
3664 auto IsEquiv = [&](Value *X, Value *Y) {
3665 return X->getType() == Y->getType() &&
3667 };
3668
3669 // Look for an identity value.
3670 if (FrontLane == 0 &&
3671 cast<FixedVectorType>(FrontV->getType())->getNumElements() ==
3672 Ty->getNumElements() &&
3673 all_of(drop_begin(enumerate(Item)), [IsEquiv, Item](const auto &E) {
3674 Value *FrontV = Item.front().first;
3675 return !E.value().first || (IsEquiv(E.value().first, FrontV) &&
3676 E.value().second == (int)E.index());
3677 })) {
3678 IdentityLeafs.insert(std::make_pair(FrontV, From));
3679 continue;
3680 }
3681 // Look for constants, for the moment only supporting constant splats.
3682 if (auto *C = dyn_cast<Constant>(FrontV);
3683 C && C->getSplatValue() &&
3684 all_of(drop_begin(Item), [Item](InstLane &IL) {
3685 Value *FrontV = Item.front().first;
3686 Value *V = IL.first;
3687 return !V || (isa<Constant>(V) &&
3688 cast<Constant>(V)->getSplatValue() ==
3689 cast<Constant>(FrontV)->getSplatValue());
3690 })) {
3691 SplatLeafs.insert(std::make_pair(FrontV, From));
3692 continue;
3693 }
3694 // Look for a splat value.
3695 if (all_of(drop_begin(Item), [Item](InstLane &IL) {
3696 auto [FrontV, FrontLane] = Item.front();
3697 auto [V, Lane] = IL;
3698 return !V || (V == FrontV && Lane == FrontLane);
3699 })) {
3700 SplatLeafs.insert(std::make_pair(FrontV, From));
3701 continue;
3702 }
3703
3704 // We need each element to be the same type of value, and check that each
3705 // element has a single use.
3706 auto CheckLaneIsEquivalentToFirst = [Item](InstLane IL) {
3707 Value *FrontV = Item.front().first;
3708 if (!IL.first)
3709 return true;
3710 Value *V = IL.first;
3711 if (auto *I = dyn_cast<Instruction>(V); I && !I->hasOneUser())
3712 return false;
3713 if (V->getValueID() != FrontV->getValueID())
3714 return false;
3715 if (auto *CI = dyn_cast<CmpInst>(V))
3716 if (CI->getPredicate() != cast<CmpInst>(FrontV)->getPredicate())
3717 return false;
3718 if (auto *CI = dyn_cast<CastInst>(V))
3719 if (CI->getSrcTy()->getScalarType() !=
3720 cast<CastInst>(FrontV)->getSrcTy()->getScalarType())
3721 return false;
3722 if (auto *SI = dyn_cast<SelectInst>(V))
3723 if (!isa<VectorType>(SI->getOperand(0)->getType()) ||
3724 SI->getOperand(0)->getType() !=
3725 cast<SelectInst>(FrontV)->getOperand(0)->getType())
3726 return false;
3727 if (isa<CallInst>(V) && !isa<IntrinsicInst>(V))
3728 return false;
3729 auto *II = dyn_cast<IntrinsicInst>(V);
3730 return !II || (isa<IntrinsicInst>(FrontV) &&
3731 II->getIntrinsicID() ==
3732 cast<IntrinsicInst>(FrontV)->getIntrinsicID() &&
3733 !II->hasOperandBundles());
3734 };
3735 if (all_of(drop_begin(Item), CheckLaneIsEquivalentToFirst)) {
3736 // Check the operator is one that we support.
3737 if (isa<BinaryOperator, CmpInst>(FrontV)) {
3738 // We exclude div/rem in case they hit UB from poison lanes.
3739 if (auto *BO = dyn_cast<BinaryOperator>(FrontV);
3740 BO && BO->isIntDivRem())
3741 return false;
3743 &cast<Instruction>(FrontV)->getOperandUse(0));
3745 &cast<Instruction>(FrontV)->getOperandUse(1));
3746 continue;
3747 } else if (isa<UnaryOperator, TruncInst, ZExtInst, SExtInst, FPToSIInst,
3748 FPToUIInst, SIToFPInst, UIToFPInst>(FrontV)) {
3750 &cast<Instruction>(FrontV)->getOperandUse(0));
3751 continue;
3752 } else if (auto *BitCast = dyn_cast<BitCastInst>(FrontV)) {
3753 // TODO: Handle vector widening/narrowing bitcasts.
3754 auto *DstTy = dyn_cast<FixedVectorType>(BitCast->getDestTy());
3755 auto *SrcTy = dyn_cast<FixedVectorType>(BitCast->getSrcTy());
3756 if (DstTy && SrcTy &&
3757 SrcTy->getNumElements() == DstTy->getNumElements()) {
3759 &BitCast->getOperandUse(0));
3760 continue;
3761 }
3762 } else if (auto *Sel = dyn_cast<SelectInst>(FrontV)) {
3764 &Sel->getOperandUse(0));
3766 &Sel->getOperandUse(1));
3768 &Sel->getOperandUse(2));
3769 continue;
3770 } else if (auto *II = dyn_cast<IntrinsicInst>(FrontV);
3771 II && isTriviallyVectorizable(II->getIntrinsicID()) &&
3772 !II->hasOperandBundles()) {
3773 for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) {
3774 if (isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Op,
3775 &TTI)) {
3776 if (!all_of(drop_begin(Item), [Item, Op](InstLane &IL) {
3777 Value *FrontV = Item.front().first;
3778 Value *V = IL.first;
3779 return !V || (cast<Instruction>(V)->getOperand(Op) ==
3780 cast<Instruction>(FrontV)->getOperand(Op));
3781 }))
3782 return false;
3783 continue;
3784 }
3786 &cast<Instruction>(FrontV)->getOperandUse(Op));
3787 }
3788 continue;
3789 }
3790 }
3791
3792 if (isFreeConcat(Item, CostKind, TTI)) {
3793 ConcatLeafs.insert(std::make_pair(FrontV, From));
3794 continue;
3795 }
3796
3797 return false;
3798 }
3799
3800 if (NumVisited <= 1)
3801 return false;
3802
3803 LLVM_DEBUG(dbgs() << "Found a superfluous identity shuffle: " << I << "\n");
3804
3805 // If we got this far, we know the shuffles are superfluous and can be
3806 // removed. Scan through again and generate the new tree of instructions.
3807 Builder.SetInsertPoint(&I);
3808 Value *V = generateNewInstTree(Start, &*I.use_begin(), Ty, IdentityLeafs,
3809 SplatLeafs, ConcatLeafs, Builder, &TTI);
3810 replaceValue(I, *V);
3811 return true;
3812}
3813
3814/// Given a commutative reduction, the order of the input lanes does not alter
3815/// the results. We can use this to remove certain shuffles feeding the
3816/// reduction, removing the need to shuffle at all.
3817bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
3818 auto *II = dyn_cast<IntrinsicInst>(&I);
3819 if (!II)
3820 return false;
3821 switch (II->getIntrinsicID()) {
3822 case Intrinsic::vector_reduce_add:
3823 case Intrinsic::vector_reduce_mul:
3824 case Intrinsic::vector_reduce_and:
3825 case Intrinsic::vector_reduce_or:
3826 case Intrinsic::vector_reduce_xor:
3827 case Intrinsic::vector_reduce_smin:
3828 case Intrinsic::vector_reduce_smax:
3829 case Intrinsic::vector_reduce_umin:
3830 case Intrinsic::vector_reduce_umax:
3831 break;
3832 default:
3833 return false;
3834 }
3835
3836 // Find all the inputs when looking through operations that do not alter the
3837 // lane order (binops, for example). Currently we look for a single shuffle,
3838 // and can ignore splat values.
3839 std::queue<Value *> Worklist;
3840 SmallPtrSet<Value *, 4> Visited;
3841 ShuffleVectorInst *Shuffle = nullptr;
3842 if (auto *Op = dyn_cast<Instruction>(I.getOperand(0)))
3843 Worklist.push(Op);
3844
3845 while (!Worklist.empty()) {
3846 Value *CV = Worklist.front();
3847 Worklist.pop();
3848 if (Visited.contains(CV))
3849 continue;
3850
3851 // Splats don't change the order, so can be safely ignored.
3852 if (isSplatValue(CV))
3853 continue;
3854
3855 Visited.insert(CV);
3856
3857 if (auto *CI = dyn_cast<Instruction>(CV)) {
3858 if (CI->isBinaryOp()) {
3859 for (auto *Op : CI->operand_values())
3860 Worklist.push(Op);
3861 continue;
3862 } else if (auto *SV = dyn_cast<ShuffleVectorInst>(CI)) {
3863 if (Shuffle && Shuffle != SV)
3864 return false;
3865 Shuffle = SV;
3866 continue;
3867 }
3868 }
3869
3870 // Anything else is currently an unknown node.
3871 return false;
3872 }
3873
3874 if (!Shuffle)
3875 return false;
3876
3877 // Check all uses of the binary ops and shuffles are also included in the
3878 // lane-invariant operations (Visited should be the list of lanewise
3879 // instructions, including the shuffle that we found).
3880 for (auto *V : Visited)
3881 for (auto *U : V->users())
3882 if (!Visited.contains(U) && U != &I)
3883 return false;
3884
3885 FixedVectorType *VecType =
3886 dyn_cast<FixedVectorType>(II->getOperand(0)->getType());
3887 if (!VecType)
3888 return false;
3889 FixedVectorType *ShuffleInputType =
3891 if (!ShuffleInputType)
3892 return false;
3893 unsigned NumInputElts = ShuffleInputType->getNumElements();
3894
3895 // Find the mask from sorting the lanes into order. This is most likely to
3896 // become a identity or concat mask. Undef elements are pushed to the end.
3897 SmallVector<int> ConcatMask;
3898 Shuffle->getShuffleMask(ConcatMask);
3899 sort(ConcatMask, [](int X, int Y) { return (unsigned)X < (unsigned)Y; });
3900 bool UsesSecondVec =
3901 any_of(ConcatMask, [&](int M) { return M >= (int)NumInputElts; });
3902
3904 UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, VecType,
3905 ShuffleInputType, Shuffle->getShuffleMask(), CostKind);
3907 UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, VecType,
3908 ShuffleInputType, ConcatMask, CostKind);
3909
3910 LLVM_DEBUG(dbgs() << "Found a reduction feeding from a shuffle: " << *Shuffle
3911 << "\n");
3912 LLVM_DEBUG(dbgs() << " OldCost: " << OldCost << " vs NewCost: " << NewCost
3913 << "\n");
3914 bool MadeChanges = false;
3915 if (NewCost < OldCost) {
3916 Builder.SetInsertPoint(Shuffle);
3917 Value *NewShuffle = Builder.CreateShuffleVector(
3918 Shuffle->getOperand(0), Shuffle->getOperand(1), ConcatMask);
3919 LLVM_DEBUG(dbgs() << "Created new shuffle: " << *NewShuffle << "\n");
3920 replaceValue(*Shuffle, *NewShuffle);
3921 return true;
3922 }
3923
3924 // See if we can re-use foldSelectShuffle, getting it to reduce the size of
3925 // the shuffle into a nicer order, as it can ignore the order of the shuffles.
3926 MadeChanges |= foldSelectShuffle(*Shuffle, true);
3927 return MadeChanges;
3928}
3929
3930/// For a given chain of patterns of the following form:
3931///
3932/// ```
3933/// %1 = shufflevector <n x ty1> %0, <n x ty1> poison <n x ty2> mask
3934///
3935/// %2 = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %0, <n x
3936/// ty1> %1)
3937/// OR
3938/// %2 = add/mul/or/and/xor <n x ty1> %0, %1
3939///
3940/// %3 = shufflevector <n x ty1> %2, <n x ty1> poison <n x ty2> mask
3941/// ...
3942/// ...
3943/// %(i - 1) = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %(i -
3944/// 3), <n x ty1> %(i - 2)
3945/// OR
3946/// %(i - 1) = add/mul/or/and/xor <n x ty1> %(i - 3), %(i - 2)
3947///
3948/// %(i) = extractelement <n x ty1> %(i - 1), 0
3949/// ```
3950///
3951/// Where:
3952/// `mask` follows a partition pattern:
3953///
3954/// Ex:
3955/// [n = 8, p = poison]
3956///
3957/// 4 5 6 7 | p p p p
3958/// 2 3 | p p p p p p
3959/// 1 | p p p p p p p
3960///
3961/// For powers of 2, there's a consistent pattern, but for other cases
3962/// the parity of the current half value at each step decides the
3963/// next partition half (see `ExpectedParityMask` for more logical details
3964/// in generalising this).
3965///
3966/// Ex:
3967/// [n = 6]
3968///
3969/// 3 4 5 | p p p
3970/// 1 2 | p p p p
3971/// 1 | p p p p p
3972bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
3973 // Going bottom-up for the pattern.
3974 std::queue<Value *> InstWorklist;
3975 InstructionCost OrigCost = 0;
3976
3977 // Common instruction operation after each shuffle op.
3978 std::optional<unsigned int> CommonCallOp = std::nullopt;
3979 std::optional<Instruction::BinaryOps> CommonBinOp = std::nullopt;
3980
3981 // For floating-point reductions, track FMF intersection across all binops.
3982 FastMathFlags CommonFMF;
3983 bool IsFloatReduction = false;
3984
3985 bool IsFirstCallOrBinInst = true;
3986 bool ShouldBeCallOrBinInst = true;
3987
3988 // This stores the last used instructions for shuffle/common op.
3989 //
3990 // PrevVecV[0] / PrevVecV[1] store the last two simultaneous
3991 // instructions from either shuffle/common op.
3992 SmallVector<Value *, 2> PrevVecV(2, nullptr);
3993
3994 Value *VecOpEE;
3995 if (!match(&I, m_ExtractElt(m_Value(VecOpEE), m_Zero())))
3996 return false;
3997
3998 auto *FVT = dyn_cast<FixedVectorType>(VecOpEE->getType());
3999 if (!FVT)
4000 return false;
4001
4002 int64_t VecSize = FVT->getNumElements();
4003 if (VecSize < 2)
4004 return false;
4005
4006 // Number of levels would be ~log2(n), considering we always partition
4007 // by half for this fold pattern.
4008 unsigned int NumLevels = Log2_64_Ceil(VecSize), VisitedCnt = 0;
4009 int64_t ShuffleMaskHalf = 1, ExpectedParityMask = 0;
4010
4011 // This is how we generalise for all element sizes.
4012 // At each step, if vector size is odd, we need non-poison
4013 // values to cover the dominant half so we don't miss out on any element.
4014 //
4015 // This mask will help us retrieve this as we go from bottom to top:
4016 //
4017 // Mask Set -> N = N * 2 - 1
4018 // Mask Unset -> N = N * 2
4019 for (int Cur = VecSize, Mask = NumLevels - 1; Cur > 1;
4020 Cur = (Cur + 1) / 2, --Mask) {
4021 if (Cur & 1)
4022 ExpectedParityMask |= (1ll << Mask);
4023 }
4024
4025 InstWorklist.push(VecOpEE);
4026
4027 bool IsPartialReduction = false;
4028 bool HasLaneDuplication = false;
4029
4030 while (!InstWorklist.empty()) {
4031 Value *CI = InstWorklist.front();
4032 InstWorklist.pop();
4033
4034 if (auto *II = dyn_cast<IntrinsicInst>(CI)) {
4035 if (!ShouldBeCallOrBinInst)
4036 return false;
4037
4038 if (!IsFirstCallOrBinInst && any_of(PrevVecV, equal_to(nullptr)))
4039 return false;
4040
4041 // For the first found call/bin op, the vector has to come from the
4042 // extract element op.
4043 if (II != (IsFirstCallOrBinInst ? VecOpEE : PrevVecV[0]))
4044 return false;
4045 IsFirstCallOrBinInst = false;
4046
4047 if (!CommonCallOp)
4048 CommonCallOp = II->getIntrinsicID();
4049 if (II->getIntrinsicID() != *CommonCallOp)
4050 return false;
4051
4052 switch (II->getIntrinsicID()) {
4053 case Intrinsic::umin:
4054 case Intrinsic::umax:
4055 case Intrinsic::smin:
4056 case Intrinsic::smax: {
4057 auto *Op0 = II->getOperand(0);
4058 auto *Op1 = II->getOperand(1);
4059 PrevVecV[0] = Op0;
4060 PrevVecV[1] = Op1;
4061 break;
4062 }
4063 default:
4064 return false;
4065 }
4066 ShouldBeCallOrBinInst ^= 1;
4067
4068 IntrinsicCostAttributes ICA(
4069 *CommonCallOp, II->getType(),
4070 {PrevVecV[0]->getType(), PrevVecV[1]->getType()});
4071 OrigCost += TTI.getIntrinsicInstrCost(ICA, CostKind);
4072
4073 // We may need a swap here since it can be (a, b) or (b, a)
4074 // and accordingly change as we go up.
4075 if (!isa<ShuffleVectorInst>(PrevVecV[1]))
4076 std::swap(PrevVecV[0], PrevVecV[1]);
4077 InstWorklist.push(PrevVecV[1]);
4078 InstWorklist.push(PrevVecV[0]);
4079 } else if (auto *BinOp = dyn_cast<BinaryOperator>(CI)) {
4080 // Similar logic for bin ops.
4081
4082 if (!ShouldBeCallOrBinInst)
4083 return false;
4084
4085 if (!IsFirstCallOrBinInst && any_of(PrevVecV, equal_to(nullptr)))
4086 return false;
4087
4088 if (BinOp != (IsFirstCallOrBinInst ? VecOpEE : PrevVecV[0]))
4089 return false;
4090 IsFirstCallOrBinInst = false;
4091
4092 if (!CommonBinOp)
4093 CommonBinOp = BinOp->getOpcode();
4094
4095 if (BinOp->getOpcode() != *CommonBinOp)
4096 return false;
4097
4098 switch (*CommonBinOp) {
4099 case BinaryOperator::Add:
4100 case BinaryOperator::Mul:
4101 case BinaryOperator::Or:
4102 case BinaryOperator::And:
4103 case BinaryOperator::Xor:
4104 case BinaryOperator::FAdd:
4105 case BinaryOperator::FMul: {
4106 auto *Op0 = BinOp->getOperand(0);
4107 auto *Op1 = BinOp->getOperand(1);
4108 PrevVecV[0] = Op0;
4109 PrevVecV[1] = Op1;
4110 break;
4111 }
4112 default:
4113 return false;
4114 }
4115
4116 // For FP reductions, require reassoc on every binop and collect FMF.
4117 if (*CommonBinOp == Instruction::FAdd ||
4118 *CommonBinOp == Instruction::FMul) {
4119 if (!BinOp->hasAllowReassoc())
4120 return false;
4121 if (!IsFloatReduction) {
4122 CommonFMF = BinOp->getFastMathFlags();
4123 IsFloatReduction = true;
4124 } else {
4125 CommonFMF &= BinOp->getFastMathFlags();
4126 }
4127 }
4128
4129 ShouldBeCallOrBinInst ^= 1;
4130
4131 OrigCost +=
4132 TTI.getArithmeticInstrCost(*CommonBinOp, BinOp->getType(), CostKind);
4133
4134 if (!isa<ShuffleVectorInst>(PrevVecV[1]))
4135 std::swap(PrevVecV[0], PrevVecV[1]);
4136 InstWorklist.push(PrevVecV[1]);
4137 InstWorklist.push(PrevVecV[0]);
4138 } else if (auto *SVInst = dyn_cast<ShuffleVectorInst>(CI)) {
4139 // We shouldn't have any null values in the previous vectors,
4140 // is so, there was a mismatch in pattern.
4141 if (ShouldBeCallOrBinInst || any_of(PrevVecV, equal_to(nullptr)))
4142 return false;
4143
4144 if (SVInst != PrevVecV[1])
4145 return false;
4146
4147 ArrayRef<int> CurMask;
4148 if (!match(SVInst, m_Shuffle(m_Specific(PrevVecV[0]), m_Poison(),
4149 m_Mask(CurMask))))
4150 return false;
4151
4152 // Subtract the parity mask when checking the condition.
4153 for (int Mask = 0, MaskSize = CurMask.size(); Mask != MaskSize; ++Mask) {
4154 if (Mask < ShuffleMaskHalf &&
4155 CurMask[Mask] != ShuffleMaskHalf + Mask - (ExpectedParityMask & 1))
4156 return false;
4157 if (Mask >= ShuffleMaskHalf && CurMask[Mask] != -1)
4158 return false;
4159 }
4160
4161 // Update mask values.
4162 ShuffleMaskHalf *= 2;
4163 ShuffleMaskHalf -= (ExpectedParityMask & 1);
4164 HasLaneDuplication |= (ExpectedParityMask & 1) != 0;
4165 ExpectedParityMask >>= 1;
4166
4168 SVInst->getType(), SVInst->getType(),
4169 CurMask, CostKind);
4170
4171 VisitedCnt += 1;
4172 if (!ExpectedParityMask && VisitedCnt == NumLevels)
4173 break;
4174
4175 ShouldBeCallOrBinInst ^= 1;
4176 } else {
4177 // Check if this is a partial reduction - the chain ended because
4178 // the source vector is not a recognized op/shuffle.
4179 // Reject non-power-of-2 vectors because parity-based masks cause
4180 // lane duplication in the reduction tree, making the partial result
4181 // not a simple subvector reduction.
4182 if (ShouldBeCallOrBinInst && VisitedCnt >= 1 && CI == PrevVecV[0] &&
4183 isPowerOf2_64(VecSize)) {
4184 IsPartialReduction = true;
4185 break;
4186 }
4187 return false;
4188 }
4189 }
4190
4191 // Full reduction pattern should end with a shuffle op.
4192 // Partial reduction ends when the source vector is reached.
4193 if (ShouldBeCallOrBinInst && !IsPartialReduction)
4194 return false;
4195
4196 // If the parity masks duplicated any lane, the fold only preserves semantics
4197 // for idempotent ops.
4198 if (HasLaneDuplication && CommonBinOp &&
4199 !Instruction::isIdempotent(*CommonBinOp))
4200 return false;
4201
4202 assert(VecSize != -1 && "Expected Match for Vector Size");
4203
4204 Value *FinalVecV = PrevVecV[0];
4205 if (!FinalVecV)
4206 return false;
4207
4208 auto *FinalVecVTy = cast<FixedVectorType>(FinalVecV->getType());
4209
4210 Intrinsic::ID ReducedOp =
4211 (CommonCallOp ? getMinMaxReductionIntrinsicID(*CommonCallOp)
4212 : getReductionForBinop(*CommonBinOp));
4213 if (!ReducedOp)
4214 return false;
4215
4216 InstructionCost NewCost = 0;
4217 FixedVectorType *ReduceVecTy = FinalVecVTy;
4218 SmallVector<int> ExtractMask;
4219
4220 if (IsPartialReduction) {
4221 unsigned SubVecSize = ShuffleMaskHalf;
4222 ReduceVecTy = FixedVectorType::get(FVT->getElementType(), SubVecSize);
4223 ExtractMask.resize(SubVecSize);
4224 std::iota(ExtractMask.begin(), ExtractMask.end(), 0);
4226 ReduceVecTy, FinalVecVTy, ExtractMask,
4227 CostKind, 0, ReduceVecTy);
4228 }
4229
4230 IntrinsicCostAttributes ICA(
4231 ReducedOp, ReduceVecTy->getElementType(),
4232 IsFloatReduction
4233 ? SmallVector<Type *, 2>{ReduceVecTy->getElementType(), ReduceVecTy}
4234 : SmallVector<Type *, 2>{ReduceVecTy},
4235 IsFloatReduction ? CommonFMF : FastMathFlags());
4236 NewCost += TTI.getIntrinsicInstrCost(ICA, CostKind);
4237
4238 LLVM_DEBUG(dbgs() << "Found reduction shuffle chain: " << I << "\n OldCost : "
4239 << OrigCost << " vs NewCost: " << NewCost << "\n");
4240
4241 if (VecOpEE->hasOneUse() ? (NewCost > OrigCost) : (NewCost >= OrigCost))
4242 return false;
4243
4244 Value *ReduceInput = FinalVecV;
4245 if (IsPartialReduction)
4246 ReduceInput = Builder.CreateShuffleVector(FinalVecV, ExtractMask);
4247
4248 CallInst *ReducedResult;
4249 if (IsFloatReduction) {
4251 *CommonBinOp, ReduceVecTy->getElementType(), /*AllowRHSConstant=*/false,
4252 CommonFMF.noSignedZeros());
4253 ReducedResult = Builder.CreateIntrinsic(ReducedOp, {ReduceVecTy},
4254 {Identity, ReduceInput});
4255 ReducedResult->setFastMathFlags(CommonFMF);
4256 } else {
4257 ReducedResult =
4258 Builder.CreateIntrinsic(ReducedOp, {ReduceVecTy}, {ReduceInput});
4259 }
4260 replaceValue(I, *ReducedResult);
4261
4262 return true;
4263}
4264
4265/// Determine if its more efficient to fold:
4266/// reduce(trunc(x)) -> trunc(reduce(x)).
4267/// reduce(sext(x)) -> sext(reduce(x)).
4268/// reduce(zext(x)) -> zext(reduce(x)).
4269bool VectorCombine::foldCastFromReductions(Instruction &I) {
4270 auto *II = dyn_cast<IntrinsicInst>(&I);
4271 if (!II)
4272 return false;
4273
4274 bool TruncOnly = false;
4275 Intrinsic::ID IID = II->getIntrinsicID();
4276 switch (IID) {
4277 case Intrinsic::vector_reduce_add:
4278 case Intrinsic::vector_reduce_mul:
4279 TruncOnly = true;
4280 break;
4281 case Intrinsic::vector_reduce_and:
4282 case Intrinsic::vector_reduce_or:
4283 case Intrinsic::vector_reduce_xor:
4284 break;
4285 default:
4286 return false;
4287 }
4288
4289 unsigned ReductionOpc = getArithmeticReductionInstruction(IID);
4290 Value *ReductionSrc = I.getOperand(0);
4291
4292 Value *Src;
4293 if (!match(ReductionSrc, m_OneUse(m_Trunc(m_Value(Src)))) &&
4294 (TruncOnly || !match(ReductionSrc, m_OneUse(m_ZExtOrSExt(m_Value(Src))))))
4295 return false;
4296
4297 auto CastOpc =
4298 (Instruction::CastOps)cast<Instruction>(ReductionSrc)->getOpcode();
4299
4300 auto *SrcTy = cast<VectorType>(Src->getType());
4301 auto *ReductionSrcTy = cast<VectorType>(ReductionSrc->getType());
4302 Type *ResultTy = I.getType();
4303
4305 ReductionOpc, ReductionSrcTy, std::nullopt, CostKind);
4306 OldCost += TTI.getCastInstrCost(CastOpc, ReductionSrcTy, SrcTy,
4308 cast<CastInst>(ReductionSrc));
4309 InstructionCost NewCost =
4310 TTI.getArithmeticReductionCost(ReductionOpc, SrcTy, std::nullopt,
4311 CostKind) +
4312 TTI.getCastInstrCost(CastOpc, ResultTy, ReductionSrcTy->getScalarType(),
4314
4315 if (OldCost <= NewCost || !NewCost.isValid())
4316 return false;
4317
4318 Value *NewReduction = Builder.CreateIntrinsic(SrcTy->getScalarType(),
4319 II->getIntrinsicID(), {Src});
4320 Value *NewCast = Builder.CreateCast(CastOpc, NewReduction, ResultTy);
4321 replaceValue(I, *NewCast);
4322 return true;
4323}
4324
4325/// Fold:
4326/// icmp pred (reduce.{add,or,and,umax,umin}(signbit_extract(x))), C
4327/// into:
4328/// icmp sgt/slt (reduce.{or,umax,and,umin}(x)), -1/0
4329///
4330/// Sign-bit reductions produce values with known semantics:
4331/// - reduce.{or,umax}: 0 if no element is negative, 1 if any is
4332/// - reduce.{and,umin}: 1 if all elements are negative, 0 if any isn't
4333/// - reduce.add: count of negative elements (0 to NumElts)
4334///
4335/// Both lshr and ashr are supported:
4336/// - lshr produces 0 or 1, so reduce.add range is [0, N]
4337/// - ashr produces 0 or -1, so reduce.add range is [-N, 0]
4338///
4339/// The fold generalizes to multiple source vectors combined with the same
4340/// operation as the reduction. For example:
4341/// reduce.or(or(shr A, shr B)) conceptually extends the vector
4342/// For reduce.add, this changes the count to M*N where M is the number of
4343/// source vectors.
4344///
4345/// We transform to a direct sign check on the original vector using
4346/// reduce.{or,umax} or reduce.{and,umin}.
4347///
4348/// In spirit, it's similar to foldSignBitCheck in InstCombine.
4349bool VectorCombine::foldSignBitReductionCmp(Instruction &I) {
4350 CmpPredicate Pred;
4351 IntrinsicInst *ReduceOp;
4352 const APInt *CmpVal;
4353 if (!match(&I,
4354 m_ICmp(Pred, m_OneUse(m_AnyIntrinsic(ReduceOp)), m_APInt(CmpVal))))
4355 return false;
4356
4357 Intrinsic::ID OrigIID = ReduceOp->getIntrinsicID();
4358 switch (OrigIID) {
4359 case Intrinsic::vector_reduce_or:
4360 case Intrinsic::vector_reduce_umax:
4361 case Intrinsic::vector_reduce_and:
4362 case Intrinsic::vector_reduce_umin:
4363 case Intrinsic::vector_reduce_add:
4364 break;
4365 default:
4366 return false;
4367 }
4368
4369 Value *ReductionSrc = ReduceOp->getArgOperand(0);
4370 auto *VecTy = dyn_cast<FixedVectorType>(ReductionSrc->getType());
4371 if (!VecTy)
4372 return false;
4373
4374 unsigned BitWidth = VecTy->getScalarSizeInBits();
4375 if (BitWidth == 1)
4376 return false;
4377
4378 unsigned NumElts = VecTy->getNumElements();
4379
4380 // Determine the expected tree opcode for multi-vector patterns.
4381 // The tree opcode must match the reduction's underlying operation.
4382 //
4383 // TODO: for pairs of equivalent operators, we should match both,
4384 // not only the most common.
4385 Instruction::BinaryOps TreeOpcode;
4386 switch (OrigIID) {
4387 case Intrinsic::vector_reduce_or:
4388 case Intrinsic::vector_reduce_umax:
4389 TreeOpcode = Instruction::Or;
4390 break;
4391 case Intrinsic::vector_reduce_and:
4392 case Intrinsic::vector_reduce_umin:
4393 TreeOpcode = Instruction::And;
4394 break;
4395 case Intrinsic::vector_reduce_add:
4396 TreeOpcode = Instruction::Add;
4397 break;
4398 default:
4399 llvm_unreachable("Unexpected intrinsic");
4400 }
4401
4402 // Collect sign-bit extraction leaves from an associative tree of TreeOpcode.
4403 // The tree conceptually extends the vector being reduced.
4404 SmallVector<Value *, 8> Worklist;
4405 SmallVector<Value *, 8> Sources; // Original vectors (X in shr X, BW-1)
4406 Worklist.push_back(ReductionSrc);
4407 std::optional<bool> IsAShr;
4408 constexpr unsigned MaxSources = 8;
4409
4410 // Calculate old cost: all shifts + tree ops + reduction
4411 InstructionCost OldCost = TTI.getInstructionCost(ReduceOp, CostKind);
4412
4413 while (!Worklist.empty() && Worklist.size() <= MaxSources &&
4414 Sources.size() <= MaxSources) {
4415 Value *V = Worklist.pop_back_val();
4416
4417 // Try to match sign-bit extraction: shr X, (bitwidth-1)
4418 Value *X;
4419 if (match(V, m_OneUse(m_Shr(m_Value(X), m_SpecificInt(BitWidth - 1))))) {
4420 auto *Shr = cast<Instruction>(V);
4421
4422 // All shifts must be the same type (all lshr or all ashr)
4423 bool ThisIsAShr = Shr->getOpcode() == Instruction::AShr;
4424 if (!IsAShr)
4425 IsAShr = ThisIsAShr;
4426 else if (*IsAShr != ThisIsAShr)
4427 return false;
4428
4429 Sources.push_back(X);
4430
4431 // As part of the fold, we remove all of the shifts, so we need to keep
4432 // track of their costs.
4433 OldCost += TTI.getInstructionCost(Shr, CostKind);
4434
4435 continue;
4436 }
4437
4438 // Try to extend through a tree node of the expected opcode
4439 Value *A, *B;
4440 if (!match(V, m_OneUse(m_BinOp(TreeOpcode, m_Value(A), m_Value(B)))))
4441 return false;
4442
4443 // We are potentially replacing these operations as well, so we add them
4444 // to the costs.
4446
4447 Worklist.push_back(A);
4448 Worklist.push_back(B);
4449 }
4450
4451 // Must have at least one source and not exceed limit
4452 if (Sources.empty() || Sources.size() > MaxSources ||
4453 Worklist.size() > MaxSources || !IsAShr)
4454 return false;
4455
4456 unsigned NumSources = Sources.size();
4457
4458 // For reduce.add, the total count must fit as a signed integer.
4459 // Range is [0, M*N] for lshr or [-M*N, 0] for ashr.
4460 if (OrigIID == Intrinsic::vector_reduce_add &&
4461 !isIntN(BitWidth, NumSources * NumElts))
4462 return false;
4463
4464 // Compute the boundary value when all elements are negative:
4465 // - Per-element contribution: 1 for lshr, -1 for ashr
4466 // - For add: M*N (total elements across all sources); for others: just 1
4467 unsigned Count =
4468 (OrigIID == Intrinsic::vector_reduce_add) ? NumSources * NumElts : 1;
4469 APInt NegativeVal(CmpVal->getBitWidth(), Count);
4470 if (*IsAShr)
4471 NegativeVal.negate();
4472
4473 // Range is [min(0, AllNegVal), max(0, AllNegVal)]
4474 APInt Zero = APInt::getZero(CmpVal->getBitWidth());
4475 APInt RangeLow = APIntOps::smin(Zero, NegativeVal);
4476 APInt RangeHigh = APIntOps::smax(Zero, NegativeVal);
4477
4478 // Determine comparison semantics:
4479 // - IsEq: true for equality test, false for inequality
4480 // - TestsNegative: true if testing against AllNegVal, false for zero
4481 //
4482 // In addition to EQ/NE against 0 or AllNegVal, we support inequalities
4483 // that fold to boundary tests given the narrow value range:
4484 // < RangeHigh -> != RangeHigh
4485 // > RangeHigh-1 -> == RangeHigh
4486 // > RangeLow -> != RangeLow
4487 // < RangeLow+1 -> == RangeLow
4488 //
4489 // For inequalities, we work with signed predicates only. Unsigned predicates
4490 // are canonicalized to signed when the range is non-negative (where they are
4491 // equivalent). When the range includes negative values, unsigned predicates
4492 // would have different semantics due to wrap-around, so we reject them.
4493 if (!ICmpInst::isEquality(Pred) && !ICmpInst::isSigned(Pred)) {
4494 if (RangeLow.isNegative())
4495 return false;
4496 Pred = ICmpInst::getSignedPredicate(Pred);
4497 }
4498
4499 bool IsEq;
4500 bool TestsNegative;
4501 if (ICmpInst::isEquality(Pred)) {
4502 if (CmpVal->isZero()) {
4503 TestsNegative = false;
4504 } else if (*CmpVal == NegativeVal) {
4505 TestsNegative = true;
4506 } else {
4507 return false;
4508 }
4509 IsEq = Pred == ICmpInst::ICMP_EQ;
4510 } else if (Pred == ICmpInst::ICMP_SLT && *CmpVal == RangeHigh) {
4511 IsEq = false;
4512 TestsNegative = (RangeHigh == NegativeVal);
4513 } else if (Pred == ICmpInst::ICMP_SGT && *CmpVal == RangeHigh - 1) {
4514 IsEq = true;
4515 TestsNegative = (RangeHigh == NegativeVal);
4516 } else if (Pred == ICmpInst::ICMP_SGT && *CmpVal == RangeLow) {
4517 IsEq = false;
4518 TestsNegative = (RangeLow == NegativeVal);
4519 } else if (Pred == ICmpInst::ICMP_SLT && *CmpVal == RangeLow + 1) {
4520 IsEq = true;
4521 TestsNegative = (RangeLow == NegativeVal);
4522 } else {
4523 return false;
4524 }
4525
4526 // For this fold we support four types of checks:
4527 //
4528 // 1. All lanes are negative - AllNeg
4529 // 2. All lanes are non-negative - AllNonNeg
4530 // 3. At least one negative lane - AnyNeg
4531 // 4. At least one non-negative lane - AnyNonNeg
4532 //
4533 // For each case, we can generate the following code:
4534 //
4535 // 1. AllNeg - reduce.and/umin(X) < 0
4536 // 2. AllNonNeg - reduce.or/umax(X) > -1
4537 // 3. AnyNeg - reduce.or/umax(X) < 0
4538 // 4. AnyNonNeg - reduce.and/umin(X) > -1
4539 //
4540 // The table below shows the aggregation of all supported cases
4541 // using these four cases.
4542 //
4543 // Reduction | == 0 | != 0 | == MAX | != MAX
4544 // ------------+-----------+-----------+-----------+-----------
4545 // or/umax | AllNonNeg | AnyNeg | AnyNeg | AllNonNeg
4546 // and/umin | AnyNonNeg | AllNeg | AllNeg | AnyNonNeg
4547 // add | AllNonNeg | AnyNeg | AllNeg | AnyNonNeg
4548 //
4549 // NOTE: MAX = 1 for or/and/umax/umin, and the vector size N for add
4550 //
4551 // For easier codegen and check inversion, we use the following encoding:
4552 //
4553 // 1. Bit-3 === requires or/umax (1) or and/umin (0) check
4554 // 2. Bit-2 === checks < 0 (1) or > -1 (0)
4555 // 3. Bit-1 === universal (1) or existential (0) check
4556 //
4557 // AnyNeg = 0b110: uses or/umax, checks negative, any-check
4558 // AllNonNeg = 0b101: uses or/umax, checks non-neg, all-check
4559 // AnyNonNeg = 0b000: uses and/umin, checks non-neg, any-check
4560 // AllNeg = 0b011: uses and/umin, checks negative, all-check
4561 //
4562 // XOR with 0b011 inverts the check (swaps all/any and neg/non-neg).
4563 //
4564 enum CheckKind : unsigned {
4565 AnyNonNeg = 0b000,
4566 AllNeg = 0b011,
4567 AllNonNeg = 0b101,
4568 AnyNeg = 0b110,
4569 };
4570 // Return true if we fold this check into or/umax and false for and/umin
4571 auto RequiresOr = [](CheckKind C) -> bool { return C & 0b100; };
4572 // Return true if we should check if result is negative and false otherwise
4573 auto IsNegativeCheck = [](CheckKind C) -> bool { return C & 0b010; };
4574 // Logically invert the check
4575 auto Invert = [](CheckKind C) { return CheckKind(C ^ 0b011); };
4576
4577 CheckKind Base;
4578 switch (OrigIID) {
4579 case Intrinsic::vector_reduce_or:
4580 case Intrinsic::vector_reduce_umax:
4581 Base = TestsNegative ? AnyNeg : AllNonNeg;
4582 break;
4583 case Intrinsic::vector_reduce_and:
4584 case Intrinsic::vector_reduce_umin:
4585 Base = TestsNegative ? AllNeg : AnyNonNeg;
4586 break;
4587 case Intrinsic::vector_reduce_add:
4588 Base = TestsNegative ? AllNeg : AllNonNeg;
4589 break;
4590 default:
4591 llvm_unreachable("Unexpected intrinsic");
4592 }
4593
4594 CheckKind Check = IsEq ? Base : Invert(Base);
4595
4596 auto PickCheaper = [&](Intrinsic::ID Arith, Intrinsic::ID MinMax) {
4597 InstructionCost ArithCost =
4599 VecTy, std::nullopt, CostKind);
4600 InstructionCost MinMaxCost =
4602 FastMathFlags(), CostKind);
4603 return ArithCost <= MinMaxCost ? std::make_pair(Arith, ArithCost)
4604 : std::make_pair(MinMax, MinMaxCost);
4605 };
4606
4607 // Choose output reduction based on encoding's MSB
4608 auto [NewIID, NewCost] = RequiresOr(Check)
4609 ? PickCheaper(Intrinsic::vector_reduce_or,
4610 Intrinsic::vector_reduce_umax)
4611 : PickCheaper(Intrinsic::vector_reduce_and,
4612 Intrinsic::vector_reduce_umin);
4613
4614 // Add cost of combining multiple sources with or/and
4615 if (NumSources > 1) {
4616 unsigned CombineOpc =
4617 RequiresOr(Check) ? Instruction::Or : Instruction::And;
4618 NewCost += TTI.getArithmeticInstrCost(CombineOpc, VecTy, CostKind) *
4619 (NumSources - 1);
4620 }
4621
4622 LLVM_DEBUG(dbgs() << "Found sign-bit reduction cmp: " << I << "\n OldCost: "
4623 << OldCost << " vs NewCost: " << NewCost << "\n");
4624
4625 if (NewCost > OldCost)
4626 return false;
4627
4628 // Generate the combined input and reduction
4629 Builder.SetInsertPoint(&I);
4630 Type *ScalarTy = VecTy->getScalarType();
4631
4632 Value *Input;
4633 if (NumSources == 1) {
4634 Input = Sources[0];
4635 } else {
4636 // Combine sources with or/and based on check type
4637 Input = RequiresOr(Check) ? Builder.CreateOr(Sources)
4638 : Builder.CreateAnd(Sources);
4639 }
4640
4641 Value *NewReduce = Builder.CreateIntrinsic(ScalarTy, NewIID, {Input});
4642 Value *NewCmp = IsNegativeCheck(Check) ? Builder.CreateIsNeg(NewReduce)
4643 : Builder.CreateIsNotNeg(NewReduce);
4644 replaceValue(I, *NewCmp);
4645 return true;
4646}
4647
4648/// vector.reduce.OP f(X_i) == 0 -> vector.reduce.OP X_i == 0
4649///
4650/// We can prove it for cases when:
4651///
4652/// 1. OP X_i == 0 <=> \forall i \in [1, N] X_i == 0
4653/// 1'. OP X_i == 0 <=> \exists j \in [1, N] X_j == 0
4654/// 2. f(x) == 0 <=> x == 0
4655///
4656/// From 1 and 2 (or 1' and 2), we can infer that
4657///
4658/// OP f(X_i) == 0 <=> OP X_i == 0.
4659///
4660/// (1)
4661/// OP f(X_i) == 0 <=> \forall i \in [1, N] f(X_i) == 0
4662/// (2)
4663/// <=> \forall i \in [1, N] X_i == 0
4664/// (1)
4665/// <=> OP(X_i) == 0
4666///
4667/// For some of the OP's and f's, we need to have domain constraints on X
4668/// to ensure properties 1 (or 1') and 2.
4669bool VectorCombine::foldICmpEqZeroVectorReduce(Instruction &I) {
4670 CmpPredicate Pred;
4671 Value *Op;
4672 if (!match(&I, m_ICmp(Pred, m_Value(Op), m_Zero())) ||
4673 !ICmpInst::isEquality(Pred))
4674 return false;
4675
4676 auto *II = dyn_cast<IntrinsicInst>(Op);
4677 if (!II)
4678 return false;
4679
4680 switch (II->getIntrinsicID()) {
4681 case Intrinsic::vector_reduce_add:
4682 case Intrinsic::vector_reduce_or:
4683 case Intrinsic::vector_reduce_umin:
4684 case Intrinsic::vector_reduce_umax:
4685 case Intrinsic::vector_reduce_smin:
4686 case Intrinsic::vector_reduce_smax:
4687 break;
4688 default:
4689 return false;
4690 }
4691
4692 Value *InnerOp = II->getArgOperand(0);
4693
4694 // TODO: fixed vector type might be too restrictive
4695 if (!II->hasOneUse() || !isa<FixedVectorType>(InnerOp->getType()))
4696 return false;
4697
4698 Value *X = nullptr;
4699
4700 // Check for zero-preserving operations where f(x) = 0 <=> x = 0
4701 //
4702 // 1. f(x) = shl nuw x, y for arbitrary y
4703 // 2. f(x) = mul nuw x, c for defined c != 0
4704 // 3. f(x) = zext x
4705 // 4. f(x) = sext x
4706 // 5. f(x) = neg x
4707 //
4708 if (!(match(InnerOp, m_NUWShl(m_Value(X), m_Value())) || // Case 1
4709 match(InnerOp, m_NUWMul(m_Value(X), m_NonZeroInt())) || // Case 2
4710 match(InnerOp, m_ZExt(m_Value(X))) || // Case 3
4711 match(InnerOp, m_SExt(m_Value(X))) || // Case 4
4712 match(InnerOp, m_Neg(m_Value(X))) // Case 5
4713 ))
4714 return false;
4715
4716 SimplifyQuery S = SQ.getWithInstruction(&I);
4717 auto *XTy = cast<FixedVectorType>(X->getType());
4718
4719 // Check for domain constraints for all supported reductions.
4720 //
4721 // a. OR X_i - has property 1 for every X
4722 // b. UMAX X_i - has property 1 for every X
4723 // c. UMIN X_i - has property 1' for every X
4724 // d. SMAX X_i - has property 1 for X >= 0
4725 // e. SMIN X_i - has property 1' for X >= 0
4726 // f. ADD X_i - has property 1 for X >= 0 && ADD X_i doesn't sign wrap
4727 //
4728 // In order for the proof to work, we need 1 (or 1') to be true for both
4729 // OP f(X_i) and OP X_i and that's why below we check constraints twice.
4730 //
4731 // NOTE: ADD X_i holds property 1 for a mirror case as well, i.e. when
4732 // X <= 0 && ADD X_i doesn't sign wrap. However, due to the nature
4733 // of known bits, we can't reasonably hold knowledge of "either 0
4734 // or negative".
4735 switch (II->getIntrinsicID()) {
4736 case Intrinsic::vector_reduce_add: {
4737 // We need to check that both X_i and f(X_i) have enough leading
4738 // zeros to not overflow.
4739 KnownBits KnownX = computeKnownBits(X, S);
4740 KnownBits KnownFX = computeKnownBits(InnerOp, S);
4741 unsigned NumElems = XTy->getNumElements();
4742 // Adding N elements loses at most ceil(log2(N)) leading bits.
4743 unsigned LostBits = Log2_32_Ceil(NumElems);
4744 unsigned LeadingZerosX = KnownX.countMinLeadingZeros();
4745 unsigned LeadingZerosFX = KnownFX.countMinLeadingZeros();
4746 // Need at least one leading zero left after summation to ensure no overflow
4747 if (LeadingZerosX <= LostBits || LeadingZerosFX <= LostBits)
4748 return false;
4749
4750 // We are not checking whether X or f(X) are positive explicitly because
4751 // we implicitly checked for it when we checked if both cases have enough
4752 // leading zeros to not wrap addition.
4753 break;
4754 }
4755 case Intrinsic::vector_reduce_smin:
4756 case Intrinsic::vector_reduce_smax:
4757 // Check whether X >= 0 and f(X) >= 0
4758 if (!isKnownNonNegative(InnerOp, S) || !isKnownNonNegative(X, S))
4759 return false;
4760
4761 break;
4762 default:
4763 break;
4764 };
4765
4766 LLVM_DEBUG(dbgs() << "Found a reduction to 0 comparison with removable op: "
4767 << *II << "\n");
4768
4769 // For zext/sext, check if the transform is profitable using cost model.
4770 // For other operations (shl, mul, neg), we're removing an instruction
4771 // while keeping the same reduction type, so it's always profitable.
4772 if (isa<ZExtInst>(InnerOp) || isa<SExtInst>(InnerOp)) {
4773 auto *FXTy = cast<FixedVectorType>(InnerOp->getType());
4774 Intrinsic::ID IID = II->getIntrinsicID();
4775
4777 cast<CastInst>(InnerOp)->getOpcode(), FXTy, XTy,
4779
4780 InstructionCost OldReduceCost, NewReduceCost;
4781 switch (IID) {
4782 case Intrinsic::vector_reduce_add:
4783 case Intrinsic::vector_reduce_or:
4784 OldReduceCost = TTI.getArithmeticReductionCost(
4785 getArithmeticReductionInstruction(IID), FXTy, std::nullopt, CostKind);
4786 NewReduceCost = TTI.getArithmeticReductionCost(
4787 getArithmeticReductionInstruction(IID), XTy, std::nullopt, CostKind);
4788 break;
4789 case Intrinsic::vector_reduce_umin:
4790 case Intrinsic::vector_reduce_umax:
4791 case Intrinsic::vector_reduce_smin:
4792 case Intrinsic::vector_reduce_smax:
4793 OldReduceCost = TTI.getMinMaxReductionCost(
4794 getMinMaxReductionIntrinsicOp(IID), FXTy, FastMathFlags(), CostKind);
4795 NewReduceCost = TTI.getMinMaxReductionCost(
4796 getMinMaxReductionIntrinsicOp(IID), XTy, FastMathFlags(), CostKind);
4797 break;
4798 default:
4799 llvm_unreachable("Unexpected reduction");
4800 }
4801
4802 InstructionCost OldCost = OldReduceCost + ExtCost;
4803 InstructionCost NewCost =
4804 NewReduceCost + (InnerOp->hasOneUse() ? 0 : ExtCost);
4805
4806 LLVM_DEBUG(dbgs() << "Found a removable extension before reduction: "
4807 << *InnerOp << "\n OldCost: " << OldCost
4808 << " vs NewCost: " << NewCost << "\n");
4809
4810 // We consider transformation to still be potentially beneficial even
4811 // when the costs are the same because we might remove a use from f(X)
4812 // and unlock other optimizations. Equal costs would just mean that we
4813 // didn't make it worse in the worst case.
4814 if (NewCost > OldCost)
4815 return false;
4816 }
4817
4818 // Since we support zext and sext as f, we might change the scalar type
4819 // of the intrinsic.
4820 Type *Ty = XTy->getScalarType();
4821 Value *NewReduce = Builder.CreateIntrinsic(Ty, II->getIntrinsicID(), {X});
4822 Value *NewCmp =
4823 Builder.CreateICmp(Pred, NewReduce, ConstantInt::getNullValue(Ty));
4824 replaceValue(I, *NewCmp);
4825 return true;
4826}
4827
4828/// Fold comparisons of reduce.or/reduce.and with reduce.umax/reduce.umin
4829/// based on cost, preserving the comparison semantics.
4830///
4831/// We use two fundamental properties for each pair:
4832///
4833/// 1. or(X) == 0 <=> umax(X) == 0
4834/// 2. or(X) == 1 <=> umax(X) == 1
4835/// 3. sign(or(X)) == sign(umax(X))
4836///
4837/// 1. and(X) == -1 <=> umin(X) == -1
4838/// 2. and(X) == -2 <=> umin(X) == -2
4839/// 3. sign(and(X)) == sign(umin(X))
4840///
4841/// From these we can infer the following transformations:
4842/// a. or(X) ==/!= 0 <-> umax(X) ==/!= 0
4843/// b. or(X) s< 0 <-> umax(X) s< 0
4844/// c. or(X) s> -1 <-> umax(X) s> -1
4845/// d. or(X) s< 1 <-> umax(X) s< 1
4846/// e. or(X) ==/!= 1 <-> umax(X) ==/!= 1
4847/// f. or(X) s< 2 <-> umax(X) s< 2
4848/// g. and(X) ==/!= -1 <-> umin(X) ==/!= -1
4849/// h. and(X) s< 0 <-> umin(X) s< 0
4850/// i. and(X) s> -1 <-> umin(X) s> -1
4851/// j. and(X) s> -2 <-> umin(X) s> -2
4852/// k. and(X) ==/!= -2 <-> umin(X) ==/!= -2
4853/// l. and(X) s> -3 <-> umin(X) s> -3
4854///
4855bool VectorCombine::foldEquivalentReductionCmp(Instruction &I) {
4856 CmpPredicate Pred;
4857 Value *ReduceOp;
4858 const APInt *CmpVal;
4859 if (!match(&I, m_ICmp(Pred, m_Value(ReduceOp), m_APInt(CmpVal))))
4860 return false;
4861
4862 auto *II = dyn_cast<IntrinsicInst>(ReduceOp);
4863 if (!II || !II->hasOneUse())
4864 return false;
4865
4866 const auto IsValidOrUmaxCmp = [&]() {
4867 // or === umax for i1
4868 if (CmpVal->getBitWidth() == 1)
4869 return true;
4870
4871 // Cases a and e
4872 bool IsEquality =
4873 (CmpVal->isZero() || CmpVal->isOne()) && ICmpInst::isEquality(Pred);
4874 // Case c
4875 bool IsPositive = CmpVal->isAllOnes() && Pred == ICmpInst::ICMP_SGT;
4876 // Cases b, d, and f
4877 bool IsNegative = (CmpVal->isZero() || CmpVal->isOne() || *CmpVal == 2) &&
4878 Pred == ICmpInst::ICMP_SLT;
4879 return IsEquality || IsPositive || IsNegative;
4880 };
4881
4882 const auto IsValidAndUminCmp = [&]() {
4883 // and === umin for i1
4884 if (CmpVal->getBitWidth() == 1)
4885 return true;
4886
4887 const auto LeadingOnes = CmpVal->countl_one();
4888
4889 // Cases g and k
4890 bool IsEquality =
4891 (CmpVal->isAllOnes() || LeadingOnes + 1 == CmpVal->getBitWidth()) &&
4893 // Case h
4894 bool IsNegative = CmpVal->isZero() && Pred == ICmpInst::ICMP_SLT;
4895 // Cases i, j, and l
4896 bool IsPositive =
4897 // if the number has at least N - 2 leading ones
4898 // and the two LSBs are:
4899 // - 1 x 1 -> -1
4900 // - 1 x 0 -> -2
4901 // - 0 x 1 -> -3
4902 LeadingOnes + 2 >= CmpVal->getBitWidth() &&
4903 ((*CmpVal)[0] || (*CmpVal)[1]) && Pred == ICmpInst::ICMP_SGT;
4904 return IsEquality || IsNegative || IsPositive;
4905 };
4906
4907 Intrinsic::ID OriginalIID = II->getIntrinsicID();
4908 Intrinsic::ID AlternativeIID;
4909
4910 // Check if this is a valid comparison pattern and determine the alternate
4911 // reduction intrinsic.
4912 switch (OriginalIID) {
4913 case Intrinsic::vector_reduce_or:
4914 if (!IsValidOrUmaxCmp())
4915 return false;
4916 AlternativeIID = Intrinsic::vector_reduce_umax;
4917 break;
4918 case Intrinsic::vector_reduce_umax:
4919 if (!IsValidOrUmaxCmp())
4920 return false;
4921 AlternativeIID = Intrinsic::vector_reduce_or;
4922 break;
4923 case Intrinsic::vector_reduce_and:
4924 if (!IsValidAndUminCmp())
4925 return false;
4926 AlternativeIID = Intrinsic::vector_reduce_umin;
4927 break;
4928 case Intrinsic::vector_reduce_umin:
4929 if (!IsValidAndUminCmp())
4930 return false;
4931 AlternativeIID = Intrinsic::vector_reduce_and;
4932 break;
4933 default:
4934 return false;
4935 }
4936
4937 Value *X = II->getArgOperand(0);
4938 auto *VecTy = dyn_cast<FixedVectorType>(X->getType());
4939 if (!VecTy)
4940 return false;
4941
4942 const auto GetReductionCost = [&](Intrinsic::ID IID) -> InstructionCost {
4943 unsigned ReductionOpc = getArithmeticReductionInstruction(IID);
4944 if (ReductionOpc != Instruction::ICmp)
4945 return TTI.getArithmeticReductionCost(ReductionOpc, VecTy, std::nullopt,
4946 CostKind);
4948 FastMathFlags(), CostKind);
4949 };
4950
4951 InstructionCost OrigCost = GetReductionCost(OriginalIID);
4952 InstructionCost AltCost = GetReductionCost(AlternativeIID);
4953
4954 LLVM_DEBUG(dbgs() << "Found equivalent reduction cmp: " << I
4955 << "\n OrigCost: " << OrigCost
4956 << " vs AltCost: " << AltCost << "\n");
4957
4958 if (AltCost >= OrigCost)
4959 return false;
4960
4961 Builder.SetInsertPoint(&I);
4962 Type *ScalarTy = VecTy->getScalarType();
4963 Value *NewReduce = Builder.CreateIntrinsic(ScalarTy, AlternativeIID, {X});
4964 Value *NewCmp =
4965 Builder.CreateICmp(Pred, NewReduce, ConstantInt::get(ScalarTy, *CmpVal));
4966
4967 replaceValue(I, *NewCmp);
4968 return true;
4969}
4970
4971/// Used by foldReduceAddCmpZero to check if we can prove that a value is
4972/// non-positive.
4973/// KnownBits cannot see sext <? x i1> as non-positive: each top bit equals a
4974/// single unknown input bit, which a per-bit lattice cannot track. The fold's
4975/// target shape is popcount-style sums of <N x i1> valid/invalid masks (e.g.
4976/// ray-intersection hits) tested for any-hit.
4977/// Previous attempts to approximate the known bits of such expressions were
4978/// using a fully recursive value tracking approach to infer a constant range
4979/// but ultimately turned to be too expensive in compile time.
4980static bool isKnownNonPositive(const Value *V, const SimplifyQuery &SQ,
4981 unsigned Depth = 0) {
4982 constexpr unsigned MaxLocalDepth = 2;
4983 if (Depth > MaxLocalDepth)
4984 return false;
4985
4986 auto NumSignBits = [&](const Value *X) {
4987 return ComputeNumSignBits(X, SQ.DL, SQ.AC, SQ.CxtI, SQ.DT);
4988 };
4989 if (NumSignBits(V) == V->getType()->getScalarSizeInBits())
4990 return true;
4991
4992 Value *A, *B;
4993 if (match(V, m_Add(m_Value(A), m_Value(B))))
4994 return NumSignBits(A) >= 2 && NumSignBits(B) >= 2 &&
4995 isKnownNonPositive(A, SQ, Depth + 1) &&
4996 isKnownNonPositive(B, SQ, Depth + 1);
4997
4998 return computeKnownBits(V, SQ).isNonPositive();
4999}
5000
5001/// Fold (icmp pred (reduce.add X), 0) to (icmp pred' (reduce.or X), 0) when X
5002/// has lanes known to all be non-negative or all non-positive, so that
5003/// sum == 0 iff every lane is 0. Falls back to reduce.umax if reduce.or is
5004/// more expensive on the target.
5005bool VectorCombine::foldReduceAddCmpZero(Instruction &I) {
5006 CmpPredicate Pred;
5007 Value *Vec;
5008 if (!match(&I, m_ICmp(Pred,
5010 m_Value(Vec))),
5011 m_Zero())))
5012 return false;
5013
5014 auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType());
5015 if (!VecTy || VecTy->getNumElements() < 2)
5016 return false;
5017
5018 SimplifyQuery Q = SQ.getWithInstruction(&I);
5019 bool IsNonNegative = isKnownNonNegative(Vec, Q);
5020 bool IsNonPositive = !IsNonNegative && isKnownNonPositive(Vec, Q);
5021 if (!IsNonNegative && !IsNonPositive)
5022 return false;
5023
5024 // Summing NumElts lanes can consume up to log2(NumElts) sign bits. Require
5025 // strictly more headroom than that so the sum cannot wrap to zero.
5026 unsigned NumElts = VecTy->getNumElements();
5027 unsigned NumSignBits = ComputeNumSignBits(Vec, *DL, SQ.AC, &I, &DT);
5028 if (Log2_32(NumElts) >= NumSignBits)
5029 return false;
5030
5031 ICmpInst::Predicate NewPred;
5032 switch (Pred) {
5033 case ICmpInst::ICMP_EQ:
5034 case ICmpInst::ICMP_ULE:
5035 case ICmpInst::ICMP_SLE:
5036 case ICmpInst::ICMP_SGE:
5037 NewPred = ICmpInst::ICMP_EQ;
5038 break;
5039 case ICmpInst::ICMP_NE:
5040 case ICmpInst::ICMP_UGT:
5041 case ICmpInst::ICMP_SGT:
5042 case ICmpInst::ICMP_SLT:
5043 NewPred = ICmpInst::ICMP_NE;
5044 break;
5045 default:
5046 return false;
5047 }
5048
5049 // SGT and SLE on a non-positive tree, and SLT and SGE on a non-negative
5050 // tree, are tautologies (always true or always false). Leave those to
5051 // InstCombine rather than mapping them here. Remaining signed inequalities
5052 // also need one extra sign bit so the sum cannot flip sign.
5053 if (!IsNonNegative &&
5054 (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLE))
5055 return false;
5056 if (!IsNonPositive &&
5057 (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE))
5058 return false;
5059 if ((Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLE ||
5060 Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) &&
5061 Log2_32(NumElts) >= NumSignBits - 1)
5062 return false;
5063
5065 Instruction::Add, VecTy, std::nullopt, CostKind);
5067 Instruction::Or, VecTy, std::nullopt, CostKind);
5069 Intrinsic::umax, VecTy, FastMathFlags(), CostKind);
5070 if (!OrCost.isValid() && !UmaxCost.isValid())
5071 return false;
5072 bool UseOr = OrCost.isValid() && (!UmaxCost.isValid() || OrCost <= UmaxCost);
5073 InstructionCost AltCost = UseOr ? OrCost : UmaxCost;
5074 if (AltCost > OrigCost)
5075 return false;
5076
5077 Builder.SetInsertPoint(&I);
5078 Value *NewReduce = UseOr ? Builder.CreateOrReduce(Vec)
5079 : Builder.CreateIntrinsic(
5080 Intrinsic::vector_reduce_umax, {VecTy}, {Vec});
5081 Worklist.pushValue(NewReduce);
5082 Value *NewCmp = Builder.CreateICmp(
5083 NewPred, NewReduce, ConstantInt::getNullValue(VecTy->getScalarType()));
5084 replaceValue(I, *NewCmp);
5085 return true;
5086}
5087
5088/// Returns true if this ShuffleVectorInst eventually feeds into a
5089/// vector reduction intrinsic (e.g., vector_reduce_add) by only following
5090/// chains of shuffles and binary operators (in any combination/order).
5091/// The search does not go deeper than the given Depth.
5093 constexpr unsigned MaxVisited = 32;
5096 bool FoundReduction = false;
5097
5098 WorkList.push_back(SVI);
5099 while (!WorkList.empty()) {
5100 Instruction *I = WorkList.pop_back_val();
5101 for (User *U : I->users()) {
5102 auto *UI = cast<Instruction>(U);
5103 if (!UI || !Visited.insert(UI).second)
5104 continue;
5105 if (Visited.size() > MaxVisited)
5106 return false;
5107 if (auto *II = dyn_cast<IntrinsicInst>(UI)) {
5108 // More than one reduction reached
5109 if (FoundReduction)
5110 return false;
5111 switch (II->getIntrinsicID()) {
5112 case Intrinsic::vector_reduce_add:
5113 case Intrinsic::vector_reduce_mul:
5114 case Intrinsic::vector_reduce_and:
5115 case Intrinsic::vector_reduce_or:
5116 case Intrinsic::vector_reduce_xor:
5117 case Intrinsic::vector_reduce_smin:
5118 case Intrinsic::vector_reduce_smax:
5119 case Intrinsic::vector_reduce_umin:
5120 case Intrinsic::vector_reduce_umax:
5121 FoundReduction = true;
5122 continue;
5123 default:
5124 return false;
5125 }
5126 }
5127
5129 return false;
5130
5131 WorkList.emplace_back(UI);
5132 }
5133 }
5134 return FoundReduction;
5135}
5136
5137/// This method looks for groups of shuffles acting on binops, of the form:
5138/// %x = shuffle ...
5139/// %y = shuffle ...
5140/// %a = binop %x, %y
5141/// %b = binop %x, %y
5142/// shuffle %a, %b, selectmask
5143/// We may, especially if the shuffle is wider than legal, be able to convert
5144/// the shuffle to a form where only parts of a and b need to be computed. On
5145/// architectures with no obvious "select" shuffle, this can reduce the total
5146/// number of operations if the target reports them as cheaper.
5147bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
5148 auto *SVI = cast<ShuffleVectorInst>(&I);
5149 auto *VT = cast<FixedVectorType>(I.getType());
5150 auto *Op0 = dyn_cast<Instruction>(SVI->getOperand(0));
5151 auto *Op1 = dyn_cast<Instruction>(SVI->getOperand(1));
5152 if (!Op0 || !Op1 || Op0 == Op1 || !Op0->isBinaryOp() || !Op1->isBinaryOp() ||
5153 VT != Op0->getType())
5154 return false;
5155
5156 auto *SVI0A = dyn_cast<Instruction>(Op0->getOperand(0));
5157 auto *SVI0B = dyn_cast<Instruction>(Op0->getOperand(1));
5158 auto *SVI1A = dyn_cast<Instruction>(Op1->getOperand(0));
5159 auto *SVI1B = dyn_cast<Instruction>(Op1->getOperand(1));
5160 SmallPtrSet<Instruction *, 4> InputShuffles({SVI0A, SVI0B, SVI1A, SVI1B});
5161 auto checkSVNonOpUses = [&](Instruction *I) {
5162 if (!I || I->getOperand(0)->getType() != VT)
5163 return true;
5164 return any_of(I->users(), [&](User *U) {
5165 return U != Op0 && U != Op1 &&
5166 !(isa<ShuffleVectorInst>(U) &&
5167 (InputShuffles.contains(cast<Instruction>(U)) ||
5168 isInstructionTriviallyDead(cast<Instruction>(U))));
5169 });
5170 };
5171 if (checkSVNonOpUses(SVI0A) || checkSVNonOpUses(SVI0B) ||
5172 checkSVNonOpUses(SVI1A) || checkSVNonOpUses(SVI1B))
5173 return false;
5174
5175 // Collect all the uses that are shuffles that we can transform together. We
5176 // may not have a single shuffle, but a group that can all be transformed
5177 // together profitably.
5179 auto collectShuffles = [&](Instruction *I) {
5180 for (auto *U : I->users()) {
5181 auto *SV = dyn_cast<ShuffleVectorInst>(U);
5182 if (!SV || SV->getType() != VT)
5183 return false;
5184 if ((SV->getOperand(0) != Op0 && SV->getOperand(0) != Op1) ||
5185 (SV->getOperand(1) != Op0 && SV->getOperand(1) != Op1))
5186 return false;
5187 if (!llvm::is_contained(Shuffles, SV))
5188 Shuffles.push_back(SV);
5189 }
5190 return true;
5191 };
5192 if (!collectShuffles(Op0) || !collectShuffles(Op1))
5193 return false;
5194 // From a reduction, we need to be processing a single shuffle, otherwise the
5195 // other uses will not be lane-invariant.
5196 if (FromReduction && Shuffles.size() > 1)
5197 return false;
5198
5199 // Add any shuffle uses for the shuffles we have found, to include them in our
5200 // cost calculations.
5201 if (!FromReduction) {
5202 for (size_t Idx = 0, E = Shuffles.size(); Idx != E; ++Idx) {
5203 for (auto *U : Shuffles[Idx]->users()) {
5204 ShuffleVectorInst *SSV = dyn_cast<ShuffleVectorInst>(U);
5205 if (SSV && isa<UndefValue>(SSV->getOperand(1)) && SSV->getType() == VT)
5206 Shuffles.push_back(SSV);
5207 }
5208 }
5209 }
5210
5211 // For each of the output shuffles, we try to sort all the first vector
5212 // elements to the beginning, followed by the second array elements at the
5213 // end. If the binops are legalized to smaller vectors, this may reduce total
5214 // number of binops. We compute the ReconstructMask mask needed to convert
5215 // back to the original lane order.
5217 SmallVector<SmallVector<int>> OrigReconstructMasks;
5218 int MaxV1Elt = 0, MaxV2Elt = 0;
5219 unsigned NumElts = VT->getNumElements();
5220 for (ShuffleVectorInst *SVN : Shuffles) {
5221 SmallVector<int> Mask;
5222 SVN->getShuffleMask(Mask);
5223
5224 // Check the operands are the same as the original, or reversed (in which
5225 // case we need to commute the mask).
5226 Value *SVOp0 = SVN->getOperand(0);
5227 Value *SVOp1 = SVN->getOperand(1);
5228 if (isa<UndefValue>(SVOp1)) {
5229 auto *SSV = cast<ShuffleVectorInst>(SVOp0);
5230 SVOp0 = SSV->getOperand(0);
5231 SVOp1 = SSV->getOperand(1);
5232 for (int &Elem : Mask) {
5233 if (Elem >= static_cast<int>(SSV->getShuffleMask().size()))
5234 return false;
5235 Elem = Elem < 0 ? Elem : SSV->getMaskValue(Elem);
5236 }
5237 }
5238 if (SVOp0 == Op1 && SVOp1 == Op0) {
5239 std::swap(SVOp0, SVOp1);
5241 }
5242 if (SVOp0 != Op0 || SVOp1 != Op1)
5243 return false;
5244
5245 // Calculate the reconstruction mask for this shuffle, as the mask needed to
5246 // take the packed values from Op0/Op1 and reconstructing to the original
5247 // order.
5248 SmallVector<int> ReconstructMask;
5249 for (unsigned I = 0; I < Mask.size(); I++) {
5250 if (Mask[I] < 0) {
5251 ReconstructMask.push_back(-1);
5252 } else if (Mask[I] < static_cast<int>(NumElts)) {
5253 MaxV1Elt = std::max(MaxV1Elt, Mask[I]);
5254 auto It = find_if(V1, [&](const std::pair<int, int> &A) {
5255 return Mask[I] == A.first;
5256 });
5257 if (It != V1.end())
5258 ReconstructMask.push_back(It - V1.begin());
5259 else {
5260 ReconstructMask.push_back(V1.size());
5261 V1.emplace_back(Mask[I], V1.size());
5262 }
5263 } else {
5264 MaxV2Elt = std::max<int>(MaxV2Elt, Mask[I] - NumElts);
5265 auto It = find_if(V2, [&](const std::pair<int, int> &A) {
5266 return Mask[I] - static_cast<int>(NumElts) == A.first;
5267 });
5268 if (It != V2.end())
5269 ReconstructMask.push_back(NumElts + It - V2.begin());
5270 else {
5271 ReconstructMask.push_back(NumElts + V2.size());
5272 V2.emplace_back(Mask[I] - NumElts, NumElts + V2.size());
5273 }
5274 }
5275 }
5276
5277 // For reductions, we know that the lane ordering out doesn't alter the
5278 // result. In-order can help simplify the shuffle away.
5279 if (FromReduction)
5280 sort(ReconstructMask);
5281 OrigReconstructMasks.push_back(std::move(ReconstructMask));
5282 }
5283
5284 // If the Maximum element used from V1 and V2 are not larger than the new
5285 // vectors, the vectors are already packes and performing the optimization
5286 // again will likely not help any further. This also prevents us from getting
5287 // stuck in a cycle in case the costs do not also rule it out.
5288 if (V1.empty() || V2.empty() ||
5289 (MaxV1Elt == static_cast<int>(V1.size()) - 1 &&
5290 MaxV2Elt == static_cast<int>(V2.size()) - 1))
5291 return false;
5292
5293 // GetBaseMaskValue takes one of the inputs, which may either be a shuffle, a
5294 // shuffle of another shuffle, or not a shuffle (that is treated like a
5295 // identity shuffle).
5296 auto GetBaseMaskValue = [&](Instruction *I, int M) {
5297 auto *SV = dyn_cast<ShuffleVectorInst>(I);
5298 if (!SV)
5299 return M;
5300 if (isa<UndefValue>(SV->getOperand(1)))
5301 if (auto *SSV = dyn_cast<ShuffleVectorInst>(SV->getOperand(0)))
5302 if (InputShuffles.contains(SSV))
5303 return SSV->getMaskValue(SV->getMaskValue(M));
5304 return SV->getMaskValue(M);
5305 };
5306
5307 // Attempt to sort the inputs my ascending mask values to make simpler input
5308 // shuffles and push complex shuffles down to the uses. We sort on the first
5309 // of the two input shuffle orders, to try and get at least one input into a
5310 // nice order.
5311 auto SortBase = [&](Instruction *A, std::pair<int, int> X,
5312 std::pair<int, int> Y) {
5313 int MXA = GetBaseMaskValue(A, X.first);
5314 int MYA = GetBaseMaskValue(A, Y.first);
5315 return MXA < MYA;
5316 };
5317 stable_sort(V1, [&](std::pair<int, int> A, std::pair<int, int> B) {
5318 return SortBase(SVI0A, A, B);
5319 });
5320 stable_sort(V2, [&](std::pair<int, int> A, std::pair<int, int> B) {
5321 return SortBase(SVI1A, A, B);
5322 });
5323 // Calculate our ReconstructMasks from the OrigReconstructMasks and the
5324 // modified order of the input shuffles.
5325 SmallVector<SmallVector<int>> ReconstructMasks;
5326 for (const auto &Mask : OrigReconstructMasks) {
5327 SmallVector<int> ReconstructMask;
5328 for (int M : Mask) {
5329 auto FindIndex = [](const SmallVector<std::pair<int, int>> &V, int M) {
5330 auto It = find_if(V, [M](auto A) { return A.second == M; });
5331 assert(It != V.end() && "Expected all entries in Mask");
5332 return std::distance(V.begin(), It);
5333 };
5334 if (M < 0)
5335 ReconstructMask.push_back(-1);
5336 else if (M < static_cast<int>(NumElts)) {
5337 ReconstructMask.push_back(FindIndex(V1, M));
5338 } else {
5339 ReconstructMask.push_back(NumElts + FindIndex(V2, M));
5340 }
5341 }
5342 ReconstructMasks.push_back(std::move(ReconstructMask));
5343 }
5344
5345 // Calculate the masks needed for the new input shuffles, which get padded
5346 // with undef
5347 SmallVector<int> V1A, V1B, V2A, V2B;
5348 for (unsigned I = 0; I < V1.size(); I++) {
5349 V1A.push_back(GetBaseMaskValue(SVI0A, V1[I].first));
5350 V1B.push_back(GetBaseMaskValue(SVI0B, V1[I].first));
5351 }
5352 for (unsigned I = 0; I < V2.size(); I++) {
5353 V2A.push_back(GetBaseMaskValue(SVI1A, V2[I].first));
5354 V2B.push_back(GetBaseMaskValue(SVI1B, V2[I].first));
5355 }
5356 while (V1A.size() < NumElts) {
5359 }
5360 while (V2A.size() < NumElts) {
5363 }
5364
5365 auto AddShuffleCost = [&](InstructionCost C, Instruction *I) {
5366 auto *SV = dyn_cast<ShuffleVectorInst>(I);
5367 if (!SV)
5368 return C;
5369 return C + TTI.getShuffleCost(isa<UndefValue>(SV->getOperand(1))
5372 VT, VT, SV->getShuffleMask(), CostKind);
5373 };
5374 auto AddShuffleMaskCost = [&](InstructionCost C, ArrayRef<int> Mask) {
5375 return C +
5377 };
5378
5379 unsigned ElementSize = VT->getElementType()->getPrimitiveSizeInBits();
5380 unsigned MaxVectorSize =
5382 unsigned MaxElementsInVector = MaxVectorSize / ElementSize;
5383 if (MaxElementsInVector == 0)
5384 return false;
5385 // When there are multiple shufflevector operations on the same input,
5386 // especially when the vector length is larger than the register size,
5387 // identical shuffle patterns may occur across different groups of elements.
5388 // To avoid overestimating the cost by counting these repeated shuffles more
5389 // than once, we only account for unique shuffle patterns. This adjustment
5390 // prevents inflated costs in the cost model for wide vectors split into
5391 // several register-sized groups.
5392 std::set<SmallVector<int, 4>> UniqueShuffles;
5393 auto AddShuffleMaskAdjustedCost = [&](InstructionCost C, ArrayRef<int> Mask) {
5394 // Compute the cost for performing the shuffle over the full vector.
5395 auto ShuffleCost =
5397 unsigned NumFullVectors = Mask.size() / MaxElementsInVector;
5398 if (NumFullVectors < 2)
5399 return C + ShuffleCost;
5400 SmallVector<int, 4> SubShuffle(MaxElementsInVector);
5401 unsigned NumUniqueGroups = 0;
5402 unsigned NumGroups = Mask.size() / MaxElementsInVector;
5403 // For each group of MaxElementsInVector contiguous elements,
5404 // collect their shuffle pattern and insert into the set of unique patterns.
5405 for (unsigned I = 0; I < NumFullVectors; ++I) {
5406 for (unsigned J = 0; J < MaxElementsInVector; ++J)
5407 SubShuffle[J] = Mask[MaxElementsInVector * I + J];
5408 if (UniqueShuffles.insert(SubShuffle).second)
5409 NumUniqueGroups += 1;
5410 }
5411 return C + ShuffleCost * NumUniqueGroups / NumGroups;
5412 };
5413 auto AddShuffleAdjustedCost = [&](InstructionCost C, Instruction *I) {
5414 auto *SV = dyn_cast<ShuffleVectorInst>(I);
5415 if (!SV)
5416 return C;
5417 SmallVector<int, 16> Mask;
5418 SV->getShuffleMask(Mask);
5419 return AddShuffleMaskAdjustedCost(C, Mask);
5420 };
5421 // Check that input consists of ShuffleVectors applied to the same input
5422 auto AllShufflesHaveSameOperands =
5423 [](SmallPtrSetImpl<Instruction *> &InputShuffles) {
5424 if (InputShuffles.size() < 2)
5425 return false;
5426 ShuffleVectorInst *FirstSV =
5427 dyn_cast<ShuffleVectorInst>(*InputShuffles.begin());
5428 if (!FirstSV)
5429 return false;
5430
5431 Value *In0 = FirstSV->getOperand(0), *In1 = FirstSV->getOperand(1);
5432 return std::all_of(
5433 std::next(InputShuffles.begin()), InputShuffles.end(),
5434 [&](Instruction *I) {
5435 ShuffleVectorInst *SV = dyn_cast<ShuffleVectorInst>(I);
5436 return SV && SV->getOperand(0) == In0 && SV->getOperand(1) == In1;
5437 });
5438 };
5439
5440 // Get the costs of the shuffles + binops before and after with the new
5441 // shuffle masks.
5442 InstructionCost CostBefore =
5443 TTI.getArithmeticInstrCost(Op0->getOpcode(), VT, CostKind) +
5444 TTI.getArithmeticInstrCost(Op1->getOpcode(), VT, CostKind);
5445 CostBefore += std::accumulate(Shuffles.begin(), Shuffles.end(),
5446 InstructionCost(0), AddShuffleCost);
5447 if (AllShufflesHaveSameOperands(InputShuffles)) {
5448 UniqueShuffles.clear();
5449 CostBefore += std::accumulate(InputShuffles.begin(), InputShuffles.end(),
5450 InstructionCost(0), AddShuffleAdjustedCost);
5451 } else {
5452 CostBefore += std::accumulate(InputShuffles.begin(), InputShuffles.end(),
5453 InstructionCost(0), AddShuffleCost);
5454 }
5455
5456 // The new binops will be unused for lanes past the used shuffle lengths.
5457 // These types attempt to get the correct cost for that from the target.
5458 FixedVectorType *Op0SmallVT =
5459 FixedVectorType::get(VT->getScalarType(), V1.size());
5460 FixedVectorType *Op1SmallVT =
5461 FixedVectorType::get(VT->getScalarType(), V2.size());
5462 InstructionCost CostAfter =
5463 TTI.getArithmeticInstrCost(Op0->getOpcode(), Op0SmallVT, CostKind) +
5464 TTI.getArithmeticInstrCost(Op1->getOpcode(), Op1SmallVT, CostKind);
5465 UniqueShuffles.clear();
5466 CostAfter += std::accumulate(ReconstructMasks.begin(), ReconstructMasks.end(),
5467 InstructionCost(0), AddShuffleMaskAdjustedCost);
5468 std::set<SmallVector<int>> OutputShuffleMasks({V1A, V1B, V2A, V2B});
5469 CostAfter +=
5470 std::accumulate(OutputShuffleMasks.begin(), OutputShuffleMasks.end(),
5471 InstructionCost(0), AddShuffleMaskCost);
5472
5473 LLVM_DEBUG(dbgs() << "Found a binop select shuffle pattern: " << I << "\n");
5474 LLVM_DEBUG(dbgs() << " CostBefore: " << CostBefore
5475 << " vs CostAfter: " << CostAfter << "\n");
5476 if (CostBefore < CostAfter ||
5477 (CostBefore == CostAfter && !feedsIntoVectorReduction(SVI)))
5478 return false;
5479
5480 // The cost model has passed, create the new instructions.
5481 auto GetShuffleOperand = [&](Instruction *I, unsigned Op) -> Value * {
5482 auto *SV = dyn_cast<ShuffleVectorInst>(I);
5483 if (!SV)
5484 return I;
5485 if (isa<UndefValue>(SV->getOperand(1)))
5486 if (auto *SSV = dyn_cast<ShuffleVectorInst>(SV->getOperand(0)))
5487 if (InputShuffles.contains(SSV))
5488 return SSV->getOperand(Op);
5489 return SV->getOperand(Op);
5490 };
5491 Builder.SetInsertPoint(*SVI0A->getInsertionPointAfterDef());
5492 Value *NSV0A = Builder.CreateShuffleVector(GetShuffleOperand(SVI0A, 0),
5493 GetShuffleOperand(SVI0A, 1), V1A);
5494 Builder.SetInsertPoint(*SVI0B->getInsertionPointAfterDef());
5495 Value *NSV0B = Builder.CreateShuffleVector(GetShuffleOperand(SVI0B, 0),
5496 GetShuffleOperand(SVI0B, 1), V1B);
5497 Builder.SetInsertPoint(*SVI1A->getInsertionPointAfterDef());
5498 Value *NSV1A = Builder.CreateShuffleVector(GetShuffleOperand(SVI1A, 0),
5499 GetShuffleOperand(SVI1A, 1), V2A);
5500 Builder.SetInsertPoint(*SVI1B->getInsertionPointAfterDef());
5501 Value *NSV1B = Builder.CreateShuffleVector(GetShuffleOperand(SVI1B, 0),
5502 GetShuffleOperand(SVI1B, 1), V2B);
5503 Builder.SetInsertPoint(Op0);
5504 Value *NOp0 = Builder.CreateBinOp((Instruction::BinaryOps)Op0->getOpcode(),
5505 NSV0A, NSV0B);
5506 if (auto *I = dyn_cast<Instruction>(NOp0))
5507 I->copyIRFlags(Op0, true);
5508 Builder.SetInsertPoint(Op1);
5509 Value *NOp1 = Builder.CreateBinOp((Instruction::BinaryOps)Op1->getOpcode(),
5510 NSV1A, NSV1B);
5511 if (auto *I = dyn_cast<Instruction>(NOp1))
5512 I->copyIRFlags(Op1, true);
5513
5514 for (int S = 0, E = ReconstructMasks.size(); S != E; S++) {
5515 Builder.SetInsertPoint(Shuffles[S]);
5516 Value *NSV = Builder.CreateShuffleVector(NOp0, NOp1, ReconstructMasks[S]);
5517 replaceValue(*Shuffles[S], *NSV, false);
5518 }
5519
5520 Worklist.pushValue(NSV0A);
5521 Worklist.pushValue(NSV0B);
5522 Worklist.pushValue(NSV1A);
5523 Worklist.pushValue(NSV1B);
5524 return true;
5525}
5526
5527/// Check if instruction depends on ZExt and this ZExt can be moved after the
5528/// instruction. Move ZExt if it is profitable. For example:
5529/// logic(zext(x),y) -> zext(logic(x,trunc(y)))
5530/// lshr((zext(x),y) -> zext(lshr(x,trunc(y)))
5531/// Cost model calculations takes into account if zext(x) has other users and
5532/// whether it can be propagated through them too.
5533bool VectorCombine::shrinkType(Instruction &I) {
5534 Value *ZExted, *OtherOperand;
5535 if (!match(&I, m_c_BitwiseLogic(m_ZExt(m_Value(ZExted)),
5536 m_Value(OtherOperand))) &&
5537 !match(&I, m_LShr(m_ZExt(m_Value(ZExted)), m_Value(OtherOperand))))
5538 return false;
5539
5540 Value *ZExtOperand = I.getOperand(I.getOperand(0) == OtherOperand ? 1 : 0);
5541
5542 auto *BigTy = cast<FixedVectorType>(I.getType());
5543 auto *SmallTy = cast<FixedVectorType>(ZExted->getType());
5544 unsigned BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
5545
5546 if (I.getOpcode() == Instruction::LShr) {
5547 // Check that the shift amount is less than the number of bits in the
5548 // smaller type. Otherwise, the smaller lshr will return a poison value.
5549 KnownBits ShAmtKB = computeKnownBits(I.getOperand(1), *DL);
5550 if (ShAmtKB.getMaxValue().uge(BW))
5551 return false;
5552 } else {
5553 // Check that the expression overall uses at most the same number of bits as
5554 // ZExted
5555 KnownBits KB = computeKnownBits(&I, *DL);
5556 if (KB.countMaxActiveBits() > BW)
5557 return false;
5558 }
5559
5560 // Calculate costs of leaving current IR as it is and moving ZExt operation
5561 // later, along with adding truncates if needed
5563 Instruction::ZExt, BigTy, SmallTy,
5564 TargetTransformInfo::CastContextHint::None, CostKind);
5565 InstructionCost CurrentCost = ZExtCost;
5566 InstructionCost ShrinkCost = 0;
5567
5568 // Calculate total cost and check that we can propagate through all ZExt users
5569 for (User *U : ZExtOperand->users()) {
5570 auto *UI = cast<Instruction>(U);
5571 if (UI == &I) {
5572 CurrentCost +=
5573 TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind);
5574 ShrinkCost +=
5575 TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind);
5576 ShrinkCost += ZExtCost;
5577 continue;
5578 }
5579
5580 if (!Instruction::isBinaryOp(UI->getOpcode()))
5581 return false;
5582
5583 // Check if we can propagate ZExt through its other users
5584 KnownBits KB = computeKnownBits(UI, *DL);
5585 if (KB.countMaxActiveBits() > BW)
5586 return false;
5587
5588 CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind);
5589 ShrinkCost +=
5590 TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind);
5591 ShrinkCost += ZExtCost;
5592 }
5593
5594 // If the other instruction operand is not a constant, we'll need to
5595 // generate a truncate instruction. So we have to adjust cost
5596 if (!isa<Constant>(OtherOperand))
5597 ShrinkCost += TTI.getCastInstrCost(
5598 Instruction::Trunc, SmallTy, BigTy,
5599 TargetTransformInfo::CastContextHint::None, CostKind);
5600
5601 // If the cost of shrinking types and leaving the IR is the same, we'll lean
5602 // towards modifying the IR because shrinking opens opportunities for other
5603 // shrinking optimisations.
5604 if (ShrinkCost > CurrentCost)
5605 return false;
5606
5607 Builder.SetInsertPoint(&I);
5608 Value *Op0 = ZExted;
5609 Value *Op1 = Builder.CreateTrunc(OtherOperand, SmallTy);
5610 // Keep the order of operands the same
5611 if (I.getOperand(0) == OtherOperand)
5612 std::swap(Op0, Op1);
5613 Value *NewBinOp =
5614 Builder.CreateBinOp((Instruction::BinaryOps)I.getOpcode(), Op0, Op1);
5615 cast<Instruction>(NewBinOp)->copyIRFlags(&I);
5616 cast<Instruction>(NewBinOp)->copyMetadata(I);
5617 Value *NewZExtr = Builder.CreateZExt(NewBinOp, BigTy);
5618 replaceValue(I, *NewZExtr);
5619 return true;
5620}
5621
5622/// insert (DstVec, (extract SrcVec, ExtIdx), InsIdx) -->
5623/// shuffle (DstVec, SrcVec, Mask)
5624bool VectorCombine::foldInsExtVectorToShuffle(Instruction &I) {
5625 Value *DstVec, *SrcVec;
5626 uint64_t ExtIdx, InsIdx;
5627 if (!match(&I,
5628 m_InsertElt(m_Value(DstVec),
5629 m_ExtractElt(m_Value(SrcVec), m_ConstantInt(ExtIdx)),
5630 m_ConstantInt(InsIdx))))
5631 return false;
5632
5633 auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType());
5634 auto *SrcVecTy = dyn_cast<FixedVectorType>(SrcVec->getType());
5635 // We can try combining vectors with different element sizes.
5636 if (!DstVecTy || !SrcVecTy ||
5637 SrcVecTy->getElementType() != DstVecTy->getElementType())
5638 return false;
5639
5640 unsigned NumDstElts = DstVecTy->getNumElements();
5641 unsigned NumSrcElts = SrcVecTy->getNumElements();
5642 if (InsIdx >= NumDstElts || ExtIdx >= NumSrcElts || NumDstElts == 1)
5643 return false;
5644
5645 // Insertion into poison is a cheaper single operand shuffle.
5647 SmallVector<int> Mask(NumDstElts, PoisonMaskElem);
5648
5649 bool NeedExpOrNarrow = NumSrcElts != NumDstElts;
5650 bool NeedDstSrcSwap = isa<PoisonValue>(DstVec) && !isa<UndefValue>(SrcVec);
5651 if (NeedDstSrcSwap) {
5653 Mask[InsIdx] = ExtIdx % NumDstElts;
5654 std::swap(DstVec, SrcVec);
5655 } else {
5657 std::iota(Mask.begin(), Mask.end(), 0);
5658 Mask[InsIdx] = (ExtIdx % NumDstElts) + NumDstElts;
5659 }
5660
5661 // Cost
5662 auto *Ins = cast<InsertElementInst>(&I);
5663 auto *Ext = cast<ExtractElementInst>(I.getOperand(1));
5664 InstructionCost InsCost =
5665 TTI.getVectorInstrCost(*Ins, DstVecTy, CostKind, InsIdx);
5666 InstructionCost ExtCost =
5667 TTI.getVectorInstrCost(*Ext, DstVecTy, CostKind, ExtIdx);
5668 InstructionCost OldCost = ExtCost + InsCost;
5669
5670 InstructionCost NewCost = 0;
5671 SmallVector<int> ExtToVecMask;
5672 if (!NeedExpOrNarrow) {
5673 // Ignore 'free' identity insertion shuffle.
5674 // TODO: getShuffleCost should return TCC_Free for Identity shuffles.
5675 if (!ShuffleVectorInst::isIdentityMask(Mask, NumSrcElts))
5676 NewCost += TTI.getShuffleCost(SK, DstVecTy, DstVecTy, Mask, CostKind, 0,
5677 nullptr, {DstVec, SrcVec});
5678 } else {
5679 // When creating a length-changing-vector, always try to keep the relevant
5680 // element in an equivalent position, so that bulk shuffles are more likely
5681 // to be useful.
5682 ExtToVecMask.assign(NumDstElts, PoisonMaskElem);
5683 ExtToVecMask[ExtIdx % NumDstElts] = ExtIdx;
5684 // Add cost for expanding or narrowing
5686 DstVecTy, SrcVecTy, ExtToVecMask, CostKind);
5687 NewCost += TTI.getShuffleCost(SK, DstVecTy, DstVecTy, Mask, CostKind);
5688 }
5689
5690 if (!Ext->hasOneUse())
5691 NewCost += ExtCost;
5692
5693 LLVM_DEBUG(dbgs() << "Found a insert/extract shuffle-like pair: " << I
5694 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
5695 << "\n");
5696
5697 if (OldCost < NewCost)
5698 return false;
5699
5700 if (NeedExpOrNarrow) {
5701 if (!NeedDstSrcSwap)
5702 SrcVec = Builder.CreateShuffleVector(SrcVec, ExtToVecMask);
5703 else
5704 DstVec = Builder.CreateShuffleVector(DstVec, ExtToVecMask);
5705 }
5706
5707 // Canonicalize undef param to RHS to help further folds.
5708 if (isa<UndefValue>(DstVec) && !isa<UndefValue>(SrcVec)) {
5709 ShuffleVectorInst::commuteShuffleMask(Mask, NumDstElts);
5710 std::swap(DstVec, SrcVec);
5711 }
5712
5713 Value *Shuf = Builder.CreateShuffleVector(DstVec, SrcVec, Mask);
5714 replaceValue(I, *Shuf);
5715
5716 return true;
5717}
5718
5719/// If we're interleaving 2 constant splats, for instance `<vscale x 8 x i32>
5720/// <splat of 666>` and `<vscale x 8 x i32> <splat of 777>`, we can create a
5721/// larger splat `<vscale x 8 x i64> <splat of ((777 << 32) | 666)>` first
5722/// before casting it back into `<vscale x 16 x i32>`.
5723bool VectorCombine::foldInterleaveIntrinsics(Instruction &I) {
5724 const APInt *SplatVal0, *SplatVal1;
5726 m_APInt(SplatVal0), m_APInt(SplatVal1))))
5727 return false;
5728
5729 LLVM_DEBUG(dbgs() << "VC: Folding interleave2 with two splats: " << I
5730 << "\n");
5731
5732 auto *VTy =
5733 cast<VectorType>(cast<IntrinsicInst>(I).getArgOperand(0)->getType());
5734 auto *ExtVTy = VectorType::getExtendedElementVectorType(VTy);
5735 unsigned Width = VTy->getElementType()->getIntegerBitWidth();
5736
5737 // Just in case the cost of interleave2 intrinsic and bitcast are both
5738 // invalid, in which case we want to bail out, we use <= rather
5739 // than < here. Even they both have valid and equal costs, it's probably
5740 // not a good idea to emit a high-cost constant splat.
5742 TTI.getCastInstrCost(Instruction::BitCast, I.getType(), ExtVTy,
5744 LLVM_DEBUG(dbgs() << "VC: The cost to cast from " << *ExtVTy << " to "
5745 << *I.getType() << " is too high.\n");
5746 return false;
5747 }
5748
5749 APInt NewSplatVal = SplatVal1->zext(Width * 2);
5750 NewSplatVal <<= Width;
5751 NewSplatVal |= SplatVal0->zext(Width * 2);
5752 auto *NewSplat = ConstantVector::getSplat(
5753 ExtVTy->getElementCount(), ConstantInt::get(F.getContext(), NewSplatVal));
5754
5755 IRBuilder<> Builder(&I);
5756 replaceValue(I, *Builder.CreateBitCast(NewSplat, I.getType()));
5757 return true;
5758}
5759
5760/// Given this sequence:
5761/// ```
5762/// %d = llvm.vector.deinterleave2 <vscale x 16 x i32> %v
5763/// %f0 = extractvalue { <vscale x 8 x i32>, <vscale x 8 x i32> } %d, 0
5764/// %f1 = extractvalue { <vscale x 8 x i32>, <vscale x 8 x i32> } %d, 1
5765///
5766/// %low0 = and <vscale x 8 x i32> %f0, splat (i32 65535)
5767/// %low1 = shl <vscale x 8 x i32> %f1, splat (i32 16)
5768/// %merge0 = or disjoint <vscale x 8 x i32> %low0, %low1
5769///
5770/// %high0 = and <vscale x 8 x i32> %f1, splat (i32 -65536)
5771/// %high1 = lshr <vscale x 8 x i32> %f0, splat (i32 16)
5772/// %merge1 = or disjoint <vscale x 8 x i32> %high0, %high1
5773/// ```
5774/// It is actually just de-interleaving a 16-bit vector with double the
5775/// vector length. More generally speaking, it's de-interleaving on a vector
5776/// with half the element width as the original vector.
5777///
5778/// Therefore, we can turn it into:
5779/// ```
5780/// %narrow.v = bitcast <vscale x 16 x i32> %v to <vscale x 32 x i16>
5781/// %d = llvm.vector.deinterleave2 <vscale x 32 x i16> %narrow.v
5782/// %f0 = extractvalue { <vscale x 16 x i16>, <vscale x 16 x i16> } %d, 0
5783/// %f1 = extractvalue { <vscale x 16 x i16>, <vscale x 16 x i16> } %d, 1
5784///
5785/// %merge0 = bitcast <vscale x 16 x i16> %f0 to <vscale x 8 x i32>
5786/// %merge1 = bitcast <vscale x 16 x i16> %f1 to <vscale x 8 x i32>
5787/// ```
5788bool VectorCombine::foldDeinterleaveIntrinsics(Instruction &I) {
5789 // This pattern involves bitcast that is not compatible with big endian.
5790 if (DL->isBigEndian())
5791 return false;
5792
5793 using namespace PatternMatch;
5794 Value *DeinterleavedVal;
5795 if (!match(&I, m_Deinterleave2(m_Value(DeinterleavedVal))))
5796 return false;
5797
5798 VectorType *VecTy = cast<VectorType>(DeinterleavedVal->getType());
5799 IntegerType *ElementTy = dyn_cast<IntegerType>(VecTy->getElementType());
5800 if (!ElementTy)
5801 return false;
5802 unsigned ElementWidth = ElementTy->getBitWidth();
5803 if (ElementWidth < 2 || !isPowerOf2_32(ElementWidth))
5804 return false;
5805 unsigned HalfElementWidth = ElementWidth / 2;
5806
5807 if (!I.hasNUses(2))
5808 return false;
5809 std::array<ExtractValueInst *, 2> OrigFields{};
5810 for (User *Usr : I.users()) {
5811 auto *E = dyn_cast<ExtractValueInst>(Usr);
5812 // The deinterleave result can only be used by extractions.
5813 if (!E || E->getNumIndices() != 1)
5814 return false;
5815 unsigned Idx = *E->idx_begin();
5816 // A single field cannot be extracted more than once.
5817 if (Idx >= 2 || OrigFields[Idx] || !E->hasNUses(2))
5818 return false;
5819 OrigFields[Idx] = E;
5820 }
5821
5822 // Find the merge instruction (i.e. OR) first.
5823 SmallVector<Instruction *, 2> MergeInsts;
5824 for (auto *FieldUsr : OrigFields[0]->users()) {
5825 if (!FieldUsr->hasOneUse() || !isa<Instruction>(FieldUsr->user_back()))
5826 return false;
5827 MergeInsts.push_back(cast<Instruction>(FieldUsr->user_back()));
5828 }
5829 assert(MergeInsts.size() == 2);
5830
5831 // Pattern match bottom-up from the merge instructions.
5832 auto MatchMerge = [&](void) -> bool {
5833 APInt LoMask = APInt::getLowBitsSet(ElementWidth, HalfElementWidth);
5834 APInt HiMask = APInt::getHighBitsSet(ElementWidth, HalfElementWidth);
5835 return match(MergeInsts[0],
5836 m_c_Or(m_And(m_Specific(OrigFields[0]), m_SpecificInt(LoMask)),
5837 m_Shl(m_Specific(OrigFields[1]),
5838 m_SpecificInt(HalfElementWidth)))) &&
5839 match(MergeInsts[1],
5840 m_c_Or(m_And(m_Specific(OrigFields[1]), m_SpecificInt(HiMask)),
5841 m_LShr(m_Specific(OrigFields[0]),
5842 m_SpecificInt(HalfElementWidth))));
5843 };
5844 if (!MatchMerge()) {
5845 std::swap(MergeInsts[0], MergeInsts[1]);
5846 if (!MatchMerge())
5847 return false;
5848 }
5849
5850 // Profitability check.
5851 InstructionCost OldCost =
5852 TTI.getInstructionCost(MergeInsts[0], CostKind) +
5853 TTI.getInstructionCost(cast<Instruction>(MergeInsts[0]->getOperand(0)),
5854 CostKind) +
5855 TTI.getInstructionCost(cast<Instruction>(MergeInsts[0]->getOperand(1)),
5856 CostKind);
5857 // There are two fields (assuming SHL has the same cost as LSHR).
5858 OldCost *= 2;
5859
5860 auto *NewFieldTy = VecTy->getWithNewBitWidth(HalfElementWidth);
5861 auto *NewVecTy =
5862 VectorType::getDoubleElementsVectorType(cast<VectorType>(NewFieldTy));
5863 InstructionCost NewCost =
5864 TTI.getCastInstrCost(Instruction::BitCast, VecTy, NewVecTy,
5866 TTI.getCastInstrCost(Instruction::BitCast, NewFieldTy,
5867 MergeInsts[0]->getType(), TTI::CastContextHint::None,
5868 CostKind) *
5869 2;
5870 if (OldCost <= NewCost || !NewCost.isValid()) {
5871 LLVM_DEBUG(
5872 dbgs() << "VC: New deinterleave2 sequence cost (" << NewCost << ")"
5873 << " is higher than that of the old one (" << OldCost << ")\n");
5874 return false;
5875 }
5876
5877 // Do the replacement.
5878 IRBuilder<> Builder(&I);
5879 Value *NewVecCast = Builder.CreateBitCast(DeinterleavedVal, NewVecTy);
5880 Value *NewDeinterleave = Builder.CreateIntrinsic(
5881 Intrinsic::vector_deinterleave2, {NewVecTy}, {NewVecCast});
5882 for (auto [Idx, MergeInst] : enumerate(MergeInsts)) {
5883 Value *NewField = Builder.CreateExtractValue(NewDeinterleave, Idx);
5884 NewField = Builder.CreateBitCast(NewField, MergeInst->getType());
5885 replaceValue(*MergeInst, *NewField);
5886 }
5887
5888 return true;
5889}
5890
5891bool VectorCombine::foldBitcastOfVPLoad(Instruction &I) {
5892 const DataLayout &DL = I.getDataLayout();
5893 auto *Cast = dyn_cast<CastInst>(&I);
5894 if (!Cast || !Cast->isNoopCast(DL) || !isa<VectorType>(Cast->getDestTy()))
5895 return false;
5896
5897 // Fold away bit casts of the loaded value by loading the desired type,
5898 // if the mask is all-ones.
5899 Value *EVL;
5900 auto *II = dyn_cast<VPIntrinsic>(I.getOperand(0));
5902 m_Value(), m_AllOnes(), m_Value(EVL)))))
5903 return false;
5904
5905 VectorType *OrigVecTy = cast<VectorType>(II->getType());
5906 Align OrigAlign =
5907 DL.getValueOrABITypeAlignment(II->getPointerAlignment(), OrigVecTy);
5908 ElementCount OrigVecCnt = OrigVecTy->getElementCount();
5909 VectorType *NewVecTy = cast<VectorType>(Cast->getDestTy());
5910 ElementCount NewVecCnt = NewVecTy->getElementCount();
5911
5912 // Right now we only support cases where the NewVec is longer, because for
5913 // cases where it's shorter, we have to be sure that EVL can be exactly
5914 // divided, otherwise it might yield incorrect results or even page faults
5915 // (if we round-up during the division).
5916 if (!(OrigVecCnt.isScalable() == NewVecCnt.isScalable() &&
5917 NewVecCnt.hasKnownScalarFactor(OrigVecCnt)))
5918 return false;
5919
5920 InstructionCost OldCost =
5921 TTI.getMemIntrinsicInstrCost({Intrinsic::vp_load, OrigVecTy,
5922 II->getMemoryPointerParam(), false,
5923 OrigAlign},
5924 CostKind) +
5925 TTI.getCastInstrCost(Instruction::BitCast, Cast->getType(), OrigVecTy,
5928 {Intrinsic::vp_load, NewVecTy, II->getMemoryPointerParam(), false,
5929 OrigAlign},
5930 CostKind);
5931 LLVM_DEBUG(dbgs() << "foldBitcastOfVPLoad: OldCost=" << OldCost
5932 << " NewCost=" << NewCost << "\n");
5933 if (NewCost > OldCost || !NewCost.isValid())
5934 return false;
5935
5936 unsigned Factor = NewVecCnt.getKnownScalarFactor(OrigVecCnt);
5937 Value *NewEVL = Builder.CreateNUWMul(EVL, Builder.getInt32(Factor));
5938 Value *NewMask = Builder.CreateVectorSplat(NewVecCnt, Builder.getTrue());
5939 CallInst *NewVP =
5940 Builder.CreateIntrinsic(NewVecTy, Intrinsic::vp_load,
5941 {II->getMemoryPointerParam(), NewMask, NewEVL});
5942 // Preserve the original alignment.
5943 NewVP->addParamAttrs(
5944 0, AttrBuilder(II->getContext()).addAlignmentAttr(OrigAlign));
5945 replaceValue(*Cast, *NewVP);
5946 return true;
5947}
5948
5949// Attempt to shrink loads that are only used by shufflevector instructions.
5950bool VectorCombine::shrinkLoadForShuffles(Instruction &I) {
5951 auto *OldLoad = dyn_cast<LoadInst>(&I);
5952 if (!OldLoad || !OldLoad->isSimple())
5953 return false;
5954
5955 auto *OldLoadTy = dyn_cast<FixedVectorType>(OldLoad->getType());
5956 if (!OldLoadTy)
5957 return false;
5958
5959 unsigned const OldNumElements = OldLoadTy->getNumElements();
5960
5961 // Search all uses of load. If all uses are shufflevector instructions, and
5962 // the second operands are all poison values, find the minimum and maximum
5963 // indices of the vector elements referenced by all shuffle masks.
5964 // Otherwise return `std::nullopt`.
5965 using IndexRange = std::pair<int, int>;
5966 auto GetIndexRangeInShuffles = [&]() -> std::optional<IndexRange> {
5967 IndexRange OutputRange = IndexRange(OldNumElements, -1);
5968 for (llvm::Use &Use : I.uses()) {
5969 // Ensure all uses match the required pattern.
5970 User *Shuffle = Use.getUser();
5971 ArrayRef<int> Mask;
5972
5973 if (!match(Shuffle,
5974 m_Shuffle(m_Specific(OldLoad), m_Undef(), m_Mask(Mask))))
5975 return std::nullopt;
5976
5977 // Ignore shufflevector instructions that have no uses.
5978 if (Shuffle->use_empty())
5979 continue;
5980
5981 // Find the min and max indices used by the shufflevector instruction.
5982 for (int Index : Mask) {
5983 if (Index >= 0 && Index < static_cast<int>(OldNumElements)) {
5984 OutputRange.first = std::min(Index, OutputRange.first);
5985 OutputRange.second = std::max(Index, OutputRange.second);
5986 }
5987 }
5988 }
5989
5990 if (OutputRange.second < OutputRange.first)
5991 return std::nullopt;
5992
5993 return OutputRange;
5994 };
5995
5996 // Get the range of vector elements used by shufflevector instructions.
5997 if (std::optional<IndexRange> Indices = GetIndexRangeInShuffles()) {
5998 unsigned const NewNumElements = Indices->second + 1u;
5999
6000 // If the range of vector elements is smaller than the full load, attempt
6001 // to create a smaller load.
6002 if (NewNumElements < OldNumElements) {
6003 IRBuilder Builder(&I);
6004 Builder.SetCurrentDebugLocation(I.getDebugLoc());
6005
6006 // Calculate costs of old and new ops.
6007 Type *ElemTy = OldLoadTy->getElementType();
6008 FixedVectorType *NewLoadTy = FixedVectorType::get(ElemTy, NewNumElements);
6009 Value *PtrOp = OldLoad->getPointerOperand();
6010
6012 Instruction::Load, OldLoad->getType(), OldLoad->getAlign(),
6013 OldLoad->getPointerAddressSpace(), CostKind);
6014 InstructionCost NewCost =
6015 TTI.getMemoryOpCost(Instruction::Load, NewLoadTy, OldLoad->getAlign(),
6016 OldLoad->getPointerAddressSpace(), CostKind);
6017
6018 using UseEntry = std::pair<ShuffleVectorInst *, std::vector<int>>;
6020 unsigned const MaxIndex = NewNumElements * 2u;
6021
6022 for (llvm::Use &Use : I.uses()) {
6023 auto *Shuffle = cast<ShuffleVectorInst>(Use.getUser());
6024
6025 // Ignore shufflevector instructions that have no uses.
6026 if (Shuffle->use_empty())
6027 continue;
6028
6029 ArrayRef<int> OldMask = Shuffle->getShuffleMask();
6030
6031 // Create entry for new use.
6032 NewUses.push_back({Shuffle, OldMask});
6033
6034 // Validate mask indices.
6035 for (int Index : OldMask) {
6036 if (Index >= static_cast<int>(MaxIndex))
6037 return false;
6038 }
6039
6040 // Update costs.
6041 OldCost +=
6043 OldLoadTy, OldMask, CostKind);
6044 NewCost +=
6046 NewLoadTy, OldMask, CostKind);
6047 }
6048
6049 LLVM_DEBUG(
6050 dbgs() << "Found a load used only by shufflevector instructions: "
6051 << I << "\n OldCost: " << OldCost
6052 << " vs NewCost: " << NewCost << "\n");
6053
6054 if (OldCost < NewCost || !NewCost.isValid())
6055 return false;
6056
6057 // Create new load of smaller vector.
6058 auto *NewLoad = cast<LoadInst>(
6059 Builder.CreateAlignedLoad(NewLoadTy, PtrOp, OldLoad->getAlign()));
6060 NewLoad->copyMetadata(I);
6061
6062 // Replace all uses.
6063 for (UseEntry &Use : NewUses) {
6064 ShuffleVectorInst *Shuffle = Use.first;
6065 std::vector<int> &NewMask = Use.second;
6066
6067 Builder.SetInsertPoint(Shuffle);
6068 Builder.SetCurrentDebugLocation(Shuffle->getDebugLoc());
6069 Value *NewShuffle = Builder.CreateShuffleVector(
6070 NewLoad, PoisonValue::get(NewLoadTy), NewMask);
6071
6072 replaceValue(*Shuffle, *NewShuffle, false);
6073 }
6074
6075 return true;
6076 }
6077 }
6078 return false;
6079}
6080
6081// Attempt to narrow a phi of shufflevector instructions where the two incoming
6082// values have the same operands but different masks. If the two shuffle masks
6083// are offsets of one another we can use one branch to rotate the incoming
6084// vector and perform one larger shuffle after the phi.
6085bool VectorCombine::shrinkPhiOfShuffles(Instruction &I) {
6086 auto *Phi = dyn_cast<PHINode>(&I);
6087 if (!Phi || Phi->getNumIncomingValues() != 2u)
6088 return false;
6089
6090 Value *Op = nullptr;
6091 ArrayRef<int> Mask0;
6092 ArrayRef<int> Mask1;
6093
6094 if (!match(Phi->getOperand(0u),
6095 m_OneUse(m_Shuffle(m_Value(Op), m_Poison(), m_Mask(Mask0)))) ||
6096 !match(Phi->getOperand(1u),
6097 m_OneUse(m_Shuffle(m_Specific(Op), m_Poison(), m_Mask(Mask1)))))
6098 return false;
6099
6100 auto *Shuf = cast<ShuffleVectorInst>(Phi->getOperand(0u));
6101
6102 // Ensure result vectors are wider than the argument vector.
6103 auto *InputVT = cast<FixedVectorType>(Op->getType());
6104 auto *ResultVT = cast<FixedVectorType>(Shuf->getType());
6105 auto const InputNumElements = InputVT->getNumElements();
6106
6107 if (InputNumElements >= ResultVT->getNumElements())
6108 return false;
6109
6110 // Take the difference of the two shuffle masks at each index. Ignore poison
6111 // values at the same index in both masks.
6112 SmallVector<int, 16> NewMask;
6113 NewMask.reserve(Mask0.size());
6114
6115 for (auto [M0, M1] : zip(Mask0, Mask1)) {
6116 if (M0 >= 0 && M1 >= 0)
6117 NewMask.push_back(M0 - M1);
6118 else if (M0 == -1 && M1 == -1)
6119 continue;
6120 else
6121 return false;
6122 }
6123
6124 // Ensure all elements of the new mask are equal. If the difference between
6125 // the incoming mask elements is the same, the two must be constant offsets
6126 // of one another.
6127 if (NewMask.empty() || !all_equal(NewMask))
6128 return false;
6129
6130 // Create new mask using difference of the two incoming masks.
6131 int MaskOffset = NewMask[0u];
6132 unsigned Index = (InputNumElements + MaskOffset) % InputNumElements;
6133 NewMask.clear();
6134
6135 for (unsigned I = 0u; I < InputNumElements; ++I) {
6136 NewMask.push_back(Index);
6137 Index = (Index + 1u) % InputNumElements;
6138 }
6139
6140 // Calculate costs for worst cases and compare.
6141 auto const Kind = TTI::SK_PermuteSingleSrc;
6142 auto OldCost =
6143 std::max(TTI.getShuffleCost(Kind, ResultVT, InputVT, Mask0, CostKind),
6144 TTI.getShuffleCost(Kind, ResultVT, InputVT, Mask1, CostKind));
6145 auto NewCost = TTI.getShuffleCost(Kind, InputVT, InputVT, NewMask, CostKind) +
6146 TTI.getShuffleCost(Kind, ResultVT, InputVT, Mask1, CostKind);
6147
6148 LLVM_DEBUG(dbgs() << "Found a phi of mergeable shuffles: " << I
6149 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
6150 << "\n");
6151
6152 if (NewCost > OldCost)
6153 return false;
6154
6155 // Create new shuffles and narrowed phi.
6156 auto Builder = IRBuilder(Shuf);
6157 Builder.SetCurrentDebugLocation(Shuf->getDebugLoc());
6158 auto *PoisonVal = PoisonValue::get(InputVT);
6159 auto *NewShuf0 = Builder.CreateShuffleVector(Op, PoisonVal, NewMask);
6160 Worklist.push(cast<Instruction>(NewShuf0));
6161
6162 Builder.SetInsertPoint(Phi);
6163 Builder.SetCurrentDebugLocation(Phi->getDebugLoc());
6164 auto *NewPhi = Builder.CreatePHI(NewShuf0->getType(), 2u);
6165 NewPhi->addIncoming(NewShuf0, Phi->getIncomingBlock(0u));
6166 NewPhi->addIncoming(Op, Phi->getIncomingBlock(1u));
6167
6168 Builder.SetInsertPoint(*NewPhi->getInsertionPointAfterDef());
6169 PoisonVal = PoisonValue::get(NewPhi->getType());
6170 auto *NewShuf1 = Builder.CreateShuffleVector(NewPhi, PoisonVal, Mask1);
6171
6172 replaceValue(*Phi, *NewShuf1);
6173 return true;
6174}
6175
6176/// This is the entry point for all transforms. Pass manager differences are
6177/// handled in the callers of this function.
6178bool VectorCombine::run() {
6180 return false;
6181
6182 // Don't attempt vectorization if the target does not support vectors.
6183 if (!TTI.getNumberOfRegisters(TTI.getRegisterClassForType(/*Vector*/ true)))
6184 return false;
6185
6186 LLVM_DEBUG(dbgs() << "\n\nVECTORCOMBINE on " << F.getName() << "\n");
6187
6188 auto FoldInst = [this](Instruction &I) {
6189 Builder.SetInsertPoint(&I);
6190 bool IsVectorType = isa<VectorType>(I.getType());
6191 bool IsFixedVectorType = isa<FixedVectorType>(I.getType());
6192 auto Opcode = I.getOpcode();
6193
6194 LLVM_DEBUG(dbgs() << "VC: Visiting: " << I << '\n');
6195
6196 // These folds should be beneficial regardless of when this pass is run
6197 // in the optimization pipeline.
6198 // The type checking is for run-time efficiency. We can avoid wasting time
6199 // dispatching to folding functions if there's no chance of matching.
6200 if (IsFixedVectorType) {
6201 switch (Opcode) {
6202 case Instruction::InsertElement:
6203 if (vectorizeLoadInsert(I))
6204 return true;
6205 break;
6206 case Instruction::ShuffleVector:
6207 if (widenSubvectorLoad(I))
6208 return true;
6209 break;
6210 default:
6211 break;
6212 }
6213 }
6214
6215 // This transform works with scalable and fixed vectors
6216 // TODO: Identify and allow other scalable transforms
6217 if (IsVectorType) {
6218 if (scalarizeOpOrCmp(I))
6219 return true;
6220 if (scalarizeLoad(I))
6221 return true;
6222 if (scalarizeExtExtract(I))
6223 return true;
6224 if (scalarizeVPIntrinsic(I))
6225 return true;
6226 if (foldInterleaveIntrinsics(I))
6227 return true;
6228 if (foldBitcastOfVPLoad(I))
6229 return true;
6230 }
6231
6232 if (foldDeinterleaveIntrinsics(I))
6233 return true;
6234
6235 if (Opcode == Instruction::Store)
6236 if (foldSingleElementStore(I))
6237 return true;
6238
6239 // If this is an early pipeline invocation of this pass, we are done.
6240 if (TryEarlyFoldsOnly)
6241 return false;
6242
6243 // Otherwise, try folds that improve codegen but may interfere with
6244 // early IR canonicalizations.
6245 // The type checking is for run-time efficiency. We can avoid wasting time
6246 // dispatching to folding functions if there's no chance of matching.
6247 if (IsFixedVectorType) {
6248 switch (Opcode) {
6249 case Instruction::InsertElement:
6250 if (foldInsExtFNeg(I))
6251 return true;
6252 if (foldInsExtBinop(I))
6253 return true;
6254 if (foldInsExtVectorToShuffle(I))
6255 return true;
6256 break;
6257 case Instruction::ShuffleVector:
6258 if (foldPermuteOfBinops(I))
6259 return true;
6260 if (foldShuffleOfBinops(I))
6261 return true;
6262 if (foldShuffleOfSelects(I))
6263 return true;
6264 if (foldShuffleOfCastops(I))
6265 return true;
6266 if (foldShuffleOfShuffles(I))
6267 return true;
6268 if (foldPermuteOfIntrinsic(I))
6269 return true;
6270 if (foldShufflesOfLengthChangingShuffles(I))
6271 return true;
6272 if (foldShuffleOfIntrinsics(I))
6273 return true;
6274 if (foldSelectShuffle(I))
6275 return true;
6276 if (foldShuffleToIdentity(I))
6277 return true;
6278 break;
6279 case Instruction::Load:
6280 if (shrinkLoadForShuffles(I))
6281 return true;
6282 break;
6283 case Instruction::BitCast:
6284 if (foldBitcastShuffle(I))
6285 return true;
6286 if (foldSelectsFromBitcast(I))
6287 return true;
6288 break;
6289 case Instruction::And:
6290 case Instruction::Or:
6291 case Instruction::Xor:
6292 if (foldBitOpOfCastops(I))
6293 return true;
6294 if (foldBitOpOfCastConstant(I))
6295 return true;
6296 break;
6297 case Instruction::PHI:
6298 if (shrinkPhiOfShuffles(I))
6299 return true;
6300 break;
6301 default:
6302 if (shrinkType(I))
6303 return true;
6304 break;
6305 }
6306 } else {
6307 switch (Opcode) {
6308 case Instruction::Call:
6309 if (foldShuffleFromReductions(I))
6310 return true;
6311 if (foldCastFromReductions(I))
6312 return true;
6313 break;
6314 case Instruction::ExtractElement:
6315 if (foldShuffleChainsToReduce(I))
6316 return true;
6317 break;
6318 case Instruction::ICmp:
6319 if (foldSignBitReductionCmp(I))
6320 return true;
6321 if (foldICmpEqZeroVectorReduce(I))
6322 return true;
6323 if (foldEquivalentReductionCmp(I))
6324 return true;
6325 if (foldReduceAddCmpZero(I))
6326 return true;
6327 [[fallthrough]];
6328 case Instruction::FCmp:
6329 if (foldExtractExtract(I))
6330 return true;
6331 break;
6332 case Instruction::Or:
6333 if (foldConcatOfBoolMasks(I))
6334 return true;
6335 [[fallthrough]];
6336 default:
6337 if (Instruction::isBinaryOp(Opcode)) {
6338 if (foldExtractExtract(I))
6339 return true;
6340 if (foldExtractedCmps(I))
6341 return true;
6342 if (foldBinopOfReductions(I))
6343 return true;
6344 }
6345 break;
6346 }
6347 }
6348 return false;
6349 };
6350
6351 bool MadeChange = false;
6352 for (BasicBlock &BB : F) {
6353 // Ignore unreachable basic blocks.
6354 if (!DT.isReachableFromEntry(&BB))
6355 continue;
6356 // Use early increment range so that we can erase instructions in loop.
6357 // make_early_inc_range is not applicable here, as the next iterator may
6358 // be invalidated by RecursivelyDeleteTriviallyDeadInstructions.
6359 // We manually maintain the next instruction and update it when it is about
6360 // to be deleted.
6361 Instruction *I = &BB.front();
6362 while (I) {
6363 NextInst = I->getNextNode();
6364 if (!I->isDebugOrPseudoInst())
6365 MadeChange |= FoldInst(*I);
6366 I = NextInst;
6367 }
6368 }
6369
6370 NextInst = nullptr;
6371
6372 while (!Worklist.isEmpty()) {
6373 Instruction *I = Worklist.removeOne();
6374 if (!I)
6375 continue;
6376
6379 continue;
6380 }
6381
6382 MadeChange |= FoldInst(*I);
6383 }
6384
6385 return MadeChange;
6386}
6387
6390 auto &AC = FAM.getResult<AssumptionAnalysis>(F);
6392 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
6393 AAResults &AA = FAM.getResult<AAManager>(F);
6394 const DataLayout *DL = &F.getDataLayout();
6397 VectorCombine Combiner(F, TTI, DT, AA, AC, DL, CostKind, TryEarlyFoldsOnly);
6398 if (!Combiner.run())
6399 return PreservedAnalyses::all();
6402 return PA;
6403}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static cl::opt< unsigned > MaxInstrsToScan("aggressive-instcombine-max-scan-instrs", cl::init(64), cl::Hidden, cl::desc("Max number of instructions to scan for aggressive instcombine."))
This is the interface for LLVM's primary stateless and local alias analysis.
#define X(NUM, ENUM, NAME)
Definition ELF.h:853
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static cl::opt< OutputCostKind > CostKind("cost-kind", cl::desc("Target cost kind"), cl::init(OutputCostKind::RecipThroughput), cl::values(clEnumValN(OutputCostKind::RecipThroughput, "throughput", "Reciprocal throughput"), clEnumValN(OutputCostKind::Latency, "latency", "Instruction latency"), clEnumValN(OutputCostKind::CodeSize, "code-size", "Code size"), clEnumValN(OutputCostKind::SizeAndLatency, "size-latency", "Code size and latency"), clEnumValN(OutputCostKind::All, "all", "Print all cost kinds")))
static cl::opt< IntrinsicCostStrategy > IntrinsicCost("intrinsic-cost-strategy", cl::desc("Costing strategy for intrinsic instructions"), cl::init(IntrinsicCostStrategy::InstructionCost), cl::values(clEnumValN(IntrinsicCostStrategy::InstructionCost, "instruction-cost", "Use TargetTransformInfo::getInstructionCost"), clEnumValN(IntrinsicCostStrategy::IntrinsicCost, "intrinsic-cost", "Use TargetTransformInfo::getIntrinsicInstrCost"), clEnumValN(IntrinsicCostStrategy::TypeBasedIntrinsicCost, "type-based-intrinsic-cost", "Calculate the intrinsic cost based only on argument types")))
This file defines the DenseMap class.
#define Check(C,...)
This is the interface for a simple mod/ref and alias analysis over globals.
Hexagon Common GEP
iv users
Definition IVUsers.cpp:48
static Value * getOpcode(Value &V, Type &Ty, InstrumentationConfig &IConf, InstrumentorIRBuilderTy &IIRB)
const size_t AbstractManglingParser< Derived, Alloc >::NumOps
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
static void eraseInstruction(Instruction &I, ICFLoopSafetyInfo &SafetyInfo, MemorySSAUpdater &MSSAU)
Definition LICM.cpp:1457
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
#define T1
MachineInstr unsigned OpIdx
uint64_t IntrinsicInst * II
FunctionAnalysisManager FAM
const SmallVectorImpl< MachineOperand > & Cond
unsigned OpIndex
This file contains some templates that are useful if you are working with the STL at all.
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
#define LLVM_DEBUG(...)
Definition Debug.h:119
static TableGen::Emitter::Opt Y("gen-skeleton-entry", EmitSkeleton, "Generate example skeleton entry")
static SymbolRef::Type getType(const Symbol *Sym)
Definition TapiFile.cpp:39
This pass exposes codegen information to IR-level passes.
static bool isFreeConcat(ArrayRef< InstLane > Item, TTI::TargetCostKind CostKind, const TargetTransformInfo &TTI)
Detect concat of multiple values into a vector.
static void analyzeCostOfVecReduction(const IntrinsicInst &II, TTI::TargetCostKind CostKind, const TargetTransformInfo &TTI, InstructionCost &CostBeforeReduction, InstructionCost &CostAfterReduction)
static SmallVector< InstLane > generateInstLaneVectorFromOperand(ArrayRef< InstLane > Item, int Op)
static Value * createShiftShuffle(Value *Vec, unsigned OldIndex, unsigned NewIndex, IRBuilderBase &Builder)
Create a shuffle that translates (shifts) 1 element from the input vector to a new element location.
static Value * generateNewInstTree(ArrayRef< InstLane > Item, Use *From, FixedVectorType *Ty, const DenseSet< std::pair< Value *, Use * > > &IdentityLeafs, const DenseSet< std::pair< Value *, Use * > > &SplatLeafs, const DenseSet< std::pair< Value *, Use * > > &ConcatLeafs, IRBuilderBase &Builder, const TargetTransformInfo *TTI)
std::pair< Value *, int > InstLane
static bool isKnownNonPositive(const Value *V, const SimplifyQuery &SQ, unsigned Depth=0)
Used by foldReduceAddCmpZero to check if we can prove that a value is non-positive.
static Align computeAlignmentAfterScalarization(Align VectorAlignment, Type *ScalarType, Value *Idx, const DataLayout &DL)
The memory operation on a vector of ScalarType had alignment of VectorAlignment.
static bool feedsIntoVectorReduction(ShuffleVectorInst *SVI)
Returns true if this ShuffleVectorInst eventually feeds into a vector reduction intrinsic (e....
static cl::opt< bool > DisableVectorCombine("disable-vector-combine", cl::init(false), cl::Hidden, cl::desc("Disable all vector combine transforms"))
static bool canWidenLoad(LoadInst *Load, const TargetTransformInfo &TTI)
static const unsigned InvalidIndex
static Value * translateExtract(ExtractElementInst *ExtElt, unsigned NewIndex, IRBuilderBase &Builder)
Given an extract element instruction with constant index operand, shuffle the source vector (shift th...
static ScalarizationResult canScalarizeAccess(VectorType *VecTy, Value *Idx, const SimplifyQuery &SQ)
Check if it is legal to scalarize a memory access to VecTy at index Idx.
static cl::opt< unsigned > MaxInstrsToScan("vector-combine-max-scan-instrs", cl::init(30), cl::Hidden, cl::desc("Max number of instructions to scan for vector combining."))
static cl::opt< bool > DisableBinopExtractShuffle("disable-binop-extract-shuffle", cl::init(false), cl::Hidden, cl::desc("Disable binop extract to shuffle transforms"))
static InstLane lookThroughShuffles(Value *V, int Lane)
static bool isMemModifiedBetween(BasicBlock::iterator Begin, BasicBlock::iterator End, const MemoryLocation &Loc, AAResults &AA)
static constexpr int Concat[]
Value * RHS
Value * LHS
A manager for alias analyses.
Class for arbitrary precision integers.
Definition APInt.h:78
LLVM_ABI APInt zext(unsigned width) const
Zero extend to a new width.
Definition APInt.cpp:1055
bool isAllOnes() const
Determine if all bits are set. This is true for zero-width values.
Definition APInt.h:372
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition APInt.h:381
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition APInt.h:1511
bool isNegative() const
Determine sign of this APInt.
Definition APInt.h:330
unsigned countl_one() const
Count the number of leading one bits.
Definition APInt.h:1638
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition APInt.h:307
static APInt getHighBitsSet(unsigned numBits, unsigned hiBitsSet)
Constructs an APInt value that has the top hiBitsSet bits set.
Definition APInt.h:297
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition APInt.h:201
bool isOne() const
Determine if this is a value of 1.
Definition APInt.h:390
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition APInt.h:240
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition APInt.h:1228
Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
const T & front() const
Get the first element.
Definition ArrayRef.h:144
size_t size() const
Get the array size.
Definition ArrayRef.h:141
A function analysis which provides an AssumptionCache.
A cache of @llvm.assume calls within a function.
LLVM_ABI bool hasAttribute(Attribute::AttrKind Kind) const
Return true if the attribute exists in this set.
InstListType::iterator iterator
Instruction iterators...
Definition BasicBlock.h:170
BinaryOps getOpcode() const
Definition InstrTypes.h:409
Represents analyses that only rely on functions' control flow.
Definition Analysis.h:73
Value * getArgOperand(unsigned i) const
iterator_range< User::op_iterator > args()
Iteration adapter for range-for loops.
void addParamAttrs(unsigned ArgNo, const AttrBuilder &B)
Adds attributes to the indicated argument.
static LLVM_ABI CastInst * Create(Instruction::CastOps, Value *S, Type *Ty, const Twine &Name="", InsertPosition InsertBefore=nullptr)
Provides a way to construct any of the CastInst subclasses using an opcode instead of the subclass's ...
static Type * makeCmpResultType(Type *opnd_type)
Create a result type for fcmp/icmp.
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:740
bool isFPPredicate() const
Definition InstrTypes.h:845
static LLVM_ABI std::optional< CmpPredicate > getMatching(CmpPredicate A, CmpPredicate B)
Compares two CmpPredicates taking samesign into account and returns the canonicalized CmpPredicate if...
Combiner implementation.
Definition Combiner.h:33
static LLVM_ABI Constant * getExtractElement(Constant *Vec, Constant *Idx, Type *OnlyIfReducedTy=nullptr)
static LLVM_ABI Constant * getBinOpIdentity(unsigned Opcode, Type *Ty, bool AllowRHSConstant=false, bool NSZ=false)
Return the identity constant for a binary opcode.
This is the shared class of boolean and integer constants.
Definition Constants.h:87
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition Constants.h:159
This class represents a range of values.
LLVM_ABI ConstantRange urem(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an unsigned remainder operation of...
LLVM_ABI ConstantRange binaryAnd(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a binary-and of a value in this ra...
LLVM_ABI bool contains(const APInt &Val) const
Return true if the specified value is in the set.
static LLVM_ABI Constant * getSplat(ElementCount EC, Constant *Elt)
Return a ConstantVector with the specified constant in each element.
static LLVM_ABI Constant * get(ArrayRef< Constant * > V)
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:64
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:225
bool empty() const
Definition DenseMap.h:173
iterator end()
Definition DenseMap.h:143
Implements a dense probed hash-table based set.
Definition DenseSet.h:289
Analysis pass which computes a DominatorTree.
Definition Dominators.h:270
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:151
LLVM_ABI bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
This instruction extracts a single (scalar) element from a VectorType value.
Convenience struct for specifying and reasoning about fast-math flags.
Definition FMF.h:23
bool noSignedZeros() const
Definition FMF.h:67
Class to represent fixed width SIMD vectors.
unsigned getNumElements() const
static FixedVectorType * getDoubleElementsVectorType(FixedVectorType *VTy)
static LLVM_ABI FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition Type.cpp:869
Predicate getSignedPredicate() const
For example, EQ->EQ, SLE->SLE, UGT->SGT, etc.
bool isEquality() const
Return true if this predicate is either EQ or NE.
Common base class shared among various IRBuilders.
Definition IRBuilder.h:114
Value * CreateNUWMul(Value *LHS, Value *RHS, const Twine &Name="")
Definition IRBuilder.h:1491
Value * CreateInsertElement(Type *VecTy, Value *NewElt, Value *Idx, const Twine &Name="")
Definition IRBuilder.h:2637
Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")
Definition IRBuilder.h:2625
LoadInst * CreateAlignedLoad(Type *Ty, Value *Ptr, MaybeAlign Align, const char *Name)
Definition IRBuilder.h:1945
LLVM_ABI Value * CreateSelectFMF(Value *C, Value *True, Value *False, FMFSource FMFSource, const Twine &Name="", Instruction *MDFrom=nullptr)
LLVM_ABI Value * CreateVectorSplat(unsigned NumElts, Value *V, const Twine &Name="")
Return a vector value that contains.
Value * CreateExtractValue(Value *Agg, ArrayRef< unsigned > Idxs, const Twine &Name="")
Definition IRBuilder.h:2684
ConstantInt * getTrue()
Get the constant value for i1 true.
Definition IRBuilder.h:509
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > OverloadTypes, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="", ArrayRef< OperandBundleDef > OpBundles={})
Create a call to intrinsic ID with Args, mangled using OverloadTypes.
LLVM_ABI Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
Value * CreateFreeze(Value *V, const Twine &Name="")
Definition IRBuilder.h:2703
void SetCurrentDebugLocation(const DebugLoc &L)
Set location information used by debugging information.
Definition IRBuilder.h:247
Value * CreateLShr(Value *LHS, Value *RHS, const Twine &Name="", bool isExact=false)
Definition IRBuilder.h:1554
Value * CreateCast(Instruction::CastOps Op, Value *V, Type *DestTy, const Twine &Name="", MDNode *FPMathTag=nullptr, FMFSource FMFSource={})
Definition IRBuilder.h:2286
Value * CreateIsNotNeg(Value *Arg, const Twine &Name="")
Return a boolean value testing if Arg > -1.
Definition IRBuilder.h:2727
Value * CreateInBoundsGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="")
Definition IRBuilder.h:2028
Value * CreatePointerBitCastOrAddrSpaceCast(Value *V, Type *DestTy, const Twine &Name="")
Definition IRBuilder.h:2311
ConstantInt * getInt64(uint64_t C)
Get a constant 64-bit value.
Definition IRBuilder.h:534
LLVM_ABI CallInst * CreateOrReduce(Value *Src)
Create a vector int OR reduction intrinsic of the source vector.
ConstantInt * getInt32(uint32_t C)
Get a constant 32-bit value.
Definition IRBuilder.h:529
Value * CreateCmp(CmpInst::Predicate Pred, Value *LHS, Value *RHS, const Twine &Name="", MDNode *FPMathTag=nullptr)
Definition IRBuilder.h:2518
PHINode * CreatePHI(Type *Ty, unsigned NumReservedValues, const Twine &Name="")
Definition IRBuilder.h:2549
InstTy * Insert(InstTy *I, const Twine &Name="") const
Insert and return the specified instruction.
Definition IRBuilder.h:172
Value * CreateIsNeg(Value *Arg, const Twine &Name="")
Return a boolean value testing if Arg < 0.
Definition IRBuilder.h:2722
Value * CreateBitCast(Value *V, Type *DestTy, const Twine &Name="")
Definition IRBuilder.h:2252
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
Definition IRBuilder.h:1928
Value * CreateShl(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition IRBuilder.h:1533
LLVM_ABI Value * CreateNAryOp(unsigned Opc, ArrayRef< Value * > Ops, const Twine &Name="", MDNode *FPMathTag=nullptr)
Create either a UnaryOperator or BinaryOperator depending on Opc.
Value * CreateZExt(Value *V, Type *DestTy, const Twine &Name="", bool IsNonNeg=false)
Definition IRBuilder.h:2130
Value * CreateShuffleVector(Value *V1, Value *V2, Value *Mask, const Twine &Name="")
Definition IRBuilder.h:2659
Value * CreateAnd(Value *LHS, Value *RHS, const Twine &Name="")
Definition IRBuilder.h:1592
StoreInst * CreateStore(Value *Val, Value *Ptr, bool isVolatile=false)
Definition IRBuilder.h:1941
Value * CreateTrunc(Value *V, Type *DestTy, const Twine &Name="", bool IsNUW=false, bool IsNSW=false)
Definition IRBuilder.h:2116
PointerType * getPtrTy(unsigned AddrSpace=0)
Fetch the type representing a pointer.
Definition IRBuilder.h:629
Value * CreateBinOp(Instruction::BinaryOps Opc, Value *LHS, Value *RHS, const Twine &Name="", MDNode *FPMathTag=nullptr)
Definition IRBuilder.h:1753
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition IRBuilder.h:207
Value * CreateFNegFMF(Value *V, FMFSource FMFSource, const Twine &Name="", MDNode *FPMathTag=nullptr)
Definition IRBuilder.h:1866
Value * CreateICmp(CmpInst::Predicate P, Value *LHS, Value *RHS, const Twine &Name="")
Definition IRBuilder.h:2494
Value * CreateOr(Value *LHS, Value *RHS, const Twine &Name="", bool IsDisjoint=false)
Definition IRBuilder.h:1614
InstSimplifyFolder - Use InstructionSimplify to fold operations to existing values.
CostType getValue() const
This function is intended to be used as sparingly as possible, since the class provides the full rang...
void push(Instruction *I)
Push the instruction onto the worklist stack.
LLVM_ABI void setHasNoUnsignedWrap(bool b=true)
Set or clear the nuw flag on this instruction, which must be an operator which supports this flag.
LLVM_ABI void copyIRFlags(const Value *V, bool IncludeWrapFlags=true)
Convenience method to copy supported exact, fast-math, and (optionally) wrapping flags from V to this...
LLVM_ABI void setHasNoSignedWrap(bool b=true)
Set or clear the nsw flag on this instruction, which must be an operator which supports this flag.
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
LLVM_ABI void andIRFlags(const Value *V)
Logical 'and' of any supported wrapping, exact, and fast-math flags of V and this instruction.
bool isBinaryOp() const
LLVM_ABI void setFastMathFlags(FastMathFlags FMF)
Convenience function for setting multiple fast-math flags on this instruction, which must be an opera...
LLVM_ABI void setNonNeg(bool b=true)
Set or clear the nneg flag on this instruction, which must be a zext instruction.
LLVM_ABI bool comesBefore(const Instruction *Other) const
Given an instruction Other in the same basic block as this instruction, return true if this instructi...
LLVM_ABI FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
LLVM_ABI AAMDNodes getAAMetadata() const
Returns the AA metadata for this instruction.
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
bool isIdempotent() const
Return true if the instruction is idempotent:
LLVM_ABI void copyMetadata(const Instruction &SrcInst, ArrayRef< unsigned > WL=ArrayRef< unsigned >())
Copy metadata from SrcInst to this instruction.
LLVM_ABI bool hasAllowReassoc() const LLVM_READONLY
Determine whether the allow-reassociation flag is set.
bool isIntDivRem() const
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:350
unsigned getBitWidth() const
Get the number of bits in this IntegerType.
A wrapper class for inspecting calls to intrinsic functions.
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
An instruction for reading from memory.
unsigned getPointerAddressSpace() const
Returns the address space of the pointer operand.
void setAlignment(Align Align)
Type * getPointerOperandType() const
Align getAlign() const
Return the alignment of the access that is being performed.
Representation for a specific memory location.
static LLVM_ABI MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalyses & preserveSet()
Mark an analysis set as preserved.
Definition Analysis.h:151
const SDValue & getOperand(unsigned Num) const
This instruction constructs a fixed permutation of two input vectors.
int getMaskValue(unsigned Elt) const
Return the shuffle mask value of this instruction for the given element index.
VectorType * getType() const
Overload to return most specific vector type.
static LLVM_ABI void getShuffleMask(const Constant *Mask, SmallVectorImpl< int > &Result)
Convert the input shuffle mask operand to a vector of integers.
static LLVM_ABI bool isIdentityMask(ArrayRef< int > Mask, int NumSrcElts)
Return true if this shuffle mask chooses elements from exactly one source vector without lane crossin...
static void commuteShuffleMask(MutableArrayRef< int > Mask, unsigned InVecNumElts)
Change values in a shuffle permute mask assuming the two vector operands of length InVecNumElts have ...
size_type size() const
Definition SmallPtrSet.h:99
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
void assign(size_type NumElts, ValueParamT Elt)
reference emplace_back(ArgTypes &&... Args)
void reserve(size_type N)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
void resize(size_type N)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
void setAlignment(Align Align)
Analysis pass providing the TargetTransformInfo.
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
static LLVM_ABI CastContextHint getCastContextHint(const Instruction *I)
Calculates a CastContextHint from I.
@ None
The insert/extract is not used with a load/store.
LLVM_ABI InstructionCost getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, CmpInst::Predicate VecPred, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, OperandValueInfo Op1Info={OK_AnyValue, OP_None}, OperandValueInfo Op2Info={OK_AnyValue, OP_None}, const Instruction *I=nullptr) const
LLVM_ABI TypeSize getRegisterBitWidth(RegisterKind K) const
static LLVM_ABI OperandValueInfo commonOperandInfo(const Value *X, const Value *Y)
Collect common data between two OperandValueInfo inputs.
LLVM_ABI InstructionCost getMemoryOpCost(unsigned Opcode, Type *Src, Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, OperandValueInfo OpdInfo={OK_AnyValue, OP_None}, const Instruction *I=nullptr) const
LLVM_ABI bool allowVectorElementIndexingUsingGEP() const
Returns true if GEP should not be used to index into vectors for this target.
LLVM_ABI InstructionCost getShuffleCost(ShuffleKind Kind, VectorType *DstTy, VectorType *SrcTy, ArrayRef< int > Mask={}, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, int Index=0, VectorType *SubTp=nullptr, ArrayRef< const Value * > Args={}, const Instruction *CxtI=nullptr) const
LLVM_ABI InstructionCost getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA, TTI::TargetCostKind CostKind) const
LLVM_ABI InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty, std::optional< FastMathFlags > FMF, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput) const
Calculate the cost of vector reduction intrinsics.
LLVM_ABI InstructionCost getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, TTI::CastContextHint CCH, TTI::TargetCostKind CostKind=TTI::TCK_SizeAndLatency, const Instruction *I=nullptr) const
LLVM_ABI InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index=-1, const Value *Op0=nullptr, const Value *Op1=nullptr, TTI::VectorInstrContext VIC=TTI::VectorInstrContext::None) const
LLVM_ABI unsigned getRegisterClassForType(bool Vector, Type *Ty=nullptr) const
LLVM_ABI InstructionCost getMinMaxReductionCost(Intrinsic::ID IID, VectorType *Ty, FastMathFlags FMF=FastMathFlags(), TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput) const
TargetCostKind
The kind of cost model.
@ TCK_RecipThroughput
Reciprocal throughput.
@ TCK_CodeSize
Instruction code size.
LLVM_ABI InstructionCost getArithmeticInstrCost(unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, TTI::OperandValueInfo Opd1Info={TTI::OK_AnyValue, TTI::OP_None}, TTI::OperandValueInfo Opd2Info={TTI::OK_AnyValue, TTI::OP_None}, ArrayRef< const Value * > Args={}, const Instruction *CxtI=nullptr, const TargetLibraryInfo *TLibInfo=nullptr) const
This is an approximation of reciprocal throughput of a math/logic op.
LLVM_ABI InstructionCost getMemIntrinsicInstrCost(const MemIntrinsicCostAttributes &MICA, TTI::TargetCostKind CostKind) const
LLVM_ABI unsigned getMinVectorRegisterBitWidth() const
LLVM_ABI InstructionCost getAddressComputationCost(Type *PtrTy, ScalarEvolution *SE, const SCEV *Ptr, TTI::TargetCostKind CostKind) const
LLVM_ABI unsigned getNumberOfRegisters(unsigned ClassID) const
LLVM_ABI InstructionCost getInstructionCost(const User *U, ArrayRef< const Value * > Operands, TargetCostKind CostKind) const
Estimate the cost of a given IR user when lowered.
LLVM_ABI InstructionCost getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract, TTI::TargetCostKind CostKind, bool ForPoisonSrc=true, ArrayRef< Value * > VL={}, TTI::VectorInstrContext VIC=TTI::VectorInstrContext::None) const
Estimate the overhead of scalarizing an instruction.
ShuffleKind
The various kinds of shuffle patterns for vector queries.
@ SK_PermuteSingleSrc
Shuffle elements of single source vector with any shuffle mask.
@ SK_Broadcast
Broadcast element 0 to all other elements.
@ SK_PermuteTwoSrc
Merge elements from two source vectors into one with any shuffle mask.
@ SK_ExtractSubvector
ExtractSubvector Index indicates start offset.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:46
bool isPointerTy() const
True if this is an instance of PointerType.
Definition Type.h:282
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition Type.h:368
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:197
LLVMContext & getContext() const
Return the LLVMContext in which this type was uniqued.
Definition Type.h:130
LLVM_ABI unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
Definition Type.cpp:232
bool isFloatingPointTy() const
Return true if this is one of the floating-point types.
Definition Type.h:186
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition Type.h:257
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
op_range operands()
Definition User.h:267
Value * getOperand(unsigned i) const
Definition User.h:207
static LLVM_ABI bool isVPBinOp(Intrinsic::ID ID)
std::optional< unsigned > getFunctionalIntrinsicID() const
std::optional< unsigned > getFunctionalOpcode() const
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:255
const Value * stripAndAccumulateInBoundsConstantOffsets(const DataLayout &DL, APInt &Offset) const
This is a wrapper around stripAndAccumulateConstantOffsets with the in-bounds requirement set to fals...
Definition Value.h:727
LLVM_ABI bool hasOneUser() const
Return true if there is exactly one user of this value.
Definition Value.cpp:163
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition Value.h:439
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition Value.cpp:553
iterator_range< user_iterator > users()
Definition Value.h:426
LLVM_ABI Align getPointerAlignment(const DataLayout &DL) const
Returns an alignment of the pointer value.
Definition Value.cpp:989
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition Value.h:543
LLVM_ABI bool hasNUses(unsigned N) const
Return true if this Value has exactly N uses.
Definition Value.cpp:147
bool use_empty() const
Definition Value.h:346
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:319
bool user_empty() const
Definition Value.h:389
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &)
static LLVM_ABI VectorType * get(Type *ElementType, ElementCount EC)
This static method is the primary way to construct an VectorType.
Type * getElementType() const
std::pair< iterator, bool > insert(const ValueT &V)
Definition DenseSet.h:212
size_type size() const
Definition DenseSet.h:87
constexpr bool hasKnownScalarFactor(const FixedOrScalableQuantity &RHS) const
Returns true if there exists a value X where RHS.multiplyCoefficientBy(X) will result in a value whos...
Definition TypeSize.h:269
constexpr ScalarTy getKnownScalarFactor(const FixedOrScalableQuantity &RHS) const
Returns a value X where RHS.multiplyCoefficientBy(X) will result in a value whose quantity matches ou...
Definition TypeSize.h:277
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition TypeSize.h:168
const ParentTy * getParent() const
Definition ilist_node.h:34
self_iterator getIterator()
Definition ilist_node.h:123
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition ilist_node.h:348
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
Abstract Attribute helper functions.
Definition Attributor.h:165
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
constexpr char Attrs[]
Key for Kernel::Metadata::mAttrs.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition APInt.h:2277
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition APInt.h:2282
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI AttributeSet getFnAttributes(LLVMContext &C, ID id)
Return the function attributes for an intrinsic.
SpecificConstantMatch m_ZeroInt()
Convenience matchers for specific integer values.
BinaryOp_match< SpecificConstantMatch, SrcTy, TargetOpcode::G_SUB > m_Neg(const SrcTy &&Src)
Matches a register negated by a G_SUB.
OneUse_match< SubPat > m_OneUse(const SubPat &SP)
match_combine_and< Ty... > m_CombineAnd(const Ty &...Ps)
Combine pattern matchers matching all of Ps patterns.
cst_pred_ty< is_all_ones > m_AllOnes()
Match an integer or vector with all bits set.
BinaryOp_match< LHS, RHS, Instruction::And > m_And(const LHS &L, const RHS &R)
auto m_Cmp()
Matches any compare instruction and ignore it.
BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::URem > m_URem(const LHS &L, const RHS &R)
auto m_Poison()
Match an arbitrary poison constant.
ap_match< APInt > m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
CastInst_match< OpTy, TruncInst > m_Trunc(const OpTy &Op)
Matches Trunc.
specific_intval< false > m_SpecificInt(const APInt &V)
Match a specific integer value or vector with all elements equal to the value.
bool match(Val *V, const Pattern &P)
match_bind< Instruction > m_Instruction(Instruction *&I)
Match an instruction, capturing it if we match.
specificval_ty m_Specific(const Value *V)
Match if we have a specific specified value.
DisjointOr_match< LHS, RHS > m_DisjointOr(const LHS &L, const RHS &R)
BinOpPred_match< LHS, RHS, is_right_shift_op > m_Shr(const LHS &L, const RHS &R)
Matches logical shift operations.
TwoOps_match< Val_t, Idx_t, Instruction::ExtractElement > m_ExtractElt(const Val_t &Val, const Idx_t &Idx)
Matches ExtractElementInst.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
auto m_BinOp()
Match an arbitrary binary operation and ignore it.
auto m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
auto m_Constant()
Match an arbitrary Constant and ignore it.
TwoOps_match< V1_t, V2_t, Instruction::ShuffleVector > m_Shuffle(const V1_t &v1, const V2_t &v2)
Matches ShuffleVectorInst independently of mask value.
cst_pred_ty< is_non_zero_int > m_NonZeroInt()
Match a non-zero integer or a vector with all non-zero elements.
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
CastInst_match< OpTy, ZExtInst > m_ZExt(const OpTy &Op)
Matches ZExt.
OverflowingBinaryOp_match< LHS, RHS, Instruction::Shl, OverflowingBinaryOperator::NoUnsignedWrap > m_NUWShl(const LHS &L, const RHS &R)
auto m_AnyIntrinsic()
Matches any intrinsic call and ignore it.
OverflowingBinaryOp_match< LHS, RHS, Instruction::Mul, OverflowingBinaryOperator::NoUnsignedWrap > m_NUWMul(const LHS &L, const RHS &R)
BinOpPred_match< LHS, RHS, is_bitwiselogic_op, true > m_c_BitwiseLogic(const LHS &L, const RHS &R)
Matches bitwise logic operations in either order.
CastOperator_match< OpTy, Instruction::BitCast > m_BitCast(const OpTy &Op)
Matches BitCast.
match_combine_or< CastInst_match< OpTy, SExtInst >, NNegZExt_match< OpTy > > m_SExtLike(const OpTy &Op)
Match either "sext" or "zext nneg".
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
CmpClass_match< LHS, RHS, ICmpInst > m_ICmp(CmpPredicate &Pred, const LHS &L, const RHS &R)
match_combine_or< CastInst_match< OpTy, ZExtInst >, CastInst_match< OpTy, SExtInst > > m_ZExtOrSExt(const OpTy &Op)
FNeg_match< OpTy > m_FNeg(const OpTy &X)
Match 'fneg X' as 'fsub -0.0, X'.
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_Undef()
Match an arbitrary undef constant.
CastInst_match< OpTy, SExtInst > m_SExt(const OpTy &Op)
Matches SExt.
is_zero m_Zero()
Match any null constant or a vector with all elements equal to 0.
BinaryOp_match< LHS, RHS, Instruction::Or, true > m_c_Or(const LHS &L, const RHS &R)
Matches an Or with LHS and RHS in either order.
ThreeOps_match< Val_t, Elt_t, Idx_t, Instruction::InsertElement > m_InsertElt(const Val_t &Val, const Elt_t &Elt, const Idx_t &Idx)
Matches InsertElementInst.
m_Intrinsic_Ty< Opnd >::Ty m_Deinterleave2(const Opnd &Op)
auto m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
@ Valid
The data is already valid.
initializer< Ty > init(const Ty &Val)
DXILDebugInfoMap run(Module &M)
@ User
could "use" a pointer
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
NodeAddr< UseNode * > Use
Definition RDFGraph.h:385
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:315
unsigned Log2_32_Ceil(uint32_t Value)
Return the ceil log base 2 of the specified value, 32 if the value is zero.
Definition MathExtras.h:344
@ Offset
Definition DWP.cpp:558
detail::zippy< detail::zip_shortest, T, U, Args... > zip(T &&t, U &&u, Args &&...args)
zip iterator for two or more iteratable types.
Definition STLExtras.h:830
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
void stable_sort(R &&Range)
Definition STLExtras.h:2115
UnaryFunction for_each(R &&Range, UnaryFunction F)
Provide wrappers to std::for_each which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1731
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1738
LLVM_ABI Intrinsic::ID getMinMaxReductionIntrinsicOp(Intrinsic::ID RdxID)
Returns the min/max intrinsic used when expanding a min/max reduction.
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
Definition Local.cpp:535
LLVM_ABI SDValue peekThroughBitcasts(SDValue V)
Return the non-bitcasted source operand of V if it exists.
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are tuples (A, B,...
Definition STLExtras.h:2553
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
unsigned Log2_64_Ceil(uint64_t Value)
Return the ceil log base 2 of the specified value, 64 if the value is zero.
Definition MathExtras.h:350
LLVM_ABI Value * simplifyUnOp(unsigned Opcode, Value *Op, const SimplifyQuery &Q)
Given operand for a UnaryOperator, fold the result or return null.
scope_exit(Callable) -> scope_exit< Callable >
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
LLVM_ABI unsigned getArithmeticReductionInstruction(Intrinsic::ID RdxID)
Returns the arithmetic instruction opcode used when expanding a reduction.
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition STLExtras.h:2207
constexpr bool isUIntN(unsigned N, uint64_t x)
Checks if an unsigned integer fits into the given (dynamic) bit width.
Definition MathExtras.h:243
LLVM_ABI Value * simplifyCall(CallBase *Call, Value *Callee, ArrayRef< Value * > Args, const SimplifyQuery &Q)
Given a callsite, callee, and arguments, fold the result or return null.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition STLExtras.h:633
LLVM_ABI bool mustSuppressSpeculation(const LoadInst &LI)
Return true if speculation of the given load must be suppressed to avoid ordering or interfering with...
Definition Loads.cpp:441
constexpr bool isPowerOf2_64(uint64_t Value)
Return true if the argument is a power of two > 0 (64 bit edition.)
Definition MathExtras.h:284
LLVM_ABI bool widenShuffleMaskElts(int Scale, ArrayRef< int > Mask, SmallVectorImpl< int > &ScaledMask)
Try to transform a shuffle mask by replacing elements with the scaled index for an equivalent mask of...
LLVM_ABI bool isSafeToSpeculativelyExecute(const Instruction *I, const Instruction *CtxI=nullptr, AssumptionCache *AC=nullptr, const DominatorTree *DT=nullptr, const TargetLibraryInfo *TLI=nullptr, bool UseVariableInfo=true, bool IgnoreUBImplyingAttrs=true)
Return true if the instruction does not have any effects besides calculating the result and does not ...
LLVM_ABI Value * getSplatValue(const Value *V)
Get splat value if the input is a splat vector or return nullptr.
constexpr auto equal_to(T &&Arg)
Functor variant of std::equal_to that can be used as a UnaryPredicate in functional algorithms like a...
Definition STLExtras.h:2172
unsigned M1(unsigned Val)
Definition VE.h:377
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1745
LLVM_ABI bool isInstructionTriviallyDead(Instruction *I, const TargetLibraryInfo *TLI=nullptr)
Return true if the result produced by the instruction is not used, and the instruction will return.
Definition Local.cpp:403
LLVM_ABI bool isSplatValue(const Value *V, int Index=-1, unsigned Depth=0)
Return true if each element of the vector value V is poisoned or equal to every other non-poisoned el...
unsigned Log2_32(uint32_t Value)
Return the floor log base 2 of the specified value, -1 if the value is zero.
Definition MathExtras.h:331
constexpr bool isPowerOf2_32(uint32_t Value)
Return true if the argument is a power of two > 0.
Definition MathExtras.h:279
bool isModSet(const ModRefInfo MRI)
Definition ModRef.h:49
void sort(IteratorTy Start, IteratorTy End)
Definition STLExtras.h:1635
LLVM_ABI void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
LLVM_ABI bool programUndefinedIfPoison(const Instruction *Inst)
LLVM_ABI bool isSafeToLoadUnconditionally(Value *V, Align Alignment, const APInt &Size, const DataLayout &DL, Instruction *ScanFrom, AssumptionCache *AC=nullptr, const DominatorTree *DT=nullptr, const TargetLibraryInfo *TLI=nullptr)
Return true if we know that executing a load from this value cannot trap.
Definition Loads.cpp:445
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:209
FunctionAddr VTableAddr Count
Definition InstrProf.h:139
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
LLVM_ABI void propagateIRFlags(Value *I, ArrayRef< Value * > VL, Value *OpValue=nullptr, bool IncludeWrapFlags=true)
Get the intersection (logical and) of all of the potential IR flags of each scalar operation (VL) tha...
LLVM_ABI bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
MutableArrayRef(T &OneElt) -> MutableArrayRef< T >
constexpr int PoisonMaskElem
LLVM_ABI bool isSafeToSpeculativelyExecuteWithOpcode(unsigned Opcode, const Instruction *Inst, const Instruction *CtxI=nullptr, AssumptionCache *AC=nullptr, const DominatorTree *DT=nullptr, const TargetLibraryInfo *TLI=nullptr, bool UseVariableInfo=true, bool IgnoreUBImplyingAttrs=true)
This returns the same result as isSafeToSpeculativelyExecute if Opcode is the actual opcode of Inst.
@ Other
Any other memory.
Definition ModRef.h:68
TargetTransformInfo TTI
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
LLVM_ABI Value * simplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, const SimplifyQuery &Q)
Given operands for a BinaryOperator, fold the result or return null.
LLVM_ABI void narrowShuffleMaskElts(int Scale, ArrayRef< int > Mask, SmallVectorImpl< int > &ScaledMask)
Replace each shuffle mask index with the scaled sequential indices for an equivalent mask of narrowed...
LLVM_ABI Intrinsic::ID getReductionForBinop(Instruction::BinaryOps Opc)
Returns the reduction intrinsic id corresponding to the binary operation.
@ And
Bitwise or logical AND of integers.
LLVM_ABI bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, unsigned ScalarOpdIdx, const TargetTransformInfo *TTI)
Identifies if the vector form of the intrinsic has a scalar operand.
DWARFExpression::Operation Op
unsigned M0(unsigned Val)
Definition VE.h:376
LLVM_ABI unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Return the number of times the sign bit of the register is replicated into the other bits.
constexpr unsigned BitWidth
LLVM_ABI bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
LLVM_ABI Constant * getLosslessInvCast(Constant *C, Type *InvCastTo, unsigned CastOp, const DataLayout &DL, PreservedCastFlags *Flags=nullptr)
Try to cast C to InvC losslessly, satisfying CastOp(InvC) equals C, or CastOp(InvC) is a refined valu...
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
auto find_if(R &&Range, UnaryPredicate P)
Provide wrappers to std::find_if which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1771
constexpr bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition MathExtras.h:248
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1946
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Definition Alignment.h:201
bool all_equal(std::initializer_list< T > Values)
Returns true if all Values in the initializer lists are equal or the list.
Definition STLExtras.h:2165
LLVM_ABI Value * simplifyCmpInst(CmpPredicate Predicate, Value *LHS, Value *RHS, const SimplifyQuery &Q)
Given operands for a CmpInst, fold the result or return null.
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
LLVM_ABI bool isKnownNonNegative(const Value *V, const SimplifyQuery &SQ, unsigned Depth=0)
Returns true if the give value is known to be non-negative.
LLVM_ABI bool isTriviallyVectorizable(Intrinsic::ID ID)
Identify if the intrinsic is trivially vectorizable.
LLVM_ABI Intrinsic::ID getMinMaxReductionIntrinsicID(Intrinsic::ID IID)
Returns the llvm.vector.reduce min/max intrinsic that corresponds to the intrinsic op.
LLVM_ABI ConstantRange computeConstantRange(const Value *V, bool ForSigned, const SimplifyQuery &SQ, unsigned Depth=0)
Determine the possible constant range of an integer or vector of integer value.
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:862
LLVM_ABI AAMDNodes adjustForAccess(unsigned AccessSize)
Create a new AAMDNode for accessing AccessSize bytes of this AAMDNode.
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition Alignment.h:39
unsigned countMaxActiveBits() const
Returns the maximum number of bits needed to represent all possible unsigned values with these known ...
Definition KnownBits.h:310
unsigned countMinLeadingZeros() const
Returns the minimum number of leading zero bits.
Definition KnownBits.h:262
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition KnownBits.h:146
const DataLayout & DL
const Instruction * CxtI
const DominatorTree * DT
SimplifyQuery getWithInstruction(const Instruction *I) const
AssumptionCache * AC