LLVM 23.0.0git
VectorCombine.cpp
Go to the documentation of this file.
1//===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass optimizes scalar/vector interactions using target cost models. The
10// transforms implemented here may not fit in traditional loop-based or SLP
11// vectorization passes.
12//
13//===----------------------------------------------------------------------===//
14
16#include "llvm/ADT/DenseMap.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/ScopeExit.h"
20#include "llvm/ADT/Statistic.h"
25#include "llvm/Analysis/Loads.h"
30#include "llvm/IR/Dominators.h"
31#include "llvm/IR/Function.h"
32#include "llvm/IR/IRBuilder.h"
40#include <numeric>
41#include <optional>
42#include <queue>
43#include <set>
44
45#define DEBUG_TYPE "vector-combine"
47
48using namespace llvm;
49using namespace llvm::PatternMatch;
50
51STATISTIC(NumVecLoad, "Number of vector loads formed");
52STATISTIC(NumVecCmp, "Number of vector compares formed");
53STATISTIC(NumVecBO, "Number of vector binops formed");
54STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed");
55STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast");
56STATISTIC(NumScalarOps, "Number of scalar unary + binary ops formed");
57STATISTIC(NumScalarCmp, "Number of scalar compares formed");
58STATISTIC(NumScalarIntrinsic, "Number of scalar intrinsic calls formed");
59
61 "disable-vector-combine", cl::init(false), cl::Hidden,
62 cl::desc("Disable all vector combine transforms"));
63
65 "disable-binop-extract-shuffle", cl::init(false), cl::Hidden,
66 cl::desc("Disable binop extract to shuffle transforms"));
67
69 "vector-combine-max-scan-instrs", cl::init(30), cl::Hidden,
70 cl::desc("Max number of instructions to scan for vector combining."));
71
72static const unsigned InvalidIndex = std::numeric_limits<unsigned>::max();
73
74namespace {
75class VectorCombine {
76public:
77 VectorCombine(Function &F, const TargetTransformInfo &TTI,
80 bool TryEarlyFoldsOnly)
81 : F(F), Builder(F.getContext(), InstSimplifyFolder(*DL)), TTI(TTI),
82 DT(DT), AA(AA), DL(DL), CostKind(CostKind),
83 SQ(*DL, /*TLI=*/nullptr, &DT, &AC),
84 TryEarlyFoldsOnly(TryEarlyFoldsOnly) {}
85
86 bool run();
87
88private:
89 Function &F;
91 const TargetTransformInfo &TTI;
92 const DominatorTree &DT;
93 AAResults &AA;
94 const DataLayout *DL;
95 TTI::TargetCostKind CostKind;
96 const SimplifyQuery SQ;
97
98 /// If true, only perform beneficial early IR transforms. Do not introduce new
99 /// vector operations.
100 bool TryEarlyFoldsOnly;
101
102 InstructionWorklist Worklist;
103
104 /// Next instruction to iterate. It will be updated when it is erased by
105 /// RecursivelyDeleteTriviallyDeadInstructions.
106 Instruction *NextInst;
107
108 // TODO: Direct calls from the top-level "run" loop use a plain "Instruction"
109 // parameter. That should be updated to specific sub-classes because the
110 // run loop was changed to dispatch on opcode.
111 bool vectorizeLoadInsert(Instruction &I);
112 bool widenSubvectorLoad(Instruction &I);
113 ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0,
114 ExtractElementInst *Ext1,
115 unsigned PreferredExtractIndex) const;
116 bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
117 const Instruction &I,
118 ExtractElementInst *&ConvertToShuffle,
119 unsigned PreferredExtractIndex);
120 Value *foldExtExtCmp(Value *V0, Value *V1, Value *ExtIndex, Instruction &I);
121 Value *foldExtExtBinop(Value *V0, Value *V1, Value *ExtIndex, Instruction &I);
122 bool foldExtractExtract(Instruction &I);
123 bool foldInsExtFNeg(Instruction &I);
124 bool foldInsExtBinop(Instruction &I);
125 bool foldInsExtVectorToShuffle(Instruction &I);
126 bool foldBitOpOfCastops(Instruction &I);
127 bool foldBitOpOfCastConstant(Instruction &I);
128 bool foldBitcastShuffle(Instruction &I);
129 bool scalarizeOpOrCmp(Instruction &I);
130 bool scalarizeVPIntrinsic(Instruction &I);
131 bool foldExtractedCmps(Instruction &I);
132 bool foldSelectsFromBitcast(Instruction &I);
133 bool foldBinopOfReductions(Instruction &I);
134 bool foldSingleElementStore(Instruction &I);
135 bool scalarizeLoad(Instruction &I);
136 bool scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy, Value *Ptr);
137 bool scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy, Value *Ptr);
138 bool scalarizeExtExtract(Instruction &I);
139 bool foldConcatOfBoolMasks(Instruction &I);
140 bool foldPermuteOfBinops(Instruction &I);
141 bool foldShuffleOfBinops(Instruction &I);
142 bool foldShuffleOfSelects(Instruction &I);
143 bool foldShuffleOfCastops(Instruction &I);
144 bool foldShuffleOfShuffles(Instruction &I);
145 bool foldPermuteOfIntrinsic(Instruction &I);
146 bool foldShufflesOfLengthChangingShuffles(Instruction &I);
147 bool foldShuffleOfIntrinsics(Instruction &I);
148 bool foldShuffleToIdentity(Instruction &I);
149 bool foldShuffleFromReductions(Instruction &I);
150 bool foldShuffleChainsToReduce(Instruction &I);
151 bool foldCastFromReductions(Instruction &I);
152 bool foldSignBitReductionCmp(Instruction &I);
153 bool foldICmpEqZeroVectorReduce(Instruction &I);
154 bool foldEquivalentReductionCmp(Instruction &I);
155 bool foldReduceAddCmpZero(Instruction &I);
156 bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
157 bool foldInterleaveIntrinsics(Instruction &I);
158 bool foldDeinterleaveIntrinsics(Instruction &I);
159 bool foldBitcastOfVPLoad(Instruction &I);
160 bool shrinkType(Instruction &I);
161 bool shrinkLoadForShuffles(Instruction &I);
162 bool shrinkPhiOfShuffles(Instruction &I);
163
164 void replaceValue(Instruction &Old, Value &New, bool Erase = true) {
165 LLVM_DEBUG(dbgs() << "VC: Replacing: " << Old << '\n');
166 LLVM_DEBUG(dbgs() << " With: " << New << '\n');
167 Old.replaceAllUsesWith(&New);
168 if (auto *NewI = dyn_cast<Instruction>(&New)) {
169 New.takeName(&Old);
170 Worklist.pushUsersToWorkList(*NewI);
171 Worklist.pushValue(NewI);
172 }
173 if (Erase && isInstructionTriviallyDead(&Old)) {
174 eraseInstruction(Old);
175 } else {
176 Worklist.push(&Old);
177 }
178 }
179
180 void eraseInstruction(Instruction &I) {
181 LLVM_DEBUG(dbgs() << "VC: Erasing: " << I << '\n');
182 SmallVector<Value *> Ops(I.operands());
183 Worklist.remove(&I);
184 I.eraseFromParent();
185
186 // Push remaining users of the operands and then the operand itself - allows
187 // further folds that were hindered by OneUse limits.
188 SmallPtrSet<Value *, 4> Visited;
189 for (Value *Op : Ops) {
190 if (!Visited.contains(Op)) {
191 if (auto *OpI = dyn_cast<Instruction>(Op)) {
193 OpI, nullptr, nullptr, [&](Value *V) {
194 if (auto *I = dyn_cast<Instruction>(V)) {
195 LLVM_DEBUG(dbgs() << "VC: Erased: " << *I << '\n');
196 Worklist.remove(I);
197 if (I == NextInst)
198 NextInst = NextInst->getNextNode();
199 Visited.insert(I);
200 }
201 }))
202 continue;
203 Worklist.pushUsersToWorkList(*OpI);
204 Worklist.pushValue(OpI);
205 }
206 }
207 }
208 }
209};
210} // namespace
211
212/// Return the source operand of a potentially bitcasted value. If there is no
213/// bitcast, return the input value itself.
215 while (auto *BitCast = dyn_cast<BitCastInst>(V))
216 V = BitCast->getOperand(0);
217 return V;
218}
219
220static bool canWidenLoad(LoadInst *Load, const TargetTransformInfo &TTI) {
221 // Do not widen load if atomic/volatile or under asan/hwasan/memtag/tsan.
222 // The widened load may load data from dirty regions or create data races
223 // non-existent in the source.
224 if (!Load || !Load->isSimple() || !Load->hasOneUse() ||
225 Load->getFunction()->hasFnAttribute(Attribute::SanitizeMemTag) ||
227 return false;
228
229 // We are potentially transforming byte-sized (8-bit) memory accesses, so make
230 // sure we have all of our type-based constraints in place for this target.
231 Type *ScalarTy = Load->getType()->getScalarType();
232 uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
233 unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
234 if (!ScalarSize || !MinVectorSize || MinVectorSize % ScalarSize != 0 ||
235 ScalarSize % 8 != 0)
236 return false;
237
238 return true;
239}
240
241bool VectorCombine::vectorizeLoadInsert(Instruction &I) {
242 // Match insert into fixed vector of scalar value.
243 // TODO: Handle non-zero insert index.
244 Value *Scalar;
245 if (!match(&I,
247 return false;
248
249 // Optionally match an extract from another vector.
250 Value *X;
251 bool HasExtract = match(Scalar, m_ExtractElt(m_Value(X), m_ZeroInt()));
252 if (!HasExtract)
253 X = Scalar;
254
255 auto *Load = dyn_cast<LoadInst>(X);
256 if (!canWidenLoad(Load, TTI))
257 return false;
258
259 Type *ScalarTy = Scalar->getType();
260 uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
261 unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
262
263 // Check safety of replacing the scalar load with a larger vector load.
264 // We use minimal alignment (maximum flexibility) because we only care about
265 // the dereferenceable region. When calculating cost and creating a new op,
266 // we may use a larger value based on alignment attributes.
267 Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
268 assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
269
270 unsigned MinVecNumElts = MinVectorSize / ScalarSize;
271 auto *MinVecTy = VectorType::get(ScalarTy, MinVecNumElts, false);
272 unsigned OffsetEltIndex = 0;
273 Align Alignment = Load->getAlign();
274 if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), *DL, Load, SQ.AC,
275 SQ.DT)) {
276 // It is not safe to load directly from the pointer, but we can still peek
277 // through gep offsets and check if it safe to load from a base address with
278 // updated alignment. If it is, we can shuffle the element(s) into place
279 // after loading.
280 unsigned OffsetBitWidth = DL->getIndexTypeSizeInBits(SrcPtr->getType());
281 APInt Offset(OffsetBitWidth, 0);
283
284 // We want to shuffle the result down from a high element of a vector, so
285 // the offset must be positive.
286 if (Offset.isNegative())
287 return false;
288
289 // The offset must be a multiple of the scalar element to shuffle cleanly
290 // in the element's size.
291 uint64_t ScalarSizeInBytes = ScalarSize / 8;
292 if (Offset.urem(ScalarSizeInBytes) != 0)
293 return false;
294
295 // If we load MinVecNumElts, will our target element still be loaded?
296 OffsetEltIndex = Offset.udiv(ScalarSizeInBytes).getZExtValue();
297 if (OffsetEltIndex >= MinVecNumElts)
298 return false;
299
300 if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), *DL, Load,
301 SQ.AC, SQ.DT))
302 return false;
303
304 // Update alignment with offset value. Note that the offset could be negated
305 // to more accurately represent "(new) SrcPtr - Offset = (old) SrcPtr", but
306 // negation does not change the result of the alignment calculation.
307 Alignment = commonAlignment(Alignment, Offset.getZExtValue());
308 }
309
310 // Original pattern: insertelt undef, load [free casts of] PtrOp, 0
311 // Use the greater of the alignment on the load or its source pointer.
312 Alignment = std::max(SrcPtr->getPointerAlignment(*DL), Alignment);
313 Type *LoadTy = Load->getType();
314 unsigned AS = Load->getPointerAddressSpace();
315 InstructionCost OldCost =
316 TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS, CostKind);
317 APInt DemandedElts = APInt::getOneBitSet(MinVecNumElts, 0);
318 OldCost +=
319 TTI.getScalarizationOverhead(MinVecTy, DemandedElts,
320 /* Insert */ true, HasExtract, CostKind);
321
322 // New pattern: load VecPtr
323 InstructionCost NewCost =
324 TTI.getMemoryOpCost(Instruction::Load, MinVecTy, Alignment, AS, CostKind);
325 // Optionally, we are shuffling the loaded vector element(s) into place.
326 // For the mask set everything but element 0 to undef to prevent poison from
327 // propagating from the extra loaded memory. This will also optionally
328 // shrink/grow the vector from the loaded size to the output size.
329 // We assume this operation has no cost in codegen if there was no offset.
330 // Note that we could use freeze to avoid poison problems, but then we might
331 // still need a shuffle to change the vector size.
332 auto *Ty = cast<FixedVectorType>(I.getType());
333 unsigned OutputNumElts = Ty->getNumElements();
334 SmallVector<int, 16> Mask(OutputNumElts, PoisonMaskElem);
335 assert(OffsetEltIndex < MinVecNumElts && "Address offset too big");
336 Mask[0] = OffsetEltIndex;
337 if (OffsetEltIndex)
338 NewCost += TTI.getShuffleCost(TTI::SK_PermuteSingleSrc, Ty, MinVecTy, Mask,
339 CostKind);
340
341 // We can aggressively convert to the vector form because the backend can
342 // invert this transform if it does not result in a performance win.
343 if (OldCost < NewCost || !NewCost.isValid())
344 return false;
345
346 // It is safe and potentially profitable to load a vector directly:
347 // inselt undef, load Scalar, 0 --> load VecPtr
348 IRBuilder<> Builder(Load);
349 Value *CastedPtr =
350 Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Builder.getPtrTy(AS));
351 Value *VecLd = Builder.CreateAlignedLoad(MinVecTy, CastedPtr, Alignment);
352 VecLd = Builder.CreateShuffleVector(VecLd, Mask);
353
354 replaceValue(I, *VecLd);
355 ++NumVecLoad;
356 return true;
357}
358
359/// If we are loading a vector and then inserting it into a larger vector with
360/// undefined elements, try to load the larger vector and eliminate the insert.
361/// This removes a shuffle in IR and may allow combining of other loaded values.
362bool VectorCombine::widenSubvectorLoad(Instruction &I) {
363 // Match subvector insert of fixed vector.
364 auto *Shuf = cast<ShuffleVectorInst>(&I);
365 if (!Shuf->isIdentityWithPadding())
366 return false;
367
368 // Allow a non-canonical shuffle mask that is choosing elements from op1.
369 unsigned NumOpElts =
370 cast<FixedVectorType>(Shuf->getOperand(0)->getType())->getNumElements();
371 unsigned OpIndex = any_of(Shuf->getShuffleMask(), [&NumOpElts](int M) {
372 return M >= (int)(NumOpElts);
373 });
374
375 auto *Load = dyn_cast<LoadInst>(Shuf->getOperand(OpIndex));
376 if (!canWidenLoad(Load, TTI))
377 return false;
378
379 // We use minimal alignment (maximum flexibility) because we only care about
380 // the dereferenceable region. When calculating cost and creating a new op,
381 // we may use a larger value based on alignment attributes.
382 auto *Ty = cast<FixedVectorType>(I.getType());
383 Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
384 assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
385 Align Alignment = Load->getAlign();
386 if (!isSafeToLoadUnconditionally(SrcPtr, Ty, Align(1), *DL, Load, SQ.AC,
387 SQ.DT))
388 return false;
389
390 Alignment = std::max(SrcPtr->getPointerAlignment(*DL), Alignment);
391 Type *LoadTy = Load->getType();
392 unsigned AS = Load->getPointerAddressSpace();
393
394 // Original pattern: insert_subvector (load PtrOp)
395 // This conservatively assumes that the cost of a subvector insert into an
396 // undef value is 0. We could add that cost if the cost model accurately
397 // reflects the real cost of that operation.
398 InstructionCost OldCost =
399 TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS, CostKind);
400
401 // New pattern: load PtrOp
402 InstructionCost NewCost =
403 TTI.getMemoryOpCost(Instruction::Load, Ty, Alignment, AS, CostKind);
404
405 // We can aggressively convert to the vector form because the backend can
406 // invert this transform if it does not result in a performance win.
407 if (OldCost < NewCost || !NewCost.isValid())
408 return false;
409
410 IRBuilder<> Builder(Load);
411 Value *CastedPtr =
412 Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Builder.getPtrTy(AS));
413 Value *VecLd = Builder.CreateAlignedLoad(Ty, CastedPtr, Alignment);
414 replaceValue(I, *VecLd);
415 ++NumVecLoad;
416 return true;
417}
418
419/// Determine which, if any, of the inputs should be replaced by a shuffle
420/// followed by extract from a different index.
421ExtractElementInst *VectorCombine::getShuffleExtract(
422 ExtractElementInst *Ext0, ExtractElementInst *Ext1,
423 unsigned PreferredExtractIndex = InvalidIndex) const {
424 auto *Index0C = dyn_cast<ConstantInt>(Ext0->getIndexOperand());
425 auto *Index1C = dyn_cast<ConstantInt>(Ext1->getIndexOperand());
426 assert(Index0C && Index1C && "Expected constant extract indexes");
427
428 unsigned Index0 = Index0C->getZExtValue();
429 unsigned Index1 = Index1C->getZExtValue();
430
431 // If the extract indexes are identical, no shuffle is needed.
432 if (Index0 == Index1)
433 return nullptr;
434
435 Type *VecTy = Ext0->getVectorOperand()->getType();
436 assert(VecTy == Ext1->getVectorOperand()->getType() && "Need matching types");
437 InstructionCost Cost0 =
438 TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0);
439 InstructionCost Cost1 =
440 TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1);
441
442 // If both costs are invalid no shuffle is needed
443 if (!Cost0.isValid() && !Cost1.isValid())
444 return nullptr;
445
446 // We are extracting from 2 different indexes, so one operand must be shuffled
447 // before performing a vector operation and/or extract. The more expensive
448 // extract will be replaced by a shuffle.
449 if (Cost0 > Cost1)
450 return Ext0;
451 if (Cost1 > Cost0)
452 return Ext1;
453
454 // If the costs are equal and there is a preferred extract index, shuffle the
455 // opposite operand.
456 if (PreferredExtractIndex == Index0)
457 return Ext1;
458 if (PreferredExtractIndex == Index1)
459 return Ext0;
460
461 // Otherwise, replace the extract with the higher index.
462 return Index0 > Index1 ? Ext0 : Ext1;
463}
464
465/// Compare the relative costs of 2 extracts followed by scalar operation vs.
466/// vector operation(s) followed by extract. Return true if the existing
467/// instructions are cheaper than a vector alternative. Otherwise, return false
468/// and if one of the extracts should be transformed to a shufflevector, set
469/// \p ConvertToShuffle to that extract instruction.
470bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0,
471 ExtractElementInst *Ext1,
472 const Instruction &I,
473 ExtractElementInst *&ConvertToShuffle,
474 unsigned PreferredExtractIndex) {
475 auto *Ext0IndexC = dyn_cast<ConstantInt>(Ext0->getIndexOperand());
476 auto *Ext1IndexC = dyn_cast<ConstantInt>(Ext1->getIndexOperand());
477 assert(Ext0IndexC && Ext1IndexC && "Expected constant extract indexes");
478
479 unsigned Opcode = I.getOpcode();
480 Value *Ext0Src = Ext0->getVectorOperand();
481 Value *Ext1Src = Ext1->getVectorOperand();
482 Type *ScalarTy = Ext0->getType();
483 auto *VecTy = cast<VectorType>(Ext0Src->getType());
484 InstructionCost ScalarOpCost, VectorOpCost;
485
486 // Get cost estimates for scalar and vector versions of the operation.
487 bool IsBinOp = Instruction::isBinaryOp(Opcode);
488 if (IsBinOp) {
489 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy, CostKind);
490 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy, CostKind);
491 } else {
492 assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
493 "Expected a compare");
494 CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate();
495 ScalarOpCost = TTI.getCmpSelInstrCost(
496 Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred, CostKind);
497 VectorOpCost = TTI.getCmpSelInstrCost(
498 Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind);
499 }
500
501 // Get cost estimates for the extract elements. These costs will factor into
502 // both sequences.
503 unsigned Ext0Index = Ext0IndexC->getZExtValue();
504 unsigned Ext1Index = Ext1IndexC->getZExtValue();
505
506 InstructionCost Extract0Cost =
507 TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Ext0Index);
508 InstructionCost Extract1Cost =
509 TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Ext1Index);
510
511 // A more expensive extract will always be replaced by a splat shuffle.
512 // For example, if Ext0 is more expensive:
513 // opcode (extelt V0, Ext0), (ext V1, Ext1) -->
514 // extelt (opcode (splat V0, Ext0), V1), Ext1
515 // TODO: Evaluate whether that always results in lowest cost. Alternatively,
516 // check the cost of creating a broadcast shuffle and shuffling both
517 // operands to element 0.
518 unsigned BestExtIndex = Extract0Cost > Extract1Cost ? Ext0Index : Ext1Index;
519 unsigned BestInsIndex = Extract0Cost > Extract1Cost ? Ext1Index : Ext0Index;
520 InstructionCost CheapExtractCost = std::min(Extract0Cost, Extract1Cost);
521
522 // Extra uses of the extracts mean that we include those costs in the
523 // vector total because those instructions will not be eliminated.
524 InstructionCost OldCost, NewCost;
525 if (Ext0Src == Ext1Src && Ext0Index == Ext1Index) {
526 // Handle a special case. If the 2 extracts are identical, adjust the
527 // formulas to account for that. The extra use charge allows for either the
528 // CSE'd pattern or an unoptimized form with identical values:
529 // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C
530 bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2)
531 : !Ext0->hasOneUse() || !Ext1->hasOneUse();
532 OldCost = CheapExtractCost + ScalarOpCost;
533 NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost;
534 } else {
535 // Handle the general case. Each extract is actually a different value:
536 // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C
537 OldCost = Extract0Cost + Extract1Cost + ScalarOpCost;
538 NewCost = VectorOpCost + CheapExtractCost +
539 !Ext0->hasOneUse() * Extract0Cost +
540 !Ext1->hasOneUse() * Extract1Cost;
541 }
542
543 ConvertToShuffle = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex);
544 if (ConvertToShuffle) {
545 if (IsBinOp && DisableBinopExtractShuffle)
546 return true;
547
548 // If we are extracting from 2 different indexes, then one operand must be
549 // shuffled before performing the vector operation. The shuffle mask is
550 // poison except for 1 lane that is being translated to the remaining
551 // extraction lane. Therefore, it is a splat shuffle. Ex:
552 // ShufMask = { poison, poison, 0, poison }
553 // TODO: The cost model has an option for a "broadcast" shuffle
554 // (splat-from-element-0), but no option for a more general splat.
555 if (auto *FixedVecTy = dyn_cast<FixedVectorType>(VecTy)) {
556 SmallVector<int> ShuffleMask(FixedVecTy->getNumElements(),
558 ShuffleMask[BestInsIndex] = BestExtIndex;
560 VecTy, VecTy, ShuffleMask, CostKind, 0,
561 nullptr, {ConvertToShuffle});
562 } else {
564 VecTy, VecTy, {}, CostKind, 0, nullptr,
565 {ConvertToShuffle});
566 }
567 }
568
569 // Aggressively form a vector op if the cost is equal because the transform
570 // may enable further optimization.
571 // Codegen can reverse this transform (scalarize) if it was not profitable.
572 return OldCost < NewCost;
573}
574
575/// Create a shuffle that translates (shifts) 1 element from the input vector
576/// to a new element location.
577static Value *createShiftShuffle(Value *Vec, unsigned OldIndex,
578 unsigned NewIndex, IRBuilderBase &Builder) {
579 // The shuffle mask is poison except for 1 lane that is being translated
580 // to the new element index. Example for OldIndex == 2 and NewIndex == 0:
581 // ShufMask = { 2, poison, poison, poison }
582 auto *VecTy = cast<FixedVectorType>(Vec->getType());
583 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
584 ShufMask[NewIndex] = OldIndex;
585 return Builder.CreateShuffleVector(Vec, ShufMask, "shift");
586}
587
588/// Given an extract element instruction with constant index operand, shuffle
589/// the source vector (shift the scalar element) to a NewIndex for extraction.
590/// Return null if the input can be constant folded, so that we are not creating
591/// unnecessary instructions.
592static Value *translateExtract(ExtractElementInst *ExtElt, unsigned NewIndex,
593 IRBuilderBase &Builder) {
594 // Shufflevectors can only be created for fixed-width vectors.
595 Value *X = ExtElt->getVectorOperand();
596 if (!isa<FixedVectorType>(X->getType()))
597 return nullptr;
598
599 // If the extract can be constant-folded, this code is unsimplified. Defer
600 // to other passes to handle that.
601 Value *C = ExtElt->getIndexOperand();
602 assert(isa<ConstantInt>(C) && "Expected a constant index operand");
603 if (isa<Constant>(X))
604 return nullptr;
605
606 Value *Shuf = createShiftShuffle(X, cast<ConstantInt>(C)->getZExtValue(),
607 NewIndex, Builder);
608 return Shuf;
609}
610
611/// Try to reduce extract element costs by converting scalar compares to vector
612/// compares followed by extract.
613/// cmp (ext0 V0, ExtIndex), (ext1 V1, ExtIndex)
614Value *VectorCombine::foldExtExtCmp(Value *V0, Value *V1, Value *ExtIndex,
615 Instruction &I) {
616 assert(isa<CmpInst>(&I) && "Expected a compare");
617
618 // cmp Pred (extelt V0, ExtIndex), (extelt V1, ExtIndex)
619 // --> extelt (cmp Pred V0, V1), ExtIndex
620 ++NumVecCmp;
621 CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate();
622 Value *VecCmp = Builder.CreateCmp(Pred, V0, V1);
623 return Builder.CreateExtractElement(VecCmp, ExtIndex, "foldExtExtCmp");
624}
625
626/// Try to reduce extract element costs by converting scalar binops to vector
627/// binops followed by extract.
628/// bo (ext0 V0, ExtIndex), (ext1 V1, ExtIndex)
629Value *VectorCombine::foldExtExtBinop(Value *V0, Value *V1, Value *ExtIndex,
630 Instruction &I) {
631 assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
632
633 // bo (extelt V0, ExtIndex), (extelt V1, ExtIndex)
634 // --> extelt (bo V0, V1), ExtIndex
635 ++NumVecBO;
636 Value *VecBO = Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0,
637 V1, "foldExtExtBinop");
638
639 // All IR flags are safe to back-propagate because any potential poison
640 // created in unused vector elements is discarded by the extract.
641 if (auto *VecBOInst = dyn_cast<Instruction>(VecBO))
642 VecBOInst->copyIRFlags(&I);
643
644 return Builder.CreateExtractElement(VecBO, ExtIndex, "foldExtExtBinop");
645}
646
647/// Match an instruction with extracted vector operands.
648bool VectorCombine::foldExtractExtract(Instruction &I) {
649 // It is not safe to transform things like div, urem, etc. because we may
650 // create undefined behavior when executing those on unknown vector elements.
652 return false;
653
654 Instruction *I0, *I1;
655 CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
656 if (!match(&I, m_Cmp(Pred, m_Instruction(I0), m_Instruction(I1))) &&
658 return false;
659
660 Value *V0, *V1;
661 uint64_t C0, C1;
662 if (!match(I0, m_ExtractElt(m_Value(V0), m_ConstantInt(C0))) ||
663 !match(I1, m_ExtractElt(m_Value(V1), m_ConstantInt(C1))) ||
664 V0->getType() != V1->getType())
665 return false;
666
667 // For fixed-width vectors, reject out-of-bounds extract indexes
668 if (auto *FixedVecTy = dyn_cast<FixedVectorType>(V0->getType())) {
669 unsigned NumElts = FixedVecTy->getNumElements();
670 if (C0 >= NumElts || C1 >= NumElts)
671 return false;
672 }
673
674 // If the scalar value 'I' is going to be re-inserted into a vector, then try
675 // to create an extract to that same element. The extract/insert can be
676 // reduced to a "select shuffle".
677 // TODO: If we add a larger pattern match that starts from an insert, this
678 // probably becomes unnecessary.
679 auto *Ext0 = cast<ExtractElementInst>(I0);
680 auto *Ext1 = cast<ExtractElementInst>(I1);
681 uint64_t InsertIndex = InvalidIndex;
682 if (I.hasOneUse())
683 match(I.user_back(),
684 m_InsertElt(m_Value(), m_Value(), m_ConstantInt(InsertIndex)));
685
686 ExtractElementInst *ExtractToChange;
687 if (isExtractExtractCheap(Ext0, Ext1, I, ExtractToChange, InsertIndex))
688 return false;
689
690 Value *ExtOp0 = Ext0->getVectorOperand();
691 Value *ExtOp1 = Ext1->getVectorOperand();
692
693 if (ExtractToChange) {
694 unsigned CheapExtractIdx = ExtractToChange == Ext0 ? C1 : C0;
695 Value *NewExtOp =
696 translateExtract(ExtractToChange, CheapExtractIdx, Builder);
697 if (!NewExtOp)
698 return false;
699 if (ExtractToChange == Ext0)
700 ExtOp0 = NewExtOp;
701 else
702 ExtOp1 = NewExtOp;
703 }
704
705 Value *ExtIndex = ExtractToChange == Ext0 ? Ext1->getIndexOperand()
706 : Ext0->getIndexOperand();
707 Value *NewExt = Pred != CmpInst::BAD_ICMP_PREDICATE
708 ? foldExtExtCmp(ExtOp0, ExtOp1, ExtIndex, I)
709 : foldExtExtBinop(ExtOp0, ExtOp1, ExtIndex, I);
710 Worklist.push(Ext0);
711 Worklist.push(Ext1);
712 replaceValue(I, *NewExt);
713 return true;
714}
715
716/// Try to replace an extract + scalar fneg + insert with a vector fneg +
717/// shuffle.
718bool VectorCombine::foldInsExtFNeg(Instruction &I) {
719 // Match an insert (op (extract)) pattern.
720 Value *DstVec;
721 uint64_t ExtIdx, InsIdx;
722 Instruction *FNeg;
723 if (!match(&I, m_InsertElt(m_Value(DstVec), m_OneUse(m_Instruction(FNeg)),
724 m_ConstantInt(InsIdx))))
725 return false;
726
727 // Note: This handles the canonical fneg instruction and "fsub -0.0, X".
728 Value *SrcVec;
729 Instruction *Extract;
730 if (!match(FNeg, m_FNeg(m_CombineAnd(
731 m_Instruction(Extract),
732 m_ExtractElt(m_Value(SrcVec), m_ConstantInt(ExtIdx))))))
733 return false;
734
735 auto *DstVecTy = cast<FixedVectorType>(DstVec->getType());
736 auto *DstVecScalarTy = DstVecTy->getScalarType();
737 auto *SrcVecTy = dyn_cast<FixedVectorType>(SrcVec->getType());
738 if (!SrcVecTy || DstVecScalarTy != SrcVecTy->getScalarType())
739 return false;
740
741 // Ignore if insert/extract index is out of bounds or destination vector has
742 // one element
743 unsigned NumDstElts = DstVecTy->getNumElements();
744 unsigned NumSrcElts = SrcVecTy->getNumElements();
745 if (ExtIdx > NumSrcElts || InsIdx >= NumDstElts || NumDstElts == 1)
746 return false;
747
748 // We are inserting the negated element into the same lane that we extracted
749 // from. This is equivalent to a select-shuffle that chooses all but the
750 // negated element from the destination vector.
751 SmallVector<int> Mask(NumDstElts);
752 std::iota(Mask.begin(), Mask.end(), 0);
753 Mask[InsIdx] = (ExtIdx % NumDstElts) + NumDstElts;
754 InstructionCost OldCost =
755 TTI.getArithmeticInstrCost(Instruction::FNeg, DstVecScalarTy, CostKind) +
756 TTI.getVectorInstrCost(I, DstVecTy, CostKind, InsIdx);
757
758 // If the extract has one use, it will be eliminated, so count it in the
759 // original cost. If it has more than one use, ignore the cost because it will
760 // be the same before/after.
761 if (Extract->hasOneUse())
762 OldCost += TTI.getVectorInstrCost(*Extract, SrcVecTy, CostKind, ExtIdx);
763
764 InstructionCost NewCost =
765 TTI.getArithmeticInstrCost(Instruction::FNeg, SrcVecTy, CostKind) +
767 DstVecTy, Mask, CostKind);
768
769 bool NeedLenChg = SrcVecTy->getNumElements() != NumDstElts;
770 // If the lengths of the two vectors are not equal,
771 // we need to add a length-change vector. Add this cost.
772 SmallVector<int> SrcMask;
773 if (NeedLenChg) {
774 SrcMask.assign(NumDstElts, PoisonMaskElem);
775 SrcMask[ExtIdx % NumDstElts] = ExtIdx;
777 DstVecTy, SrcVecTy, SrcMask, CostKind);
778 }
779
780 LLVM_DEBUG(dbgs() << "Found an insertion of (extract)fneg : " << I
781 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
782 << "\n");
783 if (NewCost > OldCost)
784 return false;
785
786 Value *NewShuf, *LenChgShuf = nullptr;
787 // insertelt DstVec, (fneg (extractelt SrcVec, Index)), Index
788 Value *VecFNeg = Builder.CreateFNegFMF(SrcVec, FNeg);
789 if (NeedLenChg) {
790 // shuffle DstVec, (shuffle (fneg SrcVec), poison, SrcMask), Mask
791 LenChgShuf = Builder.CreateShuffleVector(VecFNeg, SrcMask);
792 NewShuf = Builder.CreateShuffleVector(DstVec, LenChgShuf, Mask);
793 Worklist.pushValue(LenChgShuf);
794 } else {
795 // shuffle DstVec, (fneg SrcVec), Mask
796 NewShuf = Builder.CreateShuffleVector(DstVec, VecFNeg, Mask);
797 }
798
799 Worklist.pushValue(VecFNeg);
800 replaceValue(I, *NewShuf);
801 return true;
802}
803
804/// Try to fold insert(binop(x,y),binop(a,b),idx)
805/// --> binop(insert(x,a,idx),insert(y,b,idx))
806bool VectorCombine::foldInsExtBinop(Instruction &I) {
807 BinaryOperator *VecBinOp, *SclBinOp;
808 uint64_t Index;
809 if (!match(&I,
810 m_InsertElt(m_OneUse(m_BinOp(VecBinOp)),
811 m_OneUse(m_BinOp(SclBinOp)), m_ConstantInt(Index))))
812 return false;
813
814 // TODO: Add support for addlike etc.
815 Instruction::BinaryOps BinOpcode = VecBinOp->getOpcode();
816 if (BinOpcode != SclBinOp->getOpcode())
817 return false;
818
819 auto *ResultTy = dyn_cast<FixedVectorType>(I.getType());
820 if (!ResultTy)
821 return false;
822
823 // TODO: Attempt to detect m_ExtractElt for scalar operands and convert to
824 // shuffle?
825
827 TTI.getInstructionCost(VecBinOp, CostKind) +
829 InstructionCost NewCost =
830 TTI.getArithmeticInstrCost(BinOpcode, ResultTy, CostKind) +
831 TTI.getVectorInstrCost(Instruction::InsertElement, ResultTy, CostKind,
832 Index, VecBinOp->getOperand(0),
833 SclBinOp->getOperand(0)) +
834 TTI.getVectorInstrCost(Instruction::InsertElement, ResultTy, CostKind,
835 Index, VecBinOp->getOperand(1),
836 SclBinOp->getOperand(1));
837
838 LLVM_DEBUG(dbgs() << "Found an insertion of two binops: " << I
839 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
840 << "\n");
841 if (NewCost > OldCost)
842 return false;
843
844 Value *NewIns0 = Builder.CreateInsertElement(VecBinOp->getOperand(0),
845 SclBinOp->getOperand(0), Index);
846 Value *NewIns1 = Builder.CreateInsertElement(VecBinOp->getOperand(1),
847 SclBinOp->getOperand(1), Index);
848 Value *NewBO = Builder.CreateBinOp(BinOpcode, NewIns0, NewIns1);
849
850 // Intersect flags from the old binops.
851 if (auto *NewInst = dyn_cast<Instruction>(NewBO)) {
852 NewInst->copyIRFlags(VecBinOp);
853 NewInst->andIRFlags(SclBinOp);
854 }
855
856 Worklist.pushValue(NewIns0);
857 Worklist.pushValue(NewIns1);
858 replaceValue(I, *NewBO);
859 return true;
860}
861
862/// Match: bitop(castop(x), castop(y)) -> castop(bitop(x, y))
863/// Supports: bitcast, trunc, sext, zext
864bool VectorCombine::foldBitOpOfCastops(Instruction &I) {
865 // Check if this is a bitwise logic operation
866 auto *BinOp = dyn_cast<BinaryOperator>(&I);
867 if (!BinOp || !BinOp->isBitwiseLogicOp())
868 return false;
869
870 // Get the cast instructions
871 auto *LHSCast = dyn_cast<CastInst>(BinOp->getOperand(0));
872 auto *RHSCast = dyn_cast<CastInst>(BinOp->getOperand(1));
873 if (!LHSCast || !RHSCast) {
874 LLVM_DEBUG(dbgs() << " One or both operands are not cast instructions\n");
875 return false;
876 }
877
878 // Both casts must be the same type
879 Instruction::CastOps CastOpcode = LHSCast->getOpcode();
880 if (CastOpcode != RHSCast->getOpcode())
881 return false;
882
883 // Only handle supported cast operations
884 switch (CastOpcode) {
885 case Instruction::BitCast:
886 case Instruction::Trunc:
887 case Instruction::SExt:
888 case Instruction::ZExt:
889 break;
890 default:
891 return false;
892 }
893
894 Value *LHSSrc = LHSCast->getOperand(0);
895 Value *RHSSrc = RHSCast->getOperand(0);
896
897 // Source types must match
898 if (LHSSrc->getType() != RHSSrc->getType())
899 return false;
900
901 auto *SrcTy = LHSSrc->getType();
902 auto *DstTy = I.getType();
903 // Bitcasts can handle scalar/vector mixes, such as i16 -> <16 x i1>.
904 // Other casts only handle vector types with integer elements.
905 if (CastOpcode != Instruction::BitCast &&
906 (!isa<FixedVectorType>(SrcTy) || !isa<FixedVectorType>(DstTy)))
907 return false;
908
909 // Only integer scalar/vector values are legal for bitwise logic operations.
910 if (!SrcTy->getScalarType()->isIntegerTy() ||
911 !DstTy->getScalarType()->isIntegerTy())
912 return false;
913
914 // Cost Check :
915 // OldCost = bitlogic + 2*casts
916 // NewCost = bitlogic + cast
917
918 // Calculate specific costs for each cast with instruction context
920 CastOpcode, DstTy, SrcTy, TTI::CastContextHint::None, CostKind, LHSCast);
922 CastOpcode, DstTy, SrcTy, TTI::CastContextHint::None, CostKind, RHSCast);
923
924 InstructionCost OldCost =
925 TTI.getArithmeticInstrCost(BinOp->getOpcode(), DstTy, CostKind) +
926 LHSCastCost + RHSCastCost;
927
928 // For new cost, we can't provide an instruction (it doesn't exist yet)
929 InstructionCost GenericCastCost = TTI.getCastInstrCost(
930 CastOpcode, DstTy, SrcTy, TTI::CastContextHint::None, CostKind);
931
932 InstructionCost NewCost =
933 TTI.getArithmeticInstrCost(BinOp->getOpcode(), SrcTy, CostKind) +
934 GenericCastCost;
935
936 // Account for multi-use casts using specific costs
937 if (!LHSCast->hasOneUse())
938 NewCost += LHSCastCost;
939 if (!RHSCast->hasOneUse())
940 NewCost += RHSCastCost;
941
942 LLVM_DEBUG(dbgs() << "foldBitOpOfCastops: OldCost=" << OldCost
943 << " NewCost=" << NewCost << "\n");
944
945 if (NewCost > OldCost)
946 return false;
947
948 // Create the operation on the source type
949 Value *NewOp = Builder.CreateBinOp(BinOp->getOpcode(), LHSSrc, RHSSrc,
950 BinOp->getName() + ".inner");
951 if (auto *NewBinOp = dyn_cast<BinaryOperator>(NewOp))
952 NewBinOp->copyIRFlags(BinOp);
953
954 Worklist.pushValue(NewOp);
955
956 // Create the cast operation directly to ensure we get a new instruction
957 Instruction *NewCast = CastInst::Create(CastOpcode, NewOp, I.getType());
958
959 // Preserve cast instruction flags
960 NewCast->copyIRFlags(LHSCast);
961 NewCast->andIRFlags(RHSCast);
962
963 // Insert the new instruction
964 Value *Result = Builder.Insert(NewCast);
965
966 replaceValue(I, *Result);
967 return true;
968}
969
970/// Match:
971// bitop(castop(x), C) ->
972// bitop(castop(x), castop(InvC)) ->
973// castop(bitop(x, InvC))
974// Supports: bitcast
975bool VectorCombine::foldBitOpOfCastConstant(Instruction &I) {
977 Constant *C;
978
979 // Check if this is a bitwise logic operation
981 return false;
982
983 // Get the cast instructions
984 auto *LHSCast = dyn_cast<CastInst>(LHS);
985 if (!LHSCast)
986 return false;
987
988 Instruction::CastOps CastOpcode = LHSCast->getOpcode();
989
990 // Only handle supported cast operations
991 switch (CastOpcode) {
992 case Instruction::BitCast:
993 case Instruction::ZExt:
994 case Instruction::SExt:
995 case Instruction::Trunc:
996 break;
997 default:
998 return false;
999 }
1000
1001 Value *LHSSrc = LHSCast->getOperand(0);
1002
1003 auto *SrcTy = LHSSrc->getType();
1004 auto *DstTy = I.getType();
1005 // Bitcasts can handle scalar/vector mixes, such as i16 -> <16 x i1>.
1006 // Other casts only handle vector types with integer elements.
1007 if (CastOpcode != Instruction::BitCast &&
1008 (!isa<FixedVectorType>(SrcTy) || !isa<FixedVectorType>(DstTy)))
1009 return false;
1010
1011 // Only integer scalar/vector values are legal for bitwise logic operations.
1012 if (!SrcTy->getScalarType()->isIntegerTy() ||
1013 !DstTy->getScalarType()->isIntegerTy())
1014 return false;
1015
1016 // Find the constant InvC, such that castop(InvC) equals to C.
1017 PreservedCastFlags RHSFlags;
1018 Constant *InvC = getLosslessInvCast(C, SrcTy, CastOpcode, *DL, &RHSFlags);
1019 if (!InvC)
1020 return false;
1021
1022 // Cost Check :
1023 // OldCost = bitlogic + cast
1024 // NewCost = bitlogic + cast
1025
1026 // Calculate specific costs for each cast with instruction context
1027 InstructionCost LHSCastCost = TTI.getCastInstrCost(
1028 CastOpcode, DstTy, SrcTy, TTI::CastContextHint::None, CostKind, LHSCast);
1029
1030 InstructionCost OldCost =
1031 TTI.getArithmeticInstrCost(I.getOpcode(), DstTy, CostKind) + LHSCastCost;
1032
1033 // For new cost, we can't provide an instruction (it doesn't exist yet)
1034 InstructionCost GenericCastCost = TTI.getCastInstrCost(
1035 CastOpcode, DstTy, SrcTy, TTI::CastContextHint::None, CostKind);
1036
1037 InstructionCost NewCost =
1038 TTI.getArithmeticInstrCost(I.getOpcode(), SrcTy, CostKind) +
1039 GenericCastCost;
1040
1041 // Account for multi-use casts using specific costs
1042 if (!LHSCast->hasOneUse())
1043 NewCost += LHSCastCost;
1044
1045 LLVM_DEBUG(dbgs() << "foldBitOpOfCastConstant: OldCost=" << OldCost
1046 << " NewCost=" << NewCost << "\n");
1047
1048 if (NewCost > OldCost)
1049 return false;
1050
1051 // Create the operation on the source type
1052 Value *NewOp = Builder.CreateBinOp((Instruction::BinaryOps)I.getOpcode(),
1053 LHSSrc, InvC, I.getName() + ".inner");
1054 if (auto *NewBinOp = dyn_cast<BinaryOperator>(NewOp))
1055 NewBinOp->copyIRFlags(&I);
1056
1057 Worklist.pushValue(NewOp);
1058
1059 // Create the cast operation directly to ensure we get a new instruction
1060 Instruction *NewCast = CastInst::Create(CastOpcode, NewOp, I.getType());
1061
1062 // Preserve cast instruction flags
1063 if (RHSFlags.NNeg)
1064 NewCast->setNonNeg();
1065 if (RHSFlags.NUW)
1066 NewCast->setHasNoUnsignedWrap();
1067 if (RHSFlags.NSW)
1068 NewCast->setHasNoSignedWrap();
1069
1070 NewCast->andIRFlags(LHSCast);
1071
1072 // Insert the new instruction
1073 Value *Result = Builder.Insert(NewCast);
1074
1075 replaceValue(I, *Result);
1076 return true;
1077}
1078
1079/// If this is a bitcast of a shuffle, try to bitcast the source vector to the
1080/// destination type followed by shuffle. This can enable further transforms by
1081/// moving bitcasts or shuffles together.
1082bool VectorCombine::foldBitcastShuffle(Instruction &I) {
1083 Value *V0, *V1;
1084 ArrayRef<int> Mask;
1085 if (!match(&I, m_BitCast(m_OneUse(
1086 m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(Mask))))))
1087 return false;
1088
1089 // 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for
1090 // scalable type is unknown; Second, we cannot reason if the narrowed shuffle
1091 // mask for scalable type is a splat or not.
1092 // 2) Disallow non-vector casts.
1093 // TODO: We could allow any shuffle.
1094 auto *DestTy = dyn_cast<FixedVectorType>(I.getType());
1095 auto *SrcTy = dyn_cast<FixedVectorType>(V0->getType());
1096 if (!DestTy || !SrcTy)
1097 return false;
1098
1099 unsigned DestEltSize = DestTy->getScalarSizeInBits();
1100 unsigned SrcEltSize = SrcTy->getScalarSizeInBits();
1101 if (SrcTy->getPrimitiveSizeInBits() % DestEltSize != 0)
1102 return false;
1103
1104 bool IsUnary = isa<UndefValue>(V1);
1105
1106 // For binary shuffles, only fold bitcast(shuffle(X,Y))
1107 // if it won't increase the number of bitcasts.
1108 if (!IsUnary) {
1111 if (!(BCTy0 && BCTy0->getElementType() == DestTy->getElementType()) &&
1112 !(BCTy1 && BCTy1->getElementType() == DestTy->getElementType()))
1113 return false;
1114 }
1115
1116 SmallVector<int, 16> NewMask;
1117 if (DestEltSize <= SrcEltSize) {
1118 // The bitcast is from wide to narrow/equal elements. The shuffle mask can
1119 // always be expanded to the equivalent form choosing narrower elements.
1120 if (SrcEltSize % DestEltSize != 0)
1121 return false;
1122 unsigned ScaleFactor = SrcEltSize / DestEltSize;
1123 narrowShuffleMaskElts(ScaleFactor, Mask, NewMask);
1124 } else {
1125 // The bitcast is from narrow elements to wide elements. The shuffle mask
1126 // must choose consecutive elements to allow casting first.
1127 if (DestEltSize % SrcEltSize != 0)
1128 return false;
1129 unsigned ScaleFactor = DestEltSize / SrcEltSize;
1130 if (!widenShuffleMaskElts(ScaleFactor, Mask, NewMask))
1131 return false;
1132 }
1133
1134 // Bitcast the shuffle src - keep its original width but using the destination
1135 // scalar type.
1136 unsigned NumSrcElts = SrcTy->getPrimitiveSizeInBits() / DestEltSize;
1137 auto *NewShuffleTy =
1138 FixedVectorType::get(DestTy->getScalarType(), NumSrcElts);
1139 auto *OldShuffleTy =
1140 FixedVectorType::get(SrcTy->getScalarType(), Mask.size());
1141 unsigned NumOps = IsUnary ? 1 : 2;
1142
1143 // The new shuffle must not cost more than the old shuffle.
1147
1148 InstructionCost NewCost =
1149 TTI.getShuffleCost(SK, DestTy, NewShuffleTy, NewMask, CostKind) +
1150 (NumOps * TTI.getCastInstrCost(Instruction::BitCast, NewShuffleTy, SrcTy,
1151 TargetTransformInfo::CastContextHint::None,
1152 CostKind));
1153 InstructionCost OldCost =
1154 TTI.getShuffleCost(SK, OldShuffleTy, SrcTy, Mask, CostKind) +
1155 TTI.getCastInstrCost(Instruction::BitCast, DestTy, OldShuffleTy,
1156 TargetTransformInfo::CastContextHint::None,
1157 CostKind);
1158
1159 LLVM_DEBUG(dbgs() << "Found a bitcasted shuffle: " << I << "\n OldCost: "
1160 << OldCost << " vs NewCost: " << NewCost << "\n");
1161
1162 if (NewCost > OldCost || !NewCost.isValid())
1163 return false;
1164
1165 // bitcast (shuf V0, V1, MaskC) --> shuf (bitcast V0), (bitcast V1), MaskC'
1166 ++NumShufOfBitcast;
1167 Value *CastV0 = Builder.CreateBitCast(peekThroughBitcasts(V0), NewShuffleTy);
1168 Value *CastV1 = Builder.CreateBitCast(peekThroughBitcasts(V1), NewShuffleTy);
1169 Value *Shuf = Builder.CreateShuffleVector(CastV0, CastV1, NewMask);
1170 replaceValue(I, *Shuf);
1171 return true;
1172}
1173
1174/// VP Intrinsics whose vector operands are both splat values may be simplified
1175/// into the scalar version of the operation and the result splatted. This
1176/// can lead to scalarization down the line.
1177bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
1178 if (!isa<VPIntrinsic>(I))
1179 return false;
1180 VPIntrinsic &VPI = cast<VPIntrinsic>(I);
1181 Value *Op0 = VPI.getArgOperand(0);
1182 Value *Op1 = VPI.getArgOperand(1);
1183
1184 if (!isSplatValue(Op0) || !isSplatValue(Op1))
1185 return false;
1186
1187 // Check getSplatValue early in this function, to avoid doing unnecessary
1188 // work.
1189 Value *ScalarOp0 = getSplatValue(Op0);
1190 Value *ScalarOp1 = getSplatValue(Op1);
1191 if (!ScalarOp0 || !ScalarOp1)
1192 return false;
1193
1194 // For the binary VP intrinsics supported here, the result on disabled lanes
1195 // is a poison value. For now, only do this simplification if all lanes
1196 // are active.
1197 // TODO: Relax the condition that all lanes are active by using insertelement
1198 // on inactive lanes.
1199 auto IsAllTrueMask = [](Value *MaskVal) {
1200 if (Value *SplattedVal = getSplatValue(MaskVal))
1201 if (auto *ConstValue = dyn_cast<Constant>(SplattedVal))
1202 return ConstValue->isAllOnesValue();
1203 return false;
1204 };
1205 if (!IsAllTrueMask(VPI.getArgOperand(2)))
1206 return false;
1207
1208 // Check to make sure we support scalarization of the intrinsic
1209 Intrinsic::ID IntrID = VPI.getIntrinsicID();
1210 if (!VPBinOpIntrinsic::isVPBinOp(IntrID))
1211 return false;
1212
1213 // Calculate cost of splatting both operands into vectors and the vector
1214 // intrinsic
1215 VectorType *VecTy = cast<VectorType>(VPI.getType());
1216 SmallVector<int> Mask;
1217 if (auto *FVTy = dyn_cast<FixedVectorType>(VecTy))
1218 Mask.resize(FVTy->getNumElements(), 0);
1219 InstructionCost SplatCost =
1220 TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, CostKind, 0) +
1222 CostKind);
1223
1224 // Calculate the cost of the VP Intrinsic
1226 for (Value *V : VPI.args())
1227 Args.push_back(V->getType());
1228 IntrinsicCostAttributes Attrs(IntrID, VecTy, Args);
1229 InstructionCost VectorOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
1230 InstructionCost OldCost = 2 * SplatCost + VectorOpCost;
1231
1232 // Determine scalar opcode
1233 std::optional<unsigned> FunctionalOpcode =
1234 VPI.getFunctionalOpcode();
1235 std::optional<Intrinsic::ID> ScalarIntrID = std::nullopt;
1236 if (!FunctionalOpcode) {
1237 ScalarIntrID = VPI.getFunctionalIntrinsicID();
1238 if (!ScalarIntrID)
1239 return false;
1240 }
1241
1242 // Calculate cost of scalarizing
1243 InstructionCost ScalarOpCost = 0;
1244 if (ScalarIntrID) {
1245 IntrinsicCostAttributes Attrs(*ScalarIntrID, VecTy->getScalarType(), Args);
1246 ScalarOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
1247 } else {
1248 ScalarOpCost = TTI.getArithmeticInstrCost(*FunctionalOpcode,
1249 VecTy->getScalarType(), CostKind);
1250 }
1251
1252 // The existing splats may be kept around if other instructions use them.
1253 InstructionCost CostToKeepSplats =
1254 (SplatCost * !Op0->hasOneUse()) + (SplatCost * !Op1->hasOneUse());
1255 InstructionCost NewCost = ScalarOpCost + SplatCost + CostToKeepSplats;
1256
1257 LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI
1258 << "\n");
1259 LLVM_DEBUG(dbgs() << "Cost of Intrinsic: " << OldCost
1260 << ", Cost of scalarizing:" << NewCost << "\n");
1261
1262 // We want to scalarize unless the vector variant actually has lower cost.
1263 if (OldCost < NewCost || !NewCost.isValid())
1264 return false;
1265
1266 // Scalarize the intrinsic
1267 ElementCount EC = cast<VectorType>(Op0->getType())->getElementCount();
1268 Value *EVL = VPI.getArgOperand(3);
1269
1270 // If the VP op might introduce UB or poison, we can scalarize it provided
1271 // that we know the EVL > 0: If the EVL is zero, then the original VP op
1272 // becomes a no-op and thus won't be UB, so make sure we don't introduce UB by
1273 // scalarizing it.
1274 bool SafeToSpeculate;
1275 if (ScalarIntrID)
1276 SafeToSpeculate = Intrinsic::getFnAttributes(I.getContext(), *ScalarIntrID)
1277 .hasAttribute(Attribute::AttrKind::Speculatable);
1278 else
1280 *FunctionalOpcode, &VPI, nullptr, SQ.AC, SQ.DT);
1281 if (!SafeToSpeculate &&
1282 !isKnownNonZero(EVL, SimplifyQuery(*DL, SQ.DT, SQ.AC, &VPI)))
1283 return false;
1284
1285 Value *ScalarVal =
1286 ScalarIntrID
1287 ? Builder.CreateIntrinsic(VecTy->getScalarType(), *ScalarIntrID,
1288 {ScalarOp0, ScalarOp1})
1289 : Builder.CreateBinOp((Instruction::BinaryOps)(*FunctionalOpcode),
1290 ScalarOp0, ScalarOp1);
1291
1292 replaceValue(VPI, *Builder.CreateVectorSplat(EC, ScalarVal));
1293 return true;
1294}
1295
1296/// Match a vector op/compare/intrinsic with at least one
1297/// inserted scalar operand and convert to scalar op/cmp/intrinsic followed
1298/// by insertelement.
1299bool VectorCombine::scalarizeOpOrCmp(Instruction &I) {
1300 auto *UO = dyn_cast<UnaryOperator>(&I);
1301 auto *BO = dyn_cast<BinaryOperator>(&I);
1302 auto *CI = dyn_cast<CmpInst>(&I);
1303 auto *II = dyn_cast<IntrinsicInst>(&I);
1304 if (!UO && !BO && !CI && !II)
1305 return false;
1306
1307 // TODO: Allow intrinsics with different argument types
1308 if (II) {
1309 if (!isTriviallyVectorizable(II->getIntrinsicID()))
1310 return false;
1311 for (auto [Idx, Arg] : enumerate(II->args()))
1312 if (Arg->getType() != II->getType() &&
1313 !isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Idx, &TTI))
1314 return false;
1315 }
1316
1317 // Do not convert the vector condition of a vector select into a scalar
1318 // condition. That may cause problems for codegen because of differences in
1319 // boolean formats and register-file transfers.
1320 // TODO: Can we account for that in the cost model?
1321 if (CI)
1322 for (User *U : I.users())
1323 if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value())))
1324 return false;
1325
1326 // Match constant vectors or scalars being inserted into constant vectors:
1327 // vec_op [VecC0 | (inselt VecC0, V0, Index)], ...
1328 SmallVector<Value *> VecCs, ScalarOps;
1329 std::optional<uint64_t> Index;
1330
1331 auto Ops = II ? II->args() : I.operands();
1332 for (auto [OpNum, Op] : enumerate(Ops)) {
1333 Constant *VecC;
1334 Value *V;
1335 uint64_t InsIdx = 0;
1336 if (match(Op.get(), m_InsertElt(m_Constant(VecC), m_Value(V),
1337 m_ConstantInt(InsIdx)))) {
1338 // Bail if any inserts are out of bounds.
1339 VectorType *OpTy = cast<VectorType>(Op->getType());
1340 if (OpTy->getElementCount().getKnownMinValue() <= InsIdx)
1341 return false;
1342 // All inserts must have the same index.
1343 // TODO: Deal with mismatched index constants and variable indexes?
1344 if (!Index)
1345 Index = InsIdx;
1346 else if (InsIdx != *Index)
1347 return false;
1348 VecCs.push_back(VecC);
1349 ScalarOps.push_back(V);
1350 } else if (II && isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(),
1351 OpNum, &TTI)) {
1352 VecCs.push_back(Op.get());
1353 ScalarOps.push_back(Op.get());
1354 } else if (match(Op.get(), m_Constant(VecC))) {
1355 VecCs.push_back(VecC);
1356 ScalarOps.push_back(nullptr);
1357 } else {
1358 return false;
1359 }
1360 }
1361
1362 // Bail if all operands are constant.
1363 if (!Index.has_value())
1364 return false;
1365
1366 VectorType *VecTy = cast<VectorType>(I.getType());
1367 Type *ScalarTy = VecTy->getScalarType();
1368 assert(VecTy->isVectorTy() &&
1369 (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() ||
1370 ScalarTy->isPointerTy()) &&
1371 "Unexpected types for insert element into binop or cmp");
1372
1373 unsigned Opcode = I.getOpcode();
1374 InstructionCost ScalarOpCost, VectorOpCost;
1375 if (CI) {
1376 CmpInst::Predicate Pred = CI->getPredicate();
1377 ScalarOpCost = TTI.getCmpSelInstrCost(
1378 Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred, CostKind);
1379 VectorOpCost = TTI.getCmpSelInstrCost(
1380 Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind);
1381 } else if (UO || BO) {
1382 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy, CostKind);
1383 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy, CostKind);
1384 } else {
1385 IntrinsicCostAttributes ScalarICA(
1386 II->getIntrinsicID(), ScalarTy,
1387 SmallVector<Type *>(II->arg_size(), ScalarTy));
1388 ScalarOpCost = TTI.getIntrinsicInstrCost(ScalarICA, CostKind);
1389 IntrinsicCostAttributes VectorICA(
1390 II->getIntrinsicID(), VecTy,
1391 SmallVector<Type *>(II->arg_size(), VecTy));
1392 VectorOpCost = TTI.getIntrinsicInstrCost(VectorICA, CostKind);
1393 }
1394
1395 // Fold the vector constants in the original vectors into a new base vector to
1396 // get more accurate cost modelling.
1397 Value *NewVecC = nullptr;
1398 if (CI)
1399 NewVecC = simplifyCmpInst(CI->getPredicate(), VecCs[0], VecCs[1], SQ);
1400 else if (UO)
1401 NewVecC =
1402 simplifyUnOp(UO->getOpcode(), VecCs[0], UO->getFastMathFlags(), SQ);
1403 else if (BO)
1404 NewVecC = simplifyBinOp(BO->getOpcode(), VecCs[0], VecCs[1], SQ);
1405 else if (II)
1406 NewVecC = simplifyCall(II, II->getCalledOperand(), VecCs, SQ);
1407
1408 if (!NewVecC)
1409 return false;
1410
1411 // Get cost estimate for the insert element. This cost will factor into
1412 // both sequences.
1413 InstructionCost OldCost = VectorOpCost;
1414 InstructionCost NewCost =
1415 ScalarOpCost + TTI.getVectorInstrCost(Instruction::InsertElement, VecTy,
1416 CostKind, *Index, NewVecC);
1417
1418 for (auto [Idx, Op, VecC, Scalar] : enumerate(Ops, VecCs, ScalarOps)) {
1419 if (!Scalar || (II && isVectorIntrinsicWithScalarOpAtArg(
1420 II->getIntrinsicID(), Idx, &TTI)))
1421 continue;
1423 Instruction::InsertElement, VecTy, CostKind, *Index, VecC, Scalar);
1424 OldCost += InsertCost;
1425 NewCost += !Op->hasOneUse() * InsertCost;
1426 }
1427
1428 // We want to scalarize unless the vector variant actually has lower cost.
1429 if (OldCost < NewCost || !NewCost.isValid())
1430 return false;
1431
1432 // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
1433 // inselt NewVecC, (scalar_op V0, V1), Index
1434 if (CI)
1435 ++NumScalarCmp;
1436 else if (UO || BO)
1437 ++NumScalarOps;
1438 else
1439 ++NumScalarIntrinsic;
1440
1441 // For constant cases, extract the scalar element, this should constant fold.
1442 for (auto [OpIdx, Scalar, VecC] : enumerate(ScalarOps, VecCs))
1443 if (!Scalar)
1445 cast<Constant>(VecC), Builder.getInt64(*Index));
1446
1447 Value *Scalar;
1448 if (CI)
1449 Scalar = Builder.CreateCmp(CI->getPredicate(), ScalarOps[0], ScalarOps[1]);
1450 else if (UO || BO)
1451 Scalar = Builder.CreateNAryOp(Opcode, ScalarOps);
1452 else
1453 Scalar = Builder.CreateIntrinsic(ScalarTy, II->getIntrinsicID(), ScalarOps);
1454
1455 Scalar->setName(I.getName() + ".scalar");
1456
1457 // All IR flags are safe to back-propagate. There is no potential for extra
1458 // poison to be created by the scalar instruction.
1459 if (auto *ScalarInst = dyn_cast<Instruction>(Scalar))
1460 ScalarInst->copyIRFlags(&I);
1461
1462 Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, *Index);
1463 replaceValue(I, *Insert);
1464 return true;
1465}
1466
1467/// Try to combine a scalar binop + 2 scalar compares of extracted elements of
1468/// a vector into vector operations followed by extract. Note: The SLP pass
1469/// may miss this pattern because of implementation problems.
1470bool VectorCombine::foldExtractedCmps(Instruction &I) {
1471 auto *BI = dyn_cast<BinaryOperator>(&I);
1472
1473 // We are looking for a scalar binop of booleans.
1474 // binop i1 (cmp Pred I0, C0), (cmp Pred I1, C1)
1475 if (!BI || !I.getType()->isIntegerTy(1))
1476 return false;
1477
1478 // The compare predicates should match, and each compare should have a
1479 // constant operand.
1480 Value *B0 = I.getOperand(0), *B1 = I.getOperand(1);
1481 Instruction *I0, *I1;
1482 Constant *C0, *C1;
1483 CmpPredicate P0, P1;
1484 if (!match(B0, m_Cmp(P0, m_Instruction(I0), m_Constant(C0))) ||
1485 !match(B1, m_Cmp(P1, m_Instruction(I1), m_Constant(C1))))
1486 return false;
1487
1488 auto MatchingPred = CmpPredicate::getMatching(P0, P1);
1489 if (!MatchingPred)
1490 return false;
1491
1492 // The compare operands must be extracts of the same vector with constant
1493 // extract indexes.
1494 Value *X;
1495 uint64_t Index0, Index1;
1496 if (!match(I0, m_ExtractElt(m_Value(X), m_ConstantInt(Index0))) ||
1497 !match(I1, m_ExtractElt(m_Specific(X), m_ConstantInt(Index1))))
1498 return false;
1499
1500 auto *Ext0 = cast<ExtractElementInst>(I0);
1501 auto *Ext1 = cast<ExtractElementInst>(I1);
1502 ExtractElementInst *ConvertToShuf = getShuffleExtract(Ext0, Ext1, CostKind);
1503 if (!ConvertToShuf)
1504 return false;
1505 assert((ConvertToShuf == Ext0 || ConvertToShuf == Ext1) &&
1506 "Unknown ExtractElementInst");
1507
1508 // The original scalar pattern is:
1509 // binop i1 (cmp Pred (ext X, Index0), C0), (cmp Pred (ext X, Index1), C1)
1510 CmpInst::Predicate Pred = *MatchingPred;
1511 unsigned CmpOpcode =
1512 CmpInst::isFPPredicate(Pred) ? Instruction::FCmp : Instruction::ICmp;
1513 auto *VecTy = dyn_cast<FixedVectorType>(X->getType());
1514 if (!VecTy)
1515 return false;
1516
1517 InstructionCost Ext0Cost =
1518 TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0);
1519 InstructionCost Ext1Cost =
1520 TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1);
1522 CmpOpcode, I0->getType(), CmpInst::makeCmpResultType(I0->getType()), Pred,
1523 CostKind);
1524
1525 InstructionCost OldCost =
1526 Ext0Cost + Ext1Cost + CmpCost * 2 +
1527 TTI.getArithmeticInstrCost(I.getOpcode(), I.getType(), CostKind);
1528
1529 // The proposed vector pattern is:
1530 // vcmp = cmp Pred X, VecC
1531 // ext (binop vNi1 vcmp, (shuffle vcmp, Index1)), Index0
1532 int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0;
1533 int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1;
1536 CmpOpcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind);
1537 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
1538 ShufMask[CheapIndex] = ExpensiveIndex;
1540 CmpTy, ShufMask, CostKind);
1541 NewCost += TTI.getArithmeticInstrCost(I.getOpcode(), CmpTy, CostKind);
1542 NewCost += TTI.getVectorInstrCost(*Ext0, CmpTy, CostKind, CheapIndex);
1543 NewCost += Ext0->hasOneUse() ? 0 : Ext0Cost;
1544 NewCost += Ext1->hasOneUse() ? 0 : Ext1Cost;
1545
1546 // Aggressively form vector ops if the cost is equal because the transform
1547 // may enable further optimization.
1548 // Codegen can reverse this transform (scalarize) if it was not profitable.
1549 if (OldCost < NewCost || !NewCost.isValid())
1550 return false;
1551
1552 // Create a vector constant from the 2 scalar constants.
1553 SmallVector<Constant *, 32> CmpC(VecTy->getNumElements(),
1554 PoisonValue::get(VecTy->getElementType()));
1555 CmpC[Index0] = C0;
1556 CmpC[Index1] = C1;
1557 Value *VCmp = Builder.CreateCmp(Pred, X, ConstantVector::get(CmpC));
1558 Value *Shuf = createShiftShuffle(VCmp, ExpensiveIndex, CheapIndex, Builder);
1559 Value *LHS = ConvertToShuf == Ext0 ? Shuf : VCmp;
1560 Value *RHS = ConvertToShuf == Ext0 ? VCmp : Shuf;
1561 Value *VecLogic = Builder.CreateBinOp(BI->getOpcode(), LHS, RHS);
1562 Value *NewExt = Builder.CreateExtractElement(VecLogic, CheapIndex);
1563 replaceValue(I, *NewExt);
1564 ++NumVecCmpBO;
1565 return true;
1566}
1567
1568/// Try to fold scalar selects that select between extracted elements and zero
1569/// into extracting from a vector select. This is rooted at the bitcast.
1570///
1571/// This pattern arises when a vector is bitcast to a smaller element type,
1572/// elements are extracted, and then conditionally selected with zero:
1573///
1574/// %bc = bitcast <4 x i32> %src to <16 x i8>
1575/// %e0 = extractelement <16 x i8> %bc, i32 0
1576/// %s0 = select i1 %cond, i8 %e0, i8 0
1577/// %e1 = extractelement <16 x i8> %bc, i32 1
1578/// %s1 = select i1 %cond, i8 %e1, i8 0
1579/// ...
1580///
1581/// Transforms to:
1582/// %sel = select i1 %cond, <4 x i32> %src, <4 x i32> zeroinitializer
1583/// %bc = bitcast <4 x i32> %sel to <16 x i8>
1584/// %e0 = extractelement <16 x i8> %bc, i32 0
1585/// %e1 = extractelement <16 x i8> %bc, i32 1
1586/// ...
1587///
1588/// This is profitable because vector select on wider types produces fewer
1589/// select/cndmask instructions than scalar selects on each element.
1590bool VectorCombine::foldSelectsFromBitcast(Instruction &I) {
1591 auto *BC = dyn_cast<BitCastInst>(&I);
1592 if (!BC)
1593 return false;
1594
1595 FixedVectorType *SrcVecTy = dyn_cast<FixedVectorType>(BC->getSrcTy());
1596 FixedVectorType *DstVecTy = dyn_cast<FixedVectorType>(BC->getDestTy());
1597 if (!SrcVecTy || !DstVecTy)
1598 return false;
1599
1600 // Source must be 32-bit or 64-bit elements, destination must be smaller
1601 // integer elements. Zero in all these types is all-bits-zero.
1602 Type *SrcEltTy = SrcVecTy->getElementType();
1603 Type *DstEltTy = DstVecTy->getElementType();
1604 unsigned SrcEltBits = SrcEltTy->getPrimitiveSizeInBits();
1605 unsigned DstEltBits = DstEltTy->getPrimitiveSizeInBits();
1606
1607 if (SrcEltBits != 32 && SrcEltBits != 64)
1608 return false;
1609
1610 if (!DstEltTy->isIntegerTy() || DstEltBits >= SrcEltBits)
1611 return false;
1612
1613 // Check profitability using TTI before collecting users.
1614 Type *CondTy = CmpInst::makeCmpResultType(DstEltTy);
1615 Type *VecCondTy = CmpInst::makeCmpResultType(SrcVecTy);
1616
1617 InstructionCost ScalarSelCost =
1618 TTI.getCmpSelInstrCost(Instruction::Select, DstEltTy, CondTy,
1620 InstructionCost VecSelCost =
1621 TTI.getCmpSelInstrCost(Instruction::Select, SrcVecTy, VecCondTy,
1623
1624 // We need at least this many selects for vectorization to be profitable.
1625 // VecSelCost < ScalarSelCost * NumSelects => NumSelects > VecSelCost /
1626 // ScalarSelCost
1627 if (!ScalarSelCost.isValid() || ScalarSelCost == 0)
1628 return false;
1629
1630 unsigned MinSelects = (VecSelCost.getValue() / ScalarSelCost.getValue()) + 1;
1631
1632 // Quick check: if bitcast doesn't have enough users, bail early.
1633 if (!BC->hasNUsesOrMore(MinSelects))
1634 return false;
1635
1636 // Collect all select users that match the pattern, grouped by condition.
1637 // Pattern: select i1 %cond, (extractelement %bc, idx), 0
1638 DenseMap<Value *, SmallVector<SelectInst *, 8>> CondToSelects;
1639
1640 for (User *U : BC->users()) {
1641 auto *Ext = dyn_cast<ExtractElementInst>(U);
1642 if (!Ext)
1643 continue;
1644
1645 for (User *ExtUser : Ext->users()) {
1646 Value *Cond;
1647 // Match: select i1 %cond, %ext, 0
1648 if (match(ExtUser, m_Select(m_Value(Cond), m_Specific(Ext), m_Zero())) &&
1649 Cond->getType()->isIntegerTy(1))
1650 CondToSelects[Cond].push_back(cast<SelectInst>(ExtUser));
1651 }
1652 }
1653
1654 if (CondToSelects.empty())
1655 return false;
1656
1657 bool MadeChange = false;
1658 Value *SrcVec = BC->getOperand(0);
1659
1660 // Process each group of selects with the same condition.
1661 for (auto [Cond, Selects] : CondToSelects) {
1662 // Only profitable if vector select cost < total scalar select cost.
1663 if (Selects.size() < MinSelects) {
1664 LLVM_DEBUG(dbgs() << "VectorCombine: foldSelectsFromBitcast not "
1665 << "profitable (VecCost=" << VecSelCost
1666 << ", ScalarCost=" << ScalarSelCost
1667 << ", NumSelects=" << Selects.size() << ")\n");
1668 continue;
1669 }
1670
1671 // Create the vector select and bitcast once for this condition.
1672 auto InsertPt = std::next(BC->getIterator());
1673
1674 if (auto *CondInst = dyn_cast<Instruction>(Cond))
1675 if (DT.dominates(BC, CondInst))
1676 InsertPt = std::next(CondInst->getIterator());
1677
1678 Builder.SetInsertPoint(InsertPt);
1679 Value *VecSel =
1680 Builder.CreateSelect(Cond, SrcVec, Constant::getNullValue(SrcVecTy));
1681 Value *NewBC = Builder.CreateBitCast(VecSel, DstVecTy);
1682
1683 // Replace each scalar select with an extract from the new bitcast.
1684 for (SelectInst *Sel : Selects) {
1685 auto *Ext = cast<ExtractElementInst>(Sel->getTrueValue());
1686 Value *Idx = Ext->getIndexOperand();
1687
1688 Builder.SetInsertPoint(Sel);
1689 Value *NewExt = Builder.CreateExtractElement(NewBC, Idx);
1690 replaceValue(*Sel, *NewExt);
1691 MadeChange = true;
1692 }
1693
1694 LLVM_DEBUG(dbgs() << "VectorCombine: folded " << Selects.size()
1695 << " selects into vector select\n");
1696 }
1697
1698 return MadeChange;
1699}
1700
1703 const TargetTransformInfo &TTI,
1704 InstructionCost &CostBeforeReduction,
1705 InstructionCost &CostAfterReduction) {
1706 Instruction *Op0, *Op1;
1707 auto *RedOp = dyn_cast<Instruction>(II.getOperand(0));
1708 auto *VecRedTy = cast<VectorType>(II.getOperand(0)->getType());
1709 unsigned ReductionOpc =
1710 getArithmeticReductionInstruction(II.getIntrinsicID());
1711 if (RedOp && match(RedOp, m_ZExtOrSExt(m_Value()))) {
1712 bool IsUnsigned = isa<ZExtInst>(RedOp);
1713 auto *ExtType = cast<VectorType>(RedOp->getOperand(0)->getType());
1714
1715 CostBeforeReduction =
1716 TTI.getCastInstrCost(RedOp->getOpcode(), VecRedTy, ExtType,
1718 CostAfterReduction =
1719 TTI.getExtendedReductionCost(ReductionOpc, IsUnsigned, II.getType(),
1720 ExtType, FastMathFlags(), CostKind);
1721 return;
1722 }
1723 if (RedOp && II.getIntrinsicID() == Intrinsic::vector_reduce_add &&
1724 match(RedOp,
1726 match(Op0, m_ZExtOrSExt(m_Value())) &&
1727 Op0->getOpcode() == Op1->getOpcode() &&
1728 Op0->getOperand(0)->getType() == Op1->getOperand(0)->getType() &&
1729 (Op0->getOpcode() == RedOp->getOpcode() || Op0 == Op1)) {
1730 // Matched reduce.add(ext(mul(ext(A), ext(B)))
1731 bool IsUnsigned = isa<ZExtInst>(Op0);
1732 auto *ExtType = cast<VectorType>(Op0->getOperand(0)->getType());
1733 VectorType *MulType = VectorType::get(Op0->getType(), VecRedTy);
1734
1735 InstructionCost ExtCost =
1736 TTI.getCastInstrCost(Op0->getOpcode(), MulType, ExtType,
1738 InstructionCost MulCost =
1739 TTI.getArithmeticInstrCost(Instruction::Mul, MulType, CostKind);
1740 InstructionCost Ext2Cost =
1741 TTI.getCastInstrCost(RedOp->getOpcode(), VecRedTy, MulType,
1743
1744 CostBeforeReduction = ExtCost * 2 + MulCost + Ext2Cost;
1745 CostAfterReduction = TTI.getMulAccReductionCost(
1746 IsUnsigned, ReductionOpc, II.getType(), ExtType, CostKind);
1747 return;
1748 }
1749 CostAfterReduction = TTI.getArithmeticReductionCost(ReductionOpc, VecRedTy,
1750 std::nullopt, CostKind);
1751}
1752
1753bool VectorCombine::foldBinopOfReductions(Instruction &I) {
1754 Instruction::BinaryOps BinOpOpc = cast<BinaryOperator>(&I)->getOpcode();
1755 Intrinsic::ID ReductionIID = getReductionForBinop(BinOpOpc);
1756 if (BinOpOpc == Instruction::Sub)
1757 ReductionIID = Intrinsic::vector_reduce_add;
1758 if (ReductionIID == Intrinsic::not_intrinsic)
1759 return false;
1760
1761 auto checkIntrinsicAndGetItsArgument = [](Value *V,
1762 Intrinsic::ID IID) -> Value * {
1763 auto *II = dyn_cast<IntrinsicInst>(V);
1764 if (!II)
1765 return nullptr;
1766 if (II->getIntrinsicID() == IID && II->hasOneUse())
1767 return II->getArgOperand(0);
1768 return nullptr;
1769 };
1770
1771 Value *V0 = checkIntrinsicAndGetItsArgument(I.getOperand(0), ReductionIID);
1772 if (!V0)
1773 return false;
1774 Value *V1 = checkIntrinsicAndGetItsArgument(I.getOperand(1), ReductionIID);
1775 if (!V1)
1776 return false;
1777
1778 auto *VTy = cast<VectorType>(V0->getType());
1779 if (V1->getType() != VTy)
1780 return false;
1781 const auto &II0 = *cast<IntrinsicInst>(I.getOperand(0));
1782 const auto &II1 = *cast<IntrinsicInst>(I.getOperand(1));
1783 unsigned ReductionOpc =
1784 getArithmeticReductionInstruction(II0.getIntrinsicID());
1785
1786 InstructionCost OldCost = 0;
1787 InstructionCost NewCost = 0;
1788 InstructionCost CostOfRedOperand0 = 0;
1789 InstructionCost CostOfRed0 = 0;
1790 InstructionCost CostOfRedOperand1 = 0;
1791 InstructionCost CostOfRed1 = 0;
1792 analyzeCostOfVecReduction(II0, CostKind, TTI, CostOfRedOperand0, CostOfRed0);
1793 analyzeCostOfVecReduction(II1, CostKind, TTI, CostOfRedOperand1, CostOfRed1);
1794 OldCost = CostOfRed0 + CostOfRed1 + TTI.getInstructionCost(&I, CostKind);
1795 NewCost =
1796 CostOfRedOperand0 + CostOfRedOperand1 +
1797 TTI.getArithmeticInstrCost(BinOpOpc, VTy, CostKind) +
1798 TTI.getArithmeticReductionCost(ReductionOpc, VTy, std::nullopt, CostKind);
1799 if (NewCost >= OldCost || !NewCost.isValid())
1800 return false;
1801
1802 LLVM_DEBUG(dbgs() << "Found two mergeable reductions: " << I
1803 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
1804 << "\n");
1805 Value *VectorBO;
1806 if (BinOpOpc == Instruction::Or)
1807 VectorBO = Builder.CreateOr(V0, V1, "",
1808 cast<PossiblyDisjointInst>(I).isDisjoint());
1809 else
1810 VectorBO = Builder.CreateBinOp(BinOpOpc, V0, V1);
1811
1812 Instruction *Rdx = Builder.CreateIntrinsic(ReductionIID, {VTy}, {VectorBO});
1813 replaceValue(I, *Rdx);
1814 return true;
1815}
1816
1817// Check if memory loc modified between two instrs in the same BB
1820 const MemoryLocation &Loc, AAResults &AA) {
1821 unsigned NumScanned = 0;
1822 return std::any_of(Begin, End, [&](const Instruction &Instr) {
1823 return isModSet(AA.getModRefInfo(&Instr, Loc)) ||
1824 ++NumScanned > MaxInstrsToScan;
1825 });
1826}
1827
1828namespace {
1829/// Helper class to indicate whether a vector index can be safely scalarized and
1830/// if a freeze needs to be inserted.
1831class ScalarizationResult {
1832 enum class StatusTy { Unsafe, Safe, SafeWithFreeze };
1833
1834 StatusTy Status;
1835 Value *ToFreeze;
1836
1837 ScalarizationResult(StatusTy Status, Value *ToFreeze = nullptr)
1838 : Status(Status), ToFreeze(ToFreeze) {}
1839
1840public:
1841 ScalarizationResult(const ScalarizationResult &Other) = default;
1842 ~ScalarizationResult() {
1843 assert(!ToFreeze && "freeze() not called with ToFreeze being set");
1844 }
1845
1846 static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; }
1847 static ScalarizationResult safe() { return {StatusTy::Safe}; }
1848 static ScalarizationResult safeWithFreeze(Value *ToFreeze) {
1849 return {StatusTy::SafeWithFreeze, ToFreeze};
1850 }
1851
1852 /// Returns true if the index can be scalarize without requiring a freeze.
1853 bool isSafe() const { return Status == StatusTy::Safe; }
1854 /// Returns true if the index cannot be scalarized.
1855 bool isUnsafe() const { return Status == StatusTy::Unsafe; }
1856 /// Returns true if the index can be scalarize, but requires inserting a
1857 /// freeze.
1858 bool isSafeWithFreeze() const { return Status == StatusTy::SafeWithFreeze; }
1859
1860 /// Reset the state of Unsafe and clear ToFreze if set.
1861 void discard() {
1862 ToFreeze = nullptr;
1863 Status = StatusTy::Unsafe;
1864 }
1865
1866 /// Freeze the ToFreeze and update the use in \p User to use it.
1867 void freeze(IRBuilderBase &Builder, Instruction &UserI) {
1868 assert(isSafeWithFreeze() &&
1869 "should only be used when freezing is required");
1870 assert(is_contained(ToFreeze->users(), &UserI) &&
1871 "UserI must be a user of ToFreeze");
1872 IRBuilder<>::InsertPointGuard Guard(Builder);
1873 Builder.SetInsertPoint(cast<Instruction>(&UserI));
1874 Value *Frozen =
1875 Builder.CreateFreeze(ToFreeze, ToFreeze->getName() + ".frozen");
1876 for (Use &U : make_early_inc_range((UserI.operands())))
1877 if (U.get() == ToFreeze)
1878 U.set(Frozen);
1879
1880 ToFreeze = nullptr;
1881 }
1882};
1883} // namespace
1884
1885/// Check if it is legal to scalarize a memory access to \p VecTy at index \p
1886/// Idx. \p Idx must access a valid vector element.
1887static ScalarizationResult canScalarizeAccess(VectorType *VecTy, Value *Idx,
1888 const SimplifyQuery &SQ) {
1889 // We do checks for both fixed vector types and scalable vector types.
1890 // This is the number of elements of fixed vector types,
1891 // or the minimum number of elements of scalable vector types.
1892 uint64_t NumElements = VecTy->getElementCount().getKnownMinValue();
1893 unsigned IntWidth = Idx->getType()->getScalarSizeInBits();
1894
1895 if (auto *C = dyn_cast<ConstantInt>(Idx)) {
1896 if (C->getValue().ult(NumElements))
1897 return ScalarizationResult::safe();
1898 return ScalarizationResult::unsafe();
1899 }
1900
1901 // Always unsafe if the index type can't handle all inbound values.
1902 if (!llvm::isUIntN(IntWidth, NumElements))
1903 return ScalarizationResult::unsafe();
1904
1905 APInt Zero(IntWidth, 0);
1906 APInt MaxElts(IntWidth, NumElements);
1907 ConstantRange ValidIndices(Zero, MaxElts);
1908 ConstantRange IdxRange(IntWidth, true);
1909
1910 if (isGuaranteedNotToBePoison(Idx, SQ.AC, SQ.CxtI, SQ.DT)) {
1911 if (ValidIndices.contains(
1912 computeConstantRange(Idx, /*ForSigned=*/false, SQ)))
1913 return ScalarizationResult::safe();
1914 return ScalarizationResult::unsafe();
1915 }
1916
1917 // If the index may be poison, check if we can insert a freeze before the
1918 // range of the index is restricted.
1919 Value *IdxBase;
1920 ConstantInt *CI;
1921 if (match(Idx, m_And(m_Value(IdxBase), m_ConstantInt(CI)))) {
1922 IdxRange = IdxRange.binaryAnd(CI->getValue());
1923 } else if (match(Idx, m_URem(m_Value(IdxBase), m_ConstantInt(CI)))) {
1924 IdxRange = IdxRange.urem(CI->getValue());
1925 }
1926
1927 if (ValidIndices.contains(IdxRange))
1928 return ScalarizationResult::safeWithFreeze(IdxBase);
1929 return ScalarizationResult::unsafe();
1930}
1931
1932/// The memory operation on a vector of \p ScalarType had alignment of
1933/// \p VectorAlignment. Compute the maximal, but conservatively correct,
1934/// alignment that will be valid for the memory operation on a single scalar
1935/// element of the same type with index \p Idx.
1937 Type *ScalarType, Value *Idx,
1938 const DataLayout &DL) {
1939 if (auto *C = dyn_cast<ConstantInt>(Idx))
1940 return commonAlignment(VectorAlignment,
1941 C->getZExtValue() * DL.getTypeStoreSize(ScalarType));
1942 return commonAlignment(VectorAlignment, DL.getTypeStoreSize(ScalarType));
1943}
1944
1945// Combine patterns like:
1946// %0 = load <4 x i32>, <4 x i32>* %a
1947// %1 = insertelement <4 x i32> %0, i32 %b, i32 1
1948// store <4 x i32> %1, <4 x i32>* %a
1949// to:
1950// %0 = bitcast <4 x i32>* %a to i32*
1951// %1 = getelementptr inbounds i32, i32* %0, i64 0, i64 1
1952// store i32 %b, i32* %1
1953bool VectorCombine::foldSingleElementStore(Instruction &I) {
1955 return false;
1956 auto *SI = cast<StoreInst>(&I);
1957 if (!SI->isSimple() || !isa<VectorType>(SI->getValueOperand()->getType()))
1958 return false;
1959
1960 // TODO: Combine more complicated patterns (multiple insert) by referencing
1961 // TargetTransformInfo.
1963 Value *NewElement;
1964 Value *Idx;
1965 if (!match(SI->getValueOperand(),
1966 m_InsertElt(m_Instruction(Source), m_Value(NewElement),
1967 m_Value(Idx))))
1968 return false;
1969
1970 if (auto *Load = dyn_cast<LoadInst>(Source)) {
1971 auto VecTy = cast<VectorType>(SI->getValueOperand()->getType());
1972 Value *SrcAddr = Load->getPointerOperand()->stripPointerCasts();
1973 // Don't optimize for atomic/volatile load or store. Ensure memory is not
1974 // modified between, vector type matches store size, and index is inbounds.
1975 if (!Load->isSimple() || Load->getParent() != SI->getParent() ||
1976 !DL->typeSizeEqualsStoreSize(Load->getType()->getScalarType()) ||
1977 SrcAddr != SI->getPointerOperand()->stripPointerCasts())
1978 return false;
1979
1980 auto ScalarizableIdx =
1981 canScalarizeAccess(VecTy, Idx, SQ.getWithInstruction(Load));
1982 if (ScalarizableIdx.isUnsafe() ||
1983 isMemModifiedBetween(Load->getIterator(), SI->getIterator(),
1984 MemoryLocation::get(SI), AA))
1985 return false;
1986
1987 // Ensure we add the load back to the worklist BEFORE its users so they can
1988 // erased in the correct order.
1989 Worklist.push(Load);
1990
1991 if (ScalarizableIdx.isSafeWithFreeze())
1992 ScalarizableIdx.freeze(Builder, *cast<Instruction>(Idx));
1993 Value *GEP = Builder.CreateInBoundsGEP(
1994 SI->getValueOperand()->getType(), SI->getPointerOperand(),
1995 {ConstantInt::get(Idx->getType(), 0), Idx});
1996 StoreInst *NSI = Builder.CreateStore(NewElement, GEP);
1997 NSI->copyMetadata(*SI);
1998 Align ScalarOpAlignment = computeAlignmentAfterScalarization(
1999 std::max(SI->getAlign(), Load->getAlign()), NewElement->getType(), Idx,
2000 *DL);
2001 NSI->setAlignment(ScalarOpAlignment);
2002 replaceValue(I, *NSI);
2004 return true;
2005 }
2006
2007 return false;
2008}
2009
2010/// Try to scalarize vector loads feeding extractelement or bitcast
2011/// instructions.
2012bool VectorCombine::scalarizeLoad(Instruction &I) {
2013 Value *Ptr;
2014 if (!match(&I, m_Load(m_Value(Ptr))))
2015 return false;
2016
2017 auto *LI = cast<LoadInst>(&I);
2018 auto *VecTy = cast<VectorType>(LI->getType());
2019
2020 // The isSimple() check could be isUnordered(), but for now we cowardly
2021 // refuse to handle even unordered atomics.
2022 if (!LI->isSimple() || !DL->typeSizeEqualsStoreSize(VecTy->getScalarType()))
2023 return false;
2024
2025 bool AllExtracts = true;
2026 bool AllBitcasts = true;
2027 Instruction *LastCheckedInst = LI;
2028 unsigned NumInstChecked = 0;
2029
2030 // Check what type of users we have (must either all be extracts or
2031 // bitcasts) and ensure no memory modifications between the load and
2032 // its users.
2033 for (User *U : LI->users()) {
2034 auto *UI = dyn_cast<Instruction>(U);
2035 if (!UI || UI->getParent() != LI->getParent())
2036 return false;
2037
2038 // If any user is waiting to be erased, then bail out as this will
2039 // distort the cost calculation and possibly lead to infinite loops.
2040 if (UI->use_empty())
2041 return false;
2042
2043 if (!isa<ExtractElementInst>(UI))
2044 AllExtracts = false;
2045 if (!isa<BitCastInst>(UI))
2046 AllBitcasts = false;
2047
2048 // Check if any instruction between the load and the user may modify memory.
2049 if (LastCheckedInst->comesBefore(UI)) {
2050 for (Instruction &I :
2051 make_range(std::next(LI->getIterator()), UI->getIterator())) {
2052 // Bail out if we reached the check limit or the instruction may write
2053 // to memory.
2054 if (NumInstChecked == MaxInstrsToScan || I.mayWriteToMemory())
2055 return false;
2056 NumInstChecked++;
2057 }
2058 LastCheckedInst = UI;
2059 }
2060 }
2061
2062 if (AllExtracts)
2063 return scalarizeLoadExtract(LI, VecTy, Ptr);
2064 if (AllBitcasts)
2065 return scalarizeLoadBitcast(LI, VecTy, Ptr);
2066 return false;
2067}
2068
2069/// Try to scalarize vector loads feeding extractelement instructions.
2070bool VectorCombine::scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy,
2071 Value *Ptr) {
2073 return false;
2074
2075 DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
2076 llvm::scope_exit FailureGuard([&]() {
2077 // If the transform is aborted, discard the ScalarizationResults.
2078 for (auto &Pair : NeedFreeze)
2079 Pair.second.discard();
2080 });
2081
2082 InstructionCost OriginalCost =
2083 TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
2085 InstructionCost ScalarizedCost = 0;
2086
2087 for (User *U : LI->users()) {
2088 auto *UI = cast<ExtractElementInst>(U);
2089
2090 auto ScalarIdx = canScalarizeAccess(VecTy, UI->getIndexOperand(),
2091 SQ.getWithInstruction(LI));
2092 if (ScalarIdx.isUnsafe())
2093 return false;
2094 if (ScalarIdx.isSafeWithFreeze()) {
2095 NeedFreeze.try_emplace(UI, ScalarIdx);
2096 ScalarIdx.discard();
2097 }
2098
2099 auto *Index = dyn_cast<ConstantInt>(UI->getIndexOperand());
2100 OriginalCost +=
2101 TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind,
2102 Index ? Index->getZExtValue() : -1);
2103 ScalarizedCost +=
2104 TTI.getMemoryOpCost(Instruction::Load, VecTy->getElementType(),
2106 ScalarizedCost += TTI.getAddressComputationCost(LI->getPointerOperandType(),
2107 nullptr, nullptr, CostKind);
2108 }
2109
2110 LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << *LI
2111 << "\n LoadExtractCost: " << OriginalCost
2112 << " vs ScalarizedCost: " << ScalarizedCost << "\n");
2113
2114 if (ScalarizedCost >= OriginalCost)
2115 return false;
2116
2117 // Ensure we add the load back to the worklist BEFORE its users so they can
2118 // erased in the correct order.
2119 Worklist.push(LI);
2120
2121 Type *ElemType = VecTy->getElementType();
2122
2123 // Replace extracts with narrow scalar loads.
2124 for (User *U : LI->users()) {
2125 auto *EI = cast<ExtractElementInst>(U);
2126 Value *Idx = EI->getIndexOperand();
2127
2128 // Insert 'freeze' for poison indexes.
2129 auto It = NeedFreeze.find(EI);
2130 if (It != NeedFreeze.end())
2131 It->second.freeze(Builder, *cast<Instruction>(Idx));
2132
2133 Builder.SetInsertPoint(EI);
2134 Value *GEP =
2135 Builder.CreateInBoundsGEP(VecTy, Ptr, {Builder.getInt32(0), Idx});
2136 auto *NewLoad = cast<LoadInst>(
2137 Builder.CreateLoad(ElemType, GEP, EI->getName() + ".scalar"));
2138
2139 Align ScalarOpAlignment =
2140 computeAlignmentAfterScalarization(LI->getAlign(), ElemType, Idx, *DL);
2141 NewLoad->setAlignment(ScalarOpAlignment);
2142
2143 if (auto *ConstIdx = dyn_cast<ConstantInt>(Idx)) {
2144 size_t Offset = ConstIdx->getZExtValue() * DL->getTypeStoreSize(ElemType);
2145 AAMDNodes OldAAMD = LI->getAAMetadata();
2146 NewLoad->setAAMetadata(OldAAMD.adjustForAccess(Offset, ElemType, *DL));
2147 }
2148
2149 replaceValue(*EI, *NewLoad, false);
2150 }
2151
2152 FailureGuard.release();
2153 return true;
2154}
2155
2156/// Try to scalarize vector loads feeding bitcast instructions.
2157bool VectorCombine::scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy,
2158 Value *Ptr) {
2159 InstructionCost OriginalCost =
2160 TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
2162
2163 Type *TargetScalarType = nullptr;
2164 unsigned VecBitWidth = DL->getTypeSizeInBits(VecTy);
2165
2166 for (User *U : LI->users()) {
2167 auto *BC = cast<BitCastInst>(U);
2168
2169 Type *DestTy = BC->getDestTy();
2170 if (!DestTy->isIntegerTy() && !DestTy->isFloatingPointTy())
2171 return false;
2172
2173 unsigned DestBitWidth = DL->getTypeSizeInBits(DestTy);
2174 if (DestBitWidth != VecBitWidth)
2175 return false;
2176
2177 // All bitcasts must target the same scalar type.
2178 if (!TargetScalarType)
2179 TargetScalarType = DestTy;
2180 else if (TargetScalarType != DestTy)
2181 return false;
2182
2183 OriginalCost +=
2184 TTI.getCastInstrCost(Instruction::BitCast, TargetScalarType, VecTy,
2186 }
2187
2188 if (!TargetScalarType)
2189 return false;
2190
2191 assert(!LI->user_empty() && "Unexpected load without bitcast users");
2192 InstructionCost ScalarizedCost =
2193 TTI.getMemoryOpCost(Instruction::Load, TargetScalarType, LI->getAlign(),
2195
2196 LLVM_DEBUG(dbgs() << "Found vector load feeding only bitcasts: " << *LI
2197 << "\n OriginalCost: " << OriginalCost
2198 << " vs ScalarizedCost: " << ScalarizedCost << "\n");
2199
2200 if (ScalarizedCost >= OriginalCost)
2201 return false;
2202
2203 // Ensure we add the load back to the worklist BEFORE its users so they can
2204 // erased in the correct order.
2205 Worklist.push(LI);
2206
2207 Builder.SetInsertPoint(LI);
2208 auto *ScalarLoad =
2209 Builder.CreateLoad(TargetScalarType, Ptr, LI->getName() + ".scalar");
2210 ScalarLoad->setAlignment(LI->getAlign());
2211 ScalarLoad->copyMetadata(*LI);
2212
2213 // Replace all bitcast users with the scalar load.
2214 for (User *U : LI->users()) {
2215 auto *BC = cast<BitCastInst>(U);
2216 replaceValue(*BC, *ScalarLoad, false);
2217 }
2218
2219 return true;
2220}
2221
2222bool VectorCombine::scalarizeExtExtract(Instruction &I) {
2224 return false;
2225 auto *Ext = dyn_cast<ZExtInst>(&I);
2226 if (!Ext)
2227 return false;
2228
2229 // Try to convert a vector zext feeding only extracts to a set of scalar
2230 // (Src << ExtIdx *Size) & (Size -1)
2231 // if profitable .
2232 auto *SrcTy = dyn_cast<FixedVectorType>(Ext->getOperand(0)->getType());
2233 if (!SrcTy)
2234 return false;
2235 auto *DstTy = cast<FixedVectorType>(Ext->getType());
2236
2237 Type *ScalarDstTy = DstTy->getElementType();
2238 if (DL->getTypeSizeInBits(SrcTy) != DL->getTypeSizeInBits(ScalarDstTy))
2239 return false;
2240
2241 InstructionCost VectorCost =
2242 TTI.getCastInstrCost(Instruction::ZExt, DstTy, SrcTy,
2244 unsigned ExtCnt = 0;
2245 bool ExtLane0 = false;
2246 for (User *U : Ext->users()) {
2247 uint64_t Idx;
2248 if (!match(U, m_ExtractElt(m_Value(), m_ConstantInt(Idx))))
2249 return false;
2250 if (cast<Instruction>(U)->use_empty())
2251 continue;
2252 ExtCnt += 1;
2253 ExtLane0 |= !Idx;
2254 VectorCost += TTI.getVectorInstrCost(Instruction::ExtractElement, DstTy,
2255 CostKind, Idx, U);
2256 }
2257
2258 InstructionCost ScalarCost =
2259 ExtCnt * TTI.getArithmeticInstrCost(
2260 Instruction::And, ScalarDstTy, CostKind,
2263 (ExtCnt - ExtLane0) *
2265 Instruction::LShr, ScalarDstTy, CostKind,
2268 if (ScalarCost > VectorCost)
2269 return false;
2270
2271 Value *ScalarV = Ext->getOperand(0);
2272 if (!isGuaranteedNotToBePoison(ScalarV, SQ.AC, dyn_cast<Instruction>(ScalarV),
2273 SQ.DT)) {
2274 // Check wether all lanes are extracted, all extracts trigger UB
2275 // on poison, and the last extract (and hence all previous ones)
2276 // are guaranteed to execute if Ext executes. If so, we do not
2277 // need to insert a freeze.
2278 SmallDenseSet<ConstantInt *, 8> ExtractedLanes;
2279 bool AllExtractsTriggerUB = true;
2280 ExtractElementInst *LastExtract = nullptr;
2281 BasicBlock *ExtBB = Ext->getParent();
2282 for (User *U : Ext->users()) {
2283 auto *Extract = cast<ExtractElementInst>(U);
2284 if (Extract->getParent() != ExtBB || !programUndefinedIfPoison(Extract)) {
2285 AllExtractsTriggerUB = false;
2286 break;
2287 }
2288 ExtractedLanes.insert(cast<ConstantInt>(Extract->getIndexOperand()));
2289 if (!LastExtract || LastExtract->comesBefore(Extract))
2290 LastExtract = Extract;
2291 }
2292 if (ExtractedLanes.size() != DstTy->getNumElements() ||
2293 !AllExtractsTriggerUB ||
2295 LastExtract->getIterator()))
2296 ScalarV = Builder.CreateFreeze(ScalarV);
2297 }
2298 ScalarV = Builder.CreateBitCast(
2299 ScalarV,
2300 IntegerType::get(SrcTy->getContext(), DL->getTypeSizeInBits(SrcTy)));
2301 uint64_t SrcEltSizeInBits = DL->getTypeSizeInBits(SrcTy->getElementType());
2302 uint64_t TotalBits = DL->getTypeSizeInBits(SrcTy);
2303 APInt EltBitMask = APInt::getLowBitsSet(TotalBits, SrcEltSizeInBits);
2304 Type *PackedTy = IntegerType::get(SrcTy->getContext(), TotalBits);
2305 Value *Mask = ConstantInt::get(PackedTy, EltBitMask);
2306 for (User *U : Ext->users()) {
2307 auto *Extract = cast<ExtractElementInst>(U);
2308 uint64_t Idx =
2309 cast<ConstantInt>(Extract->getIndexOperand())->getZExtValue();
2310 uint64_t ShiftAmt =
2311 DL->isBigEndian()
2312 ? (TotalBits - SrcEltSizeInBits - Idx * SrcEltSizeInBits)
2313 : (Idx * SrcEltSizeInBits);
2314 Value *LShr = Builder.CreateLShr(ScalarV, ShiftAmt);
2315 Value *And = Builder.CreateAnd(LShr, Mask);
2316 U->replaceAllUsesWith(And);
2317 }
2318 return true;
2319}
2320
2321/// Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))"
2322/// to "(bitcast (concat X, Y))"
2323/// where X/Y are bitcasted from i1 mask vectors.
2324bool VectorCombine::foldConcatOfBoolMasks(Instruction &I) {
2325 Type *Ty = I.getType();
2326 if (!Ty->isIntegerTy())
2327 return false;
2328
2329 // TODO: Add big endian test coverage
2330 if (DL->isBigEndian())
2331 return false;
2332
2333 // Restrict to disjoint cases so the mask vectors aren't overlapping.
2334 Instruction *X, *Y;
2336 return false;
2337
2338 // Allow both sources to contain shl, to handle more generic pattern:
2339 // "(or (shl (zext (bitcast X)), C1), (shl (zext (bitcast Y)), C2))"
2340 Value *SrcX;
2341 uint64_t ShAmtX = 0;
2342 if (!match(X, m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcX)))))) &&
2343 !match(X, m_OneUse(
2345 m_ConstantInt(ShAmtX)))))
2346 return false;
2347
2348 Value *SrcY;
2349 uint64_t ShAmtY = 0;
2350 if (!match(Y, m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcY)))))) &&
2351 !match(Y, m_OneUse(
2353 m_ConstantInt(ShAmtY)))))
2354 return false;
2355
2356 // Canonicalize larger shift to the RHS.
2357 if (ShAmtX > ShAmtY) {
2358 std::swap(X, Y);
2359 std::swap(SrcX, SrcY);
2360 std::swap(ShAmtX, ShAmtY);
2361 }
2362
2363 // Ensure both sources are matching vXi1 bool mask types, and that the shift
2364 // difference is the mask width so they can be easily concatenated together.
2365 uint64_t ShAmtDiff = ShAmtY - ShAmtX;
2366 unsigned NumSHL = (ShAmtX > 0) + (ShAmtY > 0);
2367 unsigned BitWidth = Ty->getPrimitiveSizeInBits();
2368 auto *MaskTy = dyn_cast<FixedVectorType>(SrcX->getType());
2369 if (!MaskTy || SrcX->getType() != SrcY->getType() ||
2370 !MaskTy->getElementType()->isIntegerTy(1) ||
2371 MaskTy->getNumElements() != ShAmtDiff ||
2372 MaskTy->getNumElements() > (BitWidth / 2))
2373 return false;
2374
2375 auto *ConcatTy = FixedVectorType::getDoubleElementsVectorType(MaskTy);
2376 auto *ConcatIntTy =
2377 Type::getIntNTy(Ty->getContext(), ConcatTy->getNumElements());
2378 auto *MaskIntTy = Type::getIntNTy(Ty->getContext(), ShAmtDiff);
2379
2380 SmallVector<int, 32> ConcatMask(ConcatTy->getNumElements());
2381 std::iota(ConcatMask.begin(), ConcatMask.end(), 0);
2382
2383 // TODO: Is it worth supporting multi use cases?
2384 InstructionCost OldCost = 0;
2385 OldCost += TTI.getArithmeticInstrCost(Instruction::Or, Ty, CostKind);
2386 OldCost +=
2387 NumSHL * TTI.getArithmeticInstrCost(Instruction::Shl, Ty, CostKind);
2388 OldCost += 2 * TTI.getCastInstrCost(Instruction::ZExt, Ty, MaskIntTy,
2390 OldCost += 2 * TTI.getCastInstrCost(Instruction::BitCast, MaskIntTy, MaskTy,
2392
2393 InstructionCost NewCost = 0;
2395 MaskTy, ConcatMask, CostKind);
2396 NewCost += TTI.getCastInstrCost(Instruction::BitCast, ConcatIntTy, ConcatTy,
2398 if (Ty != ConcatIntTy)
2399 NewCost += TTI.getCastInstrCost(Instruction::ZExt, Ty, ConcatIntTy,
2401 if (ShAmtX > 0)
2402 NewCost += TTI.getArithmeticInstrCost(Instruction::Shl, Ty, CostKind);
2403
2404 LLVM_DEBUG(dbgs() << "Found a concatenation of bitcasted bool masks: " << I
2405 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2406 << "\n");
2407
2408 if (NewCost > OldCost)
2409 return false;
2410
2411 // Build bool mask concatenation, bitcast back to scalar integer, and perform
2412 // any residual zero-extension or shifting.
2413 Value *Concat = Builder.CreateShuffleVector(SrcX, SrcY, ConcatMask);
2414 Worklist.pushValue(Concat);
2415
2416 Value *Result = Builder.CreateBitCast(Concat, ConcatIntTy);
2417
2418 if (Ty != ConcatIntTy) {
2419 Worklist.pushValue(Result);
2420 Result = Builder.CreateZExt(Result, Ty);
2421 }
2422
2423 if (ShAmtX > 0) {
2424 Worklist.pushValue(Result);
2425 Result = Builder.CreateShl(Result, ShAmtX);
2426 }
2427
2428 replaceValue(I, *Result);
2429 return true;
2430}
2431
2432/// Try to convert "shuffle (binop (shuffle, shuffle)), undef"
2433/// --> "binop (shuffle), (shuffle)".
2434bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
2435 BinaryOperator *BinOp;
2436 ArrayRef<int> OuterMask;
2437 if (!match(&I, m_Shuffle(m_BinOp(BinOp), m_Undef(), m_Mask(OuterMask))))
2438 return false;
2439
2440 // Don't introduce poison into div/rem.
2441 if (BinOp->isIntDivRem() && llvm::is_contained(OuterMask, PoisonMaskElem))
2442 return false;
2443
2444 Value *Op00, *Op01, *Op10, *Op11;
2445 ArrayRef<int> Mask0, Mask1;
2446 bool Match0 = match(BinOp->getOperand(0),
2447 m_Shuffle(m_Value(Op00), m_Value(Op01), m_Mask(Mask0)));
2448 bool Match1 = match(BinOp->getOperand(1),
2449 m_Shuffle(m_Value(Op10), m_Value(Op11), m_Mask(Mask1)));
2450 if (!Match0 && !Match1)
2451 return false;
2452
2453 Op00 = Match0 ? Op00 : BinOp->getOperand(0);
2454 Op01 = Match0 ? Op01 : BinOp->getOperand(0);
2455 Op10 = Match1 ? Op10 : BinOp->getOperand(1);
2456 Op11 = Match1 ? Op11 : BinOp->getOperand(1);
2457
2458 Instruction::BinaryOps Opcode = BinOp->getOpcode();
2459 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
2460 auto *BinOpTy = dyn_cast<FixedVectorType>(BinOp->getType());
2461 auto *Op0Ty = dyn_cast<FixedVectorType>(Op00->getType());
2462 auto *Op1Ty = dyn_cast<FixedVectorType>(Op10->getType());
2463 if (!ShuffleDstTy || !BinOpTy || !Op0Ty || !Op1Ty)
2464 return false;
2465
2466 unsigned NumSrcElts = BinOpTy->getNumElements();
2467
2468 // Don't accept shuffles that reference the second operand in
2469 // div/rem or if its an undef arg.
2470 if ((BinOp->isIntDivRem() || !isa<PoisonValue>(I.getOperand(1))) &&
2471 any_of(OuterMask, [NumSrcElts](int M) { return M >= (int)NumSrcElts; }))
2472 return false;
2473
2474 // Merge outer / inner (or identity if no match) shuffles.
2475 SmallVector<int> NewMask0, NewMask1;
2476 for (int M : OuterMask) {
2477 if (M < 0 || M >= (int)NumSrcElts) {
2478 NewMask0.push_back(PoisonMaskElem);
2479 NewMask1.push_back(PoisonMaskElem);
2480 } else {
2481 NewMask0.push_back(Match0 ? Mask0[M] : M);
2482 NewMask1.push_back(Match1 ? Mask1[M] : M);
2483 }
2484 }
2485
2486 unsigned NumOpElts = Op0Ty->getNumElements();
2487 bool IsIdentity0 = ShuffleDstTy == Op0Ty &&
2488 all_of(NewMask0, [NumOpElts](int M) { return M < (int)NumOpElts; }) &&
2489 ShuffleVectorInst::isIdentityMask(NewMask0, NumOpElts);
2490 bool IsIdentity1 = ShuffleDstTy == Op1Ty &&
2491 all_of(NewMask1, [NumOpElts](int M) { return M < (int)NumOpElts; }) &&
2492 ShuffleVectorInst::isIdentityMask(NewMask1, NumOpElts);
2493
2494 InstructionCost NewCost = 0;
2495 // Try to merge shuffles across the binop if the new shuffles are not costly.
2496 InstructionCost BinOpCost =
2497 TTI.getArithmeticInstrCost(Opcode, BinOpTy, CostKind);
2498 InstructionCost OldCost =
2500 ShuffleDstTy, BinOpTy, OuterMask, CostKind,
2501 0, nullptr, {BinOp}, &I);
2502 if (!BinOp->hasOneUse())
2503 NewCost += BinOpCost;
2504
2505 if (Match0) {
2507 TargetTransformInfo::SK_PermuteTwoSrc, BinOpTy, Op0Ty, Mask0, CostKind,
2508 0, nullptr, {Op00, Op01}, cast<Instruction>(BinOp->getOperand(0)));
2509 OldCost += Shuf0Cost;
2510 if (!BinOp->hasOneUse() || !BinOp->getOperand(0)->hasOneUse())
2511 NewCost += Shuf0Cost;
2512 }
2513 if (Match1) {
2515 TargetTransformInfo::SK_PermuteTwoSrc, BinOpTy, Op1Ty, Mask1, CostKind,
2516 0, nullptr, {Op10, Op11}, cast<Instruction>(BinOp->getOperand(1)));
2517 OldCost += Shuf1Cost;
2518 if (!BinOp->hasOneUse() || !BinOp->getOperand(1)->hasOneUse())
2519 NewCost += Shuf1Cost;
2520 }
2521
2522 NewCost += TTI.getArithmeticInstrCost(Opcode, ShuffleDstTy, CostKind);
2523
2524 if (!IsIdentity0)
2525 NewCost +=
2527 Op0Ty, NewMask0, CostKind, 0, nullptr, {Op00, Op01});
2528 if (!IsIdentity1)
2529 NewCost +=
2531 Op1Ty, NewMask1, CostKind, 0, nullptr, {Op10, Op11});
2532
2533 LLVM_DEBUG(dbgs() << "Found a shuffle feeding a shuffled binop: " << I
2534 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2535 << "\n");
2536
2537 // If costs are equal, still fold as we reduce instruction count.
2538 if (NewCost > OldCost)
2539 return false;
2540
2541 Value *LHS =
2542 IsIdentity0 ? Op00 : Builder.CreateShuffleVector(Op00, Op01, NewMask0);
2543 Value *RHS =
2544 IsIdentity1 ? Op10 : Builder.CreateShuffleVector(Op10, Op11, NewMask1);
2545 Value *NewBO = Builder.CreateBinOp(Opcode, LHS, RHS);
2546
2547 // Intersect flags from the old binops.
2548 if (auto *NewInst = dyn_cast<Instruction>(NewBO))
2549 NewInst->copyIRFlags(BinOp);
2550
2551 Worklist.pushValue(LHS);
2552 Worklist.pushValue(RHS);
2553 replaceValue(I, *NewBO);
2554 return true;
2555}
2556
2557/// Try to convert "shuffle (binop), (binop)" into "binop (shuffle), (shuffle)".
2558/// Try to convert "shuffle (cmpop), (cmpop)" into "cmpop (shuffle), (shuffle)".
2559bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
2560 ArrayRef<int> OldMask;
2561 Instruction *LHS, *RHS;
2563 m_Mask(OldMask))))
2564 return false;
2565
2566 // TODO: Add support for addlike etc.
2567 if (LHS->getOpcode() != RHS->getOpcode())
2568 return false;
2569
2570 Value *X, *Y, *Z, *W;
2571 bool IsCommutative = false;
2572 CmpPredicate PredLHS = CmpInst::BAD_ICMP_PREDICATE;
2573 CmpPredicate PredRHS = CmpInst::BAD_ICMP_PREDICATE;
2574 if (match(LHS, m_BinOp(m_Value(X), m_Value(Y))) &&
2575 match(RHS, m_BinOp(m_Value(Z), m_Value(W)))) {
2576 auto *BO = cast<BinaryOperator>(LHS);
2577 // Don't introduce poison into div/rem.
2578 if (llvm::is_contained(OldMask, PoisonMaskElem) && BO->isIntDivRem())
2579 return false;
2580 IsCommutative = BinaryOperator::isCommutative(BO->getOpcode());
2581 } else if (match(LHS, m_Cmp(PredLHS, m_Value(X), m_Value(Y))) &&
2582 match(RHS, m_Cmp(PredRHS, m_Value(Z), m_Value(W))) &&
2583 (CmpInst::Predicate)PredLHS == (CmpInst::Predicate)PredRHS) {
2584 IsCommutative = cast<CmpInst>(LHS)->isCommutative();
2585 } else
2586 return false;
2587
2588 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
2589 auto *BinResTy = dyn_cast<FixedVectorType>(LHS->getType());
2590 auto *BinOpTy = dyn_cast<FixedVectorType>(X->getType());
2591 if (!ShuffleDstTy || !BinResTy || !BinOpTy || X->getType() != Z->getType())
2592 return false;
2593
2594 bool SameBinOp = LHS == RHS;
2595 unsigned NumSrcElts = BinOpTy->getNumElements();
2596
2597 // If we have something like "add X, Y" and "add Z, X", swap ops to match.
2598 if (IsCommutative && X != Z && Y != W && (X == W || Y == Z))
2599 std::swap(X, Y);
2600
2601 auto ConvertToUnary = [NumSrcElts](int &M) {
2602 if (M >= (int)NumSrcElts)
2603 M -= NumSrcElts;
2604 };
2605
2606 SmallVector<int> NewMask0(OldMask);
2608 TTI::OperandValueInfo Op0Info = TTI.commonOperandInfo(X, Z);
2609 if (X == Z) {
2610 llvm::for_each(NewMask0, ConvertToUnary);
2612 Z = PoisonValue::get(BinOpTy);
2613 }
2614
2615 SmallVector<int> NewMask1(OldMask);
2617 TTI::OperandValueInfo Op1Info = TTI.commonOperandInfo(Y, W);
2618 if (Y == W) {
2619 llvm::for_each(NewMask1, ConvertToUnary);
2621 W = PoisonValue::get(BinOpTy);
2622 }
2623
2624 // Try to replace a binop with a shuffle if the shuffle is not costly.
2625 // When SameBinOp, only count the binop cost once.
2628
2629 InstructionCost OldCost = LHSCost;
2630 if (!SameBinOp) {
2631 OldCost += RHSCost;
2632 }
2634 ShuffleDstTy, BinResTy, OldMask, CostKind, 0,
2635 nullptr, {LHS, RHS}, &I);
2636
2637 // Handle shuffle(binop(shuffle(x),y),binop(z,shuffle(w))) style patterns
2638 // where one use shuffles have gotten split across the binop/cmp. These
2639 // often allow a major reduction in total cost that wouldn't happen as
2640 // individual folds.
2641 auto MergeInner = [&](Value *&Op, int Offset, MutableArrayRef<int> Mask,
2642 TTI::TargetCostKind CostKind) -> bool {
2643 Value *InnerOp;
2644 ArrayRef<int> InnerMask;
2645 if (match(Op, m_OneUse(m_Shuffle(m_Value(InnerOp), m_Undef(),
2646 m_Mask(InnerMask)))) &&
2647 InnerOp->getType() == Op->getType() &&
2648 all_of(InnerMask,
2649 [NumSrcElts](int M) { return M < (int)NumSrcElts; })) {
2650 for (int &M : Mask)
2651 if (Offset <= M && M < (int)(Offset + NumSrcElts)) {
2652 M = InnerMask[M - Offset];
2653 M = 0 <= M ? M + Offset : M;
2654 }
2656 Op = InnerOp;
2657 return true;
2658 }
2659 return false;
2660 };
2661 bool ReducedInstCount = false;
2662 ReducedInstCount |= MergeInner(X, 0, NewMask0, CostKind);
2663 ReducedInstCount |= MergeInner(Y, 0, NewMask1, CostKind);
2664 ReducedInstCount |= MergeInner(Z, NumSrcElts, NewMask0, CostKind);
2665 ReducedInstCount |= MergeInner(W, NumSrcElts, NewMask1, CostKind);
2666 bool SingleSrcBinOp = (X == Y) && (Z == W) && (NewMask0 == NewMask1);
2667 // SingleSrcBinOp only reduces instruction count if we also eliminate the
2668 // original binop(s). If binops have multiple uses, they won't be eliminated.
2669 ReducedInstCount |= SingleSrcBinOp && LHS->hasOneUser() && RHS->hasOneUser();
2670
2671 auto *ShuffleCmpTy =
2672 FixedVectorType::get(BinOpTy->getElementType(), ShuffleDstTy);
2674 SK0, ShuffleCmpTy, BinOpTy, NewMask0, CostKind, 0, nullptr, {X, Z});
2675 if (!SingleSrcBinOp)
2676 NewCost += TTI.getShuffleCost(SK1, ShuffleCmpTy, BinOpTy, NewMask1,
2677 CostKind, 0, nullptr, {Y, W});
2678
2679 if (PredLHS == CmpInst::BAD_ICMP_PREDICATE) {
2680 NewCost += TTI.getArithmeticInstrCost(LHS->getOpcode(), ShuffleDstTy,
2681 CostKind, Op0Info, Op1Info);
2682 } else {
2683 NewCost +=
2684 TTI.getCmpSelInstrCost(LHS->getOpcode(), ShuffleCmpTy, ShuffleDstTy,
2685 PredLHS, CostKind, Op0Info, Op1Info);
2686 }
2687 // If LHS/RHS have other uses, we need to account for the cost of keeping
2688 // the original instructions. When SameBinOp, only add the cost once.
2689 if (!LHS->hasOneUser())
2690 NewCost += LHSCost;
2691 if (!SameBinOp && !RHS->hasOneUser())
2692 NewCost += RHSCost;
2693
2694 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two binops: " << I
2695 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2696 << "\n");
2697
2698 // If either shuffle will constant fold away, then fold for the same cost as
2699 // we will reduce the instruction count.
2700 ReducedInstCount |= (isa<Constant>(X) && isa<Constant>(Z)) ||
2701 (isa<Constant>(Y) && isa<Constant>(W));
2702 if (ReducedInstCount ? (NewCost > OldCost) : (NewCost >= OldCost))
2703 return false;
2704
2705 Value *Shuf0 = Builder.CreateShuffleVector(X, Z, NewMask0);
2706 Value *Shuf1 =
2707 SingleSrcBinOp ? Shuf0 : Builder.CreateShuffleVector(Y, W, NewMask1);
2708 Value *NewBO = PredLHS == CmpInst::BAD_ICMP_PREDICATE
2709 ? Builder.CreateBinOp(
2710 cast<BinaryOperator>(LHS)->getOpcode(), Shuf0, Shuf1)
2711 : Builder.CreateCmp(PredLHS, Shuf0, Shuf1);
2712
2713 // Intersect flags from the old binops.
2714 if (auto *NewInst = dyn_cast<Instruction>(NewBO)) {
2715 NewInst->copyIRFlags(LHS);
2716 NewInst->andIRFlags(RHS);
2717 }
2718
2719 Worklist.pushValue(Shuf0);
2720 Worklist.pushValue(Shuf1);
2721 replaceValue(I, *NewBO);
2722 return true;
2723}
2724
2725/// Try to convert,
2726/// (shuffle(select(c1,t1,f1)), (select(c2,t2,f2)), m) into
2727/// (select (shuffle c1,c2,m), (shuffle t1,t2,m), (shuffle f1,f2,m))
2728bool VectorCombine::foldShuffleOfSelects(Instruction &I) {
2729 ArrayRef<int> Mask;
2730 Value *C1, *T1, *F1, *C2, *T2, *F2;
2731 if (!match(&I, m_Shuffle(m_Select(m_Value(C1), m_Value(T1), m_Value(F1)),
2732 m_Select(m_Value(C2), m_Value(T2), m_Value(F2)),
2733 m_Mask(Mask))))
2734 return false;
2735
2736 auto *Sel1 = cast<Instruction>(I.getOperand(0));
2737 auto *Sel2 = cast<Instruction>(I.getOperand(1));
2738
2739 auto *C1VecTy = dyn_cast<FixedVectorType>(C1->getType());
2740 auto *C2VecTy = dyn_cast<FixedVectorType>(C2->getType());
2741 if (!C1VecTy || !C2VecTy || C1VecTy != C2VecTy)
2742 return false;
2743
2744 auto *SI0FOp = dyn_cast<FPMathOperator>(I.getOperand(0));
2745 auto *SI1FOp = dyn_cast<FPMathOperator>(I.getOperand(1));
2746 // SelectInsts must have the same FMF.
2747 if (((SI0FOp == nullptr) != (SI1FOp == nullptr)) ||
2748 ((SI0FOp != nullptr) &&
2749 (SI0FOp->getFastMathFlags() != SI1FOp->getFastMathFlags())))
2750 return false;
2751
2752 auto *SrcVecTy = cast<FixedVectorType>(T1->getType());
2753 auto *DstVecTy = cast<FixedVectorType>(I.getType());
2755 auto SelOp = Instruction::Select;
2756
2758 SelOp, SrcVecTy, C1VecTy, CmpInst::BAD_ICMP_PREDICATE, CostKind);
2760 SelOp, SrcVecTy, C2VecTy, CmpInst::BAD_ICMP_PREDICATE, CostKind);
2761
2762 InstructionCost OldCost =
2763 CostSel1 + CostSel2 +
2764 TTI.getShuffleCost(SK, DstVecTy, SrcVecTy, Mask, CostKind, 0, nullptr,
2765 {I.getOperand(0), I.getOperand(1)}, &I);
2766
2768 SK, FixedVectorType::get(C1VecTy->getScalarType(), Mask.size()), C1VecTy,
2769 Mask, CostKind, 0, nullptr, {C1, C2});
2770 NewCost += TTI.getShuffleCost(SK, DstVecTy, SrcVecTy, Mask, CostKind, 0,
2771 nullptr, {T1, T2});
2772 NewCost += TTI.getShuffleCost(SK, DstVecTy, SrcVecTy, Mask, CostKind, 0,
2773 nullptr, {F1, F2});
2774 auto *C1C2ShuffledVecTy = FixedVectorType::get(
2775 Type::getInt1Ty(I.getContext()), DstVecTy->getNumElements());
2776 NewCost += TTI.getCmpSelInstrCost(SelOp, DstVecTy, C1C2ShuffledVecTy,
2778
2779 if (!Sel1->hasOneUse())
2780 NewCost += CostSel1;
2781 if (!Sel2->hasOneUse())
2782 NewCost += CostSel2;
2783
2784 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two selects: " << I
2785 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2786 << "\n");
2787 if (NewCost > OldCost)
2788 return false;
2789
2790 Value *ShuffleCmp = Builder.CreateShuffleVector(C1, C2, Mask);
2791 Value *ShuffleTrue = Builder.CreateShuffleVector(T1, T2, Mask);
2792 Value *ShuffleFalse = Builder.CreateShuffleVector(F1, F2, Mask);
2793 Value *NewSel;
2794 // We presuppose that the SelectInsts have the same FMF.
2795 if (SI0FOp)
2796 NewSel = Builder.CreateSelectFMF(ShuffleCmp, ShuffleTrue, ShuffleFalse,
2797 SI0FOp->getFastMathFlags());
2798 else
2799 NewSel = Builder.CreateSelect(ShuffleCmp, ShuffleTrue, ShuffleFalse);
2800
2801 Worklist.pushValue(ShuffleCmp);
2802 Worklist.pushValue(ShuffleTrue);
2803 Worklist.pushValue(ShuffleFalse);
2804 replaceValue(I, *NewSel);
2805 return true;
2806}
2807
2808/// Try to convert "shuffle (castop), (castop)" with a shared castop operand
2809/// into "castop (shuffle)".
2810bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
2811 Value *V0, *V1;
2812 ArrayRef<int> OldMask;
2813 if (!match(&I, m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(OldMask))))
2814 return false;
2815
2816 // Check whether this is a binary shuffle.
2817 bool IsBinaryShuffle = !isa<UndefValue>(V1);
2818
2819 auto *C0 = dyn_cast<CastInst>(V0);
2820 auto *C1 = dyn_cast<CastInst>(V1);
2821 if (!C0 || (IsBinaryShuffle && !C1))
2822 return false;
2823
2824 Instruction::CastOps Opcode = C0->getOpcode();
2825
2826 // If this is allowed, foldShuffleOfCastops can get stuck in a loop
2827 // with foldBitcastOfShuffle. Reject in favor of foldBitcastOfShuffle.
2828 if (!IsBinaryShuffle && Opcode == Instruction::BitCast)
2829 return false;
2830
2831 if (IsBinaryShuffle) {
2832 if (C0->getSrcTy() != C1->getSrcTy())
2833 return false;
2834 // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
2835 if (Opcode != C1->getOpcode()) {
2836 if (match(C0, m_SExtLike(m_Value())) && match(C1, m_SExtLike(m_Value())))
2837 Opcode = Instruction::SExt;
2838 else
2839 return false;
2840 }
2841 }
2842
2843 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
2844 auto *CastDstTy = dyn_cast<FixedVectorType>(C0->getDestTy());
2845 auto *CastSrcTy = dyn_cast<FixedVectorType>(C0->getSrcTy());
2846 if (!ShuffleDstTy || !CastDstTy || !CastSrcTy)
2847 return false;
2848
2849 unsigned NumSrcElts = CastSrcTy->getNumElements();
2850 unsigned NumDstElts = CastDstTy->getNumElements();
2851 assert((NumDstElts == NumSrcElts || Opcode == Instruction::BitCast) &&
2852 "Only bitcasts expected to alter src/dst element counts");
2853
2854 // Check for bitcasting of unscalable vector types.
2855 // e.g. <32 x i40> -> <40 x i32>
2856 if (NumDstElts != NumSrcElts && (NumSrcElts % NumDstElts) != 0 &&
2857 (NumDstElts % NumSrcElts) != 0)
2858 return false;
2859
2860 SmallVector<int, 16> NewMask;
2861 if (NumSrcElts >= NumDstElts) {
2862 // The bitcast is from wide to narrow/equal elements. The shuffle mask can
2863 // always be expanded to the equivalent form choosing narrower elements.
2864 assert(NumSrcElts % NumDstElts == 0 && "Unexpected shuffle mask");
2865 unsigned ScaleFactor = NumSrcElts / NumDstElts;
2866 narrowShuffleMaskElts(ScaleFactor, OldMask, NewMask);
2867 } else {
2868 // The bitcast is from narrow elements to wide elements. The shuffle mask
2869 // must choose consecutive elements to allow casting first.
2870 assert(NumDstElts % NumSrcElts == 0 && "Unexpected shuffle mask");
2871 unsigned ScaleFactor = NumDstElts / NumSrcElts;
2872 if (!widenShuffleMaskElts(ScaleFactor, OldMask, NewMask))
2873 return false;
2874 }
2875
2876 auto *NewShuffleDstTy =
2877 FixedVectorType::get(CastSrcTy->getScalarType(), NewMask.size());
2878
2879 // Try to replace a castop with a shuffle if the shuffle is not costly.
2880 InstructionCost CostC0 =
2881 TTI.getCastInstrCost(C0->getOpcode(), CastDstTy, CastSrcTy,
2883
2885 if (IsBinaryShuffle)
2887 else
2889
2890 InstructionCost OldCost = CostC0;
2891 OldCost += TTI.getShuffleCost(ShuffleKind, ShuffleDstTy, CastDstTy, OldMask,
2892 CostKind, 0, nullptr, {}, &I);
2893
2894 InstructionCost NewCost = TTI.getShuffleCost(ShuffleKind, NewShuffleDstTy,
2895 CastSrcTy, NewMask, CostKind);
2896 NewCost += TTI.getCastInstrCost(Opcode, ShuffleDstTy, NewShuffleDstTy,
2898 if (!C0->hasOneUse())
2899 NewCost += CostC0;
2900 if (IsBinaryShuffle) {
2901 InstructionCost CostC1 =
2902 TTI.getCastInstrCost(C1->getOpcode(), CastDstTy, CastSrcTy,
2904 OldCost += CostC1;
2905 if (!C1->hasOneUse())
2906 NewCost += CostC1;
2907 }
2908
2909 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two casts: " << I
2910 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2911 << "\n");
2912 if (NewCost > OldCost)
2913 return false;
2914
2915 Value *Shuf;
2916 if (IsBinaryShuffle)
2917 Shuf = Builder.CreateShuffleVector(C0->getOperand(0), C1->getOperand(0),
2918 NewMask);
2919 else
2920 Shuf = Builder.CreateShuffleVector(C0->getOperand(0), NewMask);
2921
2922 Value *Cast = Builder.CreateCast(Opcode, Shuf, ShuffleDstTy);
2923
2924 // Intersect flags from the old casts.
2925 if (auto *NewInst = dyn_cast<Instruction>(Cast)) {
2926 NewInst->copyIRFlags(C0);
2927 if (IsBinaryShuffle)
2928 NewInst->andIRFlags(C1);
2929 }
2930
2931 Worklist.pushValue(Shuf);
2932 replaceValue(I, *Cast);
2933 return true;
2934}
2935
2936/// Try to convert any of:
2937/// "shuffle (shuffle x, y), (shuffle y, x)"
2938/// "shuffle (shuffle x, undef), (shuffle y, undef)"
2939/// "shuffle (shuffle x, undef), y"
2940/// "shuffle x, (shuffle y, undef)"
2941/// into "shuffle x, y".
2942bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
2943 ArrayRef<int> OuterMask;
2944 Value *OuterV0, *OuterV1;
2945 if (!match(&I,
2946 m_Shuffle(m_Value(OuterV0), m_Value(OuterV1), m_Mask(OuterMask))))
2947 return false;
2948
2949 ArrayRef<int> InnerMask0, InnerMask1;
2950 Value *X0, *X1, *Y0, *Y1;
2951 bool Match0 =
2952 match(OuterV0, m_Shuffle(m_Value(X0), m_Value(Y0), m_Mask(InnerMask0)));
2953 bool Match1 =
2954 match(OuterV1, m_Shuffle(m_Value(X1), m_Value(Y1), m_Mask(InnerMask1)));
2955 if (!Match0 && !Match1)
2956 return false;
2957
2958 // If the outer shuffle is a permute, then create a fake inner all-poison
2959 // shuffle. This is easier than accounting for length-changing shuffles below.
2960 SmallVector<int, 16> PoisonMask1;
2961 if (!Match1 && isa<PoisonValue>(OuterV1)) {
2962 X1 = X0;
2963 Y1 = Y0;
2964 PoisonMask1.append(InnerMask0.size(), PoisonMaskElem);
2965 InnerMask1 = PoisonMask1;
2966 Match1 = true; // fake match
2967 }
2968
2969 X0 = Match0 ? X0 : OuterV0;
2970 Y0 = Match0 ? Y0 : OuterV0;
2971 X1 = Match1 ? X1 : OuterV1;
2972 Y1 = Match1 ? Y1 : OuterV1;
2973 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
2974 auto *ShuffleSrcTy = dyn_cast<FixedVectorType>(X0->getType());
2975 auto *ShuffleImmTy = dyn_cast<FixedVectorType>(OuterV0->getType());
2976 if (!ShuffleDstTy || !ShuffleSrcTy || !ShuffleImmTy ||
2977 X0->getType() != X1->getType())
2978 return false;
2979
2980 unsigned NumSrcElts = ShuffleSrcTy->getNumElements();
2981 unsigned NumImmElts = ShuffleImmTy->getNumElements();
2982
2983 // Attempt to merge shuffles, matching upto 2 source operands.
2984 // Replace index to a poison arg with PoisonMaskElem.
2985 // Bail if either inner masks reference an undef arg.
2986 SmallVector<int, 16> NewMask(OuterMask);
2987 Value *NewX = nullptr, *NewY = nullptr;
2988 for (int &M : NewMask) {
2989 Value *Src = nullptr;
2990 if (0 <= M && M < (int)NumImmElts) {
2991 Src = OuterV0;
2992 if (Match0) {
2993 M = InnerMask0[M];
2994 Src = M >= (int)NumSrcElts ? Y0 : X0;
2995 M = M >= (int)NumSrcElts ? (M - NumSrcElts) : M;
2996 }
2997 } else if (M >= (int)NumImmElts) {
2998 Src = OuterV1;
2999 M -= NumImmElts;
3000 if (Match1) {
3001 M = InnerMask1[M];
3002 Src = M >= (int)NumSrcElts ? Y1 : X1;
3003 M = M >= (int)NumSrcElts ? (M - NumSrcElts) : M;
3004 }
3005 }
3006 if (Src && M != PoisonMaskElem) {
3007 assert(0 <= M && M < (int)NumSrcElts && "Unexpected shuffle mask index");
3008 if (isa<UndefValue>(Src)) {
3009 // We've referenced an undef element - if its poison, update the shuffle
3010 // mask, else bail.
3011 if (!isa<PoisonValue>(Src))
3012 return false;
3013 M = PoisonMaskElem;
3014 continue;
3015 }
3016 if (!NewX || NewX == Src) {
3017 NewX = Src;
3018 continue;
3019 }
3020 if (!NewY || NewY == Src) {
3021 M += NumSrcElts;
3022 NewY = Src;
3023 continue;
3024 }
3025 return false;
3026 }
3027 }
3028
3029 if (!NewX)
3030 return PoisonValue::get(ShuffleDstTy);
3031 if (!NewY)
3032 NewY = PoisonValue::get(ShuffleSrcTy);
3033
3034 // Have we folded to an Identity shuffle?
3035 if (ShuffleVectorInst::isIdentityMask(NewMask, NumSrcElts)) {
3036 replaceValue(I, *NewX);
3037 return true;
3038 }
3039
3040 // Try to merge the shuffles if the new shuffle is not costly.
3041 InstructionCost InnerCost0 = 0;
3042 if (Match0)
3043 InnerCost0 = TTI.getInstructionCost(cast<User>(OuterV0), CostKind);
3044
3045 InstructionCost InnerCost1 = 0;
3046 if (Match1)
3047 InnerCost1 = TTI.getInstructionCost(cast<User>(OuterV1), CostKind);
3048
3050
3051 InstructionCost OldCost = InnerCost0 + InnerCost1 + OuterCost;
3052
3053 bool IsUnary = all_of(NewMask, [&](int M) { return M < (int)NumSrcElts; });
3057 InstructionCost NewCost =
3058 TTI.getShuffleCost(SK, ShuffleDstTy, ShuffleSrcTy, NewMask, CostKind, 0,
3059 nullptr, {NewX, NewY});
3060 if (!OuterV0->hasOneUse())
3061 NewCost += InnerCost0;
3062 if (!OuterV1->hasOneUse())
3063 NewCost += InnerCost1;
3064
3065 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two shuffles: " << I
3066 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
3067 << "\n");
3068 if (NewCost > OldCost)
3069 return false;
3070
3071 Value *Shuf = Builder.CreateShuffleVector(NewX, NewY, NewMask);
3072 replaceValue(I, *Shuf);
3073 return true;
3074}
3075
3076/// Try to convert a chain of length-preserving shuffles that are fed by
3077/// length-changing shuffles from the same source, e.g. a chain of length 3:
3078///
3079/// "shuffle (shuffle (shuffle x, (shuffle y, undef)),
3080/// (shuffle y, undef)),
3081// (shuffle y, undef)"
3082///
3083/// into a single shuffle fed by a length-changing shuffle:
3084///
3085/// "shuffle x, (shuffle y, undef)"
3086///
3087/// Such chains arise e.g. from folding extract/insert sequences.
3088bool VectorCombine::foldShufflesOfLengthChangingShuffles(Instruction &I) {
3089 FixedVectorType *TrunkType = dyn_cast<FixedVectorType>(I.getType());
3090 if (!TrunkType)
3091 return false;
3092
3093 unsigned ChainLength = 0;
3094 SmallVector<int> Mask;
3095 SmallVector<int> YMask;
3096 InstructionCost OldCost = 0;
3097 InstructionCost NewCost = 0;
3098 Value *Trunk = &I;
3099 unsigned NumTrunkElts = TrunkType->getNumElements();
3100 Value *Y = nullptr;
3101
3102 for (;;) {
3103 // Match the current trunk against (commutations of) the pattern
3104 // "shuffle trunk', (shuffle y, undef)"
3105 ArrayRef<int> OuterMask;
3106 Value *OuterV0, *OuterV1;
3107 if (ChainLength != 0 && !Trunk->hasOneUse())
3108 break;
3109 if (!match(Trunk, m_Shuffle(m_Value(OuterV0), m_Value(OuterV1),
3110 m_Mask(OuterMask))))
3111 break;
3112 if (OuterV0->getType() != TrunkType) {
3113 // This shuffle is not length-preserving, so it cannot be part of the
3114 // chain.
3115 break;
3116 }
3117
3118 ArrayRef<int> InnerMask0, InnerMask1;
3119 Value *A0, *A1, *B0, *B1;
3120 bool Match0 =
3121 match(OuterV0, m_Shuffle(m_Value(A0), m_Value(B0), m_Mask(InnerMask0)));
3122 bool Match1 =
3123 match(OuterV1, m_Shuffle(m_Value(A1), m_Value(B1), m_Mask(InnerMask1)));
3124 bool Match0Leaf = Match0 && A0->getType() != I.getType();
3125 bool Match1Leaf = Match1 && A1->getType() != I.getType();
3126 if (Match0Leaf == Match1Leaf) {
3127 // Only handle the case of exactly one leaf in each step. The "two leaves"
3128 // case is handled by foldShuffleOfShuffles.
3129 break;
3130 }
3131
3132 SmallVector<int> CommutedOuterMask;
3133 if (Match0Leaf) {
3134 std::swap(OuterV0, OuterV1);
3135 std::swap(InnerMask0, InnerMask1);
3136 std::swap(A0, A1);
3137 std::swap(B0, B1);
3138 llvm::append_range(CommutedOuterMask, OuterMask);
3139 for (int &M : CommutedOuterMask) {
3140 if (M == PoisonMaskElem)
3141 continue;
3142 if (M < (int)NumTrunkElts)
3143 M += NumTrunkElts;
3144 else
3145 M -= NumTrunkElts;
3146 }
3147 OuterMask = CommutedOuterMask;
3148 }
3149 if (!OuterV1->hasOneUse())
3150 break;
3151
3152 if (!isa<UndefValue>(A1)) {
3153 if (!Y)
3154 Y = A1;
3155 else if (Y != A1)
3156 break;
3157 }
3158 if (!isa<UndefValue>(B1)) {
3159 if (!Y)
3160 Y = B1;
3161 else if (Y != B1)
3162 break;
3163 }
3164
3165 auto *YType = cast<FixedVectorType>(A1->getType());
3166 int NumLeafElts = YType->getNumElements();
3167 SmallVector<int> LocalYMask(InnerMask1);
3168 for (int &M : LocalYMask) {
3169 if (M >= NumLeafElts)
3170 M -= NumLeafElts;
3171 }
3172
3173 InstructionCost LocalOldCost =
3176
3177 // Handle the initial (start of chain) case.
3178 if (!ChainLength) {
3179 Mask.assign(OuterMask);
3180 YMask.assign(LocalYMask);
3181 OldCost = NewCost = LocalOldCost;
3182 Trunk = OuterV0;
3183 ChainLength++;
3184 continue;
3185 }
3186
3187 // For the non-root case, first attempt to combine masks.
3188 SmallVector<int> NewYMask(YMask);
3189 bool Valid = true;
3190 for (auto [CombinedM, LeafM] : llvm::zip(NewYMask, LocalYMask)) {
3191 if (LeafM == -1 || CombinedM == LeafM)
3192 continue;
3193 if (CombinedM == -1) {
3194 CombinedM = LeafM;
3195 } else {
3196 Valid = false;
3197 break;
3198 }
3199 }
3200 if (!Valid)
3201 break;
3202
3203 SmallVector<int> NewMask;
3204 NewMask.reserve(NumTrunkElts);
3205 for (int M : Mask) {
3206 if (M < 0 || M >= static_cast<int>(NumTrunkElts))
3207 NewMask.push_back(M);
3208 else
3209 NewMask.push_back(OuterMask[M]);
3210 }
3211
3212 // Break the chain if adding this new step complicates the shuffles such
3213 // that it would increase the new cost by more than the old cost of this
3214 // step.
3215 InstructionCost LocalNewCost =
3217 YType, NewYMask, CostKind) +
3219 TrunkType, NewMask, CostKind);
3220
3221 if (LocalNewCost >= NewCost && LocalOldCost < LocalNewCost - NewCost)
3222 break;
3223
3224 LLVM_DEBUG({
3225 if (ChainLength == 1) {
3226 dbgs() << "Found chain of shuffles fed by length-changing shuffles: "
3227 << I << '\n';
3228 }
3229 dbgs() << " next chain link: " << *Trunk << '\n'
3230 << " old cost: " << (OldCost + LocalOldCost)
3231 << " new cost: " << LocalNewCost << '\n';
3232 });
3233
3234 Mask = NewMask;
3235 YMask = NewYMask;
3236 OldCost += LocalOldCost;
3237 NewCost = LocalNewCost;
3238 Trunk = OuterV0;
3239 ChainLength++;
3240 }
3241 if (ChainLength <= 1)
3242 return false;
3243
3244 if (llvm::all_of(Mask, [&](int M) {
3245 return M < 0 || M >= static_cast<int>(NumTrunkElts);
3246 })) {
3247 // Produce a canonical simplified form if all elements are sourced from Y.
3248 for (int &M : Mask) {
3249 if (M >= static_cast<int>(NumTrunkElts))
3250 M = YMask[M - NumTrunkElts];
3251 }
3252 Value *Root =
3253 Builder.CreateShuffleVector(Y, PoisonValue::get(Y->getType()), Mask);
3254 replaceValue(I, *Root);
3255 return true;
3256 }
3257
3258 Value *Leaf =
3259 Builder.CreateShuffleVector(Y, PoisonValue::get(Y->getType()), YMask);
3260 Value *Root = Builder.CreateShuffleVector(Trunk, Leaf, Mask);
3261 replaceValue(I, *Root);
3262 return true;
3263}
3264
3265/// Try to convert
3266/// "shuffle (intrinsic), (intrinsic)" into "intrinsic (shuffle), (shuffle)".
3267bool VectorCombine::foldShuffleOfIntrinsics(Instruction &I) {
3268 Value *V0, *V1;
3269 ArrayRef<int> OldMask;
3270 if (!match(&I, m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(OldMask))))
3271 return false;
3272
3273 auto *II0 = dyn_cast<IntrinsicInst>(V0);
3274 auto *II1 = dyn_cast<IntrinsicInst>(V1);
3275 if (!II0 || !II1)
3276 return false;
3277
3278 Intrinsic::ID IID = II0->getIntrinsicID();
3279 if (IID != II1->getIntrinsicID())
3280 return false;
3281 InstructionCost CostII0 =
3282 TTI.getIntrinsicInstrCost(IntrinsicCostAttributes(IID, *II0), CostKind);
3283 InstructionCost CostII1 =
3284 TTI.getIntrinsicInstrCost(IntrinsicCostAttributes(IID, *II1), CostKind);
3285
3286 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
3287 auto *II0Ty = dyn_cast<FixedVectorType>(II0->getType());
3288 if (!ShuffleDstTy || !II0Ty)
3289 return false;
3290
3291 if (!isTriviallyVectorizable(IID))
3292 return false;
3293
3294 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I)
3296 II0->getArgOperand(I) != II1->getArgOperand(I))
3297 return false;
3298
3299 InstructionCost OldCost =
3300 CostII0 + CostII1 +
3302 II0Ty, OldMask, CostKind, 0, nullptr, {II0, II1}, &I);
3303
3304 SmallVector<Type *> NewArgsTy;
3305 InstructionCost NewCost = 0;
3306 SmallDenseSet<std::pair<Value *, Value *>> SeenOperandPairs;
3307 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
3309 NewArgsTy.push_back(II0->getArgOperand(I)->getType());
3310 } else {
3311 auto *VecTy = cast<FixedVectorType>(II0->getArgOperand(I)->getType());
3312 auto *ArgTy = FixedVectorType::get(VecTy->getElementType(),
3313 ShuffleDstTy->getNumElements());
3314 NewArgsTy.push_back(ArgTy);
3315 std::pair<Value *, Value *> OperandPair =
3316 std::make_pair(II0->getArgOperand(I), II1->getArgOperand(I));
3317 if (!SeenOperandPairs.insert(OperandPair).second) {
3318 // We've already computed the cost for this operand pair.
3319 continue;
3320 }
3321 NewCost += TTI.getShuffleCost(
3322 TargetTransformInfo::SK_PermuteTwoSrc, ArgTy, VecTy, OldMask,
3323 CostKind, 0, nullptr, {II0->getArgOperand(I), II1->getArgOperand(I)});
3324 }
3325 }
3326 IntrinsicCostAttributes NewAttr(IID, ShuffleDstTy, NewArgsTy);
3327
3328 NewCost += TTI.getIntrinsicInstrCost(NewAttr, CostKind);
3329 if (!II0->hasOneUse())
3330 NewCost += CostII0;
3331 if (II1 != II0 && !II1->hasOneUse())
3332 NewCost += CostII1;
3333
3334 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two intrinsics: " << I
3335 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
3336 << "\n");
3337
3338 if (NewCost > OldCost)
3339 return false;
3340
3341 SmallVector<Value *> NewArgs;
3342 SmallDenseMap<std::pair<Value *, Value *>, Value *> ShuffleCache;
3343 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I)
3345 NewArgs.push_back(II0->getArgOperand(I));
3346 } else {
3347 std::pair<Value *, Value *> OperandPair =
3348 std::make_pair(II0->getArgOperand(I), II1->getArgOperand(I));
3349 auto It = ShuffleCache.find(OperandPair);
3350 if (It != ShuffleCache.end()) {
3351 // Reuse previously created shuffle for this operand pair.
3352 NewArgs.push_back(It->second);
3353 continue;
3354 }
3355 Value *Shuf = Builder.CreateShuffleVector(II0->getArgOperand(I),
3356 II1->getArgOperand(I), OldMask);
3357 ShuffleCache[OperandPair] = Shuf;
3358 NewArgs.push_back(Shuf);
3359 Worklist.pushValue(Shuf);
3360 }
3361 Value *NewIntrinsic = Builder.CreateIntrinsic(ShuffleDstTy, IID, NewArgs);
3362
3363 // Intersect flags from the old intrinsics.
3364 if (auto *NewInst = dyn_cast<Instruction>(NewIntrinsic)) {
3365 NewInst->copyIRFlags(II0);
3366 NewInst->andIRFlags(II1);
3367 }
3368
3369 replaceValue(I, *NewIntrinsic);
3370 return true;
3371}
3372
3373/// Try to convert
3374/// "shuffle (intrinsic), (poison/undef)" into "intrinsic (shuffle)".
3375bool VectorCombine::foldPermuteOfIntrinsic(Instruction &I) {
3376 Value *V0;
3377 ArrayRef<int> Mask;
3378 if (!match(&I, m_Shuffle(m_Value(V0), m_Undef(), m_Mask(Mask))))
3379 return false;
3380
3381 auto *II0 = dyn_cast<IntrinsicInst>(V0);
3382 if (!II0)
3383 return false;
3384
3385 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
3386 auto *IntrinsicSrcTy = dyn_cast<FixedVectorType>(II0->getType());
3387 if (!ShuffleDstTy || !IntrinsicSrcTy)
3388 return false;
3389
3390 // Validate it's a pure permute, mask should only reference the first vector
3391 unsigned NumSrcElts = IntrinsicSrcTy->getNumElements();
3392 if (any_of(Mask, [NumSrcElts](int M) { return M >= (int)NumSrcElts; }))
3393 return false;
3394
3395 Intrinsic::ID IID = II0->getIntrinsicID();
3396 if (!isTriviallyVectorizable(IID))
3397 return false;
3398
3399 // Cost analysis
3401 TTI.getIntrinsicInstrCost(IntrinsicCostAttributes(IID, *II0), CostKind);
3402 InstructionCost OldCost =
3405 IntrinsicSrcTy, Mask, CostKind, 0, nullptr, {V0}, &I);
3406
3407 SmallVector<Type *> NewArgsTy;
3408 InstructionCost NewCost = 0;
3409 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
3411 NewArgsTy.push_back(II0->getArgOperand(I)->getType());
3412 } else {
3413 auto *VecTy = cast<FixedVectorType>(II0->getArgOperand(I)->getType());
3414 auto *ArgTy = FixedVectorType::get(VecTy->getElementType(),
3415 ShuffleDstTy->getNumElements());
3416 NewArgsTy.push_back(ArgTy);
3418 ArgTy, VecTy, Mask, CostKind, 0, nullptr,
3419 {II0->getArgOperand(I)});
3420 }
3421 }
3422 IntrinsicCostAttributes NewAttr(IID, ShuffleDstTy, NewArgsTy);
3423 NewCost += TTI.getIntrinsicInstrCost(NewAttr, CostKind);
3424
3425 // If the intrinsic has multiple uses, we need to account for the cost of
3426 // keeping the original intrinsic around.
3427 if (!II0->hasOneUse())
3428 NewCost += IntrinsicCost;
3429
3430 LLVM_DEBUG(dbgs() << "Found a permute of intrinsic: " << I << "\n OldCost: "
3431 << OldCost << " vs NewCost: " << NewCost << "\n");
3432
3433 if (NewCost > OldCost)
3434 return false;
3435
3436 // Transform
3437 SmallVector<Value *> NewArgs;
3438 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
3440 NewArgs.push_back(II0->getArgOperand(I));
3441 } else {
3442 Value *Shuf = Builder.CreateShuffleVector(II0->getArgOperand(I), Mask);
3443 NewArgs.push_back(Shuf);
3444 Worklist.pushValue(Shuf);
3445 }
3446 }
3447
3448 Value *NewIntrinsic = Builder.CreateIntrinsic(ShuffleDstTy, IID, NewArgs);
3449
3450 if (auto *NewInst = dyn_cast<Instruction>(NewIntrinsic))
3451 NewInst->copyIRFlags(II0);
3452
3453 replaceValue(I, *NewIntrinsic);
3454 return true;
3455}
3456
3457using InstLane = std::pair<Value *, int>;
3458
3459static InstLane lookThroughShuffles(Value *V, int Lane) {
3460 while (auto *SV = dyn_cast<ShuffleVectorInst>(V)) {
3461 unsigned NumElts =
3462 cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements();
3463 int M = SV->getMaskValue(Lane);
3464 if (M < 0)
3465 return {nullptr, PoisonMaskElem};
3466 if (static_cast<unsigned>(M) < NumElts) {
3467 V = SV->getOperand(0);
3468 Lane = M;
3469 } else {
3470 V = SV->getOperand(1);
3471 Lane = M - NumElts;
3472 }
3473 }
3474 return InstLane{V, Lane};
3475}
3476
3480 for (InstLane IL : Item) {
3481 auto [U, Lane] = IL;
3482 InstLane OpLane =
3483 U ? lookThroughShuffles(cast<Instruction>(U)->getOperand(Op), Lane)
3484 : InstLane{nullptr, PoisonMaskElem};
3485 NItem.emplace_back(OpLane);
3486 }
3487 return NItem;
3488}
3489
3490/// Detect concat of multiple values into a vector
3492 const TargetTransformInfo &TTI) {
3493 auto *Ty = cast<FixedVectorType>(Item.front().first->getType());
3494 unsigned NumElts = Ty->getNumElements();
3495 if (Item.size() == NumElts || NumElts == 1 || Item.size() % NumElts != 0)
3496 return false;
3497
3498 // Check that the concat is free, usually meaning that the type will be split
3499 // during legalization.
3500 SmallVector<int, 16> ConcatMask(NumElts * 2);
3501 std::iota(ConcatMask.begin(), ConcatMask.end(), 0);
3502 if (TTI.getShuffleCost(TTI::SK_PermuteTwoSrc,
3503 FixedVectorType::get(Ty->getScalarType(), NumElts * 2),
3504 Ty, ConcatMask, CostKind) != 0)
3505 return false;
3506
3507 unsigned NumSlices = Item.size() / NumElts;
3508 // Currently we generate a tree of shuffles for the concats, which limits us
3509 // to a power2.
3510 if (!isPowerOf2_32(NumSlices))
3511 return false;
3512 for (unsigned Slice = 0; Slice < NumSlices; ++Slice) {
3513 Value *SliceV = Item[Slice * NumElts].first;
3514 if (!SliceV || SliceV->getType() != Ty)
3515 return false;
3516 for (unsigned Elt = 0; Elt < NumElts; ++Elt) {
3517 auto [V, Lane] = Item[Slice * NumElts + Elt];
3518 if (Lane != static_cast<int>(Elt) || SliceV != V)
3519 return false;
3520 }
3521 }
3522 return true;
3523}
3524
3525static Value *
3527 const DenseSet<std::pair<Value *, Use *>> &IdentityLeafs,
3528 const DenseSet<std::pair<Value *, Use *>> &SplatLeafs,
3529 const DenseSet<std::pair<Value *, Use *>> &ConcatLeafs,
3530 IRBuilderBase &Builder, const TargetTransformInfo *TTI) {
3531 auto [FrontV, FrontLane] = Item.front();
3532
3533 if (IdentityLeafs.contains(std::make_pair(FrontV, From))) {
3534 return FrontV;
3535 }
3536 if (SplatLeafs.contains(std::make_pair(FrontV, From))) {
3537 SmallVector<int, 16> Mask(Ty->getNumElements(), FrontLane);
3538 return Builder.CreateShuffleVector(FrontV, Mask);
3539 }
3540 if (ConcatLeafs.contains(std::make_pair(FrontV, From))) {
3541 unsigned NumElts =
3542 cast<FixedVectorType>(FrontV->getType())->getNumElements();
3543 SmallVector<Value *> Values(Item.size() / NumElts, nullptr);
3544 for (unsigned S = 0; S < Values.size(); ++S)
3545 Values[S] = Item[S * NumElts].first;
3546
3547 while (Values.size() > 1) {
3548 NumElts *= 2;
3549 SmallVector<int, 16> Mask(NumElts, 0);
3550 std::iota(Mask.begin(), Mask.end(), 0);
3551 SmallVector<Value *> NewValues(Values.size() / 2, nullptr);
3552 for (unsigned S = 0; S < NewValues.size(); ++S)
3553 NewValues[S] =
3554 Builder.CreateShuffleVector(Values[S * 2], Values[S * 2 + 1], Mask);
3555 Values = NewValues;
3556 }
3557 return Values[0];
3558 }
3559
3560 auto *I = cast<Instruction>(FrontV);
3561 auto *II = dyn_cast<IntrinsicInst>(I);
3562 unsigned NumOps = I->getNumOperands() - (II ? 1 : 0);
3564 for (unsigned Idx = 0; Idx < NumOps; Idx++) {
3565 if (II &&
3566 isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Idx, TTI)) {
3567 Ops[Idx] = II->getOperand(Idx);
3568 continue;
3569 }
3571 &I->getOperandUse(Idx), Ty, IdentityLeafs,
3572 SplatLeafs, ConcatLeafs, Builder, TTI);
3573 }
3574
3575 SmallVector<Value *, 8> ValueList;
3576 for (const auto &Lane : Item)
3577 if (Lane.first)
3578 ValueList.push_back(Lane.first);
3579
3580 Type *DstTy =
3581 FixedVectorType::get(I->getType()->getScalarType(), Ty->getNumElements());
3582 if (auto *BI = dyn_cast<BinaryOperator>(I)) {
3583 auto *Value = Builder.CreateBinOp((Instruction::BinaryOps)BI->getOpcode(),
3584 Ops[0], Ops[1]);
3585 propagateIRFlags(Value, ValueList);
3586 return Value;
3587 }
3588 if (auto *CI = dyn_cast<CmpInst>(I)) {
3589 auto *Value = Builder.CreateCmp(CI->getPredicate(), Ops[0], Ops[1]);
3590 propagateIRFlags(Value, ValueList);
3591 return Value;
3592 }
3593 if (auto *SI = dyn_cast<SelectInst>(I)) {
3594 auto *Value = Builder.CreateSelect(Ops[0], Ops[1], Ops[2], "", SI);
3595 propagateIRFlags(Value, ValueList);
3596 return Value;
3597 }
3598 if (auto *CI = dyn_cast<CastInst>(I)) {
3599 auto *Value = Builder.CreateCast(CI->getOpcode(), Ops[0], DstTy);
3600 propagateIRFlags(Value, ValueList);
3601 return Value;
3602 }
3603 if (II) {
3604 auto *Value = Builder.CreateIntrinsic(DstTy, II->getIntrinsicID(), Ops);
3605 propagateIRFlags(Value, ValueList);
3606 return Value;
3607 }
3608 assert(isa<UnaryInstruction>(I) && "Unexpected instruction type in Generate");
3609 auto *Value =
3610 Builder.CreateUnOp((Instruction::UnaryOps)I->getOpcode(), Ops[0]);
3611 propagateIRFlags(Value, ValueList);
3612 return Value;
3613}
3614
3615// Starting from a shuffle, look up through operands tracking the shuffled index
3616// of each lane. If we can simplify away the shuffles to identities then
3617// do so.
3618bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
3619 auto *Ty = dyn_cast<FixedVectorType>(I.getType());
3620 if (!Ty || I.use_empty())
3621 return false;
3622
3623 SmallVector<InstLane> Start(Ty->getNumElements());
3624 for (unsigned M = 0, E = Ty->getNumElements(); M < E; ++M)
3625 Start[M] = lookThroughShuffles(&I, M);
3626
3628 Worklist.push_back(std::make_pair(Start, &*I.use_begin()));
3629 DenseSet<std::pair<Value *, Use *>> IdentityLeafs, SplatLeafs, ConcatLeafs;
3630 unsigned NumVisited = 0;
3631
3632 while (!Worklist.empty()) {
3633 if (++NumVisited > MaxInstrsToScan)
3634 return false;
3635
3636 auto ItemFrom = Worklist.pop_back_val();
3637 auto Item = ItemFrom.first;
3638 auto From = ItemFrom.second;
3639 auto [FrontV, FrontLane] = Item.front();
3640
3641 // If we found an undef first lane then bail out to keep things simple.
3642 if (!FrontV)
3643 return false;
3644
3645 // Helper to peek through bitcasts to the same value.
3646 auto IsEquiv = [&](Value *X, Value *Y) {
3647 return X->getType() == Y->getType() &&
3649 };
3650
3651 // Look for an identity value.
3652 if (FrontLane == 0 &&
3653 cast<FixedVectorType>(FrontV->getType())->getNumElements() ==
3654 Ty->getNumElements() &&
3655 all_of(drop_begin(enumerate(Item)), [IsEquiv, Item](const auto &E) {
3656 Value *FrontV = Item.front().first;
3657 return !E.value().first || (IsEquiv(E.value().first, FrontV) &&
3658 E.value().second == (int)E.index());
3659 })) {
3660 IdentityLeafs.insert(std::make_pair(FrontV, From));
3661 continue;
3662 }
3663 // Look for constants, for the moment only supporting constant splats.
3664 if (auto *C = dyn_cast<Constant>(FrontV);
3665 C && C->getSplatValue() &&
3666 all_of(drop_begin(Item), [Item](InstLane &IL) {
3667 Value *FrontV = Item.front().first;
3668 Value *V = IL.first;
3669 return !V || (isa<Constant>(V) &&
3670 cast<Constant>(V)->getSplatValue() ==
3671 cast<Constant>(FrontV)->getSplatValue());
3672 })) {
3673 SplatLeafs.insert(std::make_pair(FrontV, From));
3674 continue;
3675 }
3676 // Look for a splat value.
3677 if (all_of(drop_begin(Item), [Item](InstLane &IL) {
3678 auto [FrontV, FrontLane] = Item.front();
3679 auto [V, Lane] = IL;
3680 return !V || (V == FrontV && Lane == FrontLane);
3681 })) {
3682 SplatLeafs.insert(std::make_pair(FrontV, From));
3683 continue;
3684 }
3685
3686 // We need each element to be the same type of value, and check that each
3687 // element has a single use.
3688 auto CheckLaneIsEquivalentToFirst = [Item](InstLane IL) {
3689 Value *FrontV = Item.front().first;
3690 if (!IL.first)
3691 return true;
3692 Value *V = IL.first;
3693 if (auto *I = dyn_cast<Instruction>(V); I && !I->hasOneUser())
3694 return false;
3695 if (V->getValueID() != FrontV->getValueID())
3696 return false;
3697 if (auto *CI = dyn_cast<CmpInst>(V))
3698 if (CI->getPredicate() != cast<CmpInst>(FrontV)->getPredicate())
3699 return false;
3700 if (auto *CI = dyn_cast<CastInst>(V))
3701 if (CI->getSrcTy()->getScalarType() !=
3702 cast<CastInst>(FrontV)->getSrcTy()->getScalarType())
3703 return false;
3704 if (auto *SI = dyn_cast<SelectInst>(V))
3705 if (!isa<VectorType>(SI->getOperand(0)->getType()) ||
3706 SI->getOperand(0)->getType() !=
3707 cast<SelectInst>(FrontV)->getOperand(0)->getType())
3708 return false;
3709 if (isa<CallInst>(V) && !isa<IntrinsicInst>(V))
3710 return false;
3711 auto *II = dyn_cast<IntrinsicInst>(V);
3712 return !II || (isa<IntrinsicInst>(FrontV) &&
3713 II->getIntrinsicID() ==
3714 cast<IntrinsicInst>(FrontV)->getIntrinsicID() &&
3715 !II->hasOperandBundles());
3716 };
3717 if (all_of(drop_begin(Item), CheckLaneIsEquivalentToFirst)) {
3718 // Check the operator is one that we support.
3719 if (isa<BinaryOperator, CmpInst>(FrontV)) {
3720 // We exclude div/rem in case they hit UB from poison lanes.
3721 if (auto *BO = dyn_cast<BinaryOperator>(FrontV);
3722 BO && BO->isIntDivRem())
3723 return false;
3725 &cast<Instruction>(FrontV)->getOperandUse(0));
3727 &cast<Instruction>(FrontV)->getOperandUse(1));
3728 continue;
3729 } else if (isa<UnaryOperator, TruncInst, ZExtInst, SExtInst, FPToSIInst,
3730 FPToUIInst, SIToFPInst, UIToFPInst>(FrontV)) {
3732 &cast<Instruction>(FrontV)->getOperandUse(0));
3733 continue;
3734 } else if (auto *BitCast = dyn_cast<BitCastInst>(FrontV)) {
3735 // TODO: Handle vector widening/narrowing bitcasts.
3736 auto *DstTy = dyn_cast<FixedVectorType>(BitCast->getDestTy());
3737 auto *SrcTy = dyn_cast<FixedVectorType>(BitCast->getSrcTy());
3738 if (DstTy && SrcTy &&
3739 SrcTy->getNumElements() == DstTy->getNumElements()) {
3741 &BitCast->getOperandUse(0));
3742 continue;
3743 }
3744 } else if (auto *Sel = dyn_cast<SelectInst>(FrontV)) {
3746 &Sel->getOperandUse(0));
3748 &Sel->getOperandUse(1));
3750 &Sel->getOperandUse(2));
3751 continue;
3752 } else if (auto *II = dyn_cast<IntrinsicInst>(FrontV);
3753 II && isTriviallyVectorizable(II->getIntrinsicID()) &&
3754 !II->hasOperandBundles()) {
3755 for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) {
3756 if (isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Op,
3757 &TTI)) {
3758 if (!all_of(drop_begin(Item), [Item, Op](InstLane &IL) {
3759 Value *FrontV = Item.front().first;
3760 Value *V = IL.first;
3761 return !V || (cast<Instruction>(V)->getOperand(Op) ==
3762 cast<Instruction>(FrontV)->getOperand(Op));
3763 }))
3764 return false;
3765 continue;
3766 }
3768 &cast<Instruction>(FrontV)->getOperandUse(Op));
3769 }
3770 continue;
3771 }
3772 }
3773
3774 if (isFreeConcat(Item, CostKind, TTI)) {
3775 ConcatLeafs.insert(std::make_pair(FrontV, From));
3776 continue;
3777 }
3778
3779 return false;
3780 }
3781
3782 if (NumVisited <= 1)
3783 return false;
3784
3785 LLVM_DEBUG(dbgs() << "Found a superfluous identity shuffle: " << I << "\n");
3786
3787 // If we got this far, we know the shuffles are superfluous and can be
3788 // removed. Scan through again and generate the new tree of instructions.
3789 Builder.SetInsertPoint(&I);
3790 Value *V = generateNewInstTree(Start, &*I.use_begin(), Ty, IdentityLeafs,
3791 SplatLeafs, ConcatLeafs, Builder, &TTI);
3792 replaceValue(I, *V);
3793 return true;
3794}
3795
3796/// Given a commutative reduction, the order of the input lanes does not alter
3797/// the results. We can use this to remove certain shuffles feeding the
3798/// reduction, removing the need to shuffle at all.
3799bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
3800 auto *II = dyn_cast<IntrinsicInst>(&I);
3801 if (!II)
3802 return false;
3803 switch (II->getIntrinsicID()) {
3804 case Intrinsic::vector_reduce_add:
3805 case Intrinsic::vector_reduce_mul:
3806 case Intrinsic::vector_reduce_and:
3807 case Intrinsic::vector_reduce_or:
3808 case Intrinsic::vector_reduce_xor:
3809 case Intrinsic::vector_reduce_smin:
3810 case Intrinsic::vector_reduce_smax:
3811 case Intrinsic::vector_reduce_umin:
3812 case Intrinsic::vector_reduce_umax:
3813 break;
3814 default:
3815 return false;
3816 }
3817
3818 // Find all the inputs when looking through operations that do not alter the
3819 // lane order (binops, for example). Currently we look for a single shuffle,
3820 // and can ignore splat values.
3821 std::queue<Value *> Worklist;
3822 SmallPtrSet<Value *, 4> Visited;
3823 ShuffleVectorInst *Shuffle = nullptr;
3824 if (auto *Op = dyn_cast<Instruction>(I.getOperand(0)))
3825 Worklist.push(Op);
3826
3827 while (!Worklist.empty()) {
3828 Value *CV = Worklist.front();
3829 Worklist.pop();
3830 if (Visited.contains(CV))
3831 continue;
3832
3833 // Splats don't change the order, so can be safely ignored.
3834 if (isSplatValue(CV))
3835 continue;
3836
3837 Visited.insert(CV);
3838
3839 if (auto *CI = dyn_cast<Instruction>(CV)) {
3840 if (CI->isBinaryOp()) {
3841 for (auto *Op : CI->operand_values())
3842 Worklist.push(Op);
3843 continue;
3844 } else if (auto *SV = dyn_cast<ShuffleVectorInst>(CI)) {
3845 if (Shuffle && Shuffle != SV)
3846 return false;
3847 Shuffle = SV;
3848 continue;
3849 }
3850 }
3851
3852 // Anything else is currently an unknown node.
3853 return false;
3854 }
3855
3856 if (!Shuffle)
3857 return false;
3858
3859 // Check all uses of the binary ops and shuffles are also included in the
3860 // lane-invariant operations (Visited should be the list of lanewise
3861 // instructions, including the shuffle that we found).
3862 for (auto *V : Visited)
3863 for (auto *U : V->users())
3864 if (!Visited.contains(U) && U != &I)
3865 return false;
3866
3867 FixedVectorType *VecType =
3868 dyn_cast<FixedVectorType>(II->getOperand(0)->getType());
3869 if (!VecType)
3870 return false;
3871 FixedVectorType *ShuffleInputType =
3873 if (!ShuffleInputType)
3874 return false;
3875 unsigned NumInputElts = ShuffleInputType->getNumElements();
3876
3877 // Find the mask from sorting the lanes into order. This is most likely to
3878 // become a identity or concat mask. Undef elements are pushed to the end.
3879 SmallVector<int> ConcatMask;
3880 Shuffle->getShuffleMask(ConcatMask);
3881 sort(ConcatMask, [](int X, int Y) { return (unsigned)X < (unsigned)Y; });
3882 bool UsesSecondVec =
3883 any_of(ConcatMask, [&](int M) { return M >= (int)NumInputElts; });
3884
3886 UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, VecType,
3887 ShuffleInputType, Shuffle->getShuffleMask(), CostKind);
3889 UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, VecType,
3890 ShuffleInputType, ConcatMask, CostKind);
3891
3892 LLVM_DEBUG(dbgs() << "Found a reduction feeding from a shuffle: " << *Shuffle
3893 << "\n");
3894 LLVM_DEBUG(dbgs() << " OldCost: " << OldCost << " vs NewCost: " << NewCost
3895 << "\n");
3896 bool MadeChanges = false;
3897 if (NewCost < OldCost) {
3898 Builder.SetInsertPoint(Shuffle);
3899 Value *NewShuffle = Builder.CreateShuffleVector(
3900 Shuffle->getOperand(0), Shuffle->getOperand(1), ConcatMask);
3901 LLVM_DEBUG(dbgs() << "Created new shuffle: " << *NewShuffle << "\n");
3902 replaceValue(*Shuffle, *NewShuffle);
3903 return true;
3904 }
3905
3906 // See if we can re-use foldSelectShuffle, getting it to reduce the size of
3907 // the shuffle into a nicer order, as it can ignore the order of the shuffles.
3908 MadeChanges |= foldSelectShuffle(*Shuffle, true);
3909 return MadeChanges;
3910}
3911
3912/// For a given chain of patterns of the following form:
3913///
3914/// ```
3915/// %1 = shufflevector <n x ty1> %0, <n x ty1> poison <n x ty2> mask
3916///
3917/// %2 = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %0, <n x
3918/// ty1> %1)
3919/// OR
3920/// %2 = add/mul/or/and/xor <n x ty1> %0, %1
3921///
3922/// %3 = shufflevector <n x ty1> %2, <n x ty1> poison <n x ty2> mask
3923/// ...
3924/// ...
3925/// %(i - 1) = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %(i -
3926/// 3), <n x ty1> %(i - 2)
3927/// OR
3928/// %(i - 1) = add/mul/or/and/xor <n x ty1> %(i - 3), %(i - 2)
3929///
3930/// %(i) = extractelement <n x ty1> %(i - 1), 0
3931/// ```
3932///
3933/// Where:
3934/// `mask` follows a partition pattern:
3935///
3936/// Ex:
3937/// [n = 8, p = poison]
3938///
3939/// 4 5 6 7 | p p p p
3940/// 2 3 | p p p p p p
3941/// 1 | p p p p p p p
3942///
3943/// For powers of 2, there's a consistent pattern, but for other cases
3944/// the parity of the current half value at each step decides the
3945/// next partition half (see `ExpectedParityMask` for more logical details
3946/// in generalising this).
3947///
3948/// Ex:
3949/// [n = 6]
3950///
3951/// 3 4 5 | p p p
3952/// 1 2 | p p p p
3953/// 1 | p p p p p
3954bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
3955 // Going bottom-up for the pattern.
3956 std::queue<Value *> InstWorklist;
3957 InstructionCost OrigCost = 0;
3958
3959 // Common instruction operation after each shuffle op.
3960 std::optional<unsigned int> CommonCallOp = std::nullopt;
3961 std::optional<Instruction::BinaryOps> CommonBinOp = std::nullopt;
3962
3963 bool IsFirstCallOrBinInst = true;
3964 bool ShouldBeCallOrBinInst = true;
3965
3966 // This stores the last used instructions for shuffle/common op.
3967 //
3968 // PrevVecV[0] / PrevVecV[1] store the last two simultaneous
3969 // instructions from either shuffle/common op.
3970 SmallVector<Value *, 2> PrevVecV(2, nullptr);
3971
3972 Value *VecOpEE;
3973 if (!match(&I, m_ExtractElt(m_Value(VecOpEE), m_Zero())))
3974 return false;
3975
3976 auto *FVT = dyn_cast<FixedVectorType>(VecOpEE->getType());
3977 if (!FVT)
3978 return false;
3979
3980 int64_t VecSize = FVT->getNumElements();
3981 if (VecSize < 2)
3982 return false;
3983
3984 // Number of levels would be ~log2(n), considering we always partition
3985 // by half for this fold pattern.
3986 unsigned int NumLevels = Log2_64_Ceil(VecSize), VisitedCnt = 0;
3987 int64_t ShuffleMaskHalf = 1, ExpectedParityMask = 0;
3988
3989 // This is how we generalise for all element sizes.
3990 // At each step, if vector size is odd, we need non-poison
3991 // values to cover the dominant half so we don't miss out on any element.
3992 //
3993 // This mask will help us retrieve this as we go from bottom to top:
3994 //
3995 // Mask Set -> N = N * 2 - 1
3996 // Mask Unset -> N = N * 2
3997 for (int Cur = VecSize, Mask = NumLevels - 1; Cur > 1;
3998 Cur = (Cur + 1) / 2, --Mask) {
3999 if (Cur & 1)
4000 ExpectedParityMask |= (1ll << Mask);
4001 }
4002
4003 InstWorklist.push(VecOpEE);
4004
4005 bool IsPartialReduction = false;
4006
4007 while (!InstWorklist.empty()) {
4008 Value *CI = InstWorklist.front();
4009 InstWorklist.pop();
4010
4011 if (auto *II = dyn_cast<IntrinsicInst>(CI)) {
4012 if (!ShouldBeCallOrBinInst)
4013 return false;
4014
4015 if (!IsFirstCallOrBinInst && any_of(PrevVecV, equal_to(nullptr)))
4016 return false;
4017
4018 // For the first found call/bin op, the vector has to come from the
4019 // extract element op.
4020 if (II != (IsFirstCallOrBinInst ? VecOpEE : PrevVecV[0]))
4021 return false;
4022 IsFirstCallOrBinInst = false;
4023
4024 if (!CommonCallOp)
4025 CommonCallOp = II->getIntrinsicID();
4026 if (II->getIntrinsicID() != *CommonCallOp)
4027 return false;
4028
4029 switch (II->getIntrinsicID()) {
4030 case Intrinsic::umin:
4031 case Intrinsic::umax:
4032 case Intrinsic::smin:
4033 case Intrinsic::smax: {
4034 auto *Op0 = II->getOperand(0);
4035 auto *Op1 = II->getOperand(1);
4036 PrevVecV[0] = Op0;
4037 PrevVecV[1] = Op1;
4038 break;
4039 }
4040 default:
4041 return false;
4042 }
4043 ShouldBeCallOrBinInst ^= 1;
4044
4045 IntrinsicCostAttributes ICA(
4046 *CommonCallOp, II->getType(),
4047 {PrevVecV[0]->getType(), PrevVecV[1]->getType()});
4048 OrigCost += TTI.getIntrinsicInstrCost(ICA, CostKind);
4049
4050 // We may need a swap here since it can be (a, b) or (b, a)
4051 // and accordingly change as we go up.
4052 if (!isa<ShuffleVectorInst>(PrevVecV[1]))
4053 std::swap(PrevVecV[0], PrevVecV[1]);
4054 InstWorklist.push(PrevVecV[1]);
4055 InstWorklist.push(PrevVecV[0]);
4056 } else if (auto *BinOp = dyn_cast<BinaryOperator>(CI)) {
4057 // Similar logic for bin ops.
4058
4059 if (!ShouldBeCallOrBinInst)
4060 return false;
4061
4062 if (!IsFirstCallOrBinInst && any_of(PrevVecV, equal_to(nullptr)))
4063 return false;
4064
4065 if (BinOp != (IsFirstCallOrBinInst ? VecOpEE : PrevVecV[0]))
4066 return false;
4067 IsFirstCallOrBinInst = false;
4068
4069 if (!CommonBinOp)
4070 CommonBinOp = BinOp->getOpcode();
4071
4072 if (BinOp->getOpcode() != *CommonBinOp)
4073 return false;
4074
4075 switch (*CommonBinOp) {
4076 case BinaryOperator::Add:
4077 case BinaryOperator::Mul:
4078 case BinaryOperator::Or:
4079 case BinaryOperator::And:
4080 case BinaryOperator::Xor: {
4081 auto *Op0 = BinOp->getOperand(0);
4082 auto *Op1 = BinOp->getOperand(1);
4083 PrevVecV[0] = Op0;
4084 PrevVecV[1] = Op1;
4085 break;
4086 }
4087 default:
4088 return false;
4089 }
4090 ShouldBeCallOrBinInst ^= 1;
4091
4092 OrigCost +=
4093 TTI.getArithmeticInstrCost(*CommonBinOp, BinOp->getType(), CostKind);
4094
4095 if (!isa<ShuffleVectorInst>(PrevVecV[1]))
4096 std::swap(PrevVecV[0], PrevVecV[1]);
4097 InstWorklist.push(PrevVecV[1]);
4098 InstWorklist.push(PrevVecV[0]);
4099 } else if (auto *SVInst = dyn_cast<ShuffleVectorInst>(CI)) {
4100 // We shouldn't have any null values in the previous vectors,
4101 // is so, there was a mismatch in pattern.
4102 if (ShouldBeCallOrBinInst || any_of(PrevVecV, equal_to(nullptr)))
4103 return false;
4104
4105 if (SVInst != PrevVecV[1])
4106 return false;
4107
4108 ArrayRef<int> CurMask;
4109 if (!match(SVInst, m_Shuffle(m_Specific(PrevVecV[0]), m_Poison(),
4110 m_Mask(CurMask))))
4111 return false;
4112
4113 // Subtract the parity mask when checking the condition.
4114 for (int Mask = 0, MaskSize = CurMask.size(); Mask != MaskSize; ++Mask) {
4115 if (Mask < ShuffleMaskHalf &&
4116 CurMask[Mask] != ShuffleMaskHalf + Mask - (ExpectedParityMask & 1))
4117 return false;
4118 if (Mask >= ShuffleMaskHalf && CurMask[Mask] != -1)
4119 return false;
4120 }
4121
4122 // Update mask values.
4123 ShuffleMaskHalf *= 2;
4124 ShuffleMaskHalf -= (ExpectedParityMask & 1);
4125 ExpectedParityMask >>= 1;
4126
4128 SVInst->getType(), SVInst->getType(),
4129 CurMask, CostKind);
4130
4131 VisitedCnt += 1;
4132 if (!ExpectedParityMask && VisitedCnt == NumLevels)
4133 break;
4134
4135 ShouldBeCallOrBinInst ^= 1;
4136 } else {
4137 // Check if this is a partial reduction - the chain ended because
4138 // the source vector is not a recognized op/shuffle.
4139 // Reject non-power-of-2 vectors because parity-based masks cause
4140 // lane duplication in the reduction tree, making the partial result
4141 // not a simple subvector reduction.
4142 if (ShouldBeCallOrBinInst && VisitedCnt >= 1 && CI == PrevVecV[0] &&
4143 isPowerOf2_64(VecSize)) {
4144 IsPartialReduction = true;
4145 break;
4146 }
4147 return false;
4148 }
4149 }
4150
4151 // Full reduction pattern should end with a shuffle op.
4152 // Partial reduction ends when the source vector is reached.
4153 if (ShouldBeCallOrBinInst && !IsPartialReduction)
4154 return false;
4155
4156 assert(VecSize != -1 && "Expected Match for Vector Size");
4157
4158 Value *FinalVecV = PrevVecV[0];
4159 if (!FinalVecV)
4160 return false;
4161
4162 auto *FinalVecVTy = cast<FixedVectorType>(FinalVecV->getType());
4163
4164 Intrinsic::ID ReducedOp =
4165 (CommonCallOp ? getMinMaxReductionIntrinsicID(*CommonCallOp)
4166 : getReductionForBinop(*CommonBinOp));
4167 if (!ReducedOp)
4168 return false;
4169
4170 InstructionCost NewCost = 0;
4171 FixedVectorType *ReduceVecTy = FinalVecVTy;
4172 SmallVector<int> ExtractMask;
4173
4174 if (IsPartialReduction) {
4175 unsigned SubVecSize = ShuffleMaskHalf;
4176 ReduceVecTy = FixedVectorType::get(FVT->getElementType(), SubVecSize);
4177 ExtractMask.resize(SubVecSize);
4178 std::iota(ExtractMask.begin(), ExtractMask.end(), 0);
4180 ReduceVecTy, FinalVecVTy, ExtractMask,
4181 CostKind, 0, ReduceVecTy);
4182 }
4183
4184 IntrinsicCostAttributes ICA(ReducedOp, ReduceVecTy, {ReduceVecTy});
4185 NewCost += TTI.getIntrinsicInstrCost(ICA, CostKind);
4186
4187 LLVM_DEBUG(dbgs() << "Found reduction shuffle chain: " << I << "\n OldCost : "
4188 << OrigCost << " vs NewCost: " << NewCost << "\n");
4189
4190 if (VecOpEE->hasOneUse() ? (NewCost > OrigCost) : (NewCost >= OrigCost))
4191 return false;
4192
4193 Value *ReduceInput = FinalVecV;
4194 if (IsPartialReduction)
4195 ReduceInput = Builder.CreateShuffleVector(FinalVecV, ExtractMask);
4196
4197 auto *ReducedResult = Builder.CreateIntrinsic(
4198 ReducedOp, {ReduceInput->getType()}, {ReduceInput});
4199 replaceValue(I, *ReducedResult);
4200
4201 return true;
4202}
4203
4204/// Determine if its more efficient to fold:
4205/// reduce(trunc(x)) -> trunc(reduce(x)).
4206/// reduce(sext(x)) -> sext(reduce(x)).
4207/// reduce(zext(x)) -> zext(reduce(x)).
4208bool VectorCombine::foldCastFromReductions(Instruction &I) {
4209 auto *II = dyn_cast<IntrinsicInst>(&I);
4210 if (!II)
4211 return false;
4212
4213 bool TruncOnly = false;
4214 Intrinsic::ID IID = II->getIntrinsicID();
4215 switch (IID) {
4216 case Intrinsic::vector_reduce_add:
4217 case Intrinsic::vector_reduce_mul:
4218 TruncOnly = true;
4219 break;
4220 case Intrinsic::vector_reduce_and:
4221 case Intrinsic::vector_reduce_or:
4222 case Intrinsic::vector_reduce_xor:
4223 break;
4224 default:
4225 return false;
4226 }
4227
4228 unsigned ReductionOpc = getArithmeticReductionInstruction(IID);
4229 Value *ReductionSrc = I.getOperand(0);
4230
4231 Value *Src;
4232 if (!match(ReductionSrc, m_OneUse(m_Trunc(m_Value(Src)))) &&
4233 (TruncOnly || !match(ReductionSrc, m_OneUse(m_ZExtOrSExt(m_Value(Src))))))
4234 return false;
4235
4236 auto CastOpc =
4237 (Instruction::CastOps)cast<Instruction>(ReductionSrc)->getOpcode();
4238
4239 auto *SrcTy = cast<VectorType>(Src->getType());
4240 auto *ReductionSrcTy = cast<VectorType>(ReductionSrc->getType());
4241 Type *ResultTy = I.getType();
4242
4244 ReductionOpc, ReductionSrcTy, std::nullopt, CostKind);
4245 OldCost += TTI.getCastInstrCost(CastOpc, ReductionSrcTy, SrcTy,
4247 cast<CastInst>(ReductionSrc));
4248 InstructionCost NewCost =
4249 TTI.getArithmeticReductionCost(ReductionOpc, SrcTy, std::nullopt,
4250 CostKind) +
4251 TTI.getCastInstrCost(CastOpc, ResultTy, ReductionSrcTy->getScalarType(),
4253
4254 if (OldCost <= NewCost || !NewCost.isValid())
4255 return false;
4256
4257 Value *NewReduction = Builder.CreateIntrinsic(SrcTy->getScalarType(),
4258 II->getIntrinsicID(), {Src});
4259 Value *NewCast = Builder.CreateCast(CastOpc, NewReduction, ResultTy);
4260 replaceValue(I, *NewCast);
4261 return true;
4262}
4263
4264/// Fold:
4265/// icmp pred (reduce.{add,or,and,umax,umin}(signbit_extract(x))), C
4266/// into:
4267/// icmp sgt/slt (reduce.{or,umax,and,umin}(x)), -1/0
4268///
4269/// Sign-bit reductions produce values with known semantics:
4270/// - reduce.{or,umax}: 0 if no element is negative, 1 if any is
4271/// - reduce.{and,umin}: 1 if all elements are negative, 0 if any isn't
4272/// - reduce.add: count of negative elements (0 to NumElts)
4273///
4274/// Both lshr and ashr are supported:
4275/// - lshr produces 0 or 1, so reduce.add range is [0, N]
4276/// - ashr produces 0 or -1, so reduce.add range is [-N, 0]
4277///
4278/// The fold generalizes to multiple source vectors combined with the same
4279/// operation as the reduction. For example:
4280/// reduce.or(or(shr A, shr B)) conceptually extends the vector
4281/// For reduce.add, this changes the count to M*N where M is the number of
4282/// source vectors.
4283///
4284/// We transform to a direct sign check on the original vector using
4285/// reduce.{or,umax} or reduce.{and,umin}.
4286///
4287/// In spirit, it's similar to foldSignBitCheck in InstCombine.
4288bool VectorCombine::foldSignBitReductionCmp(Instruction &I) {
4289 CmpPredicate Pred;
4290 IntrinsicInst *ReduceOp;
4291 const APInt *CmpVal;
4292 if (!match(&I,
4293 m_ICmp(Pred, m_OneUse(m_AnyIntrinsic(ReduceOp)), m_APInt(CmpVal))))
4294 return false;
4295
4296 Intrinsic::ID OrigIID = ReduceOp->getIntrinsicID();
4297 switch (OrigIID) {
4298 case Intrinsic::vector_reduce_or:
4299 case Intrinsic::vector_reduce_umax:
4300 case Intrinsic::vector_reduce_and:
4301 case Intrinsic::vector_reduce_umin:
4302 case Intrinsic::vector_reduce_add:
4303 break;
4304 default:
4305 return false;
4306 }
4307
4308 Value *ReductionSrc = ReduceOp->getArgOperand(0);
4309 auto *VecTy = dyn_cast<FixedVectorType>(ReductionSrc->getType());
4310 if (!VecTy)
4311 return false;
4312
4313 unsigned BitWidth = VecTy->getScalarSizeInBits();
4314 if (BitWidth == 1)
4315 return false;
4316
4317 unsigned NumElts = VecTy->getNumElements();
4318
4319 // Determine the expected tree opcode for multi-vector patterns.
4320 // The tree opcode must match the reduction's underlying operation.
4321 //
4322 // TODO: for pairs of equivalent operators, we should match both,
4323 // not only the most common.
4324 Instruction::BinaryOps TreeOpcode;
4325 switch (OrigIID) {
4326 case Intrinsic::vector_reduce_or:
4327 case Intrinsic::vector_reduce_umax:
4328 TreeOpcode = Instruction::Or;
4329 break;
4330 case Intrinsic::vector_reduce_and:
4331 case Intrinsic::vector_reduce_umin:
4332 TreeOpcode = Instruction::And;
4333 break;
4334 case Intrinsic::vector_reduce_add:
4335 TreeOpcode = Instruction::Add;
4336 break;
4337 default:
4338 llvm_unreachable("Unexpected intrinsic");
4339 }
4340
4341 // Collect sign-bit extraction leaves from an associative tree of TreeOpcode.
4342 // The tree conceptually extends the vector being reduced.
4343 SmallVector<Value *, 8> Worklist;
4344 SmallVector<Value *, 8> Sources; // Original vectors (X in shr X, BW-1)
4345 Worklist.push_back(ReductionSrc);
4346 std::optional<bool> IsAShr;
4347 constexpr unsigned MaxSources = 8;
4348
4349 // Calculate old cost: all shifts + tree ops + reduction
4350 InstructionCost OldCost = TTI.getInstructionCost(ReduceOp, CostKind);
4351
4352 while (!Worklist.empty() && Worklist.size() <= MaxSources &&
4353 Sources.size() <= MaxSources) {
4354 Value *V = Worklist.pop_back_val();
4355
4356 // Try to match sign-bit extraction: shr X, (bitwidth-1)
4357 Value *X;
4358 if (match(V, m_OneUse(m_Shr(m_Value(X), m_SpecificInt(BitWidth - 1))))) {
4359 auto *Shr = cast<Instruction>(V);
4360
4361 // All shifts must be the same type (all lshr or all ashr)
4362 bool ThisIsAShr = Shr->getOpcode() == Instruction::AShr;
4363 if (!IsAShr)
4364 IsAShr = ThisIsAShr;
4365 else if (*IsAShr != ThisIsAShr)
4366 return false;
4367
4368 Sources.push_back(X);
4369
4370 // As part of the fold, we remove all of the shifts, so we need to keep
4371 // track of their costs.
4372 OldCost += TTI.getInstructionCost(Shr, CostKind);
4373
4374 continue;
4375 }
4376
4377 // Try to extend through a tree node of the expected opcode
4378 Value *A, *B;
4379 if (!match(V, m_OneUse(m_BinOp(TreeOpcode, m_Value(A), m_Value(B)))))
4380 return false;
4381
4382 // We are potentially replacing these operations as well, so we add them
4383 // to the costs.
4385
4386 Worklist.push_back(A);
4387 Worklist.push_back(B);
4388 }
4389
4390 // Must have at least one source and not exceed limit
4391 if (Sources.empty() || Sources.size() > MaxSources ||
4392 Worklist.size() > MaxSources || !IsAShr)
4393 return false;
4394
4395 unsigned NumSources = Sources.size();
4396
4397 // For reduce.add, the total count must fit as a signed integer.
4398 // Range is [0, M*N] for lshr or [-M*N, 0] for ashr.
4399 if (OrigIID == Intrinsic::vector_reduce_add &&
4400 !isIntN(BitWidth, NumSources * NumElts))
4401 return false;
4402
4403 // Compute the boundary value when all elements are negative:
4404 // - Per-element contribution: 1 for lshr, -1 for ashr
4405 // - For add: M*N (total elements across all sources); for others: just 1
4406 unsigned Count =
4407 (OrigIID == Intrinsic::vector_reduce_add) ? NumSources * NumElts : 1;
4408 APInt NegativeVal(CmpVal->getBitWidth(), Count);
4409 if (*IsAShr)
4410 NegativeVal.negate();
4411
4412 // Range is [min(0, AllNegVal), max(0, AllNegVal)]
4413 APInt Zero = APInt::getZero(CmpVal->getBitWidth());
4414 APInt RangeLow = APIntOps::smin(Zero, NegativeVal);
4415 APInt RangeHigh = APIntOps::smax(Zero, NegativeVal);
4416
4417 // Determine comparison semantics:
4418 // - IsEq: true for equality test, false for inequality
4419 // - TestsNegative: true if testing against AllNegVal, false for zero
4420 //
4421 // In addition to EQ/NE against 0 or AllNegVal, we support inequalities
4422 // that fold to boundary tests given the narrow value range:
4423 // < RangeHigh -> != RangeHigh
4424 // > RangeHigh-1 -> == RangeHigh
4425 // > RangeLow -> != RangeLow
4426 // < RangeLow+1 -> == RangeLow
4427 //
4428 // For inequalities, we work with signed predicates only. Unsigned predicates
4429 // are canonicalized to signed when the range is non-negative (where they are
4430 // equivalent). When the range includes negative values, unsigned predicates
4431 // would have different semantics due to wrap-around, so we reject them.
4432 if (!ICmpInst::isEquality(Pred) && !ICmpInst::isSigned(Pred)) {
4433 if (RangeLow.isNegative())
4434 return false;
4435 Pred = ICmpInst::getSignedPredicate(Pred);
4436 }
4437
4438 bool IsEq;
4439 bool TestsNegative;
4440 if (ICmpInst::isEquality(Pred)) {
4441 if (CmpVal->isZero()) {
4442 TestsNegative = false;
4443 } else if (*CmpVal == NegativeVal) {
4444 TestsNegative = true;
4445 } else {
4446 return false;
4447 }
4448 IsEq = Pred == ICmpInst::ICMP_EQ;
4449 } else if (Pred == ICmpInst::ICMP_SLT && *CmpVal == RangeHigh) {
4450 IsEq = false;
4451 TestsNegative = (RangeHigh == NegativeVal);
4452 } else if (Pred == ICmpInst::ICMP_SGT && *CmpVal == RangeHigh - 1) {
4453 IsEq = true;
4454 TestsNegative = (RangeHigh == NegativeVal);
4455 } else if (Pred == ICmpInst::ICMP_SGT && *CmpVal == RangeLow) {
4456 IsEq = false;
4457 TestsNegative = (RangeLow == NegativeVal);
4458 } else if (Pred == ICmpInst::ICMP_SLT && *CmpVal == RangeLow + 1) {
4459 IsEq = true;
4460 TestsNegative = (RangeLow == NegativeVal);
4461 } else {
4462 return false;
4463 }
4464
4465 // For this fold we support four types of checks:
4466 //
4467 // 1. All lanes are negative - AllNeg
4468 // 2. All lanes are non-negative - AllNonNeg
4469 // 3. At least one negative lane - AnyNeg
4470 // 4. At least one non-negative lane - AnyNonNeg
4471 //
4472 // For each case, we can generate the following code:
4473 //
4474 // 1. AllNeg - reduce.and/umin(X) < 0
4475 // 2. AllNonNeg - reduce.or/umax(X) > -1
4476 // 3. AnyNeg - reduce.or/umax(X) < 0
4477 // 4. AnyNonNeg - reduce.and/umin(X) > -1
4478 //
4479 // The table below shows the aggregation of all supported cases
4480 // using these four cases.
4481 //
4482 // Reduction | == 0 | != 0 | == MAX | != MAX
4483 // ------------+-----------+-----------+-----------+-----------
4484 // or/umax | AllNonNeg | AnyNeg | AnyNeg | AllNonNeg
4485 // and/umin | AnyNonNeg | AllNeg | AllNeg | AnyNonNeg
4486 // add | AllNonNeg | AnyNeg | AllNeg | AnyNonNeg
4487 //
4488 // NOTE: MAX = 1 for or/and/umax/umin, and the vector size N for add
4489 //
4490 // For easier codegen and check inversion, we use the following encoding:
4491 //
4492 // 1. Bit-3 === requires or/umax (1) or and/umin (0) check
4493 // 2. Bit-2 === checks < 0 (1) or > -1 (0)
4494 // 3. Bit-1 === universal (1) or existential (0) check
4495 //
4496 // AnyNeg = 0b110: uses or/umax, checks negative, any-check
4497 // AllNonNeg = 0b101: uses or/umax, checks non-neg, all-check
4498 // AnyNonNeg = 0b000: uses and/umin, checks non-neg, any-check
4499 // AllNeg = 0b011: uses and/umin, checks negative, all-check
4500 //
4501 // XOR with 0b011 inverts the check (swaps all/any and neg/non-neg).
4502 //
4503 enum CheckKind : unsigned {
4504 AnyNonNeg = 0b000,
4505 AllNeg = 0b011,
4506 AllNonNeg = 0b101,
4507 AnyNeg = 0b110,
4508 };
4509 // Return true if we fold this check into or/umax and false for and/umin
4510 auto RequiresOr = [](CheckKind C) -> bool { return C & 0b100; };
4511 // Return true if we should check if result is negative and false otherwise
4512 auto IsNegativeCheck = [](CheckKind C) -> bool { return C & 0b010; };
4513 // Logically invert the check
4514 auto Invert = [](CheckKind C) { return CheckKind(C ^ 0b011); };
4515
4516 CheckKind Base;
4517 switch (OrigIID) {
4518 case Intrinsic::vector_reduce_or:
4519 case Intrinsic::vector_reduce_umax:
4520 Base = TestsNegative ? AnyNeg : AllNonNeg;
4521 break;
4522 case Intrinsic::vector_reduce_and:
4523 case Intrinsic::vector_reduce_umin:
4524 Base = TestsNegative ? AllNeg : AnyNonNeg;
4525 break;
4526 case Intrinsic::vector_reduce_add:
4527 Base = TestsNegative ? AllNeg : AllNonNeg;
4528 break;
4529 default:
4530 llvm_unreachable("Unexpected intrinsic");
4531 }
4532
4533 CheckKind Check = IsEq ? Base : Invert(Base);
4534
4535 auto PickCheaper = [&](Intrinsic::ID Arith, Intrinsic::ID MinMax) {
4536 InstructionCost ArithCost =
4538 VecTy, std::nullopt, CostKind);
4539 InstructionCost MinMaxCost =
4541 FastMathFlags(), CostKind);
4542 return ArithCost <= MinMaxCost ? std::make_pair(Arith, ArithCost)
4543 : std::make_pair(MinMax, MinMaxCost);
4544 };
4545
4546 // Choose output reduction based on encoding's MSB
4547 auto [NewIID, NewCost] = RequiresOr(Check)
4548 ? PickCheaper(Intrinsic::vector_reduce_or,
4549 Intrinsic::vector_reduce_umax)
4550 : PickCheaper(Intrinsic::vector_reduce_and,
4551 Intrinsic::vector_reduce_umin);
4552
4553 // Add cost of combining multiple sources with or/and
4554 if (NumSources > 1) {
4555 unsigned CombineOpc =
4556 RequiresOr(Check) ? Instruction::Or : Instruction::And;
4557 NewCost += TTI.getArithmeticInstrCost(CombineOpc, VecTy, CostKind) *
4558 (NumSources - 1);
4559 }
4560
4561 LLVM_DEBUG(dbgs() << "Found sign-bit reduction cmp: " << I << "\n OldCost: "
4562 << OldCost << " vs NewCost: " << NewCost << "\n");
4563
4564 if (NewCost > OldCost)
4565 return false;
4566
4567 // Generate the combined input and reduction
4568 Builder.SetInsertPoint(&I);
4569 Type *ScalarTy = VecTy->getScalarType();
4570
4571 Value *Input;
4572 if (NumSources == 1) {
4573 Input = Sources[0];
4574 } else {
4575 // Combine sources with or/and based on check type
4576 Input = RequiresOr(Check) ? Builder.CreateOr(Sources)
4577 : Builder.CreateAnd(Sources);
4578 }
4579
4580 Value *NewReduce = Builder.CreateIntrinsic(ScalarTy, NewIID, {Input});
4581 Value *NewCmp = IsNegativeCheck(Check) ? Builder.CreateIsNeg(NewReduce)
4582 : Builder.CreateIsNotNeg(NewReduce);
4583 replaceValue(I, *NewCmp);
4584 return true;
4585}
4586
4587/// vector.reduce.OP f(X_i) == 0 -> vector.reduce.OP X_i == 0
4588///
4589/// We can prove it for cases when:
4590///
4591/// 1. OP X_i == 0 <=> \forall i \in [1, N] X_i == 0
4592/// 1'. OP X_i == 0 <=> \exists j \in [1, N] X_j == 0
4593/// 2. f(x) == 0 <=> x == 0
4594///
4595/// From 1 and 2 (or 1' and 2), we can infer that
4596///
4597/// OP f(X_i) == 0 <=> OP X_i == 0.
4598///
4599/// (1)
4600/// OP f(X_i) == 0 <=> \forall i \in [1, N] f(X_i) == 0
4601/// (2)
4602/// <=> \forall i \in [1, N] X_i == 0
4603/// (1)
4604/// <=> OP(X_i) == 0
4605///
4606/// For some of the OP's and f's, we need to have domain constraints on X
4607/// to ensure properties 1 (or 1') and 2.
4608bool VectorCombine::foldICmpEqZeroVectorReduce(Instruction &I) {
4609 CmpPredicate Pred;
4610 Value *Op;
4611 if (!match(&I, m_ICmp(Pred, m_Value(Op), m_Zero())) ||
4612 !ICmpInst::isEquality(Pred))
4613 return false;
4614
4615 auto *II = dyn_cast<IntrinsicInst>(Op);
4616 if (!II)
4617 return false;
4618
4619 switch (II->getIntrinsicID()) {
4620 case Intrinsic::vector_reduce_add:
4621 case Intrinsic::vector_reduce_or:
4622 case Intrinsic::vector_reduce_umin:
4623 case Intrinsic::vector_reduce_umax:
4624 case Intrinsic::vector_reduce_smin:
4625 case Intrinsic::vector_reduce_smax:
4626 break;
4627 default:
4628 return false;
4629 }
4630
4631 Value *InnerOp = II->getArgOperand(0);
4632
4633 // TODO: fixed vector type might be too restrictive
4634 if (!II->hasOneUse() || !isa<FixedVectorType>(InnerOp->getType()))
4635 return false;
4636
4637 Value *X = nullptr;
4638
4639 // Check for zero-preserving operations where f(x) = 0 <=> x = 0
4640 //
4641 // 1. f(x) = shl nuw x, y for arbitrary y
4642 // 2. f(x) = mul nuw x, c for defined c != 0
4643 // 3. f(x) = zext x
4644 // 4. f(x) = sext x
4645 // 5. f(x) = neg x
4646 //
4647 if (!(match(InnerOp, m_NUWShl(m_Value(X), m_Value())) || // Case 1
4648 match(InnerOp, m_NUWMul(m_Value(X), m_NonZeroInt())) || // Case 2
4649 match(InnerOp, m_ZExt(m_Value(X))) || // Case 3
4650 match(InnerOp, m_SExt(m_Value(X))) || // Case 4
4651 match(InnerOp, m_Neg(m_Value(X))) // Case 5
4652 ))
4653 return false;
4654
4655 SimplifyQuery S = SQ.getWithInstruction(&I);
4656 auto *XTy = cast<FixedVectorType>(X->getType());
4657
4658 // Check for domain constraints for all supported reductions.
4659 //
4660 // a. OR X_i - has property 1 for every X
4661 // b. UMAX X_i - has property 1 for every X
4662 // c. UMIN X_i - has property 1' for every X
4663 // d. SMAX X_i - has property 1 for X >= 0
4664 // e. SMIN X_i - has property 1' for X >= 0
4665 // f. ADD X_i - has property 1 for X >= 0 && ADD X_i doesn't sign wrap
4666 //
4667 // In order for the proof to work, we need 1 (or 1') to be true for both
4668 // OP f(X_i) and OP X_i and that's why below we check constraints twice.
4669 //
4670 // NOTE: ADD X_i holds property 1 for a mirror case as well, i.e. when
4671 // X <= 0 && ADD X_i doesn't sign wrap. However, due to the nature
4672 // of known bits, we can't reasonably hold knowledge of "either 0
4673 // or negative".
4674 switch (II->getIntrinsicID()) {
4675 case Intrinsic::vector_reduce_add: {
4676 // We need to check that both X_i and f(X_i) have enough leading
4677 // zeros to not overflow.
4678 KnownBits KnownX = computeKnownBits(X, S);
4679 KnownBits KnownFX = computeKnownBits(InnerOp, S);
4680 unsigned NumElems = XTy->getNumElements();
4681 // Adding N elements loses at most ceil(log2(N)) leading bits.
4682 unsigned LostBits = Log2_32_Ceil(NumElems);
4683 unsigned LeadingZerosX = KnownX.countMinLeadingZeros();
4684 unsigned LeadingZerosFX = KnownFX.countMinLeadingZeros();
4685 // Need at least one leading zero left after summation to ensure no overflow
4686 if (LeadingZerosX <= LostBits || LeadingZerosFX <= LostBits)
4687 return false;
4688
4689 // We are not checking whether X or f(X) are positive explicitly because
4690 // we implicitly checked for it when we checked if both cases have enough
4691 // leading zeros to not wrap addition.
4692 break;
4693 }
4694 case Intrinsic::vector_reduce_smin:
4695 case Intrinsic::vector_reduce_smax:
4696 // Check whether X >= 0 and f(X) >= 0
4697 if (!isKnownNonNegative(InnerOp, S) || !isKnownNonNegative(X, S))
4698 return false;
4699
4700 break;
4701 default:
4702 break;
4703 };
4704
4705 LLVM_DEBUG(dbgs() << "Found a reduction to 0 comparison with removable op: "
4706 << *II << "\n");
4707
4708 // For zext/sext, check if the transform is profitable using cost model.
4709 // For other operations (shl, mul, neg), we're removing an instruction
4710 // while keeping the same reduction type, so it's always profitable.
4711 if (isa<ZExtInst>(InnerOp) || isa<SExtInst>(InnerOp)) {
4712 auto *FXTy = cast<FixedVectorType>(InnerOp->getType());
4713 Intrinsic::ID IID = II->getIntrinsicID();
4714
4716 cast<CastInst>(InnerOp)->getOpcode(), FXTy, XTy,
4718
4719 InstructionCost OldReduceCost, NewReduceCost;
4720 switch (IID) {
4721 case Intrinsic::vector_reduce_add:
4722 case Intrinsic::vector_reduce_or:
4723 OldReduceCost = TTI.getArithmeticReductionCost(
4724 getArithmeticReductionInstruction(IID), FXTy, std::nullopt, CostKind);
4725 NewReduceCost = TTI.getArithmeticReductionCost(
4726 getArithmeticReductionInstruction(IID), XTy, std::nullopt, CostKind);
4727 break;
4728 case Intrinsic::vector_reduce_umin:
4729 case Intrinsic::vector_reduce_umax:
4730 case Intrinsic::vector_reduce_smin:
4731 case Intrinsic::vector_reduce_smax:
4732 OldReduceCost = TTI.getMinMaxReductionCost(
4733 getMinMaxReductionIntrinsicOp(IID), FXTy, FastMathFlags(), CostKind);
4734 NewReduceCost = TTI.getMinMaxReductionCost(
4735 getMinMaxReductionIntrinsicOp(IID), XTy, FastMathFlags(), CostKind);
4736 break;
4737 default:
4738 llvm_unreachable("Unexpected reduction");
4739 }
4740
4741 InstructionCost OldCost = OldReduceCost + ExtCost;
4742 InstructionCost NewCost =
4743 NewReduceCost + (InnerOp->hasOneUse() ? 0 : ExtCost);
4744
4745 LLVM_DEBUG(dbgs() << "Found a removable extension before reduction: "
4746 << *InnerOp << "\n OldCost: " << OldCost
4747 << " vs NewCost: " << NewCost << "\n");
4748
4749 // We consider transformation to still be potentially beneficial even
4750 // when the costs are the same because we might remove a use from f(X)
4751 // and unlock other optimizations. Equal costs would just mean that we
4752 // didn't make it worse in the worst case.
4753 if (NewCost > OldCost)
4754 return false;
4755 }
4756
4757 // Since we support zext and sext as f, we might change the scalar type
4758 // of the intrinsic.
4759 Type *Ty = XTy->getScalarType();
4760 Value *NewReduce = Builder.CreateIntrinsic(Ty, II->getIntrinsicID(), {X});
4761 Value *NewCmp =
4762 Builder.CreateICmp(Pred, NewReduce, ConstantInt::getNullValue(Ty));
4763 replaceValue(I, *NewCmp);
4764 return true;
4765}
4766
4767/// Fold comparisons of reduce.or/reduce.and with reduce.umax/reduce.umin
4768/// based on cost, preserving the comparison semantics.
4769///
4770/// We use two fundamental properties for each pair:
4771///
4772/// 1. or(X) == 0 <=> umax(X) == 0
4773/// 2. or(X) == 1 <=> umax(X) == 1
4774/// 3. sign(or(X)) == sign(umax(X))
4775///
4776/// 1. and(X) == -1 <=> umin(X) == -1
4777/// 2. and(X) == -2 <=> umin(X) == -2
4778/// 3. sign(and(X)) == sign(umin(X))
4779///
4780/// From these we can infer the following transformations:
4781/// a. or(X) ==/!= 0 <-> umax(X) ==/!= 0
4782/// b. or(X) s< 0 <-> umax(X) s< 0
4783/// c. or(X) s> -1 <-> umax(X) s> -1
4784/// d. or(X) s< 1 <-> umax(X) s< 1
4785/// e. or(X) ==/!= 1 <-> umax(X) ==/!= 1
4786/// f. or(X) s< 2 <-> umax(X) s< 2
4787/// g. and(X) ==/!= -1 <-> umin(X) ==/!= -1
4788/// h. and(X) s< 0 <-> umin(X) s< 0
4789/// i. and(X) s> -1 <-> umin(X) s> -1
4790/// j. and(X) s> -2 <-> umin(X) s> -2
4791/// k. and(X) ==/!= -2 <-> umin(X) ==/!= -2
4792/// l. and(X) s> -3 <-> umin(X) s> -3
4793///
4794bool VectorCombine::foldEquivalentReductionCmp(Instruction &I) {
4795 CmpPredicate Pred;
4796 Value *ReduceOp;
4797 const APInt *CmpVal;
4798 if (!match(&I, m_ICmp(Pred, m_Value(ReduceOp), m_APInt(CmpVal))))
4799 return false;
4800
4801 auto *II = dyn_cast<IntrinsicInst>(ReduceOp);
4802 if (!II || !II->hasOneUse())
4803 return false;
4804
4805 const auto IsValidOrUmaxCmp = [&]() {
4806 // or === umax for i1
4807 if (CmpVal->getBitWidth() == 1)
4808 return true;
4809
4810 // Cases a and e
4811 bool IsEquality =
4812 (CmpVal->isZero() || CmpVal->isOne()) && ICmpInst::isEquality(Pred);
4813 // Case c
4814 bool IsPositive = CmpVal->isAllOnes() && Pred == ICmpInst::ICMP_SGT;
4815 // Cases b, d, and f
4816 bool IsNegative = (CmpVal->isZero() || CmpVal->isOne() || *CmpVal == 2) &&
4817 Pred == ICmpInst::ICMP_SLT;
4818 return IsEquality || IsPositive || IsNegative;
4819 };
4820
4821 const auto IsValidAndUminCmp = [&]() {
4822 // and === umin for i1
4823 if (CmpVal->getBitWidth() == 1)
4824 return true;
4825
4826 const auto LeadingOnes = CmpVal->countl_one();
4827
4828 // Cases g and k
4829 bool IsEquality =
4830 (CmpVal->isAllOnes() || LeadingOnes + 1 == CmpVal->getBitWidth()) &&
4832 // Case h
4833 bool IsNegative = CmpVal->isZero() && Pred == ICmpInst::ICMP_SLT;
4834 // Cases i, j, and l
4835 bool IsPositive =
4836 // if the number has at least N - 2 leading ones
4837 // and the two LSBs are:
4838 // - 1 x 1 -> -1
4839 // - 1 x 0 -> -2
4840 // - 0 x 1 -> -3
4841 LeadingOnes + 2 >= CmpVal->getBitWidth() &&
4842 ((*CmpVal)[0] || (*CmpVal)[1]) && Pred == ICmpInst::ICMP_SGT;
4843 return IsEquality || IsNegative || IsPositive;
4844 };
4845
4846 Intrinsic::ID OriginalIID = II->getIntrinsicID();
4847 Intrinsic::ID AlternativeIID;
4848
4849 // Check if this is a valid comparison pattern and determine the alternate
4850 // reduction intrinsic.
4851 switch (OriginalIID) {
4852 case Intrinsic::vector_reduce_or:
4853 if (!IsValidOrUmaxCmp())
4854 return false;
4855 AlternativeIID = Intrinsic::vector_reduce_umax;
4856 break;
4857 case Intrinsic::vector_reduce_umax:
4858 if (!IsValidOrUmaxCmp())
4859 return false;
4860 AlternativeIID = Intrinsic::vector_reduce_or;
4861 break;
4862 case Intrinsic::vector_reduce_and:
4863 if (!IsValidAndUminCmp())
4864 return false;
4865 AlternativeIID = Intrinsic::vector_reduce_umin;
4866 break;
4867 case Intrinsic::vector_reduce_umin:
4868 if (!IsValidAndUminCmp())
4869 return false;
4870 AlternativeIID = Intrinsic::vector_reduce_and;
4871 break;
4872 default:
4873 return false;
4874 }
4875
4876 Value *X = II->getArgOperand(0);
4877 auto *VecTy = dyn_cast<FixedVectorType>(X->getType());
4878 if (!VecTy)
4879 return false;
4880
4881 const auto GetReductionCost = [&](Intrinsic::ID IID) -> InstructionCost {
4882 unsigned ReductionOpc = getArithmeticReductionInstruction(IID);
4883 if (ReductionOpc != Instruction::ICmp)
4884 return TTI.getArithmeticReductionCost(ReductionOpc, VecTy, std::nullopt,
4885 CostKind);
4887 FastMathFlags(), CostKind);
4888 };
4889
4890 InstructionCost OrigCost = GetReductionCost(OriginalIID);
4891 InstructionCost AltCost = GetReductionCost(AlternativeIID);
4892
4893 LLVM_DEBUG(dbgs() << "Found equivalent reduction cmp: " << I
4894 << "\n OrigCost: " << OrigCost
4895 << " vs AltCost: " << AltCost << "\n");
4896
4897 if (AltCost >= OrigCost)
4898 return false;
4899
4900 Builder.SetInsertPoint(&I);
4901 Type *ScalarTy = VecTy->getScalarType();
4902 Value *NewReduce = Builder.CreateIntrinsic(ScalarTy, AlternativeIID, {X});
4903 Value *NewCmp =
4904 Builder.CreateICmp(Pred, NewReduce, ConstantInt::get(ScalarTy, *CmpVal));
4905
4906 replaceValue(I, *NewCmp);
4907 return true;
4908}
4909
4910/// Used by foldReduceAddCmpZero to check if we can prove that a value is
4911/// non-positive.
4912/// KnownBits cannot see sext <? x i1> as non-positive: each top bit equals a
4913/// single unknown input bit, which a per-bit lattice cannot track. The fold's
4914/// target shape is popcount-style sums of <N x i1> valid/invalid masks (e.g.
4915/// ray-intersection hits) tested for any-hit.
4916/// Previous attempts to approximate the known bits of such expressions were
4917/// using a fully recursive value tracking approach to infer a constant range
4918/// but ultimately turned to be too expensive in compile time.
4919static bool isKnownNonPositive(const Value *V, const SimplifyQuery &SQ,
4920 unsigned Depth = 0) {
4921 constexpr unsigned MaxLocalDepth = 2;
4922 if (Depth > MaxLocalDepth)
4923 return false;
4924
4925 auto NumSignBits = [&](const Value *X) {
4926 return ComputeNumSignBits(X, SQ.DL, SQ.AC, SQ.CxtI, SQ.DT);
4927 };
4928 if (NumSignBits(V) == V->getType()->getScalarSizeInBits())
4929 return true;
4930
4931 Value *A, *B;
4932 if (match(V, m_Add(m_Value(A), m_Value(B))))
4933 return NumSignBits(A) >= 2 && NumSignBits(B) >= 2 &&
4934 isKnownNonPositive(A, SQ, Depth + 1) &&
4935 isKnownNonPositive(B, SQ, Depth + 1);
4936
4937 return computeKnownBits(V, SQ).isNonPositive();
4938}
4939
4940/// Fold (icmp pred (reduce.add X), 0) to (icmp pred' (reduce.or X), 0) when X
4941/// has lanes known to all be non-negative or all non-positive, so that
4942/// sum == 0 iff every lane is 0. Falls back to reduce.umax if reduce.or is
4943/// more expensive on the target.
4944bool VectorCombine::foldReduceAddCmpZero(Instruction &I) {
4945 CmpPredicate Pred;
4946 Value *Vec;
4947 if (!match(&I, m_ICmp(Pred,
4949 m_Value(Vec))),
4950 m_Zero())))
4951 return false;
4952
4953 auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType());
4954 if (!VecTy || VecTy->getNumElements() < 2)
4955 return false;
4956
4957 SimplifyQuery Q = SQ.getWithInstruction(&I);
4958 bool IsNonNegative = isKnownNonNegative(Vec, Q);
4959 bool IsNonPositive = !IsNonNegative && isKnownNonPositive(Vec, Q);
4960 if (!IsNonNegative && !IsNonPositive)
4961 return false;
4962
4963 // Summing NumElts lanes can consume up to log2(NumElts) sign bits. Require
4964 // strictly more headroom than that so the sum cannot wrap to zero.
4965 unsigned NumElts = VecTy->getNumElements();
4966 unsigned NumSignBits = ComputeNumSignBits(Vec, *DL, SQ.AC, &I, &DT);
4967 if (Log2_32(NumElts) >= NumSignBits)
4968 return false;
4969
4970 ICmpInst::Predicate NewPred;
4971 switch (Pred) {
4972 case ICmpInst::ICMP_EQ:
4973 case ICmpInst::ICMP_ULE:
4974 case ICmpInst::ICMP_SLE:
4975 case ICmpInst::ICMP_SGE:
4976 NewPred = ICmpInst::ICMP_EQ;
4977 break;
4978 case ICmpInst::ICMP_NE:
4979 case ICmpInst::ICMP_UGT:
4980 case ICmpInst::ICMP_SGT:
4981 case ICmpInst::ICMP_SLT:
4982 NewPred = ICmpInst::ICMP_NE;
4983 break;
4984 default:
4985 return false;
4986 }
4987
4988 // SGT and SLE on a non-positive tree, and SLT and SGE on a non-negative
4989 // tree, are tautologies (always true or always false). Leave those to
4990 // InstCombine rather than mapping them here. Remaining signed inequalities
4991 // also need one extra sign bit so the sum cannot flip sign.
4992 if (!IsNonNegative &&
4993 (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLE))
4994 return false;
4995 if (!IsNonPositive &&
4996 (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE))
4997 return false;
4998 if ((Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLE ||
4999 Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) &&
5000 Log2_32(NumElts) >= NumSignBits - 1)
5001 return false;
5002
5004 Instruction::Add, VecTy, std::nullopt, CostKind);
5006 Instruction::Or, VecTy, std::nullopt, CostKind);
5008 Intrinsic::umax, VecTy, FastMathFlags(), CostKind);
5009 if (!OrCost.isValid() && !UmaxCost.isValid())
5010 return false;
5011 bool UseOr = OrCost.isValid() && (!UmaxCost.isValid() || OrCost <= UmaxCost);
5012 InstructionCost AltCost = UseOr ? OrCost : UmaxCost;
5013 if (AltCost > OrigCost)
5014 return false;
5015
5016 Builder.SetInsertPoint(&I);
5017 Value *NewReduce = UseOr ? Builder.CreateOrReduce(Vec)
5018 : Builder.CreateIntrinsic(
5019 Intrinsic::vector_reduce_umax, {VecTy}, {Vec});
5020 Worklist.pushValue(NewReduce);
5021 Value *NewCmp = Builder.CreateICmp(
5022 NewPred, NewReduce, ConstantInt::getNullValue(VecTy->getScalarType()));
5023 replaceValue(I, *NewCmp);
5024 return true;
5025}
5026
5027/// Returns true if this ShuffleVectorInst eventually feeds into a
5028/// vector reduction intrinsic (e.g., vector_reduce_add) by only following
5029/// chains of shuffles and binary operators (in any combination/order).
5030/// The search does not go deeper than the given Depth.
5032 constexpr unsigned MaxVisited = 32;
5035 bool FoundReduction = false;
5036
5037 WorkList.push_back(SVI);
5038 while (!WorkList.empty()) {
5039 Instruction *I = WorkList.pop_back_val();
5040 for (User *U : I->users()) {
5041 auto *UI = cast<Instruction>(U);
5042 if (!UI || !Visited.insert(UI).second)
5043 continue;
5044 if (Visited.size() > MaxVisited)
5045 return false;
5046 if (auto *II = dyn_cast<IntrinsicInst>(UI)) {
5047 // More than one reduction reached
5048 if (FoundReduction)
5049 return false;
5050 switch (II->getIntrinsicID()) {
5051 case Intrinsic::vector_reduce_add:
5052 case Intrinsic::vector_reduce_mul:
5053 case Intrinsic::vector_reduce_and:
5054 case Intrinsic::vector_reduce_or:
5055 case Intrinsic::vector_reduce_xor:
5056 case Intrinsic::vector_reduce_smin:
5057 case Intrinsic::vector_reduce_smax:
5058 case Intrinsic::vector_reduce_umin:
5059 case Intrinsic::vector_reduce_umax:
5060 FoundReduction = true;
5061 continue;
5062 default:
5063 return false;
5064 }
5065 }
5066
5068 return false;
5069
5070 WorkList.emplace_back(UI);
5071 }
5072 }
5073 return FoundReduction;
5074}
5075
5076/// This method looks for groups of shuffles acting on binops, of the form:
5077/// %x = shuffle ...
5078/// %y = shuffle ...
5079/// %a = binop %x, %y
5080/// %b = binop %x, %y
5081/// shuffle %a, %b, selectmask
5082/// We may, especially if the shuffle is wider than legal, be able to convert
5083/// the shuffle to a form where only parts of a and b need to be computed. On
5084/// architectures with no obvious "select" shuffle, this can reduce the total
5085/// number of operations if the target reports them as cheaper.
5086bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
5087 auto *SVI = cast<ShuffleVectorInst>(&I);
5088 auto *VT = cast<FixedVectorType>(I.getType());
5089 auto *Op0 = dyn_cast<Instruction>(SVI->getOperand(0));
5090 auto *Op1 = dyn_cast<Instruction>(SVI->getOperand(1));
5091 if (!Op0 || !Op1 || Op0 == Op1 || !Op0->isBinaryOp() || !Op1->isBinaryOp() ||
5092 VT != Op0->getType())
5093 return false;
5094
5095 auto *SVI0A = dyn_cast<Instruction>(Op0->getOperand(0));
5096 auto *SVI0B = dyn_cast<Instruction>(Op0->getOperand(1));
5097 auto *SVI1A = dyn_cast<Instruction>(Op1->getOperand(0));
5098 auto *SVI1B = dyn_cast<Instruction>(Op1->getOperand(1));
5099 SmallPtrSet<Instruction *, 4> InputShuffles({SVI0A, SVI0B, SVI1A, SVI1B});
5100 auto checkSVNonOpUses = [&](Instruction *I) {
5101 if (!I || I->getOperand(0)->getType() != VT)
5102 return true;
5103 return any_of(I->users(), [&](User *U) {
5104 return U != Op0 && U != Op1 &&
5105 !(isa<ShuffleVectorInst>(U) &&
5106 (InputShuffles.contains(cast<Instruction>(U)) ||
5107 isInstructionTriviallyDead(cast<Instruction>(U))));
5108 });
5109 };
5110 if (checkSVNonOpUses(SVI0A) || checkSVNonOpUses(SVI0B) ||
5111 checkSVNonOpUses(SVI1A) || checkSVNonOpUses(SVI1B))
5112 return false;
5113
5114 // Collect all the uses that are shuffles that we can transform together. We
5115 // may not have a single shuffle, but a group that can all be transformed
5116 // together profitably.
5118 auto collectShuffles = [&](Instruction *I) {
5119 for (auto *U : I->users()) {
5120 auto *SV = dyn_cast<ShuffleVectorInst>(U);
5121 if (!SV || SV->getType() != VT)
5122 return false;
5123 if ((SV->getOperand(0) != Op0 && SV->getOperand(0) != Op1) ||
5124 (SV->getOperand(1) != Op0 && SV->getOperand(1) != Op1))
5125 return false;
5126 if (!llvm::is_contained(Shuffles, SV))
5127 Shuffles.push_back(SV);
5128 }
5129 return true;
5130 };
5131 if (!collectShuffles(Op0) || !collectShuffles(Op1))
5132 return false;
5133 // From a reduction, we need to be processing a single shuffle, otherwise the
5134 // other uses will not be lane-invariant.
5135 if (FromReduction && Shuffles.size() > 1)
5136 return false;
5137
5138 // Add any shuffle uses for the shuffles we have found, to include them in our
5139 // cost calculations.
5140 if (!FromReduction) {
5141 for (size_t Idx = 0, E = Shuffles.size(); Idx != E; ++Idx) {
5142 for (auto *U : Shuffles[Idx]->users()) {
5143 ShuffleVectorInst *SSV = dyn_cast<ShuffleVectorInst>(U);
5144 if (SSV && isa<UndefValue>(SSV->getOperand(1)) && SSV->getType() == VT)
5145 Shuffles.push_back(SSV);
5146 }
5147 }
5148 }
5149
5150 // For each of the output shuffles, we try to sort all the first vector
5151 // elements to the beginning, followed by the second array elements at the
5152 // end. If the binops are legalized to smaller vectors, this may reduce total
5153 // number of binops. We compute the ReconstructMask mask needed to convert
5154 // back to the original lane order.
5156 SmallVector<SmallVector<int>> OrigReconstructMasks;
5157 int MaxV1Elt = 0, MaxV2Elt = 0;
5158 unsigned NumElts = VT->getNumElements();
5159 for (ShuffleVectorInst *SVN : Shuffles) {
5160 SmallVector<int> Mask;
5161 SVN->getShuffleMask(Mask);
5162
5163 // Check the operands are the same as the original, or reversed (in which
5164 // case we need to commute the mask).
5165 Value *SVOp0 = SVN->getOperand(0);
5166 Value *SVOp1 = SVN->getOperand(1);
5167 if (isa<UndefValue>(SVOp1)) {
5168 auto *SSV = cast<ShuffleVectorInst>(SVOp0);
5169 SVOp0 = SSV->getOperand(0);
5170 SVOp1 = SSV->getOperand(1);
5171 for (int &Elem : Mask) {
5172 if (Elem >= static_cast<int>(SSV->getShuffleMask().size()))
5173 return false;
5174 Elem = Elem < 0 ? Elem : SSV->getMaskValue(Elem);
5175 }
5176 }
5177 if (SVOp0 == Op1 && SVOp1 == Op0) {
5178 std::swap(SVOp0, SVOp1);
5180 }
5181 if (SVOp0 != Op0 || SVOp1 != Op1)
5182 return false;
5183
5184 // Calculate the reconstruction mask for this shuffle, as the mask needed to
5185 // take the packed values from Op0/Op1 and reconstructing to the original
5186 // order.
5187 SmallVector<int> ReconstructMask;
5188 for (unsigned I = 0; I < Mask.size(); I++) {
5189 if (Mask[I] < 0) {
5190 ReconstructMask.push_back(-1);
5191 } else if (Mask[I] < static_cast<int>(NumElts)) {
5192 MaxV1Elt = std::max(MaxV1Elt, Mask[I]);
5193 auto It = find_if(V1, [&](const std::pair<int, int> &A) {
5194 return Mask[I] == A.first;
5195 });
5196 if (It != V1.end())
5197 ReconstructMask.push_back(It - V1.begin());
5198 else {
5199 ReconstructMask.push_back(V1.size());
5200 V1.emplace_back(Mask[I], V1.size());
5201 }
5202 } else {
5203 MaxV2Elt = std::max<int>(MaxV2Elt, Mask[I] - NumElts);
5204 auto It = find_if(V2, [&](const std::pair<int, int> &A) {
5205 return Mask[I] - static_cast<int>(NumElts) == A.first;
5206 });
5207 if (It != V2.end())
5208 ReconstructMask.push_back(NumElts + It - V2.begin());
5209 else {
5210 ReconstructMask.push_back(NumElts + V2.size());
5211 V2.emplace_back(Mask[I] - NumElts, NumElts + V2.size());
5212 }
5213 }
5214 }
5215
5216 // For reductions, we know that the lane ordering out doesn't alter the
5217 // result. In-order can help simplify the shuffle away.
5218 if (FromReduction)
5219 sort(ReconstructMask);
5220 OrigReconstructMasks.push_back(std::move(ReconstructMask));
5221 }
5222
5223 // If the Maximum element used from V1 and V2 are not larger than the new
5224 // vectors, the vectors are already packes and performing the optimization
5225 // again will likely not help any further. This also prevents us from getting
5226 // stuck in a cycle in case the costs do not also rule it out.
5227 if (V1.empty() || V2.empty() ||
5228 (MaxV1Elt == static_cast<int>(V1.size()) - 1 &&
5229 MaxV2Elt == static_cast<int>(V2.size()) - 1))
5230 return false;
5231
5232 // GetBaseMaskValue takes one of the inputs, which may either be a shuffle, a
5233 // shuffle of another shuffle, or not a shuffle (that is treated like a
5234 // identity shuffle).
5235 auto GetBaseMaskValue = [&](Instruction *I, int M) {
5236 auto *SV = dyn_cast<ShuffleVectorInst>(I);
5237 if (!SV)
5238 return M;
5239 if (isa<UndefValue>(SV->getOperand(1)))
5240 if (auto *SSV = dyn_cast<ShuffleVectorInst>(SV->getOperand(0)))
5241 if (InputShuffles.contains(SSV))
5242 return SSV->getMaskValue(SV->getMaskValue(M));
5243 return SV->getMaskValue(M);
5244 };
5245
5246 // Attempt to sort the inputs my ascending mask values to make simpler input
5247 // shuffles and push complex shuffles down to the uses. We sort on the first
5248 // of the two input shuffle orders, to try and get at least one input into a
5249 // nice order.
5250 auto SortBase = [&](Instruction *A, std::pair<int, int> X,
5251 std::pair<int, int> Y) {
5252 int MXA = GetBaseMaskValue(A, X.first);
5253 int MYA = GetBaseMaskValue(A, Y.first);
5254 return MXA < MYA;
5255 };
5256 stable_sort(V1, [&](std::pair<int, int> A, std::pair<int, int> B) {
5257 return SortBase(SVI0A, A, B);
5258 });
5259 stable_sort(V2, [&](std::pair<int, int> A, std::pair<int, int> B) {
5260 return SortBase(SVI1A, A, B);
5261 });
5262 // Calculate our ReconstructMasks from the OrigReconstructMasks and the
5263 // modified order of the input shuffles.
5264 SmallVector<SmallVector<int>> ReconstructMasks;
5265 for (const auto &Mask : OrigReconstructMasks) {
5266 SmallVector<int> ReconstructMask;
5267 for (int M : Mask) {
5268 auto FindIndex = [](const SmallVector<std::pair<int, int>> &V, int M) {
5269 auto It = find_if(V, [M](auto A) { return A.second == M; });
5270 assert(It != V.end() && "Expected all entries in Mask");
5271 return std::distance(V.begin(), It);
5272 };
5273 if (M < 0)
5274 ReconstructMask.push_back(-1);
5275 else if (M < static_cast<int>(NumElts)) {
5276 ReconstructMask.push_back(FindIndex(V1, M));
5277 } else {
5278 ReconstructMask.push_back(NumElts + FindIndex(V2, M));
5279 }
5280 }
5281 ReconstructMasks.push_back(std::move(ReconstructMask));
5282 }
5283
5284 // Calculate the masks needed for the new input shuffles, which get padded
5285 // with undef
5286 SmallVector<int> V1A, V1B, V2A, V2B;
5287 for (unsigned I = 0; I < V1.size(); I++) {
5288 V1A.push_back(GetBaseMaskValue(SVI0A, V1[I].first));
5289 V1B.push_back(GetBaseMaskValue(SVI0B, V1[I].first));
5290 }
5291 for (unsigned I = 0; I < V2.size(); I++) {
5292 V2A.push_back(GetBaseMaskValue(SVI1A, V2[I].first));
5293 V2B.push_back(GetBaseMaskValue(SVI1B, V2[I].first));
5294 }
5295 while (V1A.size() < NumElts) {
5298 }
5299 while (V2A.size() < NumElts) {
5302 }
5303
5304 auto AddShuffleCost = [&](InstructionCost C, Instruction *I) {
5305 auto *SV = dyn_cast<ShuffleVectorInst>(I);
5306 if (!SV)
5307 return C;
5308 return C + TTI.getShuffleCost(isa<UndefValue>(SV->getOperand(1))
5311 VT, VT, SV->getShuffleMask(), CostKind);
5312 };
5313 auto AddShuffleMaskCost = [&](InstructionCost C, ArrayRef<int> Mask) {
5314 return C +
5316 };
5317
5318 unsigned ElementSize = VT->getElementType()->getPrimitiveSizeInBits();
5319 unsigned MaxVectorSize =
5321 unsigned MaxElementsInVector = MaxVectorSize / ElementSize;
5322 if (MaxElementsInVector == 0)
5323 return false;
5324 // When there are multiple shufflevector operations on the same input,
5325 // especially when the vector length is larger than the register size,
5326 // identical shuffle patterns may occur across different groups of elements.
5327 // To avoid overestimating the cost by counting these repeated shuffles more
5328 // than once, we only account for unique shuffle patterns. This adjustment
5329 // prevents inflated costs in the cost model for wide vectors split into
5330 // several register-sized groups.
5331 std::set<SmallVector<int, 4>> UniqueShuffles;
5332 auto AddShuffleMaskAdjustedCost = [&](InstructionCost C, ArrayRef<int> Mask) {
5333 // Compute the cost for performing the shuffle over the full vector.
5334 auto ShuffleCost =
5336 unsigned NumFullVectors = Mask.size() / MaxElementsInVector;
5337 if (NumFullVectors < 2)
5338 return C + ShuffleCost;
5339 SmallVector<int, 4> SubShuffle(MaxElementsInVector);
5340 unsigned NumUniqueGroups = 0;
5341 unsigned NumGroups = Mask.size() / MaxElementsInVector;
5342 // For each group of MaxElementsInVector contiguous elements,
5343 // collect their shuffle pattern and insert into the set of unique patterns.
5344 for (unsigned I = 0; I < NumFullVectors; ++I) {
5345 for (unsigned J = 0; J < MaxElementsInVector; ++J)
5346 SubShuffle[J] = Mask[MaxElementsInVector * I + J];
5347 if (UniqueShuffles.insert(SubShuffle).second)
5348 NumUniqueGroups += 1;
5349 }
5350 return C + ShuffleCost * NumUniqueGroups / NumGroups;
5351 };
5352 auto AddShuffleAdjustedCost = [&](InstructionCost C, Instruction *I) {
5353 auto *SV = dyn_cast<ShuffleVectorInst>(I);
5354 if (!SV)
5355 return C;
5356 SmallVector<int, 16> Mask;
5357 SV->getShuffleMask(Mask);
5358 return AddShuffleMaskAdjustedCost(C, Mask);
5359 };
5360 // Check that input consists of ShuffleVectors applied to the same input
5361 auto AllShufflesHaveSameOperands =
5362 [](SmallPtrSetImpl<Instruction *> &InputShuffles) {
5363 if (InputShuffles.size() < 2)
5364 return false;
5365 ShuffleVectorInst *FirstSV =
5366 dyn_cast<ShuffleVectorInst>(*InputShuffles.begin());
5367 if (!FirstSV)
5368 return false;
5369
5370 Value *In0 = FirstSV->getOperand(0), *In1 = FirstSV->getOperand(1);
5371 return std::all_of(
5372 std::next(InputShuffles.begin()), InputShuffles.end(),
5373 [&](Instruction *I) {
5374 ShuffleVectorInst *SV = dyn_cast<ShuffleVectorInst>(I);
5375 return SV && SV->getOperand(0) == In0 && SV->getOperand(1) == In1;
5376 });
5377 };
5378
5379 // Get the costs of the shuffles + binops before and after with the new
5380 // shuffle masks.
5381 InstructionCost CostBefore =
5382 TTI.getArithmeticInstrCost(Op0->getOpcode(), VT, CostKind) +
5383 TTI.getArithmeticInstrCost(Op1->getOpcode(), VT, CostKind);
5384 CostBefore += std::accumulate(Shuffles.begin(), Shuffles.end(),
5385 InstructionCost(0), AddShuffleCost);
5386 if (AllShufflesHaveSameOperands(InputShuffles)) {
5387 UniqueShuffles.clear();
5388 CostBefore += std::accumulate(InputShuffles.begin(), InputShuffles.end(),
5389 InstructionCost(0), AddShuffleAdjustedCost);
5390 } else {
5391 CostBefore += std::accumulate(InputShuffles.begin(), InputShuffles.end(),
5392 InstructionCost(0), AddShuffleCost);
5393 }
5394
5395 // The new binops will be unused for lanes past the used shuffle lengths.
5396 // These types attempt to get the correct cost for that from the target.
5397 FixedVectorType *Op0SmallVT =
5398 FixedVectorType::get(VT->getScalarType(), V1.size());
5399 FixedVectorType *Op1SmallVT =
5400 FixedVectorType::get(VT->getScalarType(), V2.size());
5401 InstructionCost CostAfter =
5402 TTI.getArithmeticInstrCost(Op0->getOpcode(), Op0SmallVT, CostKind) +
5403 TTI.getArithmeticInstrCost(Op1->getOpcode(), Op1SmallVT, CostKind);
5404 UniqueShuffles.clear();
5405 CostAfter += std::accumulate(ReconstructMasks.begin(), ReconstructMasks.end(),
5406 InstructionCost(0), AddShuffleMaskAdjustedCost);
5407 std::set<SmallVector<int>> OutputShuffleMasks({V1A, V1B, V2A, V2B});
5408 CostAfter +=
5409 std::accumulate(OutputShuffleMasks.begin(), OutputShuffleMasks.end(),
5410 InstructionCost(0), AddShuffleMaskCost);
5411
5412 LLVM_DEBUG(dbgs() << "Found a binop select shuffle pattern: " << I << "\n");
5413 LLVM_DEBUG(dbgs() << " CostBefore: " << CostBefore
5414 << " vs CostAfter: " << CostAfter << "\n");
5415 if (CostBefore < CostAfter ||
5416 (CostBefore == CostAfter && !feedsIntoVectorReduction(SVI)))
5417 return false;
5418
5419 // The cost model has passed, create the new instructions.
5420 auto GetShuffleOperand = [&](Instruction *I, unsigned Op) -> Value * {
5421 auto *SV = dyn_cast<ShuffleVectorInst>(I);
5422 if (!SV)
5423 return I;
5424 if (isa<UndefValue>(SV->getOperand(1)))
5425 if (auto *SSV = dyn_cast<ShuffleVectorInst>(SV->getOperand(0)))
5426 if (InputShuffles.contains(SSV))
5427 return SSV->getOperand(Op);
5428 return SV->getOperand(Op);
5429 };
5430 Builder.SetInsertPoint(*SVI0A->getInsertionPointAfterDef());
5431 Value *NSV0A = Builder.CreateShuffleVector(GetShuffleOperand(SVI0A, 0),
5432 GetShuffleOperand(SVI0A, 1), V1A);
5433 Builder.SetInsertPoint(*SVI0B->getInsertionPointAfterDef());
5434 Value *NSV0B = Builder.CreateShuffleVector(GetShuffleOperand(SVI0B, 0),
5435 GetShuffleOperand(SVI0B, 1), V1B);
5436 Builder.SetInsertPoint(*SVI1A->getInsertionPointAfterDef());
5437 Value *NSV1A = Builder.CreateShuffleVector(GetShuffleOperand(SVI1A, 0),
5438 GetShuffleOperand(SVI1A, 1), V2A);
5439 Builder.SetInsertPoint(*SVI1B->getInsertionPointAfterDef());
5440 Value *NSV1B = Builder.CreateShuffleVector(GetShuffleOperand(SVI1B, 0),
5441 GetShuffleOperand(SVI1B, 1), V2B);
5442 Builder.SetInsertPoint(Op0);
5443 Value *NOp0 = Builder.CreateBinOp((Instruction::BinaryOps)Op0->getOpcode(),
5444 NSV0A, NSV0B);
5445 if (auto *I = dyn_cast<Instruction>(NOp0))
5446 I->copyIRFlags(Op0, true);
5447 Builder.SetInsertPoint(Op1);
5448 Value *NOp1 = Builder.CreateBinOp((Instruction::BinaryOps)Op1->getOpcode(),
5449 NSV1A, NSV1B);
5450 if (auto *I = dyn_cast<Instruction>(NOp1))
5451 I->copyIRFlags(Op1, true);
5452
5453 for (int S = 0, E = ReconstructMasks.size(); S != E; S++) {
5454 Builder.SetInsertPoint(Shuffles[S]);
5455 Value *NSV = Builder.CreateShuffleVector(NOp0, NOp1, ReconstructMasks[S]);
5456 replaceValue(*Shuffles[S], *NSV, false);
5457 }
5458
5459 Worklist.pushValue(NSV0A);
5460 Worklist.pushValue(NSV0B);
5461 Worklist.pushValue(NSV1A);
5462 Worklist.pushValue(NSV1B);
5463 return true;
5464}
5465
5466/// Check if instruction depends on ZExt and this ZExt can be moved after the
5467/// instruction. Move ZExt if it is profitable. For example:
5468/// logic(zext(x),y) -> zext(logic(x,trunc(y)))
5469/// lshr((zext(x),y) -> zext(lshr(x,trunc(y)))
5470/// Cost model calculations takes into account if zext(x) has other users and
5471/// whether it can be propagated through them too.
5472bool VectorCombine::shrinkType(Instruction &I) {
5473 Value *ZExted, *OtherOperand;
5474 if (!match(&I, m_c_BitwiseLogic(m_ZExt(m_Value(ZExted)),
5475 m_Value(OtherOperand))) &&
5476 !match(&I, m_LShr(m_ZExt(m_Value(ZExted)), m_Value(OtherOperand))))
5477 return false;
5478
5479 Value *ZExtOperand = I.getOperand(I.getOperand(0) == OtherOperand ? 1 : 0);
5480
5481 auto *BigTy = cast<FixedVectorType>(I.getType());
5482 auto *SmallTy = cast<FixedVectorType>(ZExted->getType());
5483 unsigned BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
5484
5485 if (I.getOpcode() == Instruction::LShr) {
5486 // Check that the shift amount is less than the number of bits in the
5487 // smaller type. Otherwise, the smaller lshr will return a poison value.
5488 KnownBits ShAmtKB = computeKnownBits(I.getOperand(1), *DL);
5489 if (ShAmtKB.getMaxValue().uge(BW))
5490 return false;
5491 } else {
5492 // Check that the expression overall uses at most the same number of bits as
5493 // ZExted
5494 KnownBits KB = computeKnownBits(&I, *DL);
5495 if (KB.countMaxActiveBits() > BW)
5496 return false;
5497 }
5498
5499 // Calculate costs of leaving current IR as it is and moving ZExt operation
5500 // later, along with adding truncates if needed
5502 Instruction::ZExt, BigTy, SmallTy,
5503 TargetTransformInfo::CastContextHint::None, CostKind);
5504 InstructionCost CurrentCost = ZExtCost;
5505 InstructionCost ShrinkCost = 0;
5506
5507 // Calculate total cost and check that we can propagate through all ZExt users
5508 for (User *U : ZExtOperand->users()) {
5509 auto *UI = cast<Instruction>(U);
5510 if (UI == &I) {
5511 CurrentCost +=
5512 TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind);
5513 ShrinkCost +=
5514 TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind);
5515 ShrinkCost += ZExtCost;
5516 continue;
5517 }
5518
5519 if (!Instruction::isBinaryOp(UI->getOpcode()))
5520 return false;
5521
5522 // Check if we can propagate ZExt through its other users
5523 KnownBits KB = computeKnownBits(UI, *DL);
5524 if (KB.countMaxActiveBits() > BW)
5525 return false;
5526
5527 CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind);
5528 ShrinkCost +=
5529 TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind);
5530 ShrinkCost += ZExtCost;
5531 }
5532
5533 // If the other instruction operand is not a constant, we'll need to
5534 // generate a truncate instruction. So we have to adjust cost
5535 if (!isa<Constant>(OtherOperand))
5536 ShrinkCost += TTI.getCastInstrCost(
5537 Instruction::Trunc, SmallTy, BigTy,
5538 TargetTransformInfo::CastContextHint::None, CostKind);
5539
5540 // If the cost of shrinking types and leaving the IR is the same, we'll lean
5541 // towards modifying the IR because shrinking opens opportunities for other
5542 // shrinking optimisations.
5543 if (ShrinkCost > CurrentCost)
5544 return false;
5545
5546 Builder.SetInsertPoint(&I);
5547 Value *Op0 = ZExted;
5548 Value *Op1 = Builder.CreateTrunc(OtherOperand, SmallTy);
5549 // Keep the order of operands the same
5550 if (I.getOperand(0) == OtherOperand)
5551 std::swap(Op0, Op1);
5552 Value *NewBinOp =
5553 Builder.CreateBinOp((Instruction::BinaryOps)I.getOpcode(), Op0, Op1);
5554 cast<Instruction>(NewBinOp)->copyIRFlags(&I);
5555 cast<Instruction>(NewBinOp)->copyMetadata(I);
5556 Value *NewZExtr = Builder.CreateZExt(NewBinOp, BigTy);
5557 replaceValue(I, *NewZExtr);
5558 return true;
5559}
5560
5561/// insert (DstVec, (extract SrcVec, ExtIdx), InsIdx) -->
5562/// shuffle (DstVec, SrcVec, Mask)
5563bool VectorCombine::foldInsExtVectorToShuffle(Instruction &I) {
5564 Value *DstVec, *SrcVec;
5565 uint64_t ExtIdx, InsIdx;
5566 if (!match(&I,
5567 m_InsertElt(m_Value(DstVec),
5568 m_ExtractElt(m_Value(SrcVec), m_ConstantInt(ExtIdx)),
5569 m_ConstantInt(InsIdx))))
5570 return false;
5571
5572 auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType());
5573 auto *SrcVecTy = dyn_cast<FixedVectorType>(SrcVec->getType());
5574 // We can try combining vectors with different element sizes.
5575 if (!DstVecTy || !SrcVecTy ||
5576 SrcVecTy->getElementType() != DstVecTy->getElementType())
5577 return false;
5578
5579 unsigned NumDstElts = DstVecTy->getNumElements();
5580 unsigned NumSrcElts = SrcVecTy->getNumElements();
5581 if (InsIdx >= NumDstElts || ExtIdx >= NumSrcElts || NumDstElts == 1)
5582 return false;
5583
5584 // Insertion into poison is a cheaper single operand shuffle.
5586 SmallVector<int> Mask(NumDstElts, PoisonMaskElem);
5587
5588 bool NeedExpOrNarrow = NumSrcElts != NumDstElts;
5589 bool NeedDstSrcSwap = isa<PoisonValue>(DstVec) && !isa<UndefValue>(SrcVec);
5590 if (NeedDstSrcSwap) {
5592 Mask[InsIdx] = ExtIdx % NumDstElts;
5593 std::swap(DstVec, SrcVec);
5594 } else {
5596 std::iota(Mask.begin(), Mask.end(), 0);
5597 Mask[InsIdx] = (ExtIdx % NumDstElts) + NumDstElts;
5598 }
5599
5600 // Cost
5601 auto *Ins = cast<InsertElementInst>(&I);
5602 auto *Ext = cast<ExtractElementInst>(I.getOperand(1));
5603 InstructionCost InsCost =
5604 TTI.getVectorInstrCost(*Ins, DstVecTy, CostKind, InsIdx);
5605 InstructionCost ExtCost =
5606 TTI.getVectorInstrCost(*Ext, DstVecTy, CostKind, ExtIdx);
5607 InstructionCost OldCost = ExtCost + InsCost;
5608
5609 InstructionCost NewCost = 0;
5610 SmallVector<int> ExtToVecMask;
5611 if (!NeedExpOrNarrow) {
5612 // Ignore 'free' identity insertion shuffle.
5613 // TODO: getShuffleCost should return TCC_Free for Identity shuffles.
5614 if (!ShuffleVectorInst::isIdentityMask(Mask, NumSrcElts))
5615 NewCost += TTI.getShuffleCost(SK, DstVecTy, DstVecTy, Mask, CostKind, 0,
5616 nullptr, {DstVec, SrcVec});
5617 } else {
5618 // When creating a length-changing-vector, always try to keep the relevant
5619 // element in an equivalent position, so that bulk shuffles are more likely
5620 // to be useful.
5621 ExtToVecMask.assign(NumDstElts, PoisonMaskElem);
5622 ExtToVecMask[ExtIdx % NumDstElts] = ExtIdx;
5623 // Add cost for expanding or narrowing
5625 DstVecTy, SrcVecTy, ExtToVecMask, CostKind);
5626 NewCost += TTI.getShuffleCost(SK, DstVecTy, DstVecTy, Mask, CostKind);
5627 }
5628
5629 if (!Ext->hasOneUse())
5630 NewCost += ExtCost;
5631
5632 LLVM_DEBUG(dbgs() << "Found a insert/extract shuffle-like pair: " << I
5633 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
5634 << "\n");
5635
5636 if (OldCost < NewCost)
5637 return false;
5638
5639 if (NeedExpOrNarrow) {
5640 if (!NeedDstSrcSwap)
5641 SrcVec = Builder.CreateShuffleVector(SrcVec, ExtToVecMask);
5642 else
5643 DstVec = Builder.CreateShuffleVector(DstVec, ExtToVecMask);
5644 }
5645
5646 // Canonicalize undef param to RHS to help further folds.
5647 if (isa<UndefValue>(DstVec) && !isa<UndefValue>(SrcVec)) {
5648 ShuffleVectorInst::commuteShuffleMask(Mask, NumDstElts);
5649 std::swap(DstVec, SrcVec);
5650 }
5651
5652 Value *Shuf = Builder.CreateShuffleVector(DstVec, SrcVec, Mask);
5653 replaceValue(I, *Shuf);
5654
5655 return true;
5656}
5657
5658/// If we're interleaving 2 constant splats, for instance `<vscale x 8 x i32>
5659/// <splat of 666>` and `<vscale x 8 x i32> <splat of 777>`, we can create a
5660/// larger splat `<vscale x 8 x i64> <splat of ((777 << 32) | 666)>` first
5661/// before casting it back into `<vscale x 16 x i32>`.
5662bool VectorCombine::foldInterleaveIntrinsics(Instruction &I) {
5663 const APInt *SplatVal0, *SplatVal1;
5665 m_APInt(SplatVal0), m_APInt(SplatVal1))))
5666 return false;
5667
5668 LLVM_DEBUG(dbgs() << "VC: Folding interleave2 with two splats: " << I
5669 << "\n");
5670
5671 auto *VTy =
5672 cast<VectorType>(cast<IntrinsicInst>(I).getArgOperand(0)->getType());
5673 auto *ExtVTy = VectorType::getExtendedElementVectorType(VTy);
5674 unsigned Width = VTy->getElementType()->getIntegerBitWidth();
5675
5676 // Just in case the cost of interleave2 intrinsic and bitcast are both
5677 // invalid, in which case we want to bail out, we use <= rather
5678 // than < here. Even they both have valid and equal costs, it's probably
5679 // not a good idea to emit a high-cost constant splat.
5681 TTI.getCastInstrCost(Instruction::BitCast, I.getType(), ExtVTy,
5683 LLVM_DEBUG(dbgs() << "VC: The cost to cast from " << *ExtVTy << " to "
5684 << *I.getType() << " is too high.\n");
5685 return false;
5686 }
5687
5688 APInt NewSplatVal = SplatVal1->zext(Width * 2);
5689 NewSplatVal <<= Width;
5690 NewSplatVal |= SplatVal0->zext(Width * 2);
5691 auto *NewSplat = ConstantVector::getSplat(
5692 ExtVTy->getElementCount(), ConstantInt::get(F.getContext(), NewSplatVal));
5693
5694 IRBuilder<> Builder(&I);
5695 replaceValue(I, *Builder.CreateBitCast(NewSplat, I.getType()));
5696 return true;
5697}
5698
5699/// Given this sequence:
5700/// ```
5701/// %d = llvm.vector.deinterleave2 <vscale x 16 x i32> %v
5702/// %f0 = extractvalue { <vscale x 8 x i32>, <vscale x 8 x i32> } %d, 0
5703/// %f1 = extractvalue { <vscale x 8 x i32>, <vscale x 8 x i32> } %d, 1
5704///
5705/// %low0 = and <vscale x 8 x i32> %f0, splat (i32 65535)
5706/// %low1 = shl <vscale x 8 x i32> %f1, splat (i32 16)
5707/// %merge0 = or disjoint <vscale x 8 x i32> %low0, %low1
5708///
5709/// %high0 = and <vscale x 8 x i32> %f1, splat (i32 -65536)
5710/// %high1 = lshr <vscale x 8 x i32> %f0, splat (i32 16)
5711/// %merge1 = or disjoint <vscale x 8 x i32> %high0, %high1
5712/// ```
5713/// It is actually just de-interleaving a 16-bit vector with double the
5714/// vector length. More generally speaking, it's de-interleaving on a vector
5715/// with half the element width as the original vector.
5716///
5717/// Therefore, we can turn it into:
5718/// ```
5719/// %narrow.v = bitcast <vscale x 16 x i32> %v to <vscale x 32 x i16>
5720/// %d = llvm.vector.deinterleave2 <vscale x 32 x i16> %narrow.v
5721/// %f0 = extractvalue { <vscale x 16 x i16>, <vscale x 16 x i16> } %d, 0
5722/// %f1 = extractvalue { <vscale x 16 x i16>, <vscale x 16 x i16> } %d, 1
5723///
5724/// %merge0 = bitcast <vscale x 16 x i16> %f0 to <vscale x 8 x i32>
5725/// %merge1 = bitcast <vscale x 16 x i16> %f1 to <vscale x 8 x i32>
5726/// ```
5727bool VectorCombine::foldDeinterleaveIntrinsics(Instruction &I) {
5728 // This pattern involves bitcast that is not compatible with big endian.
5729 if (DL->isBigEndian())
5730 return false;
5731
5732 using namespace PatternMatch;
5733 Value *DeinterleavedVal;
5734 if (!match(&I, m_Deinterleave2(m_Value(DeinterleavedVal))))
5735 return false;
5736
5737 VectorType *VecTy = cast<VectorType>(DeinterleavedVal->getType());
5738 IntegerType *ElementTy = dyn_cast<IntegerType>(VecTy->getElementType());
5739 if (!ElementTy)
5740 return false;
5741 unsigned ElementWidth = ElementTy->getBitWidth();
5742 if (ElementWidth < 2 || !isPowerOf2_32(ElementWidth))
5743 return false;
5744 unsigned HalfElementWidth = ElementWidth / 2;
5745
5746 if (!I.hasNUses(2))
5747 return false;
5748 std::array<ExtractValueInst *, 2> OrigFields{};
5749 for (User *Usr : I.users()) {
5750 auto *E = dyn_cast<ExtractValueInst>(Usr);
5751 // The deinterleave result can only be used by extractions.
5752 if (!E || E->getNumIndices() != 1)
5753 return false;
5754 unsigned Idx = *E->idx_begin();
5755 // A single field cannot be extracted more than once.
5756 if (Idx >= 2 || OrigFields[Idx] || !E->hasNUses(2))
5757 return false;
5758 OrigFields[Idx] = E;
5759 }
5760
5761 // Find the merge instruction (i.e. OR) first.
5762 SmallVector<Instruction *, 2> MergeInsts;
5763 for (auto *FieldUsr : OrigFields[0]->users()) {
5764 if (!FieldUsr->hasOneUse() || !isa<Instruction>(FieldUsr->user_back()))
5765 return false;
5766 MergeInsts.push_back(cast<Instruction>(FieldUsr->user_back()));
5767 }
5768 assert(MergeInsts.size() == 2);
5769
5770 // Pattern match bottom-up from the merge instructions.
5771 auto MatchMerge = [&](void) -> bool {
5772 APInt LoMask = APInt::getLowBitsSet(ElementWidth, HalfElementWidth);
5773 APInt HiMask = APInt::getHighBitsSet(ElementWidth, HalfElementWidth);
5774 return match(MergeInsts[0],
5775 m_c_Or(m_And(m_Specific(OrigFields[0]), m_SpecificInt(LoMask)),
5776 m_Shl(m_Specific(OrigFields[1]),
5777 m_SpecificInt(HalfElementWidth)))) &&
5778 match(MergeInsts[1],
5779 m_c_Or(m_And(m_Specific(OrigFields[1]), m_SpecificInt(HiMask)),
5780 m_LShr(m_Specific(OrigFields[0]),
5781 m_SpecificInt(HalfElementWidth))));
5782 };
5783 if (!MatchMerge()) {
5784 std::swap(MergeInsts[0], MergeInsts[1]);
5785 if (!MatchMerge())
5786 return false;
5787 }
5788
5789 // Profitability check.
5790 InstructionCost OldCost =
5791 TTI.getInstructionCost(MergeInsts[0], CostKind) +
5792 TTI.getInstructionCost(cast<Instruction>(MergeInsts[0]->getOperand(0)),
5793 CostKind) +
5794 TTI.getInstructionCost(cast<Instruction>(MergeInsts[0]->getOperand(1)),
5795 CostKind);
5796 // There are two fields (assuming SHL has the same cost as LSHR).
5797 OldCost *= 2;
5798
5799 auto *NewFieldTy = VecTy->getWithNewBitWidth(HalfElementWidth);
5800 auto *NewVecTy =
5801 VectorType::getDoubleElementsVectorType(cast<VectorType>(NewFieldTy));
5802 InstructionCost NewCost =
5803 TTI.getCastInstrCost(Instruction::BitCast, VecTy, NewVecTy,
5805 TTI.getCastInstrCost(Instruction::BitCast, NewFieldTy,
5806 MergeInsts[0]->getType(), TTI::CastContextHint::None,
5807 CostKind) *
5808 2;
5809 if (OldCost <= NewCost || !NewCost.isValid()) {
5810 LLVM_DEBUG(
5811 dbgs() << "VC: New deinterleave2 sequence cost (" << NewCost << ")"
5812 << " is higher than that of the old one (" << OldCost << ")\n");
5813 return false;
5814 }
5815
5816 // Do the replacement.
5817 IRBuilder<> Builder(&I);
5818 Value *NewVecCast = Builder.CreateBitCast(DeinterleavedVal, NewVecTy);
5819 Value *NewDeinterleave = Builder.CreateIntrinsic(
5820 Intrinsic::vector_deinterleave2, {NewVecTy}, {NewVecCast});
5821 for (auto [Idx, MergeInst] : enumerate(MergeInsts)) {
5822 Value *NewField = Builder.CreateExtractValue(NewDeinterleave, Idx);
5823 NewField = Builder.CreateBitCast(NewField, MergeInst->getType());
5824 replaceValue(*MergeInst, *NewField);
5825 }
5826
5827 return true;
5828}
5829
5830bool VectorCombine::foldBitcastOfVPLoad(Instruction &I) {
5831 const DataLayout &DL = I.getDataLayout();
5832 auto *Cast = dyn_cast<CastInst>(&I);
5833 if (!Cast || !Cast->isNoopCast(DL) || !isa<VectorType>(Cast->getDestTy()))
5834 return false;
5835
5836 // Fold away bit casts of the loaded value by loading the desired type,
5837 // if the mask is all-ones.
5838 Value *EVL;
5839 auto *II = dyn_cast<VPIntrinsic>(I.getOperand(0));
5841 m_Value(), m_AllOnes(), m_Value(EVL)))))
5842 return false;
5843
5844 VectorType *OrigVecTy = cast<VectorType>(II->getType());
5845 Align OrigAlign =
5846 DL.getValueOrABITypeAlignment(II->getPointerAlignment(), OrigVecTy);
5847 ElementCount OrigVecCnt = OrigVecTy->getElementCount();
5848 VectorType *NewVecTy = cast<VectorType>(Cast->getDestTy());
5849 ElementCount NewVecCnt = NewVecTy->getElementCount();
5850
5851 // Right now we only support cases where the NewVec is longer, because for
5852 // cases where it's shorter, we have to be sure that EVL can be exactly
5853 // divided, otherwise it might yield incorrect results or even page faults
5854 // (if we round-up during the division).
5855 if (!(OrigVecCnt.isScalable() == NewVecCnt.isScalable() &&
5856 NewVecCnt.hasKnownScalarFactor(OrigVecCnt)))
5857 return false;
5858
5859 InstructionCost OldCost =
5860 TTI.getMemIntrinsicInstrCost({Intrinsic::vp_load, OrigVecTy,
5861 II->getMemoryPointerParam(), false,
5862 OrigAlign},
5863 CostKind) +
5864 TTI.getCastInstrCost(Instruction::BitCast, Cast->getType(), OrigVecTy,
5867 {Intrinsic::vp_load, NewVecTy, II->getMemoryPointerParam(), false,
5868 OrigAlign},
5869 CostKind);
5870 LLVM_DEBUG(dbgs() << "foldBitcastOfVPLoad: OldCost=" << OldCost
5871 << " NewCost=" << NewCost << "\n");
5872 if (NewCost > OldCost || !NewCost.isValid())
5873 return false;
5874
5875 unsigned Factor = NewVecCnt.getKnownScalarFactor(OrigVecCnt);
5876 Value *NewEVL = Builder.CreateNUWMul(EVL, Builder.getInt32(Factor));
5877 Value *NewMask = Builder.CreateVectorSplat(NewVecCnt, Builder.getTrue());
5878 CallInst *NewVP =
5879 Builder.CreateIntrinsic(NewVecTy, Intrinsic::vp_load,
5880 {II->getMemoryPointerParam(), NewMask, NewEVL});
5881 // Preserve the original alignment.
5882 NewVP->addParamAttrs(
5883 0, AttrBuilder(II->getContext()).addAlignmentAttr(OrigAlign));
5884 replaceValue(*Cast, *NewVP);
5885 return true;
5886}
5887
5888// Attempt to shrink loads that are only used by shufflevector instructions.
5889bool VectorCombine::shrinkLoadForShuffles(Instruction &I) {
5890 auto *OldLoad = dyn_cast<LoadInst>(&I);
5891 if (!OldLoad || !OldLoad->isSimple())
5892 return false;
5893
5894 auto *OldLoadTy = dyn_cast<FixedVectorType>(OldLoad->getType());
5895 if (!OldLoadTy)
5896 return false;
5897
5898 unsigned const OldNumElements = OldLoadTy->getNumElements();
5899
5900 // Search all uses of load. If all uses are shufflevector instructions, and
5901 // the second operands are all poison values, find the minimum and maximum
5902 // indices of the vector elements referenced by all shuffle masks.
5903 // Otherwise return `std::nullopt`.
5904 using IndexRange = std::pair<int, int>;
5905 auto GetIndexRangeInShuffles = [&]() -> std::optional<IndexRange> {
5906 IndexRange OutputRange = IndexRange(OldNumElements, -1);
5907 for (llvm::Use &Use : I.uses()) {
5908 // Ensure all uses match the required pattern.
5909 User *Shuffle = Use.getUser();
5910 ArrayRef<int> Mask;
5911
5912 if (!match(Shuffle,
5913 m_Shuffle(m_Specific(OldLoad), m_Undef(), m_Mask(Mask))))
5914 return std::nullopt;
5915
5916 // Ignore shufflevector instructions that have no uses.
5917 if (Shuffle->use_empty())
5918 continue;
5919
5920 // Find the min and max indices used by the shufflevector instruction.
5921 for (int Index : Mask) {
5922 if (Index >= 0 && Index < static_cast<int>(OldNumElements)) {
5923 OutputRange.first = std::min(Index, OutputRange.first);
5924 OutputRange.second = std::max(Index, OutputRange.second);
5925 }
5926 }
5927 }
5928
5929 if (OutputRange.second < OutputRange.first)
5930 return std::nullopt;
5931
5932 return OutputRange;
5933 };
5934
5935 // Get the range of vector elements used by shufflevector instructions.
5936 if (std::optional<IndexRange> Indices = GetIndexRangeInShuffles()) {
5937 unsigned const NewNumElements = Indices->second + 1u;
5938
5939 // If the range of vector elements is smaller than the full load, attempt
5940 // to create a smaller load.
5941 if (NewNumElements < OldNumElements) {
5942 IRBuilder Builder(&I);
5943 Builder.SetCurrentDebugLocation(I.getDebugLoc());
5944
5945 // Calculate costs of old and new ops.
5946 Type *ElemTy = OldLoadTy->getElementType();
5947 FixedVectorType *NewLoadTy = FixedVectorType::get(ElemTy, NewNumElements);
5948 Value *PtrOp = OldLoad->getPointerOperand();
5949
5951 Instruction::Load, OldLoad->getType(), OldLoad->getAlign(),
5952 OldLoad->getPointerAddressSpace(), CostKind);
5953 InstructionCost NewCost =
5954 TTI.getMemoryOpCost(Instruction::Load, NewLoadTy, OldLoad->getAlign(),
5955 OldLoad->getPointerAddressSpace(), CostKind);
5956
5957 using UseEntry = std::pair<ShuffleVectorInst *, std::vector<int>>;
5959 unsigned const MaxIndex = NewNumElements * 2u;
5960
5961 for (llvm::Use &Use : I.uses()) {
5962 auto *Shuffle = cast<ShuffleVectorInst>(Use.getUser());
5963
5964 // Ignore shufflevector instructions that have no uses.
5965 if (Shuffle->use_empty())
5966 continue;
5967
5968 ArrayRef<int> OldMask = Shuffle->getShuffleMask();
5969
5970 // Create entry for new use.
5971 NewUses.push_back({Shuffle, OldMask});
5972
5973 // Validate mask indices.
5974 for (int Index : OldMask) {
5975 if (Index >= static_cast<int>(MaxIndex))
5976 return false;
5977 }
5978
5979 // Update costs.
5980 OldCost +=
5982 OldLoadTy, OldMask, CostKind);
5983 NewCost +=
5985 NewLoadTy, OldMask, CostKind);
5986 }
5987
5988 LLVM_DEBUG(
5989 dbgs() << "Found a load used only by shufflevector instructions: "
5990 << I << "\n OldCost: " << OldCost
5991 << " vs NewCost: " << NewCost << "\n");
5992
5993 if (OldCost < NewCost || !NewCost.isValid())
5994 return false;
5995
5996 // Create new load of smaller vector.
5997 auto *NewLoad = cast<LoadInst>(
5998 Builder.CreateAlignedLoad(NewLoadTy, PtrOp, OldLoad->getAlign()));
5999 NewLoad->copyMetadata(I);
6000
6001 // Replace all uses.
6002 for (UseEntry &Use : NewUses) {
6003 ShuffleVectorInst *Shuffle = Use.first;
6004 std::vector<int> &NewMask = Use.second;
6005
6006 Builder.SetInsertPoint(Shuffle);
6007 Builder.SetCurrentDebugLocation(Shuffle->getDebugLoc());
6008 Value *NewShuffle = Builder.CreateShuffleVector(
6009 NewLoad, PoisonValue::get(NewLoadTy), NewMask);
6010
6011 replaceValue(*Shuffle, *NewShuffle, false);
6012 }
6013
6014 return true;
6015 }
6016 }
6017 return false;
6018}
6019
6020// Attempt to narrow a phi of shufflevector instructions where the two incoming
6021// values have the same operands but different masks. If the two shuffle masks
6022// are offsets of one another we can use one branch to rotate the incoming
6023// vector and perform one larger shuffle after the phi.
6024bool VectorCombine::shrinkPhiOfShuffles(Instruction &I) {
6025 auto *Phi = dyn_cast<PHINode>(&I);
6026 if (!Phi || Phi->getNumIncomingValues() != 2u)
6027 return false;
6028
6029 Value *Op = nullptr;
6030 ArrayRef<int> Mask0;
6031 ArrayRef<int> Mask1;
6032
6033 if (!match(Phi->getOperand(0u),
6034 m_OneUse(m_Shuffle(m_Value(Op), m_Poison(), m_Mask(Mask0)))) ||
6035 !match(Phi->getOperand(1u),
6036 m_OneUse(m_Shuffle(m_Specific(Op), m_Poison(), m_Mask(Mask1)))))
6037 return false;
6038
6039 auto *Shuf = cast<ShuffleVectorInst>(Phi->getOperand(0u));
6040
6041 // Ensure result vectors are wider than the argument vector.
6042 auto *InputVT = cast<FixedVectorType>(Op->getType());
6043 auto *ResultVT = cast<FixedVectorType>(Shuf->getType());
6044 auto const InputNumElements = InputVT->getNumElements();
6045
6046 if (InputNumElements >= ResultVT->getNumElements())
6047 return false;
6048
6049 // Take the difference of the two shuffle masks at each index. Ignore poison
6050 // values at the same index in both masks.
6051 SmallVector<int, 16> NewMask;
6052 NewMask.reserve(Mask0.size());
6053
6054 for (auto [M0, M1] : zip(Mask0, Mask1)) {
6055 if (M0 >= 0 && M1 >= 0)
6056 NewMask.push_back(M0 - M1);
6057 else if (M0 == -1 && M1 == -1)
6058 continue;
6059 else
6060 return false;
6061 }
6062
6063 // Ensure all elements of the new mask are equal. If the difference between
6064 // the incoming mask elements is the same, the two must be constant offsets
6065 // of one another.
6066 if (NewMask.empty() || !all_equal(NewMask))
6067 return false;
6068
6069 // Create new mask using difference of the two incoming masks.
6070 int MaskOffset = NewMask[0u];
6071 unsigned Index = (InputNumElements + MaskOffset) % InputNumElements;
6072 NewMask.clear();
6073
6074 for (unsigned I = 0u; I < InputNumElements; ++I) {
6075 NewMask.push_back(Index);
6076 Index = (Index + 1u) % InputNumElements;
6077 }
6078
6079 // Calculate costs for worst cases and compare.
6080 auto const Kind = TTI::SK_PermuteSingleSrc;
6081 auto OldCost =
6082 std::max(TTI.getShuffleCost(Kind, ResultVT, InputVT, Mask0, CostKind),
6083 TTI.getShuffleCost(Kind, ResultVT, InputVT, Mask1, CostKind));
6084 auto NewCost = TTI.getShuffleCost(Kind, InputVT, InputVT, NewMask, CostKind) +
6085 TTI.getShuffleCost(Kind, ResultVT, InputVT, Mask1, CostKind);
6086
6087 LLVM_DEBUG(dbgs() << "Found a phi of mergeable shuffles: " << I
6088 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
6089 << "\n");
6090
6091 if (NewCost > OldCost)
6092 return false;
6093
6094 // Create new shuffles and narrowed phi.
6095 auto Builder = IRBuilder(Shuf);
6096 Builder.SetCurrentDebugLocation(Shuf->getDebugLoc());
6097 auto *PoisonVal = PoisonValue::get(InputVT);
6098 auto *NewShuf0 = Builder.CreateShuffleVector(Op, PoisonVal, NewMask);
6099 Worklist.push(cast<Instruction>(NewShuf0));
6100
6101 Builder.SetInsertPoint(Phi);
6102 Builder.SetCurrentDebugLocation(Phi->getDebugLoc());
6103 auto *NewPhi = Builder.CreatePHI(NewShuf0->getType(), 2u);
6104 NewPhi->addIncoming(NewShuf0, Phi->getIncomingBlock(0u));
6105 NewPhi->addIncoming(Op, Phi->getIncomingBlock(1u));
6106
6107 Builder.SetInsertPoint(*NewPhi->getInsertionPointAfterDef());
6108 PoisonVal = PoisonValue::get(NewPhi->getType());
6109 auto *NewShuf1 = Builder.CreateShuffleVector(NewPhi, PoisonVal, Mask1);
6110
6111 replaceValue(*Phi, *NewShuf1);
6112 return true;
6113}
6114
6115/// This is the entry point for all transforms. Pass manager differences are
6116/// handled in the callers of this function.
6117bool VectorCombine::run() {
6119 return false;
6120
6121 // Don't attempt vectorization if the target does not support vectors.
6122 if (!TTI.getNumberOfRegisters(TTI.getRegisterClassForType(/*Vector*/ true)))
6123 return false;
6124
6125 LLVM_DEBUG(dbgs() << "\n\nVECTORCOMBINE on " << F.getName() << "\n");
6126
6127 auto FoldInst = [this](Instruction &I) {
6128 Builder.SetInsertPoint(&I);
6129 bool IsVectorType = isa<VectorType>(I.getType());
6130 bool IsFixedVectorType = isa<FixedVectorType>(I.getType());
6131 auto Opcode = I.getOpcode();
6132
6133 LLVM_DEBUG(dbgs() << "VC: Visiting: " << I << '\n');
6134
6135 // These folds should be beneficial regardless of when this pass is run
6136 // in the optimization pipeline.
6137 // The type checking is for run-time efficiency. We can avoid wasting time
6138 // dispatching to folding functions if there's no chance of matching.
6139 if (IsFixedVectorType) {
6140 switch (Opcode) {
6141 case Instruction::InsertElement:
6142 if (vectorizeLoadInsert(I))
6143 return true;
6144 break;
6145 case Instruction::ShuffleVector:
6146 if (widenSubvectorLoad(I))
6147 return true;
6148 break;
6149 default:
6150 break;
6151 }
6152 }
6153
6154 // This transform works with scalable and fixed vectors
6155 // TODO: Identify and allow other scalable transforms
6156 if (IsVectorType) {
6157 if (scalarizeOpOrCmp(I))
6158 return true;
6159 if (scalarizeLoad(I))
6160 return true;
6161 if (scalarizeExtExtract(I))
6162 return true;
6163 if (scalarizeVPIntrinsic(I))
6164 return true;
6165 if (foldInterleaveIntrinsics(I))
6166 return true;
6167 if (foldBitcastOfVPLoad(I))
6168 return true;
6169 }
6170
6171 if (foldDeinterleaveIntrinsics(I))
6172 return true;
6173
6174 if (Opcode == Instruction::Store)
6175 if (foldSingleElementStore(I))
6176 return true;
6177
6178 // If this is an early pipeline invocation of this pass, we are done.
6179 if (TryEarlyFoldsOnly)
6180 return false;
6181
6182 // Otherwise, try folds that improve codegen but may interfere with
6183 // early IR canonicalizations.
6184 // The type checking is for run-time efficiency. We can avoid wasting time
6185 // dispatching to folding functions if there's no chance of matching.
6186 if (IsFixedVectorType) {
6187 switch (Opcode) {
6188 case Instruction::InsertElement:
6189 if (foldInsExtFNeg(I))
6190 return true;
6191 if (foldInsExtBinop(I))
6192 return true;
6193 if (foldInsExtVectorToShuffle(I))
6194 return true;
6195 break;
6196 case Instruction::ShuffleVector:
6197 if (foldPermuteOfBinops(I))
6198 return true;
6199 if (foldShuffleOfBinops(I))
6200 return true;
6201 if (foldShuffleOfSelects(I))
6202 return true;
6203 if (foldShuffleOfCastops(I))
6204 return true;
6205 if (foldShuffleOfShuffles(I))
6206 return true;
6207 if (foldPermuteOfIntrinsic(I))
6208 return true;
6209 if (foldShufflesOfLengthChangingShuffles(I))
6210 return true;
6211 if (foldShuffleOfIntrinsics(I))
6212 return true;
6213 if (foldSelectShuffle(I))
6214 return true;
6215 if (foldShuffleToIdentity(I))
6216 return true;
6217 break;
6218 case Instruction::Load:
6219 if (shrinkLoadForShuffles(I))
6220 return true;
6221 break;
6222 case Instruction::BitCast:
6223 if (foldBitcastShuffle(I))
6224 return true;
6225 if (foldSelectsFromBitcast(I))
6226 return true;
6227 break;
6228 case Instruction::And:
6229 case Instruction::Or:
6230 case Instruction::Xor:
6231 if (foldBitOpOfCastops(I))
6232 return true;
6233 if (foldBitOpOfCastConstant(I))
6234 return true;
6235 break;
6236 case Instruction::PHI:
6237 if (shrinkPhiOfShuffles(I))
6238 return true;
6239 break;
6240 default:
6241 if (shrinkType(I))
6242 return true;
6243 break;
6244 }
6245 } else {
6246 switch (Opcode) {
6247 case Instruction::Call:
6248 if (foldShuffleFromReductions(I))
6249 return true;
6250 if (foldCastFromReductions(I))
6251 return true;
6252 break;
6253 case Instruction::ExtractElement:
6254 if (foldShuffleChainsToReduce(I))
6255 return true;
6256 break;
6257 case Instruction::ICmp:
6258 if (foldSignBitReductionCmp(I))
6259 return true;
6260 if (foldICmpEqZeroVectorReduce(I))
6261 return true;
6262 if (foldEquivalentReductionCmp(I))
6263 return true;
6264 if (foldReduceAddCmpZero(I))
6265 return true;
6266 [[fallthrough]];
6267 case Instruction::FCmp:
6268 if (foldExtractExtract(I))
6269 return true;
6270 break;
6271 case Instruction::Or:
6272 if (foldConcatOfBoolMasks(I))
6273 return true;
6274 [[fallthrough]];
6275 default:
6276 if (Instruction::isBinaryOp(Opcode)) {
6277 if (foldExtractExtract(I))
6278 return true;
6279 if (foldExtractedCmps(I))
6280 return true;
6281 if (foldBinopOfReductions(I))
6282 return true;
6283 }
6284 break;
6285 }
6286 }
6287 return false;
6288 };
6289
6290 bool MadeChange = false;
6291 for (BasicBlock &BB : F) {
6292 // Ignore unreachable basic blocks.
6293 if (!DT.isReachableFromEntry(&BB))
6294 continue;
6295 // Use early increment range so that we can erase instructions in loop.
6296 // make_early_inc_range is not applicable here, as the next iterator may
6297 // be invalidated by RecursivelyDeleteTriviallyDeadInstructions.
6298 // We manually maintain the next instruction and update it when it is about
6299 // to be deleted.
6300 Instruction *I = &BB.front();
6301 while (I) {
6302 NextInst = I->getNextNode();
6303 if (!I->isDebugOrPseudoInst())
6304 MadeChange |= FoldInst(*I);
6305 I = NextInst;
6306 }
6307 }
6308
6309 NextInst = nullptr;
6310
6311 while (!Worklist.isEmpty()) {
6312 Instruction *I = Worklist.removeOne();
6313 if (!I)
6314 continue;
6315
6318 continue;
6319 }
6320
6321 MadeChange |= FoldInst(*I);
6322 }
6323
6324 return MadeChange;
6325}
6326
6329 auto &AC = FAM.getResult<AssumptionAnalysis>(F);
6331 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
6332 AAResults &AA = FAM.getResult<AAManager>(F);
6333 const DataLayout *DL = &F.getDataLayout();
6334 VectorCombine Combiner(F, TTI, DT, AA, AC, DL, TTI::TCK_RecipThroughput,
6335 TryEarlyFoldsOnly);
6336 if (!Combiner.run())
6337 return PreservedAnalyses::all();
6340 return PA;
6341}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static cl::opt< unsigned > MaxInstrsToScan("aggressive-instcombine-max-scan-instrs", cl::init(64), cl::Hidden, cl::desc("Max number of instructions to scan for aggressive instcombine."))
This is the interface for LLVM's primary stateless and local alias analysis.
#define X(NUM, ENUM, NAME)
Definition ELF.h:853
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static cl::opt< OutputCostKind > CostKind("cost-kind", cl::desc("Target cost kind"), cl::init(OutputCostKind::RecipThroughput), cl::values(clEnumValN(OutputCostKind::RecipThroughput, "throughput", "Reciprocal throughput"), clEnumValN(OutputCostKind::Latency, "latency", "Instruction latency"), clEnumValN(OutputCostKind::CodeSize, "code-size", "Code size"), clEnumValN(OutputCostKind::SizeAndLatency, "size-latency", "Code size and latency"), clEnumValN(OutputCostKind::All, "all", "Print all cost kinds")))
static cl::opt< IntrinsicCostStrategy > IntrinsicCost("intrinsic-cost-strategy", cl::desc("Costing strategy for intrinsic instructions"), cl::init(IntrinsicCostStrategy::InstructionCost), cl::values(clEnumValN(IntrinsicCostStrategy::InstructionCost, "instruction-cost", "Use TargetTransformInfo::getInstructionCost"), clEnumValN(IntrinsicCostStrategy::IntrinsicCost, "intrinsic-cost", "Use TargetTransformInfo::getIntrinsicInstrCost"), clEnumValN(IntrinsicCostStrategy::TypeBasedIntrinsicCost, "type-based-intrinsic-cost", "Calculate the intrinsic cost based only on argument types")))
This file defines the DenseMap class.
#define Check(C,...)
This is the interface for a simple mod/ref and alias analysis over globals.
Hexagon Common GEP
iv users
Definition IVUsers.cpp:48
static Value * getOpcode(Value &V, Type &Ty, InstrumentationConfig &IConf, InstrumentorIRBuilderTy &IIRB)
const size_t AbstractManglingParser< Derived, Alloc >::NumOps
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
static void eraseInstruction(Instruction &I, ICFLoopSafetyInfo &SafetyInfo, MemorySSAUpdater &MSSAU)
Definition LICM.cpp:1457
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
#define T1
MachineInstr unsigned OpIdx
uint64_t IntrinsicInst * II
FunctionAnalysisManager FAM
const SmallVectorImpl< MachineOperand > & Cond
unsigned OpIndex
This file contains some templates that are useful if you are working with the STL at all.
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
#define LLVM_DEBUG(...)
Definition Debug.h:119
static TableGen::Emitter::Opt Y("gen-skeleton-entry", EmitSkeleton, "Generate example skeleton entry")
static SymbolRef::Type getType(const Symbol *Sym)
Definition TapiFile.cpp:39
This pass exposes codegen information to IR-level passes.
static bool isFreeConcat(ArrayRef< InstLane > Item, TTI::TargetCostKind CostKind, const TargetTransformInfo &TTI)
Detect concat of multiple values into a vector.
static void analyzeCostOfVecReduction(const IntrinsicInst &II, TTI::TargetCostKind CostKind, const TargetTransformInfo &TTI, InstructionCost &CostBeforeReduction, InstructionCost &CostAfterReduction)
static SmallVector< InstLane > generateInstLaneVectorFromOperand(ArrayRef< InstLane > Item, int Op)
static Value * createShiftShuffle(Value *Vec, unsigned OldIndex, unsigned NewIndex, IRBuilderBase &Builder)
Create a shuffle that translates (shifts) 1 element from the input vector to a new element location.
static Value * generateNewInstTree(ArrayRef< InstLane > Item, Use *From, FixedVectorType *Ty, const DenseSet< std::pair< Value *, Use * > > &IdentityLeafs, const DenseSet< std::pair< Value *, Use * > > &SplatLeafs, const DenseSet< std::pair< Value *, Use * > > &ConcatLeafs, IRBuilderBase &Builder, const TargetTransformInfo *TTI)
std::pair< Value *, int > InstLane
static bool isKnownNonPositive(const Value *V, const SimplifyQuery &SQ, unsigned Depth=0)
Used by foldReduceAddCmpZero to check if we can prove that a value is non-positive.
static Align computeAlignmentAfterScalarization(Align VectorAlignment, Type *ScalarType, Value *Idx, const DataLayout &DL)
The memory operation on a vector of ScalarType had alignment of VectorAlignment.
static bool feedsIntoVectorReduction(ShuffleVectorInst *SVI)
Returns true if this ShuffleVectorInst eventually feeds into a vector reduction intrinsic (e....
static cl::opt< bool > DisableVectorCombine("disable-vector-combine", cl::init(false), cl::Hidden, cl::desc("Disable all vector combine transforms"))
static bool canWidenLoad(LoadInst *Load, const TargetTransformInfo &TTI)
static const unsigned InvalidIndex
static Value * translateExtract(ExtractElementInst *ExtElt, unsigned NewIndex, IRBuilderBase &Builder)
Given an extract element instruction with constant index operand, shuffle the source vector (shift th...
static ScalarizationResult canScalarizeAccess(VectorType *VecTy, Value *Idx, const SimplifyQuery &SQ)
Check if it is legal to scalarize a memory access to VecTy at index Idx.
static cl::opt< unsigned > MaxInstrsToScan("vector-combine-max-scan-instrs", cl::init(30), cl::Hidden, cl::desc("Max number of instructions to scan for vector combining."))
static cl::opt< bool > DisableBinopExtractShuffle("disable-binop-extract-shuffle", cl::init(false), cl::Hidden, cl::desc("Disable binop extract to shuffle transforms"))
static InstLane lookThroughShuffles(Value *V, int Lane)
static bool isMemModifiedBetween(BasicBlock::iterator Begin, BasicBlock::iterator End, const MemoryLocation &Loc, AAResults &AA)
static constexpr int Concat[]
Value * RHS
Value * LHS
A manager for alias analyses.
Class for arbitrary precision integers.
Definition APInt.h:78
LLVM_ABI APInt zext(unsigned width) const
Zero extend to a new width.
Definition APInt.cpp:1055
bool isAllOnes() const
Determine if all bits are set. This is true for zero-width values.
Definition APInt.h:372
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition APInt.h:381
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition APInt.h:1511
bool isNegative() const
Determine sign of this APInt.
Definition APInt.h:330
unsigned countl_one() const
Count the number of leading one bits.
Definition APInt.h:1638
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition APInt.h:307
static APInt getHighBitsSet(unsigned numBits, unsigned hiBitsSet)
Constructs an APInt value that has the top hiBitsSet bits set.
Definition APInt.h:297
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition APInt.h:201
bool isOne() const
Determine if this is a value of 1.
Definition APInt.h:390
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition APInt.h:240
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition APInt.h:1228
Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
const T & front() const
Get the first element.
Definition ArrayRef.h:144
size_t size() const
Get the array size.
Definition ArrayRef.h:141
A function analysis which provides an AssumptionCache.
A cache of @llvm.assume calls within a function.
LLVM_ABI bool hasAttribute(Attribute::AttrKind Kind) const
Return true if the attribute exists in this set.
InstListType::iterator iterator
Instruction iterators...
Definition BasicBlock.h:170
BinaryOps getOpcode() const
Definition InstrTypes.h:409
Represents analyses that only rely on functions' control flow.
Definition Analysis.h:73
Value * getArgOperand(unsigned i) const
iterator_range< User::op_iterator > args()
Iteration adapter for range-for loops.
void addParamAttrs(unsigned ArgNo, const AttrBuilder &B)
Adds attributes to the indicated argument.
static LLVM_ABI CastInst * Create(Instruction::CastOps, Value *S, Type *Ty, const Twine &Name="", InsertPosition InsertBefore=nullptr)
Provides a way to construct any of the CastInst subclasses using an opcode instead of the subclass's ...
static Type * makeCmpResultType(Type *opnd_type)
Create a result type for fcmp/icmp.
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:740
bool isFPPredicate() const
Definition InstrTypes.h:845
static LLVM_ABI std::optional< CmpPredicate > getMatching(CmpPredicate A, CmpPredicate B)
Compares two CmpPredicates taking samesign into account and returns the canonicalized CmpPredicate if...
Combiner implementation.
Definition Combiner.h:33
static LLVM_ABI Constant * getExtractElement(Constant *Vec, Constant *Idx, Type *OnlyIfReducedTy=nullptr)
This is the shared class of boolean and integer constants.
Definition Constants.h:87
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition Constants.h:159
This class represents a range of values.
LLVM_ABI ConstantRange urem(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an unsigned remainder operation of...
LLVM_ABI ConstantRange binaryAnd(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a binary-and of a value in this ra...
LLVM_ABI bool contains(const APInt &Val) const
Return true if the specified value is in the set.
static LLVM_ABI Constant * getSplat(ElementCount EC, Constant *Elt)
Return a ConstantVector with the specified constant in each element.
static LLVM_ABI Constant * get(ArrayRef< Constant * > V)
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:64
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:178
bool empty() const
Definition DenseMap.h:113
iterator end()
Definition DenseMap.h:85
Implements a dense probed hash-table based set.
Definition DenseSet.h:289
Analysis pass which computes a DominatorTree.
Definition Dominators.h:278
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:159
LLVM_ABI bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
This instruction extracts a single (scalar) element from a VectorType value.
Convenience struct for specifying and reasoning about fast-math flags.
Definition FMF.h:23
Class to represent fixed width SIMD vectors.
unsigned getNumElements() const
static FixedVectorType * getDoubleElementsVectorType(FixedVectorType *VTy)
static LLVM_ABI FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition Type.cpp:869
Predicate getSignedPredicate() const
For example, EQ->EQ, SLE->SLE, UGT->SGT, etc.
bool isEquality() const
Return true if this predicate is either EQ or NE.
Common base class shared among various IRBuilders.
Definition IRBuilder.h:114
Value * CreateNUWMul(Value *LHS, Value *RHS, const Twine &Name="")
Definition IRBuilder.h:1491
Value * CreateInsertElement(Type *VecTy, Value *NewElt, Value *Idx, const Twine &Name="")
Definition IRBuilder.h:2627
Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")
Definition IRBuilder.h:2615
LoadInst * CreateAlignedLoad(Type *Ty, Value *Ptr, MaybeAlign Align, const char *Name)
Definition IRBuilder.h:1935
LLVM_ABI Value * CreateSelectFMF(Value *C, Value *True, Value *False, FMFSource FMFSource, const Twine &Name="", Instruction *MDFrom=nullptr)
LLVM_ABI Value * CreateVectorSplat(unsigned NumElts, Value *V, const Twine &Name="")
Return a vector value that contains.
Value * CreateExtractValue(Value *Agg, ArrayRef< unsigned > Idxs, const Twine &Name="")
Definition IRBuilder.h:2674
ConstantInt * getTrue()
Get the constant value for i1 true.
Definition IRBuilder.h:509
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > OverloadTypes, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="", ArrayRef< OperandBundleDef > OpBundles={})
Create a call to intrinsic ID with Args, mangled using OverloadTypes.
LLVM_ABI Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
Value * CreateFreeze(Value *V, const Twine &Name="")
Definition IRBuilder.h:2693
void SetCurrentDebugLocation(const DebugLoc &L)
Set location information used by debugging information.
Definition IRBuilder.h:247
Value * CreateLShr(Value *LHS, Value *RHS, const Twine &Name="", bool isExact=false)
Definition IRBuilder.h:1554
Value * CreateCast(Instruction::CastOps Op, Value *V, Type *DestTy, const Twine &Name="", MDNode *FPMathTag=nullptr, FMFSource FMFSource={})
Definition IRBuilder.h:2276
Value * CreateIsNotNeg(Value *Arg, const Twine &Name="")
Return a boolean value testing if Arg > -1.
Definition IRBuilder.h:2717
Value * CreateInBoundsGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="")
Definition IRBuilder.h:2018
Value * CreatePointerBitCastOrAddrSpaceCast(Value *V, Type *DestTy, const Twine &Name="")
Definition IRBuilder.h:2301
ConstantInt * getInt64(uint64_t C)
Get a constant 64-bit value.
Definition IRBuilder.h:534
LLVM_ABI CallInst * CreateOrReduce(Value *Src)
Create a vector int OR reduction intrinsic of the source vector.
ConstantInt * getInt32(uint32_t C)
Get a constant 32-bit value.
Definition IRBuilder.h:529
Value * CreateCmp(CmpInst::Predicate Pred, Value *LHS, Value *RHS, const Twine &Name="", MDNode *FPMathTag=nullptr)
Definition IRBuilder.h:2508
PHINode * CreatePHI(Type *Ty, unsigned NumReservedValues, const Twine &Name="")
Definition IRBuilder.h:2539
InstTy * Insert(InstTy *I, const Twine &Name="") const
Insert and return the specified instruction.
Definition IRBuilder.h:172
Value * CreateIsNeg(Value *Arg, const Twine &Name="")
Return a boolean value testing if Arg < 0.
Definition IRBuilder.h:2712
Value * CreateBitCast(Value *V, Type *DestTy, const Twine &Name="")
Definition IRBuilder.h:2242
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
Definition IRBuilder.h:1918
Value * CreateShl(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition IRBuilder.h:1533
LLVM_ABI Value * CreateNAryOp(unsigned Opc, ArrayRef< Value * > Ops, const Twine &Name="", MDNode *FPMathTag=nullptr)
Create either a UnaryOperator or BinaryOperator depending on Opc.
Value * CreateZExt(Value *V, Type *DestTy, const Twine &Name="", bool IsNonNeg=false)
Definition IRBuilder.h:2120
Value * CreateShuffleVector(Value *V1, Value *V2, Value *Mask, const Twine &Name="")
Definition IRBuilder.h:2649
Value * CreateAnd(Value *LHS, Value *RHS, const Twine &Name="")
Definition IRBuilder.h:1592
StoreInst * CreateStore(Value *Val, Value *Ptr, bool isVolatile=false)
Definition IRBuilder.h:1931
Value * CreateTrunc(Value *V, Type *DestTy, const Twine &Name="", bool IsNUW=false, bool IsNSW=false)
Definition IRBuilder.h:2106
PointerType * getPtrTy(unsigned AddrSpace=0)
Fetch the type representing a pointer.
Definition IRBuilder.h:629
Value * CreateBinOp(Instruction::BinaryOps Opc, Value *LHS, Value *RHS, const Twine &Name="", MDNode *FPMathTag=nullptr)
Definition IRBuilder.h:1753
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition IRBuilder.h:207
Value * CreateFNegFMF(Value *V, FMFSource FMFSource, const Twine &Name="", MDNode *FPMathTag=nullptr)
Definition IRBuilder.h:1856
Value * CreateICmp(CmpInst::Predicate P, Value *LHS, Value *RHS, const Twine &Name="")
Definition IRBuilder.h:2484
Value * CreateOr(Value *LHS, Value *RHS, const Twine &Name="", bool IsDisjoint=false)
Definition IRBuilder.h:1614
InstSimplifyFolder - Use InstructionSimplify to fold operations to existing values.
CostType getValue() const
This function is intended to be used as sparingly as possible, since the class provides the full rang...
void push(Instruction *I)
Push the instruction onto the worklist stack.
LLVM_ABI void setHasNoUnsignedWrap(bool b=true)
Set or clear the nuw flag on this instruction, which must be an operator which supports this flag.
LLVM_ABI void copyIRFlags(const Value *V, bool IncludeWrapFlags=true)
Convenience method to copy supported exact, fast-math, and (optionally) wrapping flags from V to this...
LLVM_ABI void setHasNoSignedWrap(bool b=true)
Set or clear the nsw flag on this instruction, which must be an operator which supports this flag.
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
LLVM_ABI void andIRFlags(const Value *V)
Logical 'and' of any supported wrapping, exact, and fast-math flags of V and this instruction.
bool isBinaryOp() const
LLVM_ABI void setNonNeg(bool b=true)
Set or clear the nneg flag on this instruction, which must be a zext instruction.
LLVM_ABI bool comesBefore(const Instruction *Other) const
Given an instruction Other in the same basic block as this instruction, return true if this instructi...
LLVM_ABI AAMDNodes getAAMetadata() const
Returns the AA metadata for this instruction.
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
LLVM_ABI void copyMetadata(const Instruction &SrcInst, ArrayRef< unsigned > WL=ArrayRef< unsigned >())
Copy metadata from SrcInst to this instruction.
bool isIntDivRem() const
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:350
unsigned getBitWidth() const
Get the number of bits in this IntegerType.
A wrapper class for inspecting calls to intrinsic functions.
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
An instruction for reading from memory.
unsigned getPointerAddressSpace() const
Returns the address space of the pointer operand.
void setAlignment(Align Align)
Type * getPointerOperandType() const
Align getAlign() const
Return the alignment of the access that is being performed.
Representation for a specific memory location.
static LLVM_ABI MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalyses & preserveSet()
Mark an analysis set as preserved.
Definition Analysis.h:151
const SDValue & getOperand(unsigned Num) const
This instruction constructs a fixed permutation of two input vectors.
int getMaskValue(unsigned Elt) const
Return the shuffle mask value of this instruction for the given element index.
VectorType * getType() const
Overload to return most specific vector type.
static LLVM_ABI void getShuffleMask(const Constant *Mask, SmallVectorImpl< int > &Result)
Convert the input shuffle mask operand to a vector of integers.
static LLVM_ABI bool isIdentityMask(ArrayRef< int > Mask, int NumSrcElts)
Return true if this shuffle mask chooses elements from exactly one source vector without lane crossin...
static void commuteShuffleMask(MutableArrayRef< int > Mask, unsigned InVecNumElts)
Change values in a shuffle permute mask assuming the two vector operands of length InVecNumElts have ...
size_type size() const
Definition SmallPtrSet.h:99
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
void assign(size_type NumElts, ValueParamT Elt)
reference emplace_back(ArgTypes &&... Args)
void reserve(size_type N)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
void resize(size_type N)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
void setAlignment(Align Align)
Analysis pass providing the TargetTransformInfo.
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
static LLVM_ABI CastContextHint getCastContextHint(const Instruction *I)
Calculates a CastContextHint from I.
@ None
The insert/extract is not used with a load/store.
LLVM_ABI InstructionCost getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, CmpInst::Predicate VecPred, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, OperandValueInfo Op1Info={OK_AnyValue, OP_None}, OperandValueInfo Op2Info={OK_AnyValue, OP_None}, const Instruction *I=nullptr) const
LLVM_ABI TypeSize getRegisterBitWidth(RegisterKind K) const
static LLVM_ABI OperandValueInfo commonOperandInfo(const Value *X, const Value *Y)
Collect common data between two OperandValueInfo inputs.
LLVM_ABI InstructionCost getMemoryOpCost(unsigned Opcode, Type *Src, Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, OperandValueInfo OpdInfo={OK_AnyValue, OP_None}, const Instruction *I=nullptr) const
LLVM_ABI bool allowVectorElementIndexingUsingGEP() const
Returns true if GEP should not be used to index into vectors for this target.
LLVM_ABI InstructionCost getShuffleCost(ShuffleKind Kind, VectorType *DstTy, VectorType *SrcTy, ArrayRef< int > Mask={}, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, int Index=0, VectorType *SubTp=nullptr, ArrayRef< const Value * > Args={}, const Instruction *CxtI=nullptr) const
LLVM_ABI InstructionCost getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA, TTI::TargetCostKind CostKind) const
LLVM_ABI InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty, std::optional< FastMathFlags > FMF, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput) const
Calculate the cost of vector reduction intrinsics.
LLVM_ABI InstructionCost getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, TTI::CastContextHint CCH, TTI::TargetCostKind CostKind=TTI::TCK_SizeAndLatency, const Instruction *I=nullptr) const
LLVM_ABI InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index=-1, const Value *Op0=nullptr, const Value *Op1=nullptr, TTI::VectorInstrContext VIC=TTI::VectorInstrContext::None) const
LLVM_ABI unsigned getRegisterClassForType(bool Vector, Type *Ty=nullptr) const
LLVM_ABI InstructionCost getMinMaxReductionCost(Intrinsic::ID IID, VectorType *Ty, FastMathFlags FMF=FastMathFlags(), TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput) const
TargetCostKind
The kind of cost model.
@ TCK_RecipThroughput
Reciprocal throughput.
LLVM_ABI InstructionCost getArithmeticInstrCost(unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, TTI::OperandValueInfo Opd1Info={TTI::OK_AnyValue, TTI::OP_None}, TTI::OperandValueInfo Opd2Info={TTI::OK_AnyValue, TTI::OP_None}, ArrayRef< const Value * > Args={}, const Instruction *CxtI=nullptr, const TargetLibraryInfo *TLibInfo=nullptr) const
This is an approximation of reciprocal throughput of a math/logic op.
LLVM_ABI InstructionCost getMemIntrinsicInstrCost(const MemIntrinsicCostAttributes &MICA, TTI::TargetCostKind CostKind) const
LLVM_ABI unsigned getMinVectorRegisterBitWidth() const
LLVM_ABI InstructionCost getAddressComputationCost(Type *PtrTy, ScalarEvolution *SE, const SCEV *Ptr, TTI::TargetCostKind CostKind) const
LLVM_ABI unsigned getNumberOfRegisters(unsigned ClassID) const
LLVM_ABI InstructionCost getInstructionCost(const User *U, ArrayRef< const Value * > Operands, TargetCostKind CostKind) const
Estimate the cost of a given IR user when lowered.
LLVM_ABI InstructionCost getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract, TTI::TargetCostKind CostKind, bool ForPoisonSrc=true, ArrayRef< Value * > VL={}, TTI::VectorInstrContext VIC=TTI::VectorInstrContext::None) const
Estimate the overhead of scalarizing an instruction.
ShuffleKind
The various kinds of shuffle patterns for vector queries.
@ SK_PermuteSingleSrc
Shuffle elements of single source vector with any shuffle mask.
@ SK_Broadcast
Broadcast element 0 to all other elements.
@ SK_PermuteTwoSrc
Merge elements from two source vectors into one with any shuffle mask.
@ SK_ExtractSubvector
ExtractSubvector Index indicates start offset.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:46
bool isPointerTy() const
True if this is an instance of PointerType.
Definition Type.h:282
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition Type.h:368
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:197
LLVMContext & getContext() const
Return the LLVMContext in which this type was uniqued.
Definition Type.h:130
LLVM_ABI unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
Definition Type.cpp:232
bool isFloatingPointTy() const
Return true if this is one of the floating-point types.
Definition Type.h:186
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition Type.h:257
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
op_range operands()
Definition User.h:267
Value * getOperand(unsigned i) const
Definition User.h:207
static LLVM_ABI bool isVPBinOp(Intrinsic::ID ID)
std::optional< unsigned > getFunctionalIntrinsicID() const
std::optional< unsigned > getFunctionalOpcode() const
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:255
const Value * stripAndAccumulateInBoundsConstantOffsets(const DataLayout &DL, APInt &Offset) const
This is a wrapper around stripAndAccumulateConstantOffsets with the in-bounds requirement set to fals...
Definition Value.h:727
LLVM_ABI bool hasOneUser() const
Return true if there is exactly one user of this value.
Definition Value.cpp:162
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition Value.h:439
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition Value.cpp:549
iterator_range< user_iterator > users()
Definition Value.h:426
LLVM_ABI Align getPointerAlignment(const DataLayout &DL) const
Returns an alignment of the pointer value.
Definition Value.cpp:969
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition Value.h:543
LLVM_ABI bool hasNUses(unsigned N) const
Return true if this Value has exactly N uses.
Definition Value.cpp:146
bool use_empty() const
Definition Value.h:346
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:318
bool user_empty() const
Definition Value.h:389
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &)
static LLVM_ABI VectorType * get(Type *ElementType, ElementCount EC)
This static method is the primary way to construct an VectorType.
Type * getElementType() const
std::pair< iterator, bool > insert(const ValueT &V)
Definition DenseSet.h:212
size_type size() const
Definition DenseSet.h:87
constexpr bool hasKnownScalarFactor(const FixedOrScalableQuantity &RHS) const
Returns true if there exists a value X where RHS.multiplyCoefficientBy(X) will result in a value whos...
Definition TypeSize.h:269
constexpr ScalarTy getKnownScalarFactor(const FixedOrScalableQuantity &RHS) const
Returns a value X where RHS.multiplyCoefficientBy(X) will result in a value whose quantity matches ou...
Definition TypeSize.h:277
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition TypeSize.h:168
const ParentTy * getParent() const
Definition ilist_node.h:34
self_iterator getIterator()
Definition ilist_node.h:123
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition ilist_node.h:348
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
Abstract Attribute helper functions.
Definition Attributor.h:165
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
constexpr char Attrs[]
Key for Kernel::Metadata::mAttrs.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition APInt.h:2277
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition APInt.h:2282
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI AttributeSet getFnAttributes(LLVMContext &C, ID id)
Return the function attributes for an intrinsic.
SpecificConstantMatch m_ZeroInt()
Convenience matchers for specific integer values.
BinaryOp_match< SpecificConstantMatch, SrcTy, TargetOpcode::G_SUB > m_Neg(const SrcTy &&Src)
Matches a register negated by a G_SUB.
OneUse_match< SubPat > m_OneUse(const SubPat &SP)
match_combine_and< Ty... > m_CombineAnd(const Ty &...Ps)
Combine pattern matchers matching all of Ps patterns.
cst_pred_ty< is_all_ones > m_AllOnes()
Match an integer or vector with all bits set.
BinaryOp_match< LHS, RHS, Instruction::And > m_And(const LHS &L, const RHS &R)
auto m_Cmp()
Matches any compare instruction and ignore it.
BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::URem > m_URem(const LHS &L, const RHS &R)
auto m_Poison()
Match an arbitrary poison constant.
ap_match< APInt > m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
CastInst_match< OpTy, TruncInst > m_Trunc(const OpTy &Op)
Matches Trunc.
specific_intval< false > m_SpecificInt(const APInt &V)
Match a specific integer value or vector with all elements equal to the value.
bool match(Val *V, const Pattern &P)
match_bind< Instruction > m_Instruction(Instruction *&I)
Match an instruction, capturing it if we match.
specificval_ty m_Specific(const Value *V)
Match if we have a specific specified value.
DisjointOr_match< LHS, RHS > m_DisjointOr(const LHS &L, const RHS &R)
BinOpPred_match< LHS, RHS, is_right_shift_op > m_Shr(const LHS &L, const RHS &R)
Matches logical shift operations.
TwoOps_match< Val_t, Idx_t, Instruction::ExtractElement > m_ExtractElt(const Val_t &Val, const Idx_t &Idx)
Matches ExtractElementInst.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
auto m_BinOp()
Match an arbitrary binary operation and ignore it.
auto m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
auto m_Constant()
Match an arbitrary Constant and ignore it.
TwoOps_match< V1_t, V2_t, Instruction::ShuffleVector > m_Shuffle(const V1_t &v1, const V2_t &v2)
Matches ShuffleVectorInst independently of mask value.
cst_pred_ty< is_non_zero_int > m_NonZeroInt()
Match a non-zero integer or a vector with all non-zero elements.
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
CastInst_match< OpTy, ZExtInst > m_ZExt(const OpTy &Op)
Matches ZExt.
OverflowingBinaryOp_match< LHS, RHS, Instruction::Shl, OverflowingBinaryOperator::NoUnsignedWrap > m_NUWShl(const LHS &L, const RHS &R)
auto m_AnyIntrinsic()
Matches any intrinsic call and ignore it.
OverflowingBinaryOp_match< LHS, RHS, Instruction::Mul, OverflowingBinaryOperator::NoUnsignedWrap > m_NUWMul(const LHS &L, const RHS &R)
BinOpPred_match< LHS, RHS, is_bitwiselogic_op, true > m_c_BitwiseLogic(const LHS &L, const RHS &R)
Matches bitwise logic operations in either order.
CastOperator_match< OpTy, Instruction::BitCast > m_BitCast(const OpTy &Op)
Matches BitCast.
match_combine_or< CastInst_match< OpTy, SExtInst >, NNegZExt_match< OpTy > > m_SExtLike(const OpTy &Op)
Match either "sext" or "zext nneg".
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
CmpClass_match< LHS, RHS, ICmpInst > m_ICmp(CmpPredicate &Pred, const LHS &L, const RHS &R)
match_combine_or< CastInst_match< OpTy, ZExtInst >, CastInst_match< OpTy, SExtInst > > m_ZExtOrSExt(const OpTy &Op)
FNeg_match< OpTy > m_FNeg(const OpTy &X)
Match 'fneg X' as 'fsub -0.0, X'.
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_Undef()
Match an arbitrary undef constant.
CastInst_match< OpTy, SExtInst > m_SExt(const OpTy &Op)
Matches SExt.
is_zero m_Zero()
Match any null constant or a vector with all elements equal to 0.
BinaryOp_match< LHS, RHS, Instruction::Or, true > m_c_Or(const LHS &L, const RHS &R)
Matches an Or with LHS and RHS in either order.
ThreeOps_match< Val_t, Elt_t, Idx_t, Instruction::InsertElement > m_InsertElt(const Val_t &Val, const Elt_t &Elt, const Idx_t &Idx)
Matches InsertElementInst.
m_Intrinsic_Ty< Opnd >::Ty m_Deinterleave2(const Opnd &Op)
auto m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
@ Valid
The data is already valid.
initializer< Ty > init(const Ty &Val)
DXILDebugInfoMap run(Module &M)
@ User
could "use" a pointer
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
NodeAddr< UseNode * > Use
Definition RDFGraph.h:385
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:315
unsigned Log2_32_Ceil(uint32_t Value)
Return the ceil log base 2 of the specified value, 32 if the value is zero.
Definition MathExtras.h:344
@ Offset
Definition DWP.cpp:558
detail::zippy< detail::zip_shortest, T, U, Args... > zip(T &&t, U &&u, Args &&...args)
zip iterator for two or more iteratable types.
Definition STLExtras.h:830
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
void stable_sort(R &&Range)
Definition STLExtras.h:2115
UnaryFunction for_each(R &&Range, UnaryFunction F)
Provide wrappers to std::for_each which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1731
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1738
LLVM_ABI Intrinsic::ID getMinMaxReductionIntrinsicOp(Intrinsic::ID RdxID)
Returns the min/max intrinsic used when expanding a min/max reduction.
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
Definition Local.cpp:535
LLVM_ABI SDValue peekThroughBitcasts(SDValue V)
Return the non-bitcasted source operand of V if it exists.
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are tuples (A, B,...
Definition STLExtras.h:2553
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
unsigned Log2_64_Ceil(uint64_t Value)
Return the ceil log base 2 of the specified value, 64 if the value is zero.
Definition MathExtras.h:350
LLVM_ABI Value * simplifyUnOp(unsigned Opcode, Value *Op, const SimplifyQuery &Q)
Given operand for a UnaryOperator, fold the result or return null.
scope_exit(Callable) -> scope_exit< Callable >
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
LLVM_ABI unsigned getArithmeticReductionInstruction(Intrinsic::ID RdxID)
Returns the arithmetic instruction opcode used when expanding a reduction.
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition STLExtras.h:2207
constexpr bool isUIntN(unsigned N, uint64_t x)
Checks if an unsigned integer fits into the given (dynamic) bit width.
Definition MathExtras.h:243
LLVM_ABI Value * simplifyCall(CallBase *Call, Value *Callee, ArrayRef< Value * > Args, const SimplifyQuery &Q)
Given a callsite, callee, and arguments, fold the result or return null.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition STLExtras.h:633
LLVM_ABI bool mustSuppressSpeculation(const LoadInst &LI)
Return true if speculation of the given load must be suppressed to avoid ordering or interfering with...
Definition Loads.cpp:432
constexpr bool isPowerOf2_64(uint64_t Value)
Return true if the argument is a power of two > 0 (64 bit edition.)
Definition MathExtras.h:284
LLVM_ABI bool widenShuffleMaskElts(int Scale, ArrayRef< int > Mask, SmallVectorImpl< int > &ScaledMask)
Try to transform a shuffle mask by replacing elements with the scaled index for an equivalent mask of...
LLVM_ABI bool isSafeToSpeculativelyExecute(const Instruction *I, const Instruction *CtxI=nullptr, AssumptionCache *AC=nullptr, const DominatorTree *DT=nullptr, const TargetLibraryInfo *TLI=nullptr, bool UseVariableInfo=true, bool IgnoreUBImplyingAttrs=true)
Return true if the instruction does not have any effects besides calculating the result and does not ...
LLVM_ABI Value * getSplatValue(const Value *V)
Get splat value if the input is a splat vector or return nullptr.
constexpr auto equal_to(T &&Arg)
Functor variant of std::equal_to that can be used as a UnaryPredicate in functional algorithms like a...
Definition STLExtras.h:2172
unsigned M1(unsigned Val)
Definition VE.h:377
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1745
LLVM_ABI bool isInstructionTriviallyDead(Instruction *I, const TargetLibraryInfo *TLI=nullptr)
Return true if the result produced by the instruction is not used, and the instruction will return.
Definition Local.cpp:403
LLVM_ABI bool isSplatValue(const Value *V, int Index=-1, unsigned Depth=0)
Return true if each element of the vector value V is poisoned or equal to every other non-poisoned el...
unsigned Log2_32(uint32_t Value)
Return the floor log base 2 of the specified value, -1 if the value is zero.
Definition MathExtras.h:331
constexpr bool isPowerOf2_32(uint32_t Value)
Return true if the argument is a power of two > 0.
Definition MathExtras.h:279
bool isModSet(const ModRefInfo MRI)
Definition ModRef.h:49
void sort(IteratorTy Start, IteratorTy End)
Definition STLExtras.h:1635
LLVM_ABI void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
LLVM_ABI bool programUndefinedIfPoison(const Instruction *Inst)
LLVM_ABI bool isSafeToLoadUnconditionally(Value *V, Align Alignment, const APInt &Size, const DataLayout &DL, Instruction *ScanFrom, AssumptionCache *AC=nullptr, const DominatorTree *DT=nullptr, const TargetLibraryInfo *TLI=nullptr)
Return true if we know that executing a load from this value cannot trap.
Definition Loads.cpp:447
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:209
FunctionAddr VTableAddr Count
Definition InstrProf.h:139
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
LLVM_ABI void propagateIRFlags(Value *I, ArrayRef< Value * > VL, Value *OpValue=nullptr, bool IncludeWrapFlags=true)
Get the intersection (logical and) of all of the potential IR flags of each scalar operation (VL) tha...
LLVM_ABI bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
MutableArrayRef(T &OneElt) -> MutableArrayRef< T >
constexpr int PoisonMaskElem
LLVM_ABI bool isSafeToSpeculativelyExecuteWithOpcode(unsigned Opcode, const Instruction *Inst, const Instruction *CtxI=nullptr, AssumptionCache *AC=nullptr, const DominatorTree *DT=nullptr, const TargetLibraryInfo *TLI=nullptr, bool UseVariableInfo=true, bool IgnoreUBImplyingAttrs=true)
This returns the same result as isSafeToSpeculativelyExecute if Opcode is the actual opcode of Inst.
@ Other
Any other memory.
Definition ModRef.h:68
TargetTransformInfo TTI
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
LLVM_ABI Value * simplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, const SimplifyQuery &Q)
Given operands for a BinaryOperator, fold the result or return null.
LLVM_ABI void narrowShuffleMaskElts(int Scale, ArrayRef< int > Mask, SmallVectorImpl< int > &ScaledMask)
Replace each shuffle mask index with the scaled sequential indices for an equivalent mask of narrowed...
LLVM_ABI Intrinsic::ID getReductionForBinop(Instruction::BinaryOps Opc)
Returns the reduction intrinsic id corresponding to the binary operation.
@ And
Bitwise or logical AND of integers.
LLVM_ABI bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, unsigned ScalarOpdIdx, const TargetTransformInfo *TTI)
Identifies if the vector form of the intrinsic has a scalar operand.
DWARFExpression::Operation Op
unsigned M0(unsigned Val)
Definition VE.h:376
LLVM_ABI unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Return the number of times the sign bit of the register is replicated into the other bits.
constexpr unsigned BitWidth
LLVM_ABI bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
LLVM_ABI Constant * getLosslessInvCast(Constant *C, Type *InvCastTo, unsigned CastOp, const DataLayout &DL, PreservedCastFlags *Flags=nullptr)
Try to cast C to InvC losslessly, satisfying CastOp(InvC) equals C, or CastOp(InvC) is a refined valu...
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
auto find_if(R &&Range, UnaryPredicate P)
Provide wrappers to std::find_if which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1771
constexpr bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition MathExtras.h:248
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1946
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Definition Alignment.h:201
bool all_equal(std::initializer_list< T > Values)
Returns true if all Values in the initializer lists are equal or the list.
Definition STLExtras.h:2165
LLVM_ABI Value * simplifyCmpInst(CmpPredicate Predicate, Value *LHS, Value *RHS, const SimplifyQuery &Q)
Given operands for a CmpInst, fold the result or return null.
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
LLVM_ABI bool isKnownNonNegative(const Value *V, const SimplifyQuery &SQ, unsigned Depth=0)
Returns true if the give value is known to be non-negative.
LLVM_ABI bool isTriviallyVectorizable(Intrinsic::ID ID)
Identify if the intrinsic is trivially vectorizable.
LLVM_ABI Intrinsic::ID getMinMaxReductionIntrinsicID(Intrinsic::ID IID)
Returns the llvm.vector.reduce min/max intrinsic that corresponds to the intrinsic op.
LLVM_ABI ConstantRange computeConstantRange(const Value *V, bool ForSigned, const SimplifyQuery &SQ, unsigned Depth=0)
Determine the possible constant range of an integer or vector of integer value.
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:876
LLVM_ABI AAMDNodes adjustForAccess(unsigned AccessSize)
Create a new AAMDNode for accessing AccessSize bytes of this AAMDNode.
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition Alignment.h:39
unsigned countMaxActiveBits() const
Returns the maximum number of bits needed to represent all possible unsigned values with these known ...
Definition KnownBits.h:310
unsigned countMinLeadingZeros() const
Returns the minimum number of leading zero bits.
Definition KnownBits.h:262
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition KnownBits.h:146
const DataLayout & DL
const Instruction * CxtI
const DominatorTree * DT
SimplifyQuery getWithInstruction(const Instruction *I) const
AssumptionCache * AC