LLVM 23.0.0git
VectorCombine.cpp
Go to the documentation of this file.
1//===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass optimizes scalar/vector interactions using target cost models. The
10// transforms implemented here may not fit in traditional loop-based or SLP
11// vectorization passes.
12//
13//===----------------------------------------------------------------------===//
14
16#include "llvm/ADT/DenseMap.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/ScopeExit.h"
20#include "llvm/ADT/Statistic.h"
25#include "llvm/Analysis/Loads.h"
30#include "llvm/IR/Dominators.h"
31#include "llvm/IR/Function.h"
32#include "llvm/IR/IRBuilder.h"
40#include <numeric>
41#include <optional>
42#include <queue>
43#include <set>
44
45#define DEBUG_TYPE "vector-combine"
47
48using namespace llvm;
49using namespace llvm::PatternMatch;
50
51STATISTIC(NumVecLoad, "Number of vector loads formed");
52STATISTIC(NumVecCmp, "Number of vector compares formed");
53STATISTIC(NumVecBO, "Number of vector binops formed");
54STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed");
55STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast");
56STATISTIC(NumScalarOps, "Number of scalar unary + binary ops formed");
57STATISTIC(NumScalarCmp, "Number of scalar compares formed");
58STATISTIC(NumScalarIntrinsic, "Number of scalar intrinsic calls formed");
59
61 "disable-vector-combine", cl::init(false), cl::Hidden,
62 cl::desc("Disable all vector combine transforms"));
63
65 "disable-binop-extract-shuffle", cl::init(false), cl::Hidden,
66 cl::desc("Disable binop extract to shuffle transforms"));
67
69 "vector-combine-max-scan-instrs", cl::init(30), cl::Hidden,
70 cl::desc("Max number of instructions to scan for vector combining."));
71
72static const unsigned InvalidIndex = std::numeric_limits<unsigned>::max();
73
74namespace {
75class VectorCombine {
76public:
77 VectorCombine(Function &F, const TargetTransformInfo &TTI,
80 bool TryEarlyFoldsOnly)
81 : F(F), Builder(F.getContext(), InstSimplifyFolder(*DL)), TTI(TTI),
82 DT(DT), AA(AA), DL(DL), CostKind(CostKind),
83 SQ(*DL, /*TLI=*/nullptr, &DT, &AC),
84 TryEarlyFoldsOnly(TryEarlyFoldsOnly) {}
85
86 bool run();
87
88private:
89 Function &F;
91 const TargetTransformInfo &TTI;
92 const DominatorTree &DT;
93 AAResults &AA;
94 const DataLayout *DL;
95 TTI::TargetCostKind CostKind;
96 const SimplifyQuery SQ;
97
98 /// If true, only perform beneficial early IR transforms. Do not introduce new
99 /// vector operations.
100 bool TryEarlyFoldsOnly;
101
102 InstructionWorklist Worklist;
103
104 /// Next instruction to iterate. It will be updated when it is erased by
105 /// RecursivelyDeleteTriviallyDeadInstructions.
106 Instruction *NextInst;
107
108 // TODO: Direct calls from the top-level "run" loop use a plain "Instruction"
109 // parameter. That should be updated to specific sub-classes because the
110 // run loop was changed to dispatch on opcode.
111 bool vectorizeLoadInsert(Instruction &I);
112 bool widenSubvectorLoad(Instruction &I);
113 ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0,
114 ExtractElementInst *Ext1,
115 unsigned PreferredExtractIndex) const;
116 bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
117 const Instruction &I,
118 ExtractElementInst *&ConvertToShuffle,
119 unsigned PreferredExtractIndex);
120 Value *foldExtExtCmp(Value *V0, Value *V1, Value *ExtIndex, Instruction &I);
121 Value *foldExtExtBinop(Value *V0, Value *V1, Value *ExtIndex, Instruction &I);
122 bool foldExtractExtract(Instruction &I);
123 bool foldInsExtFNeg(Instruction &I);
124 bool foldInsExtBinop(Instruction &I);
125 bool foldInsExtVectorToShuffle(Instruction &I);
126 bool foldBitOpOfCastops(Instruction &I);
127 bool foldBitOpOfCastConstant(Instruction &I);
128 bool foldBitcastShuffle(Instruction &I);
129 bool scalarizeOpOrCmp(Instruction &I);
130 bool scalarizeVPIntrinsic(Instruction &I);
131 bool foldExtractedCmps(Instruction &I);
132 bool foldSelectsFromBitcast(Instruction &I);
133 bool foldBinopOfReductions(Instruction &I);
134 bool foldSingleElementStore(Instruction &I);
135 bool scalarizeLoad(Instruction &I);
136 bool scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy, Value *Ptr);
137 bool scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy, Value *Ptr);
138 bool scalarizeExtExtract(Instruction &I);
139 bool foldConcatOfBoolMasks(Instruction &I);
140 bool foldPermuteOfBinops(Instruction &I);
141 bool foldShuffleOfBinops(Instruction &I);
142 bool foldShuffleOfSelects(Instruction &I);
143 bool foldShuffleOfCastops(Instruction &I);
144 bool foldShuffleOfShuffles(Instruction &I);
145 bool foldPermuteOfIntrinsic(Instruction &I);
146 bool foldShufflesOfLengthChangingShuffles(Instruction &I);
147 bool foldShuffleOfIntrinsics(Instruction &I);
148 bool foldShuffleToIdentity(Instruction &I);
149 bool foldShuffleFromReductions(Instruction &I);
150 bool foldShuffleChainsToReduce(Instruction &I);
151 bool foldCastFromReductions(Instruction &I);
152 bool foldSignBitReductionCmp(Instruction &I);
153 bool foldICmpEqZeroVectorReduce(Instruction &I);
154 bool foldEquivalentReductionCmp(Instruction &I);
155 bool foldReduceAddCmpZero(Instruction &I);
156 bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
157 bool foldInterleaveIntrinsics(Instruction &I);
158 bool foldDeinterleaveIntrinsics(Instruction &I);
159 bool foldBitcastOfVPLoad(Instruction &I);
160 bool shrinkType(Instruction &I);
161 bool shrinkLoadForShuffles(Instruction &I);
162 bool shrinkPhiOfShuffles(Instruction &I);
163
164 void replaceValue(Instruction &Old, Value &New, bool Erase = true) {
165 LLVM_DEBUG(dbgs() << "VC: Replacing: " << Old << '\n');
166 LLVM_DEBUG(dbgs() << " With: " << New << '\n');
167 Old.replaceAllUsesWith(&New);
168 if (auto *NewI = dyn_cast<Instruction>(&New)) {
169 New.takeName(&Old);
170 Worklist.pushUsersToWorkList(*NewI);
171 Worklist.pushValue(NewI);
172 }
173 if (Erase && isInstructionTriviallyDead(&Old)) {
174 eraseInstruction(Old);
175 } else {
176 Worklist.push(&Old);
177 }
178 }
179
180 void eraseInstruction(Instruction &I) {
181 LLVM_DEBUG(dbgs() << "VC: Erasing: " << I << '\n');
182 SmallVector<Value *> Ops(I.operands());
183 Worklist.remove(&I);
184 I.eraseFromParent();
185
186 // Push remaining users of the operands and then the operand itself - allows
187 // further folds that were hindered by OneUse limits.
188 SmallPtrSet<Value *, 4> Visited;
189 for (Value *Op : Ops) {
190 if (!Visited.contains(Op)) {
191 if (auto *OpI = dyn_cast<Instruction>(Op)) {
193 OpI, nullptr, nullptr, [&](Value *V) {
194 if (auto *I = dyn_cast<Instruction>(V)) {
195 LLVM_DEBUG(dbgs() << "VC: Erased: " << *I << '\n');
196 Worklist.remove(I);
197 if (I == NextInst)
198 NextInst = NextInst->getNextNode();
199 Visited.insert(I);
200 }
201 }))
202 continue;
203 Worklist.pushUsersToWorkList(*OpI);
204 Worklist.pushValue(OpI);
205 }
206 }
207 }
208 }
209};
210} // namespace
211
212/// Return the source operand of a potentially bitcasted value. If there is no
213/// bitcast, return the input value itself.
215 while (auto *BitCast = dyn_cast<BitCastInst>(V))
216 V = BitCast->getOperand(0);
217 return V;
218}
219
220static bool canWidenLoad(LoadInst *Load, const TargetTransformInfo &TTI) {
221 // Do not widen load if atomic/volatile or under asan/hwasan/memtag/tsan.
222 // The widened load may load data from dirty regions or create data races
223 // non-existent in the source.
224 if (!Load || !Load->isSimple() || !Load->hasOneUse() ||
225 Load->getFunction()->hasFnAttribute(Attribute::SanitizeMemTag) ||
227 return false;
228
229 // We are potentially transforming byte-sized (8-bit) memory accesses, so make
230 // sure we have all of our type-based constraints in place for this target.
231 Type *ScalarTy = Load->getType()->getScalarType();
232 uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
233 unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
234 if (!ScalarSize || !MinVectorSize || MinVectorSize % ScalarSize != 0 ||
235 ScalarSize % 8 != 0)
236 return false;
237
238 return true;
239}
240
241bool VectorCombine::vectorizeLoadInsert(Instruction &I) {
242 // Match insert into fixed vector of scalar value.
243 // TODO: Handle non-zero insert index.
244 Value *Scalar;
245 if (!match(&I,
247 return false;
248
249 // Optionally match an extract from another vector.
250 Value *X;
251 bool HasExtract = match(Scalar, m_ExtractElt(m_Value(X), m_ZeroInt()));
252 if (!HasExtract)
253 X = Scalar;
254
255 auto *Load = dyn_cast<LoadInst>(X);
256 if (!canWidenLoad(Load, TTI))
257 return false;
258
259 Type *ScalarTy = Scalar->getType();
260 uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
261 unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
262
263 // Check safety of replacing the scalar load with a larger vector load.
264 // We use minimal alignment (maximum flexibility) because we only care about
265 // the dereferenceable region. When calculating cost and creating a new op,
266 // we may use a larger value based on alignment attributes.
267 Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
268 assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
269
270 unsigned MinVecNumElts = MinVectorSize / ScalarSize;
271 auto *MinVecTy = VectorType::get(ScalarTy, MinVecNumElts, false);
272 unsigned OffsetEltIndex = 0;
273 Align Alignment = Load->getAlign();
274 if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), *DL, Load, SQ.AC,
275 SQ.DT)) {
276 // It is not safe to load directly from the pointer, but we can still peek
277 // through gep offsets and check if it safe to load from a base address with
278 // updated alignment. If it is, we can shuffle the element(s) into place
279 // after loading.
280 unsigned OffsetBitWidth = DL->getIndexTypeSizeInBits(SrcPtr->getType());
281 APInt Offset(OffsetBitWidth, 0);
283
284 // We want to shuffle the result down from a high element of a vector, so
285 // the offset must be positive.
286 if (Offset.isNegative())
287 return false;
288
289 // The offset must be a multiple of the scalar element to shuffle cleanly
290 // in the element's size.
291 uint64_t ScalarSizeInBytes = ScalarSize / 8;
292 if (Offset.urem(ScalarSizeInBytes) != 0)
293 return false;
294
295 // If we load MinVecNumElts, will our target element still be loaded?
296 OffsetEltIndex = Offset.udiv(ScalarSizeInBytes).getZExtValue();
297 if (OffsetEltIndex >= MinVecNumElts)
298 return false;
299
300 if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), *DL, Load,
301 SQ.AC, SQ.DT))
302 return false;
303
304 // Update alignment with offset value. Note that the offset could be negated
305 // to more accurately represent "(new) SrcPtr - Offset = (old) SrcPtr", but
306 // negation does not change the result of the alignment calculation.
307 Alignment = commonAlignment(Alignment, Offset.getZExtValue());
308 }
309
310 // Original pattern: insertelt undef, load [free casts of] PtrOp, 0
311 // Use the greater of the alignment on the load or its source pointer.
312 Alignment = std::max(SrcPtr->getPointerAlignment(*DL), Alignment);
313 Type *LoadTy = Load->getType();
314 unsigned AS = Load->getPointerAddressSpace();
315 InstructionCost OldCost =
316 TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS, CostKind);
317 APInt DemandedElts = APInt::getOneBitSet(MinVecNumElts, 0);
318 OldCost +=
319 TTI.getScalarizationOverhead(MinVecTy, DemandedElts,
320 /* Insert */ true, HasExtract, CostKind);
321
322 // New pattern: load VecPtr
323 InstructionCost NewCost =
324 TTI.getMemoryOpCost(Instruction::Load, MinVecTy, Alignment, AS, CostKind);
325 // Optionally, we are shuffling the loaded vector element(s) into place.
326 // For the mask set everything but element 0 to undef to prevent poison from
327 // propagating from the extra loaded memory. This will also optionally
328 // shrink/grow the vector from the loaded size to the output size.
329 // We assume this operation has no cost in codegen if there was no offset.
330 // Note that we could use freeze to avoid poison problems, but then we might
331 // still need a shuffle to change the vector size.
332 auto *Ty = cast<FixedVectorType>(I.getType());
333 unsigned OutputNumElts = Ty->getNumElements();
334 SmallVector<int, 16> Mask(OutputNumElts, PoisonMaskElem);
335 assert(OffsetEltIndex < MinVecNumElts && "Address offset too big");
336 Mask[0] = OffsetEltIndex;
337 if (OffsetEltIndex)
338 NewCost += TTI.getShuffleCost(TTI::SK_PermuteSingleSrc, Ty, MinVecTy, Mask,
339 CostKind);
340
341 // We can aggressively convert to the vector form because the backend can
342 // invert this transform if it does not result in a performance win.
343 if (OldCost < NewCost || !NewCost.isValid())
344 return false;
345
346 // It is safe and potentially profitable to load a vector directly:
347 // inselt undef, load Scalar, 0 --> load VecPtr
348 IRBuilder<> Builder(Load);
349 Value *CastedPtr =
350 Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Builder.getPtrTy(AS));
351 Value *VecLd = Builder.CreateAlignedLoad(MinVecTy, CastedPtr, Alignment);
352 VecLd = Builder.CreateShuffleVector(VecLd, Mask);
353
354 replaceValue(I, *VecLd);
355 ++NumVecLoad;
356 return true;
357}
358
359/// If we are loading a vector and then inserting it into a larger vector with
360/// undefined elements, try to load the larger vector and eliminate the insert.
361/// This removes a shuffle in IR and may allow combining of other loaded values.
362bool VectorCombine::widenSubvectorLoad(Instruction &I) {
363 // Match subvector insert of fixed vector.
364 auto *Shuf = cast<ShuffleVectorInst>(&I);
365 if (!Shuf->isIdentityWithPadding())
366 return false;
367
368 // Allow a non-canonical shuffle mask that is choosing elements from op1.
369 unsigned NumOpElts =
370 cast<FixedVectorType>(Shuf->getOperand(0)->getType())->getNumElements();
371 unsigned OpIndex = any_of(Shuf->getShuffleMask(), [&NumOpElts](int M) {
372 return M >= (int)(NumOpElts);
373 });
374
375 auto *Load = dyn_cast<LoadInst>(Shuf->getOperand(OpIndex));
376 if (!canWidenLoad(Load, TTI))
377 return false;
378
379 // We use minimal alignment (maximum flexibility) because we only care about
380 // the dereferenceable region. When calculating cost and creating a new op,
381 // we may use a larger value based on alignment attributes.
382 auto *Ty = cast<FixedVectorType>(I.getType());
383 Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
384 assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
385 Align Alignment = Load->getAlign();
386 if (!isSafeToLoadUnconditionally(SrcPtr, Ty, Align(1), *DL, Load, SQ.AC,
387 SQ.DT))
388 return false;
389
390 Alignment = std::max(SrcPtr->getPointerAlignment(*DL), Alignment);
391 Type *LoadTy = Load->getType();
392 unsigned AS = Load->getPointerAddressSpace();
393
394 // Original pattern: insert_subvector (load PtrOp)
395 // This conservatively assumes that the cost of a subvector insert into an
396 // undef value is 0. We could add that cost if the cost model accurately
397 // reflects the real cost of that operation.
398 InstructionCost OldCost =
399 TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS, CostKind);
400
401 // New pattern: load PtrOp
402 InstructionCost NewCost =
403 TTI.getMemoryOpCost(Instruction::Load, Ty, Alignment, AS, CostKind);
404
405 // We can aggressively convert to the vector form because the backend can
406 // invert this transform if it does not result in a performance win.
407 if (OldCost < NewCost || !NewCost.isValid())
408 return false;
409
410 IRBuilder<> Builder(Load);
411 Value *CastedPtr =
412 Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Builder.getPtrTy(AS));
413 Value *VecLd = Builder.CreateAlignedLoad(Ty, CastedPtr, Alignment);
414 replaceValue(I, *VecLd);
415 ++NumVecLoad;
416 return true;
417}
418
419/// Determine which, if any, of the inputs should be replaced by a shuffle
420/// followed by extract from a different index.
421ExtractElementInst *VectorCombine::getShuffleExtract(
422 ExtractElementInst *Ext0, ExtractElementInst *Ext1,
423 unsigned PreferredExtractIndex = InvalidIndex) const {
424 auto *Index0C = dyn_cast<ConstantInt>(Ext0->getIndexOperand());
425 auto *Index1C = dyn_cast<ConstantInt>(Ext1->getIndexOperand());
426 assert(Index0C && Index1C && "Expected constant extract indexes");
427
428 unsigned Index0 = Index0C->getZExtValue();
429 unsigned Index1 = Index1C->getZExtValue();
430
431 // If the extract indexes are identical, no shuffle is needed.
432 if (Index0 == Index1)
433 return nullptr;
434
435 Type *VecTy = Ext0->getVectorOperand()->getType();
436 assert(VecTy == Ext1->getVectorOperand()->getType() && "Need matching types");
437 InstructionCost Cost0 =
438 TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0);
439 InstructionCost Cost1 =
440 TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1);
441
442 // If both costs are invalid no shuffle is needed
443 if (!Cost0.isValid() && !Cost1.isValid())
444 return nullptr;
445
446 // We are extracting from 2 different indexes, so one operand must be shuffled
447 // before performing a vector operation and/or extract. The more expensive
448 // extract will be replaced by a shuffle.
449 if (Cost0 > Cost1)
450 return Ext0;
451 if (Cost1 > Cost0)
452 return Ext1;
453
454 // If the costs are equal and there is a preferred extract index, shuffle the
455 // opposite operand.
456 if (PreferredExtractIndex == Index0)
457 return Ext1;
458 if (PreferredExtractIndex == Index1)
459 return Ext0;
460
461 // Otherwise, replace the extract with the higher index.
462 return Index0 > Index1 ? Ext0 : Ext1;
463}
464
465/// Compare the relative costs of 2 extracts followed by scalar operation vs.
466/// vector operation(s) followed by extract. Return true if the existing
467/// instructions are cheaper than a vector alternative. Otherwise, return false
468/// and if one of the extracts should be transformed to a shufflevector, set
469/// \p ConvertToShuffle to that extract instruction.
470bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0,
471 ExtractElementInst *Ext1,
472 const Instruction &I,
473 ExtractElementInst *&ConvertToShuffle,
474 unsigned PreferredExtractIndex) {
475 auto *Ext0IndexC = dyn_cast<ConstantInt>(Ext0->getIndexOperand());
476 auto *Ext1IndexC = dyn_cast<ConstantInt>(Ext1->getIndexOperand());
477 assert(Ext0IndexC && Ext1IndexC && "Expected constant extract indexes");
478
479 unsigned Opcode = I.getOpcode();
480 Value *Ext0Src = Ext0->getVectorOperand();
481 Value *Ext1Src = Ext1->getVectorOperand();
482 Type *ScalarTy = Ext0->getType();
483 auto *VecTy = cast<VectorType>(Ext0Src->getType());
484 InstructionCost ScalarOpCost, VectorOpCost;
485
486 // Get cost estimates for scalar and vector versions of the operation.
487 bool IsBinOp = Instruction::isBinaryOp(Opcode);
488 if (IsBinOp) {
489 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy, CostKind);
490 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy, CostKind);
491 } else {
492 assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
493 "Expected a compare");
494 CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate();
495 ScalarOpCost = TTI.getCmpSelInstrCost(
496 Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred, CostKind);
497 VectorOpCost = TTI.getCmpSelInstrCost(
498 Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind);
499 }
500
501 // Get cost estimates for the extract elements. These costs will factor into
502 // both sequences.
503 unsigned Ext0Index = Ext0IndexC->getZExtValue();
504 unsigned Ext1Index = Ext1IndexC->getZExtValue();
505
506 InstructionCost Extract0Cost =
507 TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Ext0Index);
508 InstructionCost Extract1Cost =
509 TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Ext1Index);
510
511 // A more expensive extract will always be replaced by a splat shuffle.
512 // For example, if Ext0 is more expensive:
513 // opcode (extelt V0, Ext0), (ext V1, Ext1) -->
514 // extelt (opcode (splat V0, Ext0), V1), Ext1
515 // TODO: Evaluate whether that always results in lowest cost. Alternatively,
516 // check the cost of creating a broadcast shuffle and shuffling both
517 // operands to element 0.
518 unsigned BestExtIndex = Extract0Cost > Extract1Cost ? Ext0Index : Ext1Index;
519 unsigned BestInsIndex = Extract0Cost > Extract1Cost ? Ext1Index : Ext0Index;
520 InstructionCost CheapExtractCost = std::min(Extract0Cost, Extract1Cost);
521
522 // Extra uses of the extracts mean that we include those costs in the
523 // vector total because those instructions will not be eliminated.
524 InstructionCost OldCost, NewCost;
525 if (Ext0Src == Ext1Src && Ext0Index == Ext1Index) {
526 // Handle a special case. If the 2 extracts are identical, adjust the
527 // formulas to account for that. The extra use charge allows for either the
528 // CSE'd pattern or an unoptimized form with identical values:
529 // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C
530 bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2)
531 : !Ext0->hasOneUse() || !Ext1->hasOneUse();
532 OldCost = CheapExtractCost + ScalarOpCost;
533 NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost;
534 } else {
535 // Handle the general case. Each extract is actually a different value:
536 // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C
537 OldCost = Extract0Cost + Extract1Cost + ScalarOpCost;
538 NewCost = VectorOpCost + CheapExtractCost +
539 !Ext0->hasOneUse() * Extract0Cost +
540 !Ext1->hasOneUse() * Extract1Cost;
541 }
542
543 ConvertToShuffle = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex);
544 if (ConvertToShuffle) {
545 if (IsBinOp && DisableBinopExtractShuffle)
546 return true;
547
548 // If we are extracting from 2 different indexes, then one operand must be
549 // shuffled before performing the vector operation. The shuffle mask is
550 // poison except for 1 lane that is being translated to the remaining
551 // extraction lane. Therefore, it is a splat shuffle. Ex:
552 // ShufMask = { poison, poison, 0, poison }
553 // TODO: The cost model has an option for a "broadcast" shuffle
554 // (splat-from-element-0), but no option for a more general splat.
555 if (auto *FixedVecTy = dyn_cast<FixedVectorType>(VecTy)) {
556 SmallVector<int> ShuffleMask(FixedVecTy->getNumElements(),
558 ShuffleMask[BestInsIndex] = BestExtIndex;
560 VecTy, VecTy, ShuffleMask, CostKind, 0,
561 nullptr, {ConvertToShuffle});
562 } else {
564 VecTy, VecTy, {}, CostKind, 0, nullptr,
565 {ConvertToShuffle});
566 }
567 }
568
569 // Aggressively form a vector op if the cost is equal because the transform
570 // may enable further optimization.
571 // Codegen can reverse this transform (scalarize) if it was not profitable.
572 return OldCost < NewCost;
573}
574
575/// Create a shuffle that translates (shifts) 1 element from the input vector
576/// to a new element location.
577static Value *createShiftShuffle(Value *Vec, unsigned OldIndex,
578 unsigned NewIndex, IRBuilderBase &Builder) {
579 // The shuffle mask is poison except for 1 lane that is being translated
580 // to the new element index. Example for OldIndex == 2 and NewIndex == 0:
581 // ShufMask = { 2, poison, poison, poison }
582 auto *VecTy = cast<FixedVectorType>(Vec->getType());
583 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
584 ShufMask[NewIndex] = OldIndex;
585 return Builder.CreateShuffleVector(Vec, ShufMask, "shift");
586}
587
588/// Given an extract element instruction with constant index operand, shuffle
589/// the source vector (shift the scalar element) to a NewIndex for extraction.
590/// Return null if the input can be constant folded, so that we are not creating
591/// unnecessary instructions.
592static Value *translateExtract(ExtractElementInst *ExtElt, unsigned NewIndex,
593 IRBuilderBase &Builder) {
594 // Shufflevectors can only be created for fixed-width vectors.
595 Value *X = ExtElt->getVectorOperand();
596 if (!isa<FixedVectorType>(X->getType()))
597 return nullptr;
598
599 // If the extract can be constant-folded, this code is unsimplified. Defer
600 // to other passes to handle that.
601 Value *C = ExtElt->getIndexOperand();
602 assert(isa<ConstantInt>(C) && "Expected a constant index operand");
603 if (isa<Constant>(X))
604 return nullptr;
605
606 Value *Shuf = createShiftShuffle(X, cast<ConstantInt>(C)->getZExtValue(),
607 NewIndex, Builder);
608 return Shuf;
609}
610
611/// Try to reduce extract element costs by converting scalar compares to vector
612/// compares followed by extract.
613/// cmp (ext0 V0, ExtIndex), (ext1 V1, ExtIndex)
614Value *VectorCombine::foldExtExtCmp(Value *V0, Value *V1, Value *ExtIndex,
615 Instruction &I) {
616 assert(isa<CmpInst>(&I) && "Expected a compare");
617
618 // cmp Pred (extelt V0, ExtIndex), (extelt V1, ExtIndex)
619 // --> extelt (cmp Pred V0, V1), ExtIndex
620 ++NumVecCmp;
621 CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate();
622 Value *VecCmp = Builder.CreateCmp(Pred, V0, V1);
623 return Builder.CreateExtractElement(VecCmp, ExtIndex, "foldExtExtCmp");
624}
625
626/// Try to reduce extract element costs by converting scalar binops to vector
627/// binops followed by extract.
628/// bo (ext0 V0, ExtIndex), (ext1 V1, ExtIndex)
629Value *VectorCombine::foldExtExtBinop(Value *V0, Value *V1, Value *ExtIndex,
630 Instruction &I) {
631 assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
632
633 // bo (extelt V0, ExtIndex), (extelt V1, ExtIndex)
634 // --> extelt (bo V0, V1), ExtIndex
635 ++NumVecBO;
636 Value *VecBO = Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0,
637 V1, "foldExtExtBinop");
638
639 // All IR flags are safe to back-propagate because any potential poison
640 // created in unused vector elements is discarded by the extract.
641 if (auto *VecBOInst = dyn_cast<Instruction>(VecBO))
642 VecBOInst->copyIRFlags(&I);
643
644 return Builder.CreateExtractElement(VecBO, ExtIndex, "foldExtExtBinop");
645}
646
647/// Match an instruction with extracted vector operands.
648bool VectorCombine::foldExtractExtract(Instruction &I) {
649 // It is not safe to transform things like div, urem, etc. because we may
650 // create undefined behavior when executing those on unknown vector elements.
652 return false;
653
654 Instruction *I0, *I1;
655 CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
656 if (!match(&I, m_Cmp(Pred, m_Instruction(I0), m_Instruction(I1))) &&
658 return false;
659
660 Value *V0, *V1;
661 uint64_t C0, C1;
662 if (!match(I0, m_ExtractElt(m_Value(V0), m_ConstantInt(C0))) ||
664 V0->getType() != V1->getType())
665 return false;
666
667 // For fixed-width vectors, reject out-of-bounds extract indexes
668 if (auto *FixedVecTy = dyn_cast<FixedVectorType>(V0->getType())) {
669 unsigned NumElts = FixedVecTy->getNumElements();
670 if (C0 >= NumElts || C1 >= NumElts)
671 return false;
672 }
673
674 // If the scalar value 'I' is going to be re-inserted into a vector, then try
675 // to create an extract to that same element. The extract/insert can be
676 // reduced to a "select shuffle".
677 // TODO: If we add a larger pattern match that starts from an insert, this
678 // probably becomes unnecessary.
679 auto *Ext0 = cast<ExtractElementInst>(I0);
680 auto *Ext1 = cast<ExtractElementInst>(I1);
681 uint64_t InsertIndex = InvalidIndex;
682 if (I.hasOneUse())
683 match(I.user_back(),
684 m_InsertElt(m_Value(), m_Value(), m_ConstantInt(InsertIndex)));
685
686 ExtractElementInst *ExtractToChange;
687 if (isExtractExtractCheap(Ext0, Ext1, I, ExtractToChange, InsertIndex))
688 return false;
689
690 Value *ExtOp0 = Ext0->getVectorOperand();
691 Value *ExtOp1 = Ext1->getVectorOperand();
692
693 if (ExtractToChange) {
694 unsigned CheapExtractIdx = ExtractToChange == Ext0 ? C1 : C0;
695 Value *NewExtOp =
696 translateExtract(ExtractToChange, CheapExtractIdx, Builder);
697 if (!NewExtOp)
698 return false;
699 if (ExtractToChange == Ext0)
700 ExtOp0 = NewExtOp;
701 else
702 ExtOp1 = NewExtOp;
703 }
704
705 Value *ExtIndex = ExtractToChange == Ext0 ? Ext1->getIndexOperand()
706 : Ext0->getIndexOperand();
707 Value *NewExt = Pred != CmpInst::BAD_ICMP_PREDICATE
708 ? foldExtExtCmp(ExtOp0, ExtOp1, ExtIndex, I)
709 : foldExtExtBinop(ExtOp0, ExtOp1, ExtIndex, I);
710 Worklist.push(Ext0);
711 Worklist.push(Ext1);
712 replaceValue(I, *NewExt);
713 return true;
714}
715
716/// Try to replace an extract + scalar fneg + insert with a vector fneg +
717/// shuffle.
718bool VectorCombine::foldInsExtFNeg(Instruction &I) {
719 // Match an insert (op (extract)) pattern.
720 Value *DstVec;
721 uint64_t ExtIdx, InsIdx;
722 Instruction *FNeg;
723 if (!match(&I, m_InsertElt(m_Value(DstVec), m_OneUse(m_Instruction(FNeg)),
724 m_ConstantInt(InsIdx))))
725 return false;
726
727 // Note: This handles the canonical fneg instruction and "fsub -0.0, X".
728 Value *SrcVec;
729 Instruction *Extract;
730 if (!match(FNeg, m_FNeg(m_CombineAnd(
731 m_Instruction(Extract),
732 m_ExtractElt(m_Value(SrcVec), m_ConstantInt(ExtIdx))))))
733 return false;
734
735 auto *DstVecTy = cast<FixedVectorType>(DstVec->getType());
736 auto *DstVecScalarTy = DstVecTy->getScalarType();
737 auto *SrcVecTy = dyn_cast<FixedVectorType>(SrcVec->getType());
738 if (!SrcVecTy || DstVecScalarTy != SrcVecTy->getScalarType())
739 return false;
740
741 // Ignore if insert/extract index is out of bounds or destination vector has
742 // one element
743 unsigned NumDstElts = DstVecTy->getNumElements();
744 unsigned NumSrcElts = SrcVecTy->getNumElements();
745 if (ExtIdx > NumSrcElts || InsIdx >= NumDstElts || NumDstElts == 1)
746 return false;
747
748 // We are inserting the negated element into the same lane that we extracted
749 // from. This is equivalent to a select-shuffle that chooses all but the
750 // negated element from the destination vector.
751 SmallVector<int> Mask(NumDstElts);
752 std::iota(Mask.begin(), Mask.end(), 0);
753 Mask[InsIdx] = (ExtIdx % NumDstElts) + NumDstElts;
754 InstructionCost OldCost =
755 TTI.getArithmeticInstrCost(Instruction::FNeg, DstVecScalarTy, CostKind) +
756 TTI.getVectorInstrCost(I, DstVecTy, CostKind, InsIdx);
757
758 // If the extract has one use, it will be eliminated, so count it in the
759 // original cost. If it has more than one use, ignore the cost because it will
760 // be the same before/after.
761 if (Extract->hasOneUse())
762 OldCost += TTI.getVectorInstrCost(*Extract, SrcVecTy, CostKind, ExtIdx);
763
764 InstructionCost NewCost =
765 TTI.getArithmeticInstrCost(Instruction::FNeg, SrcVecTy, CostKind) +
767 DstVecTy, Mask, CostKind);
768
769 bool NeedLenChg = SrcVecTy->getNumElements() != NumDstElts;
770 // If the lengths of the two vectors are not equal,
771 // we need to add a length-change vector. Add this cost.
772 SmallVector<int> SrcMask;
773 if (NeedLenChg) {
774 SrcMask.assign(NumDstElts, PoisonMaskElem);
775 SrcMask[ExtIdx % NumDstElts] = ExtIdx;
777 DstVecTy, SrcVecTy, SrcMask, CostKind);
778 }
779
780 LLVM_DEBUG(dbgs() << "Found an insertion of (extract)fneg : " << I
781 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
782 << "\n");
783 if (NewCost > OldCost)
784 return false;
785
786 Value *NewShuf, *LenChgShuf = nullptr;
787 // insertelt DstVec, (fneg (extractelt SrcVec, Index)), Index
788 Value *VecFNeg = Builder.CreateFNegFMF(SrcVec, FNeg);
789 if (NeedLenChg) {
790 // shuffle DstVec, (shuffle (fneg SrcVec), poison, SrcMask), Mask
791 LenChgShuf = Builder.CreateShuffleVector(VecFNeg, SrcMask);
792 NewShuf = Builder.CreateShuffleVector(DstVec, LenChgShuf, Mask);
793 Worklist.pushValue(LenChgShuf);
794 } else {
795 // shuffle DstVec, (fneg SrcVec), Mask
796 NewShuf = Builder.CreateShuffleVector(DstVec, VecFNeg, Mask);
797 }
798
799 Worklist.pushValue(VecFNeg);
800 replaceValue(I, *NewShuf);
801 return true;
802}
803
804/// Try to fold insert(binop(x,y),binop(a,b),idx)
805/// --> binop(insert(x,a,idx),insert(y,b,idx))
806bool VectorCombine::foldInsExtBinop(Instruction &I) {
807 BinaryOperator *VecBinOp, *SclBinOp;
808 uint64_t Index;
809 if (!match(&I,
810 m_InsertElt(m_OneUse(m_BinOp(VecBinOp)),
811 m_OneUse(m_BinOp(SclBinOp)), m_ConstantInt(Index))))
812 return false;
813
814 // TODO: Add support for addlike etc.
815 Instruction::BinaryOps BinOpcode = VecBinOp->getOpcode();
816 if (BinOpcode != SclBinOp->getOpcode())
817 return false;
818
819 auto *ResultTy = dyn_cast<FixedVectorType>(I.getType());
820 if (!ResultTy)
821 return false;
822
823 // TODO: Attempt to detect m_ExtractElt for scalar operands and convert to
824 // shuffle?
825
827 TTI.getInstructionCost(VecBinOp, CostKind) +
829 InstructionCost NewCost =
830 TTI.getArithmeticInstrCost(BinOpcode, ResultTy, CostKind) +
831 TTI.getVectorInstrCost(Instruction::InsertElement, ResultTy, CostKind,
832 Index, VecBinOp->getOperand(0),
833 SclBinOp->getOperand(0)) +
834 TTI.getVectorInstrCost(Instruction::InsertElement, ResultTy, CostKind,
835 Index, VecBinOp->getOperand(1),
836 SclBinOp->getOperand(1));
837
838 LLVM_DEBUG(dbgs() << "Found an insertion of two binops: " << I
839 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
840 << "\n");
841 if (NewCost > OldCost)
842 return false;
843
844 Value *NewIns0 = Builder.CreateInsertElement(VecBinOp->getOperand(0),
845 SclBinOp->getOperand(0), Index);
846 Value *NewIns1 = Builder.CreateInsertElement(VecBinOp->getOperand(1),
847 SclBinOp->getOperand(1), Index);
848 Value *NewBO = Builder.CreateBinOp(BinOpcode, NewIns0, NewIns1);
849
850 // Intersect flags from the old binops.
851 if (auto *NewInst = dyn_cast<Instruction>(NewBO)) {
852 NewInst->copyIRFlags(VecBinOp);
853 NewInst->andIRFlags(SclBinOp);
854 }
855
856 Worklist.pushValue(NewIns0);
857 Worklist.pushValue(NewIns1);
858 replaceValue(I, *NewBO);
859 return true;
860}
861
862/// Match: bitop(castop(x), castop(y)) -> castop(bitop(x, y))
863/// Supports: bitcast, trunc, sext, zext
864bool VectorCombine::foldBitOpOfCastops(Instruction &I) {
865 // Check if this is a bitwise logic operation
866 auto *BinOp = dyn_cast<BinaryOperator>(&I);
867 if (!BinOp || !BinOp->isBitwiseLogicOp())
868 return false;
869
870 // Get the cast instructions
871 auto *LHSCast = dyn_cast<CastInst>(BinOp->getOperand(0));
872 auto *RHSCast = dyn_cast<CastInst>(BinOp->getOperand(1));
873 if (!LHSCast || !RHSCast) {
874 LLVM_DEBUG(dbgs() << " One or both operands are not cast instructions\n");
875 return false;
876 }
877
878 // Both casts must be the same type
879 Instruction::CastOps CastOpcode = LHSCast->getOpcode();
880 if (CastOpcode != RHSCast->getOpcode())
881 return false;
882
883 // Only handle supported cast operations
884 switch (CastOpcode) {
885 case Instruction::BitCast:
886 case Instruction::Trunc:
887 case Instruction::SExt:
888 case Instruction::ZExt:
889 break;
890 default:
891 return false;
892 }
893
894 Value *LHSSrc = LHSCast->getOperand(0);
895 Value *RHSSrc = RHSCast->getOperand(0);
896
897 // Source types must match
898 if (LHSSrc->getType() != RHSSrc->getType())
899 return false;
900
901 auto *SrcTy = LHSSrc->getType();
902 auto *DstTy = I.getType();
903 // Bitcasts can handle scalar/vector mixes, such as i16 -> <16 x i1>.
904 // Other casts only handle vector types with integer elements.
905 if (CastOpcode != Instruction::BitCast &&
906 (!isa<FixedVectorType>(SrcTy) || !isa<FixedVectorType>(DstTy)))
907 return false;
908
909 // Only integer scalar/vector values are legal for bitwise logic operations.
910 if (!SrcTy->getScalarType()->isIntegerTy() ||
911 !DstTy->getScalarType()->isIntegerTy())
912 return false;
913
914 // Cost Check :
915 // OldCost = bitlogic + 2*casts
916 // NewCost = bitlogic + cast
917
918 // Calculate specific costs for each cast with instruction context
920 CastOpcode, DstTy, SrcTy, TTI::CastContextHint::None, CostKind, LHSCast);
922 CastOpcode, DstTy, SrcTy, TTI::CastContextHint::None, CostKind, RHSCast);
923
924 InstructionCost OldCost =
925 TTI.getArithmeticInstrCost(BinOp->getOpcode(), DstTy, CostKind) +
926 LHSCastCost + RHSCastCost;
927
928 // For new cost, we can't provide an instruction (it doesn't exist yet)
929 InstructionCost GenericCastCost = TTI.getCastInstrCost(
930 CastOpcode, DstTy, SrcTy, TTI::CastContextHint::None, CostKind);
931
932 InstructionCost NewCost =
933 TTI.getArithmeticInstrCost(BinOp->getOpcode(), SrcTy, CostKind) +
934 GenericCastCost;
935
936 // Account for multi-use casts using specific costs
937 if (!LHSCast->hasOneUse())
938 NewCost += LHSCastCost;
939 if (!RHSCast->hasOneUse())
940 NewCost += RHSCastCost;
941
942 LLVM_DEBUG(dbgs() << "foldBitOpOfCastops: OldCost=" << OldCost
943 << " NewCost=" << NewCost << "\n");
944
945 if (NewCost > OldCost)
946 return false;
947
948 // Create the operation on the source type
949 Value *NewOp = Builder.CreateBinOp(BinOp->getOpcode(), LHSSrc, RHSSrc,
950 BinOp->getName() + ".inner");
951 if (auto *NewBinOp = dyn_cast<BinaryOperator>(NewOp))
952 NewBinOp->copyIRFlags(BinOp);
953
954 Worklist.pushValue(NewOp);
955
956 // Create the cast operation directly to ensure we get a new instruction
957 Instruction *NewCast = CastInst::Create(CastOpcode, NewOp, I.getType());
958
959 // Preserve cast instruction flags
960 NewCast->copyIRFlags(LHSCast);
961 NewCast->andIRFlags(RHSCast);
962
963 // Insert the new instruction
964 Value *Result = Builder.Insert(NewCast);
965
966 replaceValue(I, *Result);
967 return true;
968}
969
970/// Match:
971// bitop(castop(x), C) ->
972// bitop(castop(x), castop(InvC)) ->
973// castop(bitop(x, InvC))
974// Supports: bitcast
975bool VectorCombine::foldBitOpOfCastConstant(Instruction &I) {
977 Constant *C;
978
979 // Check if this is a bitwise logic operation
981 return false;
982
983 // Get the cast instructions
984 auto *LHSCast = dyn_cast<CastInst>(LHS);
985 if (!LHSCast)
986 return false;
987
988 Instruction::CastOps CastOpcode = LHSCast->getOpcode();
989
990 // Only handle supported cast operations
991 switch (CastOpcode) {
992 case Instruction::BitCast:
993 case Instruction::ZExt:
994 case Instruction::SExt:
995 case Instruction::Trunc:
996 break;
997 default:
998 return false;
999 }
1000
1001 Value *LHSSrc = LHSCast->getOperand(0);
1002
1003 auto *SrcTy = LHSSrc->getType();
1004 auto *DstTy = I.getType();
1005 // Bitcasts can handle scalar/vector mixes, such as i16 -> <16 x i1>.
1006 // Other casts only handle vector types with integer elements.
1007 if (CastOpcode != Instruction::BitCast &&
1008 (!isa<FixedVectorType>(SrcTy) || !isa<FixedVectorType>(DstTy)))
1009 return false;
1010
1011 // Only integer scalar/vector values are legal for bitwise logic operations.
1012 if (!SrcTy->getScalarType()->isIntegerTy() ||
1013 !DstTy->getScalarType()->isIntegerTy())
1014 return false;
1015
1016 // Find the constant InvC, such that castop(InvC) equals to C.
1017 PreservedCastFlags RHSFlags;
1018 Constant *InvC = getLosslessInvCast(C, SrcTy, CastOpcode, *DL, &RHSFlags);
1019 if (!InvC)
1020 return false;
1021
1022 // Cost Check :
1023 // OldCost = bitlogic + cast
1024 // NewCost = bitlogic + cast
1025
1026 // Calculate specific costs for each cast with instruction context
1027 InstructionCost LHSCastCost = TTI.getCastInstrCost(
1028 CastOpcode, DstTy, SrcTy, TTI::CastContextHint::None, CostKind, LHSCast);
1029
1030 InstructionCost OldCost =
1031 TTI.getArithmeticInstrCost(I.getOpcode(), DstTy, CostKind) + LHSCastCost;
1032
1033 // For new cost, we can't provide an instruction (it doesn't exist yet)
1034 InstructionCost GenericCastCost = TTI.getCastInstrCost(
1035 CastOpcode, DstTy, SrcTy, TTI::CastContextHint::None, CostKind);
1036
1037 InstructionCost NewCost =
1038 TTI.getArithmeticInstrCost(I.getOpcode(), SrcTy, CostKind) +
1039 GenericCastCost;
1040
1041 // Account for multi-use casts using specific costs
1042 if (!LHSCast->hasOneUse())
1043 NewCost += LHSCastCost;
1044
1045 LLVM_DEBUG(dbgs() << "foldBitOpOfCastConstant: OldCost=" << OldCost
1046 << " NewCost=" << NewCost << "\n");
1047
1048 if (NewCost > OldCost)
1049 return false;
1050
1051 // Create the operation on the source type
1052 Value *NewOp = Builder.CreateBinOp((Instruction::BinaryOps)I.getOpcode(),
1053 LHSSrc, InvC, I.getName() + ".inner");
1054 if (auto *NewBinOp = dyn_cast<BinaryOperator>(NewOp))
1055 NewBinOp->copyIRFlags(&I);
1056
1057 Worklist.pushValue(NewOp);
1058
1059 // Create the cast operation directly to ensure we get a new instruction
1060 Instruction *NewCast = CastInst::Create(CastOpcode, NewOp, I.getType());
1061
1062 // Preserve cast instruction flags
1063 if (RHSFlags.NNeg)
1064 NewCast->setNonNeg();
1065 if (RHSFlags.NUW)
1066 NewCast->setHasNoUnsignedWrap();
1067 if (RHSFlags.NSW)
1068 NewCast->setHasNoSignedWrap();
1069
1070 NewCast->andIRFlags(LHSCast);
1071
1072 // Insert the new instruction
1073 Value *Result = Builder.Insert(NewCast);
1074
1075 replaceValue(I, *Result);
1076 return true;
1077}
1078
1079/// If this is a bitcast of a shuffle, try to bitcast the source vector to the
1080/// destination type followed by shuffle. This can enable further transforms by
1081/// moving bitcasts or shuffles together.
1082bool VectorCombine::foldBitcastShuffle(Instruction &I) {
1083 Value *V0, *V1;
1084 ArrayRef<int> Mask;
1085 if (!match(&I, m_BitCast(m_OneUse(
1086 m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(Mask))))))
1087 return false;
1088
1089 // 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for
1090 // scalable type is unknown; Second, we cannot reason if the narrowed shuffle
1091 // mask for scalable type is a splat or not.
1092 // 2) Disallow non-vector casts.
1093 // TODO: We could allow any shuffle.
1094 auto *DestTy = dyn_cast<FixedVectorType>(I.getType());
1095 auto *SrcTy = dyn_cast<FixedVectorType>(V0->getType());
1096 if (!DestTy || !SrcTy)
1097 return false;
1098
1099 unsigned DestEltSize = DestTy->getScalarSizeInBits();
1100 unsigned SrcEltSize = SrcTy->getScalarSizeInBits();
1101 if (SrcTy->getPrimitiveSizeInBits() % DestEltSize != 0)
1102 return false;
1103
1104 bool IsUnary = isa<UndefValue>(V1);
1105
1106 // For binary shuffles, only fold bitcast(shuffle(X,Y))
1107 // if it won't increase the number of bitcasts.
1108 if (!IsUnary) {
1111 if (!(BCTy0 && BCTy0->getElementType() == DestTy->getElementType()) &&
1112 !(BCTy1 && BCTy1->getElementType() == DestTy->getElementType()))
1113 return false;
1114 }
1115
1116 SmallVector<int, 16> NewMask;
1117 if (DestEltSize <= SrcEltSize) {
1118 // The bitcast is from wide to narrow/equal elements. The shuffle mask can
1119 // always be expanded to the equivalent form choosing narrower elements.
1120 if (SrcEltSize % DestEltSize != 0)
1121 return false;
1122 unsigned ScaleFactor = SrcEltSize / DestEltSize;
1123 narrowShuffleMaskElts(ScaleFactor, Mask, NewMask);
1124 } else {
1125 // The bitcast is from narrow elements to wide elements. The shuffle mask
1126 // must choose consecutive elements to allow casting first.
1127 if (DestEltSize % SrcEltSize != 0)
1128 return false;
1129 unsigned ScaleFactor = DestEltSize / SrcEltSize;
1130 if (!widenShuffleMaskElts(ScaleFactor, Mask, NewMask))
1131 return false;
1132 }
1133
1134 // Bitcast the shuffle src - keep its original width but using the destination
1135 // scalar type.
1136 unsigned NumSrcElts = SrcTy->getPrimitiveSizeInBits() / DestEltSize;
1137 auto *NewShuffleTy =
1138 FixedVectorType::get(DestTy->getScalarType(), NumSrcElts);
1139 auto *OldShuffleTy =
1140 FixedVectorType::get(SrcTy->getScalarType(), Mask.size());
1141 unsigned NumOps = IsUnary ? 1 : 2;
1142
1143 // The new shuffle must not cost more than the old shuffle.
1147
1148 InstructionCost NewCost =
1149 TTI.getShuffleCost(SK, DestTy, NewShuffleTy, NewMask, CostKind) +
1150 (NumOps * TTI.getCastInstrCost(Instruction::BitCast, NewShuffleTy, SrcTy,
1151 TargetTransformInfo::CastContextHint::None,
1152 CostKind));
1153 InstructionCost OldCost =
1154 TTI.getShuffleCost(SK, OldShuffleTy, SrcTy, Mask, CostKind) +
1155 TTI.getCastInstrCost(Instruction::BitCast, DestTy, OldShuffleTy,
1156 TargetTransformInfo::CastContextHint::None,
1157 CostKind);
1158
1159 LLVM_DEBUG(dbgs() << "Found a bitcasted shuffle: " << I << "\n OldCost: "
1160 << OldCost << " vs NewCost: " << NewCost << "\n");
1161
1162 if (NewCost > OldCost || !NewCost.isValid())
1163 return false;
1164
1165 // bitcast (shuf V0, V1, MaskC) --> shuf (bitcast V0), (bitcast V1), MaskC'
1166 ++NumShufOfBitcast;
1167 Value *CastV0 = Builder.CreateBitCast(peekThroughBitcasts(V0), NewShuffleTy);
1168 Value *CastV1 = Builder.CreateBitCast(peekThroughBitcasts(V1), NewShuffleTy);
1169 Value *Shuf = Builder.CreateShuffleVector(CastV0, CastV1, NewMask);
1170 replaceValue(I, *Shuf);
1171 return true;
1172}
1173
1174/// VP Intrinsics whose vector operands are both splat values may be simplified
1175/// into the scalar version of the operation and the result splatted. This
1176/// can lead to scalarization down the line.
1177bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
1178 if (!isa<VPIntrinsic>(I))
1179 return false;
1180 VPIntrinsic &VPI = cast<VPIntrinsic>(I);
1181 Value *Op0 = VPI.getArgOperand(0);
1182 Value *Op1 = VPI.getArgOperand(1);
1183
1184 if (!isSplatValue(Op0) || !isSplatValue(Op1))
1185 return false;
1186
1187 // Check getSplatValue early in this function, to avoid doing unnecessary
1188 // work.
1189 Value *ScalarOp0 = getSplatValue(Op0);
1190 Value *ScalarOp1 = getSplatValue(Op1);
1191 if (!ScalarOp0 || !ScalarOp1)
1192 return false;
1193
1194 // For the binary VP intrinsics supported here, the result on disabled lanes
1195 // is a poison value. For now, only do this simplification if all lanes
1196 // are active.
1197 // TODO: Relax the condition that all lanes are active by using insertelement
1198 // on inactive lanes.
1199 auto IsAllTrueMask = [](Value *MaskVal) {
1200 if (Value *SplattedVal = getSplatValue(MaskVal))
1201 if (auto *ConstValue = dyn_cast<Constant>(SplattedVal))
1202 return ConstValue->isAllOnesValue();
1203 return false;
1204 };
1205 if (!IsAllTrueMask(VPI.getArgOperand(2)))
1206 return false;
1207
1208 // Check to make sure we support scalarization of the intrinsic
1209 Intrinsic::ID IntrID = VPI.getIntrinsicID();
1210 if (!VPBinOpIntrinsic::isVPBinOp(IntrID))
1211 return false;
1212
1213 // Calculate cost of splatting both operands into vectors and the vector
1214 // intrinsic
1215 VectorType *VecTy = cast<VectorType>(VPI.getType());
1216 SmallVector<int> Mask;
1217 if (auto *FVTy = dyn_cast<FixedVectorType>(VecTy))
1218 Mask.resize(FVTy->getNumElements(), 0);
1219 InstructionCost SplatCost =
1220 TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, CostKind, 0) +
1222 CostKind);
1223
1224 // Calculate the cost of the VP Intrinsic
1226 for (Value *V : VPI.args())
1227 Args.push_back(V->getType());
1228 IntrinsicCostAttributes Attrs(IntrID, VecTy, Args);
1229 InstructionCost VectorOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
1230 InstructionCost OldCost = 2 * SplatCost + VectorOpCost;
1231
1232 // Determine scalar opcode
1233 std::optional<unsigned> FunctionalOpcode =
1234 VPI.getFunctionalOpcode();
1235 std::optional<Intrinsic::ID> ScalarIntrID = std::nullopt;
1236 if (!FunctionalOpcode) {
1237 ScalarIntrID = VPI.getFunctionalIntrinsicID();
1238 if (!ScalarIntrID)
1239 return false;
1240 }
1241
1242 // Calculate cost of scalarizing
1243 InstructionCost ScalarOpCost = 0;
1244 if (ScalarIntrID) {
1245 IntrinsicCostAttributes Attrs(*ScalarIntrID, VecTy->getScalarType(), Args);
1246 ScalarOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
1247 } else {
1248 ScalarOpCost = TTI.getArithmeticInstrCost(*FunctionalOpcode,
1249 VecTy->getScalarType(), CostKind);
1250 }
1251
1252 // The existing splats may be kept around if other instructions use them.
1253 InstructionCost CostToKeepSplats =
1254 (SplatCost * !Op0->hasOneUse()) + (SplatCost * !Op1->hasOneUse());
1255 InstructionCost NewCost = ScalarOpCost + SplatCost + CostToKeepSplats;
1256
1257 LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI
1258 << "\n");
1259 LLVM_DEBUG(dbgs() << "Cost of Intrinsic: " << OldCost
1260 << ", Cost of scalarizing:" << NewCost << "\n");
1261
1262 // We want to scalarize unless the vector variant actually has lower cost.
1263 if (OldCost < NewCost || !NewCost.isValid())
1264 return false;
1265
1266 // Scalarize the intrinsic
1267 ElementCount EC = cast<VectorType>(Op0->getType())->getElementCount();
1268 Value *EVL = VPI.getArgOperand(3);
1269
1270 // If the VP op might introduce UB or poison, we can scalarize it provided
1271 // that we know the EVL > 0: If the EVL is zero, then the original VP op
1272 // becomes a no-op and thus won't be UB, so make sure we don't introduce UB by
1273 // scalarizing it.
1274 bool SafeToSpeculate;
1275 if (ScalarIntrID)
1276 SafeToSpeculate = Intrinsic::getFnAttributes(I.getContext(), *ScalarIntrID)
1277 .hasAttribute(Attribute::AttrKind::Speculatable);
1278 else
1280 *FunctionalOpcode, &VPI, nullptr, SQ.AC, SQ.DT);
1281 if (!SafeToSpeculate &&
1282 !isKnownNonZero(EVL, SimplifyQuery(*DL, SQ.DT, SQ.AC, &VPI)))
1283 return false;
1284
1285 Value *ScalarVal =
1286 ScalarIntrID
1287 ? Builder.CreateIntrinsic(VecTy->getScalarType(), *ScalarIntrID,
1288 {ScalarOp0, ScalarOp1})
1289 : Builder.CreateBinOp((Instruction::BinaryOps)(*FunctionalOpcode),
1290 ScalarOp0, ScalarOp1);
1291
1292 replaceValue(VPI, *Builder.CreateVectorSplat(EC, ScalarVal));
1293 return true;
1294}
1295
1296/// Match a vector op/compare/intrinsic with at least one
1297/// inserted scalar operand and convert to scalar op/cmp/intrinsic followed
1298/// by insertelement.
1299bool VectorCombine::scalarizeOpOrCmp(Instruction &I) {
1300 auto *UO = dyn_cast<UnaryOperator>(&I);
1301 auto *BO = dyn_cast<BinaryOperator>(&I);
1302 auto *CI = dyn_cast<CmpInst>(&I);
1303 auto *II = dyn_cast<IntrinsicInst>(&I);
1304 if (!UO && !BO && !CI && !II)
1305 return false;
1306
1307 // TODO: Allow intrinsics with different argument types
1308 if (II) {
1309 if (!isTriviallyVectorizable(II->getIntrinsicID()))
1310 return false;
1311 for (auto [Idx, Arg] : enumerate(II->args()))
1312 if (Arg->getType() != II->getType() &&
1313 !isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Idx, &TTI))
1314 return false;
1315 }
1316
1317 // Do not convert the vector condition of a vector select into a scalar
1318 // condition. That may cause problems for codegen because of differences in
1319 // boolean formats and register-file transfers.
1320 // TODO: Can we account for that in the cost model?
1321 if (CI)
1322 for (User *U : I.users())
1323 if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value())))
1324 return false;
1325
1326 // Match constant vectors or scalars being inserted into constant vectors:
1327 // vec_op [VecC0 | (inselt VecC0, V0, Index)], ...
1328 SmallVector<Value *> VecCs, ScalarOps;
1329 std::optional<uint64_t> Index;
1330
1331 auto Ops = II ? II->args() : I.operands();
1332 for (auto [OpNum, Op] : enumerate(Ops)) {
1333 Constant *VecC;
1334 Value *V;
1335 uint64_t InsIdx = 0;
1336 if (match(Op.get(), m_InsertElt(m_Constant(VecC), m_Value(V),
1337 m_ConstantInt(InsIdx)))) {
1338 // Bail if any inserts are out of bounds.
1339 VectorType *OpTy = cast<VectorType>(Op->getType());
1340 if (OpTy->getElementCount().getKnownMinValue() <= InsIdx)
1341 return false;
1342 // All inserts must have the same index.
1343 // TODO: Deal with mismatched index constants and variable indexes?
1344 if (!Index)
1345 Index = InsIdx;
1346 else if (InsIdx != *Index)
1347 return false;
1348 VecCs.push_back(VecC);
1349 ScalarOps.push_back(V);
1350 } else if (II && isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(),
1351 OpNum, &TTI)) {
1352 VecCs.push_back(Op.get());
1353 ScalarOps.push_back(Op.get());
1354 } else if (match(Op.get(), m_Constant(VecC))) {
1355 VecCs.push_back(VecC);
1356 ScalarOps.push_back(nullptr);
1357 } else {
1358 return false;
1359 }
1360 }
1361
1362 // Bail if all operands are constant.
1363 if (!Index.has_value())
1364 return false;
1365
1366 VectorType *VecTy = cast<VectorType>(I.getType());
1367 Type *ScalarTy = VecTy->getScalarType();
1368 assert(VecTy->isVectorTy() &&
1369 (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() ||
1370 ScalarTy->isPointerTy()) &&
1371 "Unexpected types for insert element into binop or cmp");
1372
1373 unsigned Opcode = I.getOpcode();
1374 InstructionCost ScalarOpCost, VectorOpCost;
1375 if (CI) {
1376 CmpInst::Predicate Pred = CI->getPredicate();
1377 ScalarOpCost = TTI.getCmpSelInstrCost(
1378 Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred, CostKind);
1379 VectorOpCost = TTI.getCmpSelInstrCost(
1380 Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind);
1381 } else if (UO || BO) {
1382 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy, CostKind);
1383 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy, CostKind);
1384 } else {
1385 IntrinsicCostAttributes ScalarICA(
1386 II->getIntrinsicID(), ScalarTy,
1387 SmallVector<Type *>(II->arg_size(), ScalarTy));
1388 ScalarOpCost = TTI.getIntrinsicInstrCost(ScalarICA, CostKind);
1389 IntrinsicCostAttributes VectorICA(
1390 II->getIntrinsicID(), VecTy,
1391 SmallVector<Type *>(II->arg_size(), VecTy));
1392 VectorOpCost = TTI.getIntrinsicInstrCost(VectorICA, CostKind);
1393 }
1394
1395 // Fold the vector constants in the original vectors into a new base vector to
1396 // get more accurate cost modelling.
1397 Value *NewVecC = nullptr;
1398 if (CI)
1399 NewVecC = simplifyCmpInst(CI->getPredicate(), VecCs[0], VecCs[1], SQ);
1400 else if (UO)
1401 NewVecC =
1402 simplifyUnOp(UO->getOpcode(), VecCs[0], UO->getFastMathFlags(), SQ);
1403 else if (BO)
1404 NewVecC = simplifyBinOp(BO->getOpcode(), VecCs[0], VecCs[1], SQ);
1405 else if (II)
1406 NewVecC = simplifyCall(II, II->getCalledOperand(), VecCs, SQ);
1407
1408 if (!NewVecC)
1409 return false;
1410
1411 // Get cost estimate for the insert element. This cost will factor into
1412 // both sequences.
1413 InstructionCost OldCost = VectorOpCost;
1414 InstructionCost NewCost =
1415 ScalarOpCost + TTI.getVectorInstrCost(Instruction::InsertElement, VecTy,
1416 CostKind, *Index, NewVecC);
1417
1418 for (auto [Idx, Op, VecC, Scalar] : enumerate(Ops, VecCs, ScalarOps)) {
1419 if (!Scalar || (II && isVectorIntrinsicWithScalarOpAtArg(
1420 II->getIntrinsicID(), Idx, &TTI)))
1421 continue;
1423 Instruction::InsertElement, VecTy, CostKind, *Index, VecC, Scalar);
1424 OldCost += InsertCost;
1425 NewCost += !Op->hasOneUse() * InsertCost;
1426 }
1427
1428 // We want to scalarize unless the vector variant actually has lower cost.
1429 if (OldCost < NewCost || !NewCost.isValid())
1430 return false;
1431
1432 // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
1433 // inselt NewVecC, (scalar_op V0, V1), Index
1434 if (CI)
1435 ++NumScalarCmp;
1436 else if (UO || BO)
1437 ++NumScalarOps;
1438 else
1439 ++NumScalarIntrinsic;
1440
1441 // For constant cases, extract the scalar element, this should constant fold.
1442 for (auto [OpIdx, Scalar, VecC] : enumerate(ScalarOps, VecCs))
1443 if (!Scalar)
1445 cast<Constant>(VecC), Builder.getInt64(*Index));
1446
1447 Value *Scalar;
1448 if (CI)
1449 Scalar = Builder.CreateCmp(CI->getPredicate(), ScalarOps[0], ScalarOps[1]);
1450 else if (UO || BO)
1451 Scalar = Builder.CreateNAryOp(Opcode, ScalarOps);
1452 else
1453 Scalar = Builder.CreateIntrinsic(ScalarTy, II->getIntrinsicID(), ScalarOps);
1454
1455 Scalar->setName(I.getName() + ".scalar");
1456
1457 // All IR flags are safe to back-propagate. There is no potential for extra
1458 // poison to be created by the scalar instruction.
1459 if (auto *ScalarInst = dyn_cast<Instruction>(Scalar))
1460 ScalarInst->copyIRFlags(&I);
1461
1462 Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, *Index);
1463 replaceValue(I, *Insert);
1464 return true;
1465}
1466
1467/// Try to combine a scalar binop + 2 scalar compares of extracted elements of
1468/// a vector into vector operations followed by extract. Note: The SLP pass
1469/// may miss this pattern because of implementation problems.
1470bool VectorCombine::foldExtractedCmps(Instruction &I) {
1471 auto *BI = dyn_cast<BinaryOperator>(&I);
1472
1473 // We are looking for a scalar binop of booleans.
1474 // binop i1 (cmp Pred I0, C0), (cmp Pred I1, C1)
1475 if (!BI || !I.getType()->isIntegerTy(1))
1476 return false;
1477
1478 // The compare predicates should match, and each compare should have a
1479 // constant operand.
1480 Value *B0 = I.getOperand(0), *B1 = I.getOperand(1);
1481 Instruction *I0, *I1;
1482 Constant *C0, *C1;
1483 CmpPredicate P0, P1;
1484 if (!match(B0, m_Cmp(P0, m_Instruction(I0), m_Constant(C0))) ||
1485 !match(B1, m_Cmp(P1, m_Instruction(I1), m_Constant(C1))))
1486 return false;
1487
1488 auto MatchingPred = CmpPredicate::getMatching(P0, P1);
1489 if (!MatchingPred)
1490 return false;
1491
1492 // The compare operands must be extracts of the same vector with constant
1493 // extract indexes.
1494 Value *X;
1495 uint64_t Index0, Index1;
1496 if (!match(I0, m_ExtractElt(m_Value(X), m_ConstantInt(Index0))) ||
1497 !match(I1, m_ExtractElt(m_Specific(X), m_ConstantInt(Index1))))
1498 return false;
1499
1500 auto *Ext0 = cast<ExtractElementInst>(I0);
1501 auto *Ext1 = cast<ExtractElementInst>(I1);
1502 ExtractElementInst *ConvertToShuf = getShuffleExtract(Ext0, Ext1, CostKind);
1503 if (!ConvertToShuf)
1504 return false;
1505 assert((ConvertToShuf == Ext0 || ConvertToShuf == Ext1) &&
1506 "Unknown ExtractElementInst");
1507
1508 // The original scalar pattern is:
1509 // binop i1 (cmp Pred (ext X, Index0), C0), (cmp Pred (ext X, Index1), C1)
1510 CmpInst::Predicate Pred = *MatchingPred;
1511 unsigned CmpOpcode =
1512 CmpInst::isFPPredicate(Pred) ? Instruction::FCmp : Instruction::ICmp;
1513 auto *VecTy = dyn_cast<FixedVectorType>(X->getType());
1514 if (!VecTy)
1515 return false;
1516
1517 InstructionCost Ext0Cost =
1518 TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0);
1519 InstructionCost Ext1Cost =
1520 TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1);
1522 CmpOpcode, I0->getType(), CmpInst::makeCmpResultType(I0->getType()), Pred,
1523 CostKind);
1524
1525 InstructionCost OldCost =
1526 Ext0Cost + Ext1Cost + CmpCost * 2 +
1527 TTI.getArithmeticInstrCost(I.getOpcode(), I.getType(), CostKind);
1528
1529 // The proposed vector pattern is:
1530 // vcmp = cmp Pred X, VecC
1531 // ext (binop vNi1 vcmp, (shuffle vcmp, Index1)), Index0
1532 int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0;
1533 int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1;
1536 CmpOpcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind);
1537 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
1538 ShufMask[CheapIndex] = ExpensiveIndex;
1540 CmpTy, ShufMask, CostKind);
1541 NewCost += TTI.getArithmeticInstrCost(I.getOpcode(), CmpTy, CostKind);
1542 NewCost += TTI.getVectorInstrCost(*Ext0, CmpTy, CostKind, CheapIndex);
1543 NewCost += Ext0->hasOneUse() ? 0 : Ext0Cost;
1544 NewCost += Ext1->hasOneUse() ? 0 : Ext1Cost;
1545
1546 // Aggressively form vector ops if the cost is equal because the transform
1547 // may enable further optimization.
1548 // Codegen can reverse this transform (scalarize) if it was not profitable.
1549 if (OldCost < NewCost || !NewCost.isValid())
1550 return false;
1551
1552 // Create a vector constant from the 2 scalar constants.
1553 SmallVector<Constant *, 32> CmpC(VecTy->getNumElements(),
1554 PoisonValue::get(VecTy->getElementType()));
1555 CmpC[Index0] = C0;
1556 CmpC[Index1] = C1;
1557 Value *VCmp = Builder.CreateCmp(Pred, X, ConstantVector::get(CmpC));
1558 Value *Shuf = createShiftShuffle(VCmp, ExpensiveIndex, CheapIndex, Builder);
1559 Value *LHS = ConvertToShuf == Ext0 ? Shuf : VCmp;
1560 Value *RHS = ConvertToShuf == Ext0 ? VCmp : Shuf;
1561 Value *VecLogic = Builder.CreateBinOp(BI->getOpcode(), LHS, RHS);
1562 Value *NewExt = Builder.CreateExtractElement(VecLogic, CheapIndex);
1563 replaceValue(I, *NewExt);
1564 ++NumVecCmpBO;
1565 return true;
1566}
1567
1568/// Try to fold scalar selects that select between extracted elements and zero
1569/// into extracting from a vector select. This is rooted at the bitcast.
1570///
1571/// This pattern arises when a vector is bitcast to a smaller element type,
1572/// elements are extracted, and then conditionally selected with zero:
1573///
1574/// %bc = bitcast <4 x i32> %src to <16 x i8>
1575/// %e0 = extractelement <16 x i8> %bc, i32 0
1576/// %s0 = select i1 %cond, i8 %e0, i8 0
1577/// %e1 = extractelement <16 x i8> %bc, i32 1
1578/// %s1 = select i1 %cond, i8 %e1, i8 0
1579/// ...
1580///
1581/// Transforms to:
1582/// %sel = select i1 %cond, <4 x i32> %src, <4 x i32> zeroinitializer
1583/// %bc = bitcast <4 x i32> %sel to <16 x i8>
1584/// %e0 = extractelement <16 x i8> %bc, i32 0
1585/// %e1 = extractelement <16 x i8> %bc, i32 1
1586/// ...
1587///
1588/// This is profitable because vector select on wider types produces fewer
1589/// select/cndmask instructions than scalar selects on each element.
1590bool VectorCombine::foldSelectsFromBitcast(Instruction &I) {
1591 auto *BC = dyn_cast<BitCastInst>(&I);
1592 if (!BC)
1593 return false;
1594
1595 FixedVectorType *SrcVecTy = dyn_cast<FixedVectorType>(BC->getSrcTy());
1596 FixedVectorType *DstVecTy = dyn_cast<FixedVectorType>(BC->getDestTy());
1597 if (!SrcVecTy || !DstVecTy)
1598 return false;
1599
1600 // Source must be 32-bit or 64-bit elements, destination must be smaller
1601 // integer elements. Zero in all these types is all-bits-zero.
1602 Type *SrcEltTy = SrcVecTy->getElementType();
1603 Type *DstEltTy = DstVecTy->getElementType();
1604 unsigned SrcEltBits = SrcEltTy->getPrimitiveSizeInBits();
1605 unsigned DstEltBits = DstEltTy->getPrimitiveSizeInBits();
1606
1607 if (SrcEltBits != 32 && SrcEltBits != 64)
1608 return false;
1609
1610 if (!DstEltTy->isIntegerTy() || DstEltBits >= SrcEltBits)
1611 return false;
1612
1613 // Check profitability using TTI before collecting users.
1614 Type *CondTy = CmpInst::makeCmpResultType(DstEltTy);
1615 Type *VecCondTy = CmpInst::makeCmpResultType(SrcVecTy);
1616
1617 InstructionCost ScalarSelCost =
1618 TTI.getCmpSelInstrCost(Instruction::Select, DstEltTy, CondTy,
1620 InstructionCost VecSelCost =
1621 TTI.getCmpSelInstrCost(Instruction::Select, SrcVecTy, VecCondTy,
1623
1624 // We need at least this many selects for vectorization to be profitable.
1625 // VecSelCost < ScalarSelCost * NumSelects => NumSelects > VecSelCost /
1626 // ScalarSelCost
1627 if (!ScalarSelCost.isValid() || ScalarSelCost == 0)
1628 return false;
1629
1630 unsigned MinSelects = (VecSelCost.getValue() / ScalarSelCost.getValue()) + 1;
1631
1632 // Quick check: if bitcast doesn't have enough users, bail early.
1633 if (!BC->hasNUsesOrMore(MinSelects))
1634 return false;
1635
1636 // Collect all select users that match the pattern, grouped by condition.
1637 // Pattern: select i1 %cond, (extractelement %bc, idx), 0
1638 DenseMap<Value *, SmallVector<SelectInst *, 8>> CondToSelects;
1639
1640 for (User *U : BC->users()) {
1641 auto *Ext = dyn_cast<ExtractElementInst>(U);
1642 if (!Ext)
1643 continue;
1644
1645 for (User *ExtUser : Ext->users()) {
1646 Value *Cond;
1647 // Match: select i1 %cond, %ext, 0
1648 if (match(ExtUser, m_Select(m_Value(Cond), m_Specific(Ext), m_Zero())) &&
1649 Cond->getType()->isIntegerTy(1))
1650 CondToSelects[Cond].push_back(cast<SelectInst>(ExtUser));
1651 }
1652 }
1653
1654 if (CondToSelects.empty())
1655 return false;
1656
1657 bool MadeChange = false;
1658 Value *SrcVec = BC->getOperand(0);
1659
1660 // Process each group of selects with the same condition.
1661 for (auto [Cond, Selects] : CondToSelects) {
1662 // Only profitable if vector select cost < total scalar select cost.
1663 if (Selects.size() < MinSelects) {
1664 LLVM_DEBUG(dbgs() << "VectorCombine: foldSelectsFromBitcast not "
1665 << "profitable (VecCost=" << VecSelCost
1666 << ", ScalarCost=" << ScalarSelCost
1667 << ", NumSelects=" << Selects.size() << ")\n");
1668 continue;
1669 }
1670
1671 // Create the vector select and bitcast once for this condition.
1672 auto InsertPt = std::next(BC->getIterator());
1673
1674 if (auto *CondInst = dyn_cast<Instruction>(Cond))
1675 if (DT.dominates(BC, CondInst))
1676 InsertPt = std::next(CondInst->getIterator());
1677
1678 Builder.SetInsertPoint(InsertPt);
1679 Value *VecSel =
1680 Builder.CreateSelect(Cond, SrcVec, Constant::getNullValue(SrcVecTy));
1681 Value *NewBC = Builder.CreateBitCast(VecSel, DstVecTy);
1682
1683 // Replace each scalar select with an extract from the new bitcast.
1684 for (SelectInst *Sel : Selects) {
1685 auto *Ext = cast<ExtractElementInst>(Sel->getTrueValue());
1686 Value *Idx = Ext->getIndexOperand();
1687
1688 Builder.SetInsertPoint(Sel);
1689 Value *NewExt = Builder.CreateExtractElement(NewBC, Idx);
1690 replaceValue(*Sel, *NewExt);
1691 MadeChange = true;
1692 }
1693
1694 LLVM_DEBUG(dbgs() << "VectorCombine: folded " << Selects.size()
1695 << " selects into vector select\n");
1696 }
1697
1698 return MadeChange;
1699}
1700
1703 const TargetTransformInfo &TTI,
1704 InstructionCost &CostBeforeReduction,
1705 InstructionCost &CostAfterReduction) {
1706 Instruction *Op0, *Op1;
1707 auto *RedOp = dyn_cast<Instruction>(II.getOperand(0));
1708 auto *VecRedTy = cast<VectorType>(II.getOperand(0)->getType());
1709 unsigned ReductionOpc =
1710 getArithmeticReductionInstruction(II.getIntrinsicID());
1711 if (RedOp && match(RedOp, m_ZExtOrSExt(m_Value()))) {
1712 bool IsUnsigned = isa<ZExtInst>(RedOp);
1713 auto *ExtType = cast<VectorType>(RedOp->getOperand(0)->getType());
1714
1715 CostBeforeReduction =
1716 TTI.getCastInstrCost(RedOp->getOpcode(), VecRedTy, ExtType,
1718 CostAfterReduction =
1719 TTI.getExtendedReductionCost(ReductionOpc, IsUnsigned, II.getType(),
1720 ExtType, FastMathFlags(), CostKind);
1721 return;
1722 }
1723 if (RedOp && II.getIntrinsicID() == Intrinsic::vector_reduce_add &&
1724 match(RedOp,
1726 match(Op0, m_ZExtOrSExt(m_Value())) &&
1727 Op0->getOpcode() == Op1->getOpcode() &&
1728 Op0->getOperand(0)->getType() == Op1->getOperand(0)->getType() &&
1729 (Op0->getOpcode() == RedOp->getOpcode() || Op0 == Op1)) {
1730 // Matched reduce.add(ext(mul(ext(A), ext(B)))
1731 bool IsUnsigned = isa<ZExtInst>(Op0);
1732 auto *ExtType = cast<VectorType>(Op0->getOperand(0)->getType());
1733 VectorType *MulType = VectorType::get(Op0->getType(), VecRedTy);
1734
1735 InstructionCost ExtCost =
1736 TTI.getCastInstrCost(Op0->getOpcode(), MulType, ExtType,
1738 InstructionCost MulCost =
1739 TTI.getArithmeticInstrCost(Instruction::Mul, MulType, CostKind);
1740 InstructionCost Ext2Cost =
1741 TTI.getCastInstrCost(RedOp->getOpcode(), VecRedTy, MulType,
1743
1744 CostBeforeReduction = ExtCost * 2 + MulCost + Ext2Cost;
1745 CostAfterReduction = TTI.getMulAccReductionCost(
1746 IsUnsigned, ReductionOpc, II.getType(), ExtType, CostKind);
1747 return;
1748 }
1749 CostAfterReduction = TTI.getArithmeticReductionCost(ReductionOpc, VecRedTy,
1750 std::nullopt, CostKind);
1751}
1752
1753bool VectorCombine::foldBinopOfReductions(Instruction &I) {
1754 Instruction::BinaryOps BinOpOpc = cast<BinaryOperator>(&I)->getOpcode();
1755 Intrinsic::ID ReductionIID = getReductionForBinop(BinOpOpc);
1756 if (BinOpOpc == Instruction::Sub)
1757 ReductionIID = Intrinsic::vector_reduce_add;
1758 if (ReductionIID == Intrinsic::not_intrinsic)
1759 return false;
1760 // FP reductions have a start-value operand that this fold doesn't handle.
1761 if (ReductionIID == Intrinsic::vector_reduce_fadd ||
1762 ReductionIID == Intrinsic::vector_reduce_fmul)
1763 return false;
1764
1765 auto checkIntrinsicAndGetItsArgument = [](Value *V,
1766 Intrinsic::ID IID) -> Value * {
1767 auto *II = dyn_cast<IntrinsicInst>(V);
1768 if (!II)
1769 return nullptr;
1770 if (II->getIntrinsicID() == IID && II->hasOneUse())
1771 return II->getArgOperand(0);
1772 return nullptr;
1773 };
1774
1775 Value *V0 = checkIntrinsicAndGetItsArgument(I.getOperand(0), ReductionIID);
1776 if (!V0)
1777 return false;
1778 Value *V1 = checkIntrinsicAndGetItsArgument(I.getOperand(1), ReductionIID);
1779 if (!V1)
1780 return false;
1781
1782 auto *VTy = cast<VectorType>(V0->getType());
1783 if (V1->getType() != VTy)
1784 return false;
1785 const auto &II0 = *cast<IntrinsicInst>(I.getOperand(0));
1786 const auto &II1 = *cast<IntrinsicInst>(I.getOperand(1));
1787 unsigned ReductionOpc =
1788 getArithmeticReductionInstruction(II0.getIntrinsicID());
1789
1790 InstructionCost OldCost = 0;
1791 InstructionCost NewCost = 0;
1792 InstructionCost CostOfRedOperand0 = 0;
1793 InstructionCost CostOfRed0 = 0;
1794 InstructionCost CostOfRedOperand1 = 0;
1795 InstructionCost CostOfRed1 = 0;
1796 analyzeCostOfVecReduction(II0, CostKind, TTI, CostOfRedOperand0, CostOfRed0);
1797 analyzeCostOfVecReduction(II1, CostKind, TTI, CostOfRedOperand1, CostOfRed1);
1798 OldCost = CostOfRed0 + CostOfRed1 + TTI.getInstructionCost(&I, CostKind);
1799 NewCost =
1800 CostOfRedOperand0 + CostOfRedOperand1 +
1801 TTI.getArithmeticInstrCost(BinOpOpc, VTy, CostKind) +
1802 TTI.getArithmeticReductionCost(ReductionOpc, VTy, std::nullopt, CostKind);
1803 if (NewCost >= OldCost || !NewCost.isValid())
1804 return false;
1805
1806 LLVM_DEBUG(dbgs() << "Found two mergeable reductions: " << I
1807 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
1808 << "\n");
1809 Value *VectorBO;
1810 if (BinOpOpc == Instruction::Or)
1811 VectorBO = Builder.CreateOr(V0, V1, "",
1812 cast<PossiblyDisjointInst>(I).isDisjoint());
1813 else
1814 VectorBO = Builder.CreateBinOp(BinOpOpc, V0, V1);
1815
1816 Value *Rdx = Builder.CreateIntrinsic(ReductionIID, {VTy}, {VectorBO});
1817 replaceValue(I, *Rdx);
1818 return true;
1819}
1820
1821// Check if memory loc modified between two instrs in the same BB
1824 const MemoryLocation &Loc, AAResults &AA) {
1825 unsigned NumScanned = 0;
1826 return std::any_of(Begin, End, [&](const Instruction &Instr) {
1827 return isModSet(AA.getModRefInfo(&Instr, Loc)) ||
1828 ++NumScanned > MaxInstrsToScan;
1829 });
1830}
1831
1832namespace {
1833/// Helper class to indicate whether a vector index can be safely scalarized and
1834/// if a freeze needs to be inserted.
1835class ScalarizationResult {
1836 enum class StatusTy { Unsafe, Safe, SafeWithFreeze };
1837
1838 StatusTy Status;
1839 Value *ToFreeze;
1840
1841 ScalarizationResult(StatusTy Status, Value *ToFreeze = nullptr)
1842 : Status(Status), ToFreeze(ToFreeze) {}
1843
1844public:
1845 ScalarizationResult(const ScalarizationResult &Other) = default;
1846 ~ScalarizationResult() {
1847 assert(!ToFreeze && "freeze() not called with ToFreeze being set");
1848 }
1849
1850 static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; }
1851 static ScalarizationResult safe() { return {StatusTy::Safe}; }
1852 static ScalarizationResult safeWithFreeze(Value *ToFreeze) {
1853 return {StatusTy::SafeWithFreeze, ToFreeze};
1854 }
1855
1856 /// Returns true if the index can be scalarize without requiring a freeze.
1857 bool isSafe() const { return Status == StatusTy::Safe; }
1858 /// Returns true if the index cannot be scalarized.
1859 bool isUnsafe() const { return Status == StatusTy::Unsafe; }
1860 /// Returns true if the index can be scalarize, but requires inserting a
1861 /// freeze.
1862 bool isSafeWithFreeze() const { return Status == StatusTy::SafeWithFreeze; }
1863
1864 /// Reset the state of Unsafe and clear ToFreze if set.
1865 void discard() {
1866 ToFreeze = nullptr;
1867 Status = StatusTy::Unsafe;
1868 }
1869
1870 /// Freeze the ToFreeze and update the use in \p User to use it.
1871 void freeze(IRBuilderBase &Builder, Instruction &UserI) {
1872 assert(isSafeWithFreeze() &&
1873 "should only be used when freezing is required");
1874 assert(is_contained(ToFreeze->users(), &UserI) &&
1875 "UserI must be a user of ToFreeze");
1876 IRBuilder<>::InsertPointGuard Guard(Builder);
1877 Builder.SetInsertPoint(cast<Instruction>(&UserI));
1878 Value *Frozen =
1879 Builder.CreateFreeze(ToFreeze, ToFreeze->getName() + ".frozen");
1880 for (Use &U : make_early_inc_range((UserI.operands())))
1881 if (U.get() == ToFreeze)
1882 U.set(Frozen);
1883
1884 ToFreeze = nullptr;
1885 }
1886};
1887} // namespace
1888
1889/// Check if it is legal to scalarize a memory access to \p VecTy at index \p
1890/// Idx. \p Idx must access a valid vector element.
1891static ScalarizationResult canScalarizeAccess(VectorType *VecTy, Value *Idx,
1892 const SimplifyQuery &SQ) {
1893 // We do checks for both fixed vector types and scalable vector types.
1894 // This is the number of elements of fixed vector types,
1895 // or the minimum number of elements of scalable vector types.
1896 uint64_t NumElements = VecTy->getElementCount().getKnownMinValue();
1897 unsigned IntWidth = Idx->getType()->getScalarSizeInBits();
1898
1899 if (auto *C = dyn_cast<ConstantInt>(Idx)) {
1900 if (C->getValue().ult(NumElements))
1901 return ScalarizationResult::safe();
1902 return ScalarizationResult::unsafe();
1903 }
1904
1905 // Always unsafe if the index type can't handle all inbound values.
1906 if (!llvm::isUIntN(IntWidth, NumElements))
1907 return ScalarizationResult::unsafe();
1908
1909 APInt Zero(IntWidth, 0);
1910 APInt MaxElts(IntWidth, NumElements);
1911 ConstantRange ValidIndices(Zero, MaxElts);
1912 ConstantRange IdxRange(IntWidth, true);
1913
1914 if (isGuaranteedNotToBePoison(Idx, SQ.AC, SQ.CxtI, SQ.DT)) {
1915 if (ValidIndices.contains(
1916 computeConstantRange(Idx, /*ForSigned=*/false, SQ)))
1917 return ScalarizationResult::safe();
1918 return ScalarizationResult::unsafe();
1919 }
1920
1921 // If the index may be poison, check if we can insert a freeze before the
1922 // range of the index is restricted.
1923 Value *IdxBase;
1924 ConstantInt *CI;
1925 if (match(Idx, m_And(m_Value(IdxBase), m_ConstantInt(CI)))) {
1926 IdxRange = IdxRange.binaryAnd(CI->getValue());
1927 } else if (match(Idx, m_URem(m_Value(IdxBase), m_ConstantInt(CI)))) {
1928 IdxRange = IdxRange.urem(CI->getValue());
1929 }
1930
1931 if (ValidIndices.contains(IdxRange))
1932 return ScalarizationResult::safeWithFreeze(IdxBase);
1933 return ScalarizationResult::unsafe();
1934}
1935
1936/// The memory operation on a vector of \p ScalarType had alignment of
1937/// \p VectorAlignment. Compute the maximal, but conservatively correct,
1938/// alignment that will be valid for the memory operation on a single scalar
1939/// element of the same type with index \p Idx.
1941 Type *ScalarType, Value *Idx,
1942 const DataLayout &DL) {
1943 if (auto *C = dyn_cast<ConstantInt>(Idx))
1944 return commonAlignment(VectorAlignment,
1945 C->getZExtValue() * DL.getTypeStoreSize(ScalarType));
1946 return commonAlignment(VectorAlignment, DL.getTypeStoreSize(ScalarType));
1947}
1948
1949// Combine patterns like:
1950// %0 = load <4 x i32>, <4 x i32>* %a
1951// %1 = insertelement <4 x i32> %0, i32 %b, i32 1
1952// store <4 x i32> %1, <4 x i32>* %a
1953// to:
1954// %0 = bitcast <4 x i32>* %a to i32*
1955// %1 = getelementptr inbounds i32, i32* %0, i64 0, i64 1
1956// store i32 %b, i32* %1
1957bool VectorCombine::foldSingleElementStore(Instruction &I) {
1959 return false;
1960 auto *SI = cast<StoreInst>(&I);
1961 if (!SI->isSimple() || !isa<VectorType>(SI->getValueOperand()->getType()))
1962 return false;
1963
1964 // TODO: Combine more complicated patterns (multiple insert) by referencing
1965 // TargetTransformInfo.
1967 Value *NewElement;
1968 Value *Idx;
1969 if (!match(SI->getValueOperand(),
1970 m_InsertElt(m_Instruction(Source), m_Value(NewElement),
1971 m_Value(Idx))))
1972 return false;
1973
1974 if (auto *Load = dyn_cast<LoadInst>(Source)) {
1975 auto VecTy = cast<VectorType>(SI->getValueOperand()->getType());
1976 Value *SrcAddr = Load->getPointerOperand()->stripPointerCasts();
1977 // Don't optimize for atomic/volatile load or store. Ensure memory is not
1978 // modified between, vector type matches store size, and index is inbounds.
1979 if (!Load->isSimple() || Load->getParent() != SI->getParent() ||
1980 !DL->typeSizeEqualsStoreSize(Load->getType()->getScalarType()) ||
1981 SrcAddr != SI->getPointerOperand()->stripPointerCasts())
1982 return false;
1983
1984 auto ScalarizableIdx =
1985 canScalarizeAccess(VecTy, Idx, SQ.getWithInstruction(Load));
1986 if (ScalarizableIdx.isUnsafe() ||
1987 isMemModifiedBetween(Load->getIterator(), SI->getIterator(),
1988 MemoryLocation::get(SI), AA))
1989 return false;
1990
1991 // Ensure we add the load back to the worklist BEFORE its users so they can
1992 // erased in the correct order.
1993 Worklist.push(Load);
1994
1995 if (ScalarizableIdx.isSafeWithFreeze())
1996 ScalarizableIdx.freeze(Builder, *cast<Instruction>(Idx));
1997 Value *GEP = Builder.CreateInBoundsGEP(
1998 SI->getValueOperand()->getType(), SI->getPointerOperand(),
1999 {ConstantInt::get(Idx->getType(), 0), Idx});
2000 StoreInst *NSI = Builder.CreateStore(NewElement, GEP);
2001 NSI->copyMetadata(*SI);
2002 Align ScalarOpAlignment = computeAlignmentAfterScalarization(
2003 std::max(SI->getAlign(), Load->getAlign()), NewElement->getType(), Idx,
2004 *DL);
2005 NSI->setAlignment(ScalarOpAlignment);
2006 replaceValue(I, *NSI);
2008 return true;
2009 }
2010
2011 return false;
2012}
2013
2014/// Try to scalarize vector loads feeding extractelement or bitcast
2015/// instructions.
2016bool VectorCombine::scalarizeLoad(Instruction &I) {
2017 Value *Ptr;
2018 if (!match(&I, m_Load(m_Value(Ptr))))
2019 return false;
2020
2021 auto *LI = cast<LoadInst>(&I);
2022 auto *VecTy = cast<VectorType>(LI->getType());
2023
2024 // The isSimple() check could be isUnordered(), but for now we cowardly
2025 // refuse to handle even unordered atomics.
2026 if (!LI->isSimple() || !DL->typeSizeEqualsStoreSize(VecTy->getScalarType()))
2027 return false;
2028
2029 bool AllExtracts = true;
2030 bool AllBitcasts = true;
2031 Instruction *LastCheckedInst = LI;
2032 unsigned NumInstChecked = 0;
2033
2034 // Check what type of users we have (must either all be extracts or
2035 // bitcasts) and ensure no memory modifications between the load and
2036 // its users.
2037 for (User *U : LI->users()) {
2038 auto *UI = dyn_cast<Instruction>(U);
2039 if (!UI || UI->getParent() != LI->getParent())
2040 return false;
2041
2042 // If any user is waiting to be erased, then bail out as this will
2043 // distort the cost calculation and possibly lead to infinite loops.
2044 if (UI->use_empty())
2045 return false;
2046
2047 if (!isa<ExtractElementInst>(UI))
2048 AllExtracts = false;
2049 if (!isa<BitCastInst>(UI))
2050 AllBitcasts = false;
2051
2052 // Check if any instruction between the load and the user may modify memory.
2053 if (LastCheckedInst->comesBefore(UI)) {
2054 for (Instruction &I :
2055 make_range(std::next(LI->getIterator()), UI->getIterator())) {
2056 // Bail out if we reached the check limit or the instruction may write
2057 // to memory.
2058 if (NumInstChecked == MaxInstrsToScan || I.mayWriteToMemory())
2059 return false;
2060 NumInstChecked++;
2061 }
2062 LastCheckedInst = UI;
2063 }
2064 }
2065
2066 if (AllExtracts)
2067 return scalarizeLoadExtract(LI, VecTy, Ptr);
2068 if (AllBitcasts)
2069 return scalarizeLoadBitcast(LI, VecTy, Ptr);
2070 return false;
2071}
2072
2073/// Try to scalarize vector loads feeding extractelement instructions.
2074bool VectorCombine::scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy,
2075 Value *Ptr) {
2077 return false;
2078
2079 DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
2080 llvm::scope_exit FailureGuard([&]() {
2081 // If the transform is aborted, discard the ScalarizationResults.
2082 for (auto &Pair : NeedFreeze)
2083 Pair.second.discard();
2084 });
2085
2086 InstructionCost OriginalCost =
2087 TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
2089 InstructionCost ScalarizedCost = 0;
2090
2091 for (User *U : LI->users()) {
2092 auto *UI = cast<ExtractElementInst>(U);
2093
2094 auto ScalarIdx = canScalarizeAccess(VecTy, UI->getIndexOperand(),
2095 SQ.getWithInstruction(LI));
2096 if (ScalarIdx.isUnsafe())
2097 return false;
2098 if (ScalarIdx.isSafeWithFreeze()) {
2099 NeedFreeze.try_emplace(UI, ScalarIdx);
2100 ScalarIdx.discard();
2101 }
2102
2103 auto *Index = dyn_cast<ConstantInt>(UI->getIndexOperand());
2104 OriginalCost +=
2105 TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind,
2106 Index ? Index->getZExtValue() : -1);
2107 ScalarizedCost +=
2108 TTI.getMemoryOpCost(Instruction::Load, VecTy->getElementType(),
2110 ScalarizedCost += TTI.getAddressComputationCost(LI->getPointerOperandType(),
2111 nullptr, nullptr, CostKind);
2112 }
2113
2114 LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << *LI
2115 << "\n LoadExtractCost: " << OriginalCost
2116 << " vs ScalarizedCost: " << ScalarizedCost << "\n");
2117
2118 if (ScalarizedCost >= OriginalCost)
2119 return false;
2120
2121 // Ensure we add the load back to the worklist BEFORE its users so they can
2122 // erased in the correct order.
2123 Worklist.push(LI);
2124
2125 Type *ElemType = VecTy->getElementType();
2126
2127 // Replace extracts with narrow scalar loads.
2128 for (User *U : LI->users()) {
2129 auto *EI = cast<ExtractElementInst>(U);
2130 Value *Idx = EI->getIndexOperand();
2131
2132 // Insert 'freeze' for poison indexes.
2133 auto It = NeedFreeze.find(EI);
2134 if (It != NeedFreeze.end())
2135 It->second.freeze(Builder, *cast<Instruction>(Idx));
2136
2137 Builder.SetInsertPoint(EI);
2138 Value *GEP =
2139 Builder.CreateInBoundsGEP(VecTy, Ptr, {Builder.getInt32(0), Idx});
2140 auto *NewLoad = cast<LoadInst>(
2141 Builder.CreateLoad(ElemType, GEP, EI->getName() + ".scalar"));
2142
2143 Align ScalarOpAlignment =
2144 computeAlignmentAfterScalarization(LI->getAlign(), ElemType, Idx, *DL);
2145 NewLoad->setAlignment(ScalarOpAlignment);
2146
2147 if (auto *ConstIdx = dyn_cast<ConstantInt>(Idx)) {
2148 size_t Offset = ConstIdx->getZExtValue() * DL->getTypeStoreSize(ElemType);
2149 AAMDNodes OldAAMD = LI->getAAMetadata();
2150 NewLoad->setAAMetadata(OldAAMD.adjustForAccess(Offset, ElemType, *DL));
2151 }
2152
2153 replaceValue(*EI, *NewLoad, false);
2154 }
2155
2156 FailureGuard.release();
2157 return true;
2158}
2159
2160/// Try to scalarize vector loads feeding bitcast instructions.
2161bool VectorCombine::scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy,
2162 Value *Ptr) {
2163 InstructionCost OriginalCost =
2164 TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
2166
2167 Type *TargetScalarType = nullptr;
2168 unsigned VecBitWidth = DL->getTypeSizeInBits(VecTy);
2169
2170 for (User *U : LI->users()) {
2171 auto *BC = cast<BitCastInst>(U);
2172
2173 Type *DestTy = BC->getDestTy();
2174 if (!DestTy->isIntegerTy() && !DestTy->isFloatingPointTy())
2175 return false;
2176
2177 unsigned DestBitWidth = DL->getTypeSizeInBits(DestTy);
2178 if (DestBitWidth != VecBitWidth)
2179 return false;
2180
2181 // All bitcasts must target the same scalar type.
2182 if (!TargetScalarType)
2183 TargetScalarType = DestTy;
2184 else if (TargetScalarType != DestTy)
2185 return false;
2186
2187 OriginalCost +=
2188 TTI.getCastInstrCost(Instruction::BitCast, TargetScalarType, VecTy,
2190 }
2191
2192 if (!TargetScalarType)
2193 return false;
2194
2195 assert(!LI->user_empty() && "Unexpected load without bitcast users");
2196 InstructionCost ScalarizedCost =
2197 TTI.getMemoryOpCost(Instruction::Load, TargetScalarType, LI->getAlign(),
2199
2200 LLVM_DEBUG(dbgs() << "Found vector load feeding only bitcasts: " << *LI
2201 << "\n OriginalCost: " << OriginalCost
2202 << " vs ScalarizedCost: " << ScalarizedCost << "\n");
2203
2204 if (ScalarizedCost >= OriginalCost)
2205 return false;
2206
2207 // Ensure we add the load back to the worklist BEFORE its users so they can
2208 // erased in the correct order.
2209 Worklist.push(LI);
2210
2211 Builder.SetInsertPoint(LI);
2212 auto *ScalarLoad =
2213 Builder.CreateLoad(TargetScalarType, Ptr, LI->getName() + ".scalar");
2214 ScalarLoad->setAlignment(LI->getAlign());
2215 ScalarLoad->copyMetadata(*LI);
2216
2217 // Replace all bitcast users with the scalar load.
2218 for (User *U : LI->users()) {
2219 auto *BC = cast<BitCastInst>(U);
2220 replaceValue(*BC, *ScalarLoad, false);
2221 }
2222
2223 return true;
2224}
2225
2226bool VectorCombine::scalarizeExtExtract(Instruction &I) {
2228 return false;
2229 auto *Ext = dyn_cast<ZExtInst>(&I);
2230 if (!Ext)
2231 return false;
2232
2233 // Try to convert a vector zext feeding only extracts to a set of scalar
2234 // (Src << ExtIdx *Size) & (Size -1)
2235 // if profitable .
2236 auto *SrcTy = dyn_cast<FixedVectorType>(Ext->getOperand(0)->getType());
2237 if (!SrcTy)
2238 return false;
2239 auto *DstTy = cast<FixedVectorType>(Ext->getType());
2240
2241 Type *ScalarDstTy = DstTy->getElementType();
2242 if (DL->getTypeSizeInBits(SrcTy) != DL->getTypeSizeInBits(ScalarDstTy))
2243 return false;
2244
2245 InstructionCost VectorCost =
2246 TTI.getCastInstrCost(Instruction::ZExt, DstTy, SrcTy,
2248 unsigned ExtCnt = 0;
2249 bool ExtLane0 = false;
2250 for (User *U : Ext->users()) {
2251 uint64_t Idx;
2252 if (!match(U, m_ExtractElt(m_Value(), m_ConstantInt(Idx))))
2253 return false;
2254 if (cast<Instruction>(U)->use_empty())
2255 continue;
2256 ExtCnt += 1;
2257 ExtLane0 |= !Idx;
2258 VectorCost += TTI.getVectorInstrCost(Instruction::ExtractElement, DstTy,
2259 CostKind, Idx, U);
2260 }
2261
2262 InstructionCost ScalarCost =
2263 ExtCnt * TTI.getArithmeticInstrCost(
2264 Instruction::And, ScalarDstTy, CostKind,
2267 (ExtCnt - ExtLane0) *
2269 Instruction::LShr, ScalarDstTy, CostKind,
2272 if (ScalarCost > VectorCost)
2273 return false;
2274
2275 Value *ScalarV = Ext->getOperand(0);
2276 if (!isGuaranteedNotToBePoison(ScalarV, SQ.AC, dyn_cast<Instruction>(ScalarV),
2277 SQ.DT)) {
2278 // Check wether all lanes are extracted, all extracts trigger UB
2279 // on poison, and the last extract (and hence all previous ones)
2280 // are guaranteed to execute if Ext executes. If so, we do not
2281 // need to insert a freeze.
2282 SmallDenseSet<ConstantInt *, 8> ExtractedLanes;
2283 bool AllExtractsTriggerUB = true;
2284 ExtractElementInst *LastExtract = nullptr;
2285 BasicBlock *ExtBB = Ext->getParent();
2286 for (User *U : Ext->users()) {
2287 auto *Extract = cast<ExtractElementInst>(U);
2288 if (Extract->getParent() != ExtBB || !programUndefinedIfPoison(Extract)) {
2289 AllExtractsTriggerUB = false;
2290 break;
2291 }
2292 ExtractedLanes.insert(cast<ConstantInt>(Extract->getIndexOperand()));
2293 if (!LastExtract || LastExtract->comesBefore(Extract))
2294 LastExtract = Extract;
2295 }
2296 if (ExtractedLanes.size() != DstTy->getNumElements() ||
2297 !AllExtractsTriggerUB ||
2299 LastExtract->getIterator()))
2300 ScalarV = Builder.CreateFreeze(ScalarV);
2301 }
2302 ScalarV = Builder.CreateBitCast(
2303 ScalarV,
2304 IntegerType::get(SrcTy->getContext(), DL->getTypeSizeInBits(SrcTy)));
2305 uint64_t SrcEltSizeInBits = DL->getTypeSizeInBits(SrcTy->getElementType());
2306 uint64_t TotalBits = DL->getTypeSizeInBits(SrcTy);
2307 APInt EltBitMask = APInt::getLowBitsSet(TotalBits, SrcEltSizeInBits);
2308 Type *PackedTy = IntegerType::get(SrcTy->getContext(), TotalBits);
2309 Value *Mask = ConstantInt::get(PackedTy, EltBitMask);
2310 for (User *U : Ext->users()) {
2311 auto *Extract = cast<ExtractElementInst>(U);
2312 uint64_t Idx =
2313 cast<ConstantInt>(Extract->getIndexOperand())->getZExtValue();
2314 uint64_t ShiftAmt =
2315 DL->isBigEndian()
2316 ? (TotalBits - SrcEltSizeInBits - Idx * SrcEltSizeInBits)
2317 : (Idx * SrcEltSizeInBits);
2318 Value *LShr = Builder.CreateLShr(ScalarV, ShiftAmt);
2319 Value *And = Builder.CreateAnd(LShr, Mask);
2320 U->replaceAllUsesWith(And);
2321 }
2322 return true;
2323}
2324
2325/// Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))"
2326/// to "(bitcast (concat X, Y))"
2327/// where X/Y are bitcasted from i1 mask vectors.
2328bool VectorCombine::foldConcatOfBoolMasks(Instruction &I) {
2329 Type *Ty = I.getType();
2330 if (!Ty->isIntegerTy())
2331 return false;
2332
2333 // TODO: Add big endian test coverage
2334 if (DL->isBigEndian())
2335 return false;
2336
2337 // Restrict to disjoint cases so the mask vectors aren't overlapping.
2338 Instruction *X, *Y;
2340 return false;
2341
2342 // Allow both sources to contain shl, to handle more generic pattern:
2343 // "(or (shl (zext (bitcast X)), C1), (shl (zext (bitcast Y)), C2))"
2344 Value *SrcX;
2345 uint64_t ShAmtX = 0;
2346 if (!match(X, m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcX)))))) &&
2347 !match(X, m_OneUse(
2349 m_ConstantInt(ShAmtX)))))
2350 return false;
2351
2352 Value *SrcY;
2353 uint64_t ShAmtY = 0;
2354 if (!match(Y, m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcY)))))) &&
2355 !match(Y, m_OneUse(
2357 m_ConstantInt(ShAmtY)))))
2358 return false;
2359
2360 // Canonicalize larger shift to the RHS.
2361 if (ShAmtX > ShAmtY) {
2362 std::swap(X, Y);
2363 std::swap(SrcX, SrcY);
2364 std::swap(ShAmtX, ShAmtY);
2365 }
2366
2367 // Ensure both sources are matching vXi1 bool mask types, and that the shift
2368 // difference is the mask width so they can be easily concatenated together.
2369 uint64_t ShAmtDiff = ShAmtY - ShAmtX;
2370 unsigned NumSHL = (ShAmtX > 0) + (ShAmtY > 0);
2371 unsigned BitWidth = Ty->getPrimitiveSizeInBits();
2372 auto *MaskTy = dyn_cast<FixedVectorType>(SrcX->getType());
2373 if (!MaskTy || SrcX->getType() != SrcY->getType() ||
2374 !MaskTy->getElementType()->isIntegerTy(1) ||
2375 MaskTy->getNumElements() != ShAmtDiff ||
2376 MaskTy->getNumElements() > (BitWidth / 2))
2377 return false;
2378
2379 auto *ConcatTy = FixedVectorType::getDoubleElementsVectorType(MaskTy);
2380 auto *ConcatIntTy =
2381 Type::getIntNTy(Ty->getContext(), ConcatTy->getNumElements());
2382 auto *MaskIntTy = Type::getIntNTy(Ty->getContext(), ShAmtDiff);
2383
2384 SmallVector<int, 32> ConcatMask(ConcatTy->getNumElements());
2385 std::iota(ConcatMask.begin(), ConcatMask.end(), 0);
2386
2387 // TODO: Is it worth supporting multi use cases?
2388 InstructionCost OldCost = 0;
2389 OldCost += TTI.getArithmeticInstrCost(Instruction::Or, Ty, CostKind);
2390 OldCost +=
2391 NumSHL * TTI.getArithmeticInstrCost(Instruction::Shl, Ty, CostKind);
2392 OldCost += 2 * TTI.getCastInstrCost(Instruction::ZExt, Ty, MaskIntTy,
2394 OldCost += 2 * TTI.getCastInstrCost(Instruction::BitCast, MaskIntTy, MaskTy,
2396
2397 InstructionCost NewCost = 0;
2399 MaskTy, ConcatMask, CostKind);
2400 NewCost += TTI.getCastInstrCost(Instruction::BitCast, ConcatIntTy, ConcatTy,
2402 if (Ty != ConcatIntTy)
2403 NewCost += TTI.getCastInstrCost(Instruction::ZExt, Ty, ConcatIntTy,
2405 if (ShAmtX > 0)
2406 NewCost += TTI.getArithmeticInstrCost(Instruction::Shl, Ty, CostKind);
2407
2408 LLVM_DEBUG(dbgs() << "Found a concatenation of bitcasted bool masks: " << I
2409 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2410 << "\n");
2411
2412 if (NewCost > OldCost)
2413 return false;
2414
2415 // Build bool mask concatenation, bitcast back to scalar integer, and perform
2416 // any residual zero-extension or shifting.
2417 Value *Concat = Builder.CreateShuffleVector(SrcX, SrcY, ConcatMask);
2418 Worklist.pushValue(Concat);
2419
2420 Value *Result = Builder.CreateBitCast(Concat, ConcatIntTy);
2421
2422 if (Ty != ConcatIntTy) {
2423 Worklist.pushValue(Result);
2424 Result = Builder.CreateZExt(Result, Ty);
2425 }
2426
2427 if (ShAmtX > 0) {
2428 Worklist.pushValue(Result);
2429 Result = Builder.CreateShl(Result, ShAmtX);
2430 }
2431
2432 replaceValue(I, *Result);
2433 return true;
2434}
2435
2436/// Try to convert "shuffle (binop (shuffle, shuffle)), undef"
2437/// --> "binop (shuffle), (shuffle)".
2438bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
2439 BinaryOperator *BinOp;
2440 ArrayRef<int> OuterMask;
2441 if (!match(&I, m_Shuffle(m_BinOp(BinOp), m_Undef(), m_Mask(OuterMask))))
2442 return false;
2443
2444 // Don't introduce poison into div/rem.
2445 if (BinOp->isIntDivRem() && llvm::is_contained(OuterMask, PoisonMaskElem))
2446 return false;
2447
2448 Value *Op00, *Op01, *Op10, *Op11;
2449 ArrayRef<int> Mask0, Mask1;
2450 bool Match0 = match(BinOp->getOperand(0),
2451 m_Shuffle(m_Value(Op00), m_Value(Op01), m_Mask(Mask0)));
2452 bool Match1 = match(BinOp->getOperand(1),
2453 m_Shuffle(m_Value(Op10), m_Value(Op11), m_Mask(Mask1)));
2454 if (!Match0 && !Match1)
2455 return false;
2456
2457 Op00 = Match0 ? Op00 : BinOp->getOperand(0);
2458 Op01 = Match0 ? Op01 : BinOp->getOperand(0);
2459 Op10 = Match1 ? Op10 : BinOp->getOperand(1);
2460 Op11 = Match1 ? Op11 : BinOp->getOperand(1);
2461
2462 Instruction::BinaryOps Opcode = BinOp->getOpcode();
2463 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
2464 auto *BinOpTy = dyn_cast<FixedVectorType>(BinOp->getType());
2465 auto *Op0Ty = dyn_cast<FixedVectorType>(Op00->getType());
2466 auto *Op1Ty = dyn_cast<FixedVectorType>(Op10->getType());
2467 if (!ShuffleDstTy || !BinOpTy || !Op0Ty || !Op1Ty)
2468 return false;
2469
2470 unsigned NumSrcElts = BinOpTy->getNumElements();
2471
2472 // Don't accept shuffles that reference the second operand in
2473 // div/rem or if its an undef arg.
2474 if ((BinOp->isIntDivRem() || !isa<PoisonValue>(I.getOperand(1))) &&
2475 any_of(OuterMask, [NumSrcElts](int M) { return M >= (int)NumSrcElts; }))
2476 return false;
2477
2478 // Merge outer / inner (or identity if no match) shuffles.
2479 SmallVector<int> NewMask0, NewMask1;
2480 for (int M : OuterMask) {
2481 if (M < 0 || M >= (int)NumSrcElts) {
2482 NewMask0.push_back(PoisonMaskElem);
2483 NewMask1.push_back(PoisonMaskElem);
2484 } else {
2485 NewMask0.push_back(Match0 ? Mask0[M] : M);
2486 NewMask1.push_back(Match1 ? Mask1[M] : M);
2487 }
2488 }
2489
2490 unsigned NumOpElts = Op0Ty->getNumElements();
2491 bool IsIdentity0 = ShuffleDstTy == Op0Ty &&
2492 all_of(NewMask0, [NumOpElts](int M) { return M < (int)NumOpElts; }) &&
2493 ShuffleVectorInst::isIdentityMask(NewMask0, NumOpElts);
2494 bool IsIdentity1 = ShuffleDstTy == Op1Ty &&
2495 all_of(NewMask1, [NumOpElts](int M) { return M < (int)NumOpElts; }) &&
2496 ShuffleVectorInst::isIdentityMask(NewMask1, NumOpElts);
2497
2498 InstructionCost NewCost = 0;
2499 // Try to merge shuffles across the binop if the new shuffles are not costly.
2500 InstructionCost BinOpCost =
2501 TTI.getArithmeticInstrCost(Opcode, BinOpTy, CostKind);
2502 InstructionCost OldCost =
2504 ShuffleDstTy, BinOpTy, OuterMask, CostKind,
2505 0, nullptr, {BinOp}, &I);
2506 if (!BinOp->hasOneUse())
2507 NewCost += BinOpCost;
2508
2509 if (Match0) {
2511 TargetTransformInfo::SK_PermuteTwoSrc, BinOpTy, Op0Ty, Mask0, CostKind,
2512 0, nullptr, {Op00, Op01}, cast<Instruction>(BinOp->getOperand(0)));
2513 OldCost += Shuf0Cost;
2514 if (!BinOp->hasOneUse() || !BinOp->getOperand(0)->hasOneUse())
2515 NewCost += Shuf0Cost;
2516 }
2517 if (Match1) {
2519 TargetTransformInfo::SK_PermuteTwoSrc, BinOpTy, Op1Ty, Mask1, CostKind,
2520 0, nullptr, {Op10, Op11}, cast<Instruction>(BinOp->getOperand(1)));
2521 OldCost += Shuf1Cost;
2522 if (!BinOp->hasOneUse() || !BinOp->getOperand(1)->hasOneUse())
2523 NewCost += Shuf1Cost;
2524 }
2525
2526 NewCost += TTI.getArithmeticInstrCost(Opcode, ShuffleDstTy, CostKind);
2527
2528 if (!IsIdentity0)
2529 NewCost +=
2531 Op0Ty, NewMask0, CostKind, 0, nullptr, {Op00, Op01});
2532 if (!IsIdentity1)
2533 NewCost +=
2535 Op1Ty, NewMask1, CostKind, 0, nullptr, {Op10, Op11});
2536
2537 LLVM_DEBUG(dbgs() << "Found a shuffle feeding a shuffled binop: " << I
2538 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2539 << "\n");
2540
2541 // If costs are equal, still fold as we reduce instruction count.
2542 if (NewCost > OldCost)
2543 return false;
2544
2545 Value *LHS =
2546 IsIdentity0 ? Op00 : Builder.CreateShuffleVector(Op00, Op01, NewMask0);
2547 Value *RHS =
2548 IsIdentity1 ? Op10 : Builder.CreateShuffleVector(Op10, Op11, NewMask1);
2549 Value *NewBO = Builder.CreateBinOp(Opcode, LHS, RHS);
2550
2551 // Intersect flags from the old binops.
2552 if (auto *NewInst = dyn_cast<Instruction>(NewBO))
2553 NewInst->copyIRFlags(BinOp);
2554
2555 Worklist.pushValue(LHS);
2556 Worklist.pushValue(RHS);
2557 replaceValue(I, *NewBO);
2558 return true;
2559}
2560
2561/// Try to convert "shuffle (binop), (binop)" into "binop (shuffle), (shuffle)".
2562/// Try to convert "shuffle (cmpop), (cmpop)" into "cmpop (shuffle), (shuffle)".
2563bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
2564 ArrayRef<int> OldMask;
2565 Instruction *LHS, *RHS;
2567 m_Mask(OldMask))))
2568 return false;
2569
2570 // TODO: Add support for addlike etc.
2571 if (LHS->getOpcode() != RHS->getOpcode())
2572 return false;
2573
2574 Value *X, *Y, *Z, *W;
2575 bool IsCommutative = false;
2576 CmpPredicate PredLHS = CmpInst::BAD_ICMP_PREDICATE;
2577 CmpPredicate PredRHS = CmpInst::BAD_ICMP_PREDICATE;
2578 if (match(LHS, m_BinOp(m_Value(X), m_Value(Y))) &&
2579 match(RHS, m_BinOp(m_Value(Z), m_Value(W)))) {
2580 auto *BO = cast<BinaryOperator>(LHS);
2581 // Don't introduce poison into div/rem.
2582 if (llvm::is_contained(OldMask, PoisonMaskElem) && BO->isIntDivRem())
2583 return false;
2584 IsCommutative = BinaryOperator::isCommutative(BO->getOpcode());
2585 } else if (match(LHS, m_Cmp(PredLHS, m_Value(X), m_Value(Y))) &&
2586 match(RHS, m_Cmp(PredRHS, m_Value(Z), m_Value(W))) &&
2587 (CmpInst::Predicate)PredLHS == (CmpInst::Predicate)PredRHS) {
2588 IsCommutative = cast<CmpInst>(LHS)->isCommutative();
2589 } else
2590 return false;
2591
2592 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
2593 auto *BinResTy = dyn_cast<FixedVectorType>(LHS->getType());
2594 auto *BinOpTy = dyn_cast<FixedVectorType>(X->getType());
2595 if (!ShuffleDstTy || !BinResTy || !BinOpTy || X->getType() != Z->getType())
2596 return false;
2597
2598 bool SameBinOp = LHS == RHS;
2599 unsigned NumSrcElts = BinOpTy->getNumElements();
2600
2601 // If we have something like "add X, Y" and "add Z, X", swap ops to match.
2602 if (IsCommutative && X != Z && Y != W && (X == W || Y == Z))
2603 std::swap(X, Y);
2604
2605 auto ConvertToUnary = [NumSrcElts](int &M) {
2606 if (M >= (int)NumSrcElts)
2607 M -= NumSrcElts;
2608 };
2609
2610 SmallVector<int> NewMask0(OldMask);
2612 TTI::OperandValueInfo Op0Info = TTI.commonOperandInfo(X, Z);
2613 if (X == Z) {
2614 llvm::for_each(NewMask0, ConvertToUnary);
2616 Z = PoisonValue::get(BinOpTy);
2617 }
2618
2619 SmallVector<int> NewMask1(OldMask);
2621 TTI::OperandValueInfo Op1Info = TTI.commonOperandInfo(Y, W);
2622 if (Y == W) {
2623 llvm::for_each(NewMask1, ConvertToUnary);
2625 W = PoisonValue::get(BinOpTy);
2626 }
2627
2628 // Try to replace a binop with a shuffle if the shuffle is not costly.
2629 // When SameBinOp, only count the binop cost once.
2632
2633 InstructionCost OldCost = LHSCost;
2634 if (!SameBinOp) {
2635 OldCost += RHSCost;
2636 }
2638 ShuffleDstTy, BinResTy, OldMask, CostKind, 0,
2639 nullptr, {LHS, RHS}, &I);
2640
2641 // Handle shuffle(binop(shuffle(x),y),binop(z,shuffle(w))) style patterns
2642 // where one use shuffles have gotten split across the binop/cmp. These
2643 // often allow a major reduction in total cost that wouldn't happen as
2644 // individual folds.
2645 auto MergeInner = [&](Value *&Op, int Offset, MutableArrayRef<int> Mask,
2646 TTI::TargetCostKind CostKind) -> bool {
2647 Value *InnerOp;
2648 ArrayRef<int> InnerMask;
2649 if (match(Op, m_OneUse(m_Shuffle(m_Value(InnerOp), m_Undef(),
2650 m_Mask(InnerMask)))) &&
2651 InnerOp->getType() == Op->getType() &&
2652 all_of(InnerMask,
2653 [NumSrcElts](int M) { return M < (int)NumSrcElts; })) {
2654 for (int &M : Mask)
2655 if (Offset <= M && M < (int)(Offset + NumSrcElts)) {
2656 M = InnerMask[M - Offset];
2657 M = 0 <= M ? M + Offset : M;
2658 }
2660 Op = InnerOp;
2661 return true;
2662 }
2663 return false;
2664 };
2665 bool ReducedInstCount = false;
2666 ReducedInstCount |= MergeInner(X, 0, NewMask0, CostKind);
2667 ReducedInstCount |= MergeInner(Y, 0, NewMask1, CostKind);
2668 ReducedInstCount |= MergeInner(Z, NumSrcElts, NewMask0, CostKind);
2669 ReducedInstCount |= MergeInner(W, NumSrcElts, NewMask1, CostKind);
2670 bool SingleSrcBinOp = (X == Y) && (Z == W) && (NewMask0 == NewMask1);
2671 // SingleSrcBinOp only reduces instruction count if we also eliminate the
2672 // original binop(s). If binops have multiple uses, they won't be eliminated.
2673 ReducedInstCount |= SingleSrcBinOp && LHS->hasOneUser() && RHS->hasOneUser();
2674
2675 auto *ShuffleCmpTy =
2676 FixedVectorType::get(BinOpTy->getElementType(), ShuffleDstTy);
2678 SK0, ShuffleCmpTy, BinOpTy, NewMask0, CostKind, 0, nullptr, {X, Z});
2679 if (!SingleSrcBinOp)
2680 NewCost += TTI.getShuffleCost(SK1, ShuffleCmpTy, BinOpTy, NewMask1,
2681 CostKind, 0, nullptr, {Y, W});
2682
2683 if (PredLHS == CmpInst::BAD_ICMP_PREDICATE) {
2684 NewCost += TTI.getArithmeticInstrCost(LHS->getOpcode(), ShuffleDstTy,
2685 CostKind, Op0Info, Op1Info);
2686 } else {
2687 NewCost +=
2688 TTI.getCmpSelInstrCost(LHS->getOpcode(), ShuffleCmpTy, ShuffleDstTy,
2689 PredLHS, CostKind, Op0Info, Op1Info);
2690 }
2691 // If LHS/RHS have other uses, we need to account for the cost of keeping
2692 // the original instructions. When SameBinOp, only add the cost once.
2693 if (!LHS->hasOneUser())
2694 NewCost += LHSCost;
2695 if (!SameBinOp && !RHS->hasOneUser())
2696 NewCost += RHSCost;
2697
2698 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two binops: " << I
2699 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2700 << "\n");
2701
2702 // If either shuffle will constant fold away, then fold for the same cost as
2703 // we will reduce the instruction count.
2704 ReducedInstCount |= (isa<Constant>(X) && isa<Constant>(Z)) ||
2705 (isa<Constant>(Y) && isa<Constant>(W));
2706 if (ReducedInstCount ? (NewCost > OldCost) : (NewCost >= OldCost))
2707 return false;
2708
2709 Value *Shuf0 = Builder.CreateShuffleVector(X, Z, NewMask0);
2710 Value *Shuf1 =
2711 SingleSrcBinOp ? Shuf0 : Builder.CreateShuffleVector(Y, W, NewMask1);
2712 Value *NewBO = PredLHS == CmpInst::BAD_ICMP_PREDICATE
2713 ? Builder.CreateBinOp(
2714 cast<BinaryOperator>(LHS)->getOpcode(), Shuf0, Shuf1)
2715 : Builder.CreateCmp(PredLHS, Shuf0, Shuf1);
2716
2717 // Intersect flags from the old binops.
2718 if (auto *NewInst = dyn_cast<Instruction>(NewBO)) {
2719 NewInst->copyIRFlags(LHS);
2720 NewInst->andIRFlags(RHS);
2721 }
2722
2723 Worklist.pushValue(Shuf0);
2724 Worklist.pushValue(Shuf1);
2725 replaceValue(I, *NewBO);
2726 return true;
2727}
2728
2729/// Try to convert,
2730/// (shuffle(select(c1,t1,f1)), (select(c2,t2,f2)), m) into
2731/// (select (shuffle c1,c2,m), (shuffle t1,t2,m), (shuffle f1,f2,m))
2732bool VectorCombine::foldShuffleOfSelects(Instruction &I) {
2733 ArrayRef<int> Mask;
2734 Value *C1, *T1, *F1, *C2, *T2, *F2;
2735 if (!match(&I, m_Shuffle(m_Select(m_Value(C1), m_Value(T1), m_Value(F1)),
2736 m_Select(m_Value(C2), m_Value(T2), m_Value(F2)),
2737 m_Mask(Mask))))
2738 return false;
2739
2740 auto *Sel1 = cast<Instruction>(I.getOperand(0));
2741 auto *Sel2 = cast<Instruction>(I.getOperand(1));
2742
2743 auto *C1VecTy = dyn_cast<FixedVectorType>(C1->getType());
2744 auto *C2VecTy = dyn_cast<FixedVectorType>(C2->getType());
2745 if (!C1VecTy || !C2VecTy || C1VecTy != C2VecTy)
2746 return false;
2747
2748 auto *SI0FOp = dyn_cast<FPMathOperator>(I.getOperand(0));
2749 auto *SI1FOp = dyn_cast<FPMathOperator>(I.getOperand(1));
2750 // SelectInsts must have the same FMF.
2751 if (((SI0FOp == nullptr) != (SI1FOp == nullptr)) ||
2752 ((SI0FOp != nullptr) &&
2753 (SI0FOp->getFastMathFlags() != SI1FOp->getFastMathFlags())))
2754 return false;
2755
2756 auto *SrcVecTy = cast<FixedVectorType>(T1->getType());
2757 auto *DstVecTy = cast<FixedVectorType>(I.getType());
2759 auto SelOp = Instruction::Select;
2760
2762 SelOp, SrcVecTy, C1VecTy, CmpInst::BAD_ICMP_PREDICATE, CostKind);
2764 SelOp, SrcVecTy, C2VecTy, CmpInst::BAD_ICMP_PREDICATE, CostKind);
2765
2766 InstructionCost OldCost =
2767 CostSel1 + CostSel2 +
2768 TTI.getShuffleCost(SK, DstVecTy, SrcVecTy, Mask, CostKind, 0, nullptr,
2769 {I.getOperand(0), I.getOperand(1)}, &I);
2770
2772 SK, FixedVectorType::get(C1VecTy->getScalarType(), Mask.size()), C1VecTy,
2773 Mask, CostKind, 0, nullptr, {C1, C2});
2774 NewCost += TTI.getShuffleCost(SK, DstVecTy, SrcVecTy, Mask, CostKind, 0,
2775 nullptr, {T1, T2});
2776 NewCost += TTI.getShuffleCost(SK, DstVecTy, SrcVecTy, Mask, CostKind, 0,
2777 nullptr, {F1, F2});
2778 auto *C1C2ShuffledVecTy = FixedVectorType::get(
2779 Type::getInt1Ty(I.getContext()), DstVecTy->getNumElements());
2780 NewCost += TTI.getCmpSelInstrCost(SelOp, DstVecTy, C1C2ShuffledVecTy,
2782
2783 if (!Sel1->hasOneUse())
2784 NewCost += CostSel1;
2785 if (!Sel2->hasOneUse())
2786 NewCost += CostSel2;
2787
2788 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two selects: " << I
2789 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2790 << "\n");
2791 if (NewCost > OldCost)
2792 return false;
2793
2794 Value *ShuffleCmp = Builder.CreateShuffleVector(C1, C2, Mask);
2795 Value *ShuffleTrue = Builder.CreateShuffleVector(T1, T2, Mask);
2796 Value *ShuffleFalse = Builder.CreateShuffleVector(F1, F2, Mask);
2797 Value *NewSel;
2798 // We presuppose that the SelectInsts have the same FMF.
2799 if (SI0FOp)
2800 NewSel = Builder.CreateSelectFMF(ShuffleCmp, ShuffleTrue, ShuffleFalse,
2801 SI0FOp->getFastMathFlags());
2802 else
2803 NewSel = Builder.CreateSelect(ShuffleCmp, ShuffleTrue, ShuffleFalse);
2804
2805 Worklist.pushValue(ShuffleCmp);
2806 Worklist.pushValue(ShuffleTrue);
2807 Worklist.pushValue(ShuffleFalse);
2808 replaceValue(I, *NewSel);
2809 return true;
2810}
2811
2812/// Try to convert "shuffle (castop), (castop)" with a shared castop operand
2813/// into "castop (shuffle)".
2814bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
2815 Value *V0, *V1;
2816 ArrayRef<int> OldMask;
2817 if (!match(&I, m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(OldMask))))
2818 return false;
2819
2820 // Check whether this is a binary shuffle.
2821 bool IsBinaryShuffle = !isa<UndefValue>(V1);
2822
2823 auto *C0 = dyn_cast<CastInst>(V0);
2824 auto *C1 = dyn_cast<CastInst>(V1);
2825 if (!C0 || (IsBinaryShuffle && !C1))
2826 return false;
2827
2828 Instruction::CastOps Opcode = C0->getOpcode();
2829
2830 // If this is allowed, foldShuffleOfCastops can get stuck in a loop
2831 // with foldBitcastOfShuffle. Reject in favor of foldBitcastOfShuffle.
2832 if (!IsBinaryShuffle && Opcode == Instruction::BitCast)
2833 return false;
2834
2835 if (IsBinaryShuffle) {
2836 if (C0->getSrcTy() != C1->getSrcTy())
2837 return false;
2838 // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
2839 if (Opcode != C1->getOpcode()) {
2840 if (match(C0, m_SExtLike(m_Value())) && match(C1, m_SExtLike(m_Value())))
2841 Opcode = Instruction::SExt;
2842 else
2843 return false;
2844 }
2845 }
2846
2847 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
2848 auto *CastDstTy = dyn_cast<FixedVectorType>(C0->getDestTy());
2849 auto *CastSrcTy = dyn_cast<FixedVectorType>(C0->getSrcTy());
2850 if (!ShuffleDstTy || !CastDstTy || !CastSrcTy)
2851 return false;
2852
2853 unsigned NumSrcElts = CastSrcTy->getNumElements();
2854 unsigned NumDstElts = CastDstTy->getNumElements();
2855 assert((NumDstElts == NumSrcElts || Opcode == Instruction::BitCast) &&
2856 "Only bitcasts expected to alter src/dst element counts");
2857
2858 // Check for bitcasting of unscalable vector types.
2859 // e.g. <32 x i40> -> <40 x i32>
2860 if (NumDstElts != NumSrcElts && (NumSrcElts % NumDstElts) != 0 &&
2861 (NumDstElts % NumSrcElts) != 0)
2862 return false;
2863
2864 SmallVector<int, 16> NewMask;
2865 if (NumSrcElts >= NumDstElts) {
2866 // The bitcast is from wide to narrow/equal elements. The shuffle mask can
2867 // always be expanded to the equivalent form choosing narrower elements.
2868 assert(NumSrcElts % NumDstElts == 0 && "Unexpected shuffle mask");
2869 unsigned ScaleFactor = NumSrcElts / NumDstElts;
2870 narrowShuffleMaskElts(ScaleFactor, OldMask, NewMask);
2871 } else {
2872 // The bitcast is from narrow elements to wide elements. The shuffle mask
2873 // must choose consecutive elements to allow casting first.
2874 assert(NumDstElts % NumSrcElts == 0 && "Unexpected shuffle mask");
2875 unsigned ScaleFactor = NumDstElts / NumSrcElts;
2876 if (!widenShuffleMaskElts(ScaleFactor, OldMask, NewMask))
2877 return false;
2878 }
2879
2880 auto *NewShuffleDstTy =
2881 FixedVectorType::get(CastSrcTy->getScalarType(), NewMask.size());
2882
2883 // Try to replace a castop with a shuffle if the shuffle is not costly.
2884 InstructionCost CostC0 =
2885 TTI.getCastInstrCost(C0->getOpcode(), CastDstTy, CastSrcTy,
2887
2889 if (IsBinaryShuffle)
2891 else
2893
2894 InstructionCost OldCost = CostC0;
2895 OldCost += TTI.getShuffleCost(ShuffleKind, ShuffleDstTy, CastDstTy, OldMask,
2896 CostKind, 0, nullptr, {}, &I);
2897
2898 InstructionCost NewCost = TTI.getShuffleCost(ShuffleKind, NewShuffleDstTy,
2899 CastSrcTy, NewMask, CostKind);
2900 NewCost += TTI.getCastInstrCost(Opcode, ShuffleDstTy, NewShuffleDstTy,
2902 if (!C0->hasOneUse())
2903 NewCost += CostC0;
2904 if (IsBinaryShuffle) {
2905 InstructionCost CostC1 =
2906 TTI.getCastInstrCost(C1->getOpcode(), CastDstTy, CastSrcTy,
2908 OldCost += CostC1;
2909 if (!C1->hasOneUse())
2910 NewCost += CostC1;
2911 }
2912
2913 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two casts: " << I
2914 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2915 << "\n");
2916 if (NewCost > OldCost)
2917 return false;
2918
2919 Value *Shuf;
2920 if (IsBinaryShuffle)
2921 Shuf = Builder.CreateShuffleVector(C0->getOperand(0), C1->getOperand(0),
2922 NewMask);
2923 else
2924 Shuf = Builder.CreateShuffleVector(C0->getOperand(0), NewMask);
2925
2926 Value *Cast = Builder.CreateCast(Opcode, Shuf, ShuffleDstTy);
2927
2928 // Intersect flags from the old casts.
2929 if (auto *NewInst = dyn_cast<Instruction>(Cast)) {
2930 NewInst->copyIRFlags(C0);
2931 if (IsBinaryShuffle)
2932 NewInst->andIRFlags(C1);
2933 }
2934
2935 Worklist.pushValue(Shuf);
2936 replaceValue(I, *Cast);
2937 return true;
2938}
2939
2940/// Try to convert any of:
2941/// "shuffle (shuffle x, y), (shuffle y, x)"
2942/// "shuffle (shuffle x, undef), (shuffle y, undef)"
2943/// "shuffle (shuffle x, undef), y"
2944/// "shuffle x, (shuffle y, undef)"
2945/// into "shuffle x, y".
2946bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
2947 ArrayRef<int> OuterMask;
2948 Value *OuterV0, *OuterV1;
2949 if (!match(&I,
2950 m_Shuffle(m_Value(OuterV0), m_Value(OuterV1), m_Mask(OuterMask))))
2951 return false;
2952
2953 ArrayRef<int> InnerMask0, InnerMask1;
2954 Value *X0, *X1, *Y0, *Y1;
2955 bool Match0 =
2956 match(OuterV0, m_Shuffle(m_Value(X0), m_Value(Y0), m_Mask(InnerMask0)));
2957 bool Match1 =
2958 match(OuterV1, m_Shuffle(m_Value(X1), m_Value(Y1), m_Mask(InnerMask1)));
2959 if (!Match0 && !Match1)
2960 return false;
2961
2962 // If the outer shuffle is a permute, then create a fake inner all-poison
2963 // shuffle. This is easier than accounting for length-changing shuffles below.
2964 SmallVector<int, 16> PoisonMask1;
2965 if (!Match1 && isa<PoisonValue>(OuterV1)) {
2966 X1 = X0;
2967 Y1 = Y0;
2968 PoisonMask1.append(InnerMask0.size(), PoisonMaskElem);
2969 InnerMask1 = PoisonMask1;
2970 Match1 = true; // fake match
2971 }
2972
2973 X0 = Match0 ? X0 : OuterV0;
2974 Y0 = Match0 ? Y0 : OuterV0;
2975 X1 = Match1 ? X1 : OuterV1;
2976 Y1 = Match1 ? Y1 : OuterV1;
2977 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
2978 auto *ShuffleSrcTy = dyn_cast<FixedVectorType>(X0->getType());
2979 auto *ShuffleImmTy = dyn_cast<FixedVectorType>(OuterV0->getType());
2980 if (!ShuffleDstTy || !ShuffleSrcTy || !ShuffleImmTy ||
2981 X0->getType() != X1->getType())
2982 return false;
2983
2984 unsigned NumSrcElts = ShuffleSrcTy->getNumElements();
2985 unsigned NumImmElts = ShuffleImmTy->getNumElements();
2986
2987 // Attempt to merge shuffles, matching upto 2 source operands.
2988 // Replace index to a poison arg with PoisonMaskElem.
2989 // Bail if either inner masks reference an undef arg.
2990 SmallVector<int, 16> NewMask(OuterMask);
2991 Value *NewX = nullptr, *NewY = nullptr;
2992 for (int &M : NewMask) {
2993 Value *Src = nullptr;
2994 if (0 <= M && M < (int)NumImmElts) {
2995 Src = OuterV0;
2996 if (Match0) {
2997 M = InnerMask0[M];
2998 Src = M >= (int)NumSrcElts ? Y0 : X0;
2999 M = M >= (int)NumSrcElts ? (M - NumSrcElts) : M;
3000 }
3001 } else if (M >= (int)NumImmElts) {
3002 Src = OuterV1;
3003 M -= NumImmElts;
3004 if (Match1) {
3005 M = InnerMask1[M];
3006 Src = M >= (int)NumSrcElts ? Y1 : X1;
3007 M = M >= (int)NumSrcElts ? (M - NumSrcElts) : M;
3008 }
3009 }
3010 if (Src && M != PoisonMaskElem) {
3011 assert(0 <= M && M < (int)NumSrcElts && "Unexpected shuffle mask index");
3012 if (isa<UndefValue>(Src)) {
3013 // We've referenced an undef element - if its poison, update the shuffle
3014 // mask, else bail.
3015 if (!isa<PoisonValue>(Src))
3016 return false;
3017 M = PoisonMaskElem;
3018 continue;
3019 }
3020 if (!NewX || NewX == Src) {
3021 NewX = Src;
3022 continue;
3023 }
3024 if (!NewY || NewY == Src) {
3025 M += NumSrcElts;
3026 NewY = Src;
3027 continue;
3028 }
3029 return false;
3030 }
3031 }
3032
3033 if (!NewX) {
3034 replaceValue(I, *PoisonValue::get(ShuffleDstTy));
3035 return true;
3036 }
3037
3038 if (!NewY)
3039 NewY = PoisonValue::get(ShuffleSrcTy);
3040
3041 // Have we folded to an Identity shuffle?
3042 if (ShuffleVectorInst::isIdentityMask(NewMask, NumSrcElts)) {
3043 replaceValue(I, *NewX);
3044 return true;
3045 }
3046
3047 // Try to merge the shuffles if the new shuffle is not costly.
3048 InstructionCost InnerCost0 = 0;
3049 if (Match0)
3050 InnerCost0 = TTI.getInstructionCost(cast<User>(OuterV0), CostKind);
3051
3052 InstructionCost InnerCost1 = 0;
3053 if (Match1)
3054 InnerCost1 = TTI.getInstructionCost(cast<User>(OuterV1), CostKind);
3055
3057
3058 InstructionCost OldCost = InnerCost0 + InnerCost1 + OuterCost;
3059
3060 bool IsUnary = all_of(NewMask, [&](int M) { return M < (int)NumSrcElts; });
3064 InstructionCost NewCost =
3065 TTI.getShuffleCost(SK, ShuffleDstTy, ShuffleSrcTy, NewMask, CostKind, 0,
3066 nullptr, {NewX, NewY});
3067 if (!OuterV0->hasOneUse())
3068 NewCost += InnerCost0;
3069 if (!OuterV1->hasOneUse())
3070 NewCost += InnerCost1;
3071
3072 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two shuffles: " << I
3073 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
3074 << "\n");
3075 if (NewCost > OldCost)
3076 return false;
3077
3078 Value *Shuf = Builder.CreateShuffleVector(NewX, NewY, NewMask);
3079 replaceValue(I, *Shuf);
3080 return true;
3081}
3082
3083/// Try to convert a chain of length-preserving shuffles that are fed by
3084/// length-changing shuffles from the same source, e.g. a chain of length 3:
3085///
3086/// "shuffle (shuffle (shuffle x, (shuffle y, undef)),
3087/// (shuffle y, undef)),
3088// (shuffle y, undef)"
3089///
3090/// into a single shuffle fed by a length-changing shuffle:
3091///
3092/// "shuffle x, (shuffle y, undef)"
3093///
3094/// Such chains arise e.g. from folding extract/insert sequences.
3095bool VectorCombine::foldShufflesOfLengthChangingShuffles(Instruction &I) {
3096 FixedVectorType *TrunkType = dyn_cast<FixedVectorType>(I.getType());
3097 if (!TrunkType)
3098 return false;
3099
3100 unsigned ChainLength = 0;
3101 SmallVector<int> Mask;
3102 SmallVector<int> YMask;
3103 InstructionCost OldCost = 0;
3104 InstructionCost NewCost = 0;
3105 Value *Trunk = &I;
3106 unsigned NumTrunkElts = TrunkType->getNumElements();
3107 Value *Y = nullptr;
3108
3109 for (;;) {
3110 // Match the current trunk against (commutations of) the pattern
3111 // "shuffle trunk', (shuffle y, undef)"
3112 ArrayRef<int> OuterMask;
3113 Value *OuterV0, *OuterV1;
3114 if (ChainLength != 0 && !Trunk->hasOneUse())
3115 break;
3116 if (!match(Trunk, m_Shuffle(m_Value(OuterV0), m_Value(OuterV1),
3117 m_Mask(OuterMask))))
3118 break;
3119 if (OuterV0->getType() != TrunkType) {
3120 // This shuffle is not length-preserving, so it cannot be part of the
3121 // chain.
3122 break;
3123 }
3124
3125 ArrayRef<int> InnerMask0, InnerMask1;
3126 Value *A0, *A1, *B0, *B1;
3127 bool Match0 =
3128 match(OuterV0, m_Shuffle(m_Value(A0), m_Value(B0), m_Mask(InnerMask0)));
3129 bool Match1 =
3130 match(OuterV1, m_Shuffle(m_Value(A1), m_Value(B1), m_Mask(InnerMask1)));
3131 bool Match0Leaf = Match0 && A0->getType() != I.getType();
3132 bool Match1Leaf = Match1 && A1->getType() != I.getType();
3133 if (Match0Leaf == Match1Leaf) {
3134 // Only handle the case of exactly one leaf in each step. The "two leaves"
3135 // case is handled by foldShuffleOfShuffles.
3136 break;
3137 }
3138
3139 SmallVector<int> CommutedOuterMask;
3140 if (Match0Leaf) {
3141 std::swap(OuterV0, OuterV1);
3142 std::swap(InnerMask0, InnerMask1);
3143 std::swap(A0, A1);
3144 std::swap(B0, B1);
3145 llvm::append_range(CommutedOuterMask, OuterMask);
3146 for (int &M : CommutedOuterMask) {
3147 if (M == PoisonMaskElem)
3148 continue;
3149 if (M < (int)NumTrunkElts)
3150 M += NumTrunkElts;
3151 else
3152 M -= NumTrunkElts;
3153 }
3154 OuterMask = CommutedOuterMask;
3155 }
3156 if (!OuterV1->hasOneUse())
3157 break;
3158
3159 if (!isa<UndefValue>(A1)) {
3160 if (!Y)
3161 Y = A1;
3162 else if (Y != A1)
3163 break;
3164 }
3165 if (!isa<UndefValue>(B1)) {
3166 if (!Y)
3167 Y = B1;
3168 else if (Y != B1)
3169 break;
3170 }
3171
3172 auto *YType = cast<FixedVectorType>(A1->getType());
3173 int NumLeafElts = YType->getNumElements();
3174 SmallVector<int> LocalYMask(InnerMask1);
3175 for (int &M : LocalYMask) {
3176 if (M >= NumLeafElts)
3177 M -= NumLeafElts;
3178 }
3179
3180 InstructionCost LocalOldCost =
3183
3184 // Handle the initial (start of chain) case.
3185 if (!ChainLength) {
3186 Mask.assign(OuterMask);
3187 YMask.assign(LocalYMask);
3188 OldCost = NewCost = LocalOldCost;
3189 Trunk = OuterV0;
3190 ChainLength++;
3191 continue;
3192 }
3193
3194 // For the non-root case, first attempt to combine masks.
3195 SmallVector<int> NewYMask(YMask);
3196 bool Valid = true;
3197 for (auto [CombinedM, LeafM] : llvm::zip(NewYMask, LocalYMask)) {
3198 if (LeafM == -1 || CombinedM == LeafM)
3199 continue;
3200 if (CombinedM == -1) {
3201 CombinedM = LeafM;
3202 } else {
3203 Valid = false;
3204 break;
3205 }
3206 }
3207 if (!Valid)
3208 break;
3209
3210 SmallVector<int> NewMask;
3211 NewMask.reserve(NumTrunkElts);
3212 for (int M : Mask) {
3213 if (M < 0 || M >= static_cast<int>(NumTrunkElts))
3214 NewMask.push_back(M);
3215 else
3216 NewMask.push_back(OuterMask[M]);
3217 }
3218
3219 // Break the chain if adding this new step complicates the shuffles such
3220 // that it would increase the new cost by more than the old cost of this
3221 // step.
3222 InstructionCost LocalNewCost =
3224 YType, NewYMask, CostKind) +
3226 TrunkType, NewMask, CostKind);
3227
3228 if (LocalNewCost >= NewCost && LocalOldCost < LocalNewCost - NewCost)
3229 break;
3230
3231 LLVM_DEBUG({
3232 if (ChainLength == 1) {
3233 dbgs() << "Found chain of shuffles fed by length-changing shuffles: "
3234 << I << '\n';
3235 }
3236 dbgs() << " next chain link: " << *Trunk << '\n'
3237 << " old cost: " << (OldCost + LocalOldCost)
3238 << " new cost: " << LocalNewCost << '\n';
3239 });
3240
3241 Mask = NewMask;
3242 YMask = NewYMask;
3243 OldCost += LocalOldCost;
3244 NewCost = LocalNewCost;
3245 Trunk = OuterV0;
3246 ChainLength++;
3247 }
3248 if (ChainLength <= 1)
3249 return false;
3250
3251 if (llvm::all_of(Mask, [&](int M) {
3252 return M < 0 || M >= static_cast<int>(NumTrunkElts);
3253 })) {
3254 // Produce a canonical simplified form if all elements are sourced from Y.
3255 for (int &M : Mask) {
3256 if (M >= static_cast<int>(NumTrunkElts))
3257 M = YMask[M - NumTrunkElts];
3258 }
3259 Value *Root =
3260 Builder.CreateShuffleVector(Y, PoisonValue::get(Y->getType()), Mask);
3261 replaceValue(I, *Root);
3262 return true;
3263 }
3264
3265 Value *Leaf =
3266 Builder.CreateShuffleVector(Y, PoisonValue::get(Y->getType()), YMask);
3267 Value *Root = Builder.CreateShuffleVector(Trunk, Leaf, Mask);
3268 replaceValue(I, *Root);
3269 return true;
3270}
3271
3272/// Try to convert
3273/// "shuffle (intrinsic), (intrinsic)" into "intrinsic (shuffle), (shuffle)".
3274bool VectorCombine::foldShuffleOfIntrinsics(Instruction &I) {
3275 Value *V0, *V1;
3276 ArrayRef<int> OldMask;
3277 if (!match(&I, m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(OldMask))))
3278 return false;
3279
3280 auto *II0 = dyn_cast<IntrinsicInst>(V0);
3281 auto *II1 = dyn_cast<IntrinsicInst>(V1);
3282 if (!II0 || !II1)
3283 return false;
3284
3285 Intrinsic::ID IID = II0->getIntrinsicID();
3286 if (IID != II1->getIntrinsicID())
3287 return false;
3288 InstructionCost CostII0 =
3289 TTI.getIntrinsicInstrCost(IntrinsicCostAttributes(IID, *II0), CostKind);
3290 InstructionCost CostII1 =
3291 TTI.getIntrinsicInstrCost(IntrinsicCostAttributes(IID, *II1), CostKind);
3292
3293 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
3294 auto *II0Ty = dyn_cast<FixedVectorType>(II0->getType());
3295 if (!ShuffleDstTy || !II0Ty)
3296 return false;
3297
3298 if (!isTriviallyVectorizable(IID))
3299 return false;
3300
3301 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
3302 Value *Arg0 = II0->getArgOperand(I);
3303 Value *Arg1 = II1->getArgOperand(I);
3305 // Scalar operands must be identical.
3306 if (Arg0 != Arg1)
3307 return false;
3308 } else if (Arg0->getType() != Arg1->getType()) {
3309 // The corresponding vector operands are shuffled together, so they must
3310 // share the same type. For intrinsics overloaded on their operand type
3311 // (e.g. llvm.fptosi.sat), two calls can produce the same result type
3312 // from different operand types; shuffling those would be invalid.
3313 return false;
3314 }
3315 }
3316
3317 InstructionCost OldCost =
3318 CostII0 + CostII1 +
3320 II0Ty, OldMask, CostKind, 0, nullptr, {II0, II1}, &I);
3321
3322 SmallVector<Type *> NewArgsTy;
3323 InstructionCost NewCost = 0;
3324 SmallDenseSet<std::pair<Value *, Value *>> SeenOperandPairs;
3325 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
3327 NewArgsTy.push_back(II0->getArgOperand(I)->getType());
3328 } else {
3329 auto *VecTy = cast<FixedVectorType>(II0->getArgOperand(I)->getType());
3330 auto *ArgTy = FixedVectorType::get(VecTy->getElementType(),
3331 ShuffleDstTy->getNumElements());
3332 NewArgsTy.push_back(ArgTy);
3333 std::pair<Value *, Value *> OperandPair =
3334 std::make_pair(II0->getArgOperand(I), II1->getArgOperand(I));
3335 if (!SeenOperandPairs.insert(OperandPair).second) {
3336 // We've already computed the cost for this operand pair.
3337 continue;
3338 }
3339 NewCost += TTI.getShuffleCost(
3340 TargetTransformInfo::SK_PermuteTwoSrc, ArgTy, VecTy, OldMask,
3341 CostKind, 0, nullptr, {II0->getArgOperand(I), II1->getArgOperand(I)});
3342 }
3343 }
3344 IntrinsicCostAttributes NewAttr(IID, ShuffleDstTy, NewArgsTy);
3345
3346 NewCost += TTI.getIntrinsicInstrCost(NewAttr, CostKind);
3347 if (!II0->hasOneUse())
3348 NewCost += CostII0;
3349 if (II1 != II0 && !II1->hasOneUse())
3350 NewCost += CostII1;
3351
3352 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two intrinsics: " << I
3353 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
3354 << "\n");
3355
3356 if (NewCost > OldCost)
3357 return false;
3358
3359 SmallVector<Value *> NewArgs;
3360 SmallDenseMap<std::pair<Value *, Value *>, Value *> ShuffleCache;
3361 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I)
3363 NewArgs.push_back(II0->getArgOperand(I));
3364 } else {
3365 std::pair<Value *, Value *> OperandPair =
3366 std::make_pair(II0->getArgOperand(I), II1->getArgOperand(I));
3367 auto It = ShuffleCache.find(OperandPair);
3368 if (It != ShuffleCache.end()) {
3369 // Reuse previously created shuffle for this operand pair.
3370 NewArgs.push_back(It->second);
3371 continue;
3372 }
3373 Value *Shuf = Builder.CreateShuffleVector(II0->getArgOperand(I),
3374 II1->getArgOperand(I), OldMask);
3375 ShuffleCache[OperandPair] = Shuf;
3376 NewArgs.push_back(Shuf);
3377 Worklist.pushValue(Shuf);
3378 }
3379 Value *NewIntrinsic = Builder.CreateIntrinsic(ShuffleDstTy, IID, NewArgs);
3380
3381 // Intersect flags from the old intrinsics.
3382 if (auto *NewInst = dyn_cast<Instruction>(NewIntrinsic)) {
3383 NewInst->copyIRFlags(II0);
3384 NewInst->andIRFlags(II1);
3385 }
3386
3387 replaceValue(I, *NewIntrinsic);
3388 return true;
3389}
3390
3391/// Try to convert
3392/// "shuffle (intrinsic), (poison/undef)" into "intrinsic (shuffle)".
3393bool VectorCombine::foldPermuteOfIntrinsic(Instruction &I) {
3394 Value *V0;
3395 ArrayRef<int> Mask;
3396 if (!match(&I, m_Shuffle(m_Value(V0), m_Undef(), m_Mask(Mask))))
3397 return false;
3398
3399 auto *II0 = dyn_cast<IntrinsicInst>(V0);
3400 if (!II0)
3401 return false;
3402
3403 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
3404 auto *IntrinsicSrcTy = dyn_cast<FixedVectorType>(II0->getType());
3405 if (!ShuffleDstTy || !IntrinsicSrcTy)
3406 return false;
3407
3408 // Validate it's a pure permute, mask should only reference the first vector
3409 unsigned NumSrcElts = IntrinsicSrcTy->getNumElements();
3410 if (any_of(Mask, [NumSrcElts](int M) { return M >= (int)NumSrcElts; }))
3411 return false;
3412
3413 Intrinsic::ID IID = II0->getIntrinsicID();
3414 if (!isTriviallyVectorizable(IID))
3415 return false;
3416
3417 // Cost analysis
3419 TTI.getIntrinsicInstrCost(IntrinsicCostAttributes(IID, *II0), CostKind);
3420 InstructionCost OldCost =
3423 IntrinsicSrcTy, Mask, CostKind, 0, nullptr, {V0}, &I);
3424
3425 SmallVector<Type *> NewArgsTy;
3426 InstructionCost NewCost = 0;
3427 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
3429 NewArgsTy.push_back(II0->getArgOperand(I)->getType());
3430 } else {
3431 auto *VecTy = cast<FixedVectorType>(II0->getArgOperand(I)->getType());
3432 auto *ArgTy = FixedVectorType::get(VecTy->getElementType(),
3433 ShuffleDstTy->getNumElements());
3434 NewArgsTy.push_back(ArgTy);
3436 ArgTy, VecTy, Mask, CostKind, 0, nullptr,
3437 {II0->getArgOperand(I)});
3438 }
3439 }
3440 IntrinsicCostAttributes NewAttr(IID, ShuffleDstTy, NewArgsTy);
3441 NewCost += TTI.getIntrinsicInstrCost(NewAttr, CostKind);
3442
3443 // If the intrinsic has multiple uses, we need to account for the cost of
3444 // keeping the original intrinsic around.
3445 if (!II0->hasOneUse())
3446 NewCost += IntrinsicCost;
3447
3448 LLVM_DEBUG(dbgs() << "Found a permute of intrinsic: " << I << "\n OldCost: "
3449 << OldCost << " vs NewCost: " << NewCost << "\n");
3450
3451 if (NewCost > OldCost)
3452 return false;
3453
3454 // Transform
3455 SmallVector<Value *> NewArgs;
3456 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
3458 NewArgs.push_back(II0->getArgOperand(I));
3459 } else {
3460 Value *Shuf = Builder.CreateShuffleVector(II0->getArgOperand(I), Mask);
3461 NewArgs.push_back(Shuf);
3462 Worklist.pushValue(Shuf);
3463 }
3464 }
3465
3466 Value *NewIntrinsic = Builder.CreateIntrinsic(ShuffleDstTy, IID, NewArgs);
3467
3468 if (auto *NewInst = dyn_cast<Instruction>(NewIntrinsic))
3469 NewInst->copyIRFlags(II0);
3470
3471 replaceValue(I, *NewIntrinsic);
3472 return true;
3473}
3474
3475using InstLane = std::pair<Value *, int>;
3476
3477static InstLane lookThroughShuffles(Value *V, int Lane) {
3478 while (auto *SV = dyn_cast<ShuffleVectorInst>(V)) {
3479 unsigned NumElts =
3480 cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements();
3481 int M = SV->getMaskValue(Lane);
3482 if (M < 0)
3483 return {nullptr, PoisonMaskElem};
3484 if (static_cast<unsigned>(M) < NumElts) {
3485 V = SV->getOperand(0);
3486 Lane = M;
3487 } else {
3488 V = SV->getOperand(1);
3489 Lane = M - NumElts;
3490 }
3491 }
3492 return InstLane{V, Lane};
3493}
3494
3498 for (InstLane IL : Item) {
3499 auto [U, Lane] = IL;
3500 InstLane OpLane =
3501 U ? lookThroughShuffles(cast<Instruction>(U)->getOperand(Op), Lane)
3502 : InstLane{nullptr, PoisonMaskElem};
3503 NItem.emplace_back(OpLane);
3504 }
3505 return NItem;
3506}
3507
3508/// Detect concat of multiple values into a vector
3510 const TargetTransformInfo &TTI) {
3511 auto *Ty = cast<FixedVectorType>(Item.front().first->getType());
3512 unsigned NumElts = Ty->getNumElements();
3513 if (Item.size() == NumElts || NumElts == 1 || Item.size() % NumElts != 0)
3514 return false;
3515
3516 // Check that the concat is free, usually meaning that the type will be split
3517 // during legalization.
3518 SmallVector<int, 16> ConcatMask(NumElts * 2);
3519 std::iota(ConcatMask.begin(), ConcatMask.end(), 0);
3520 if (TTI.getShuffleCost(TTI::SK_PermuteTwoSrc,
3521 FixedVectorType::get(Ty->getScalarType(), NumElts * 2),
3522 Ty, ConcatMask, CostKind) != 0)
3523 return false;
3524
3525 unsigned NumSlices = Item.size() / NumElts;
3526 // Currently we generate a tree of shuffles for the concats, which limits us
3527 // to a power2.
3528 if (!isPowerOf2_32(NumSlices))
3529 return false;
3530 for (unsigned Slice = 0; Slice < NumSlices; ++Slice) {
3531 Value *SliceV = Item[Slice * NumElts].first;
3532 if (!SliceV || SliceV->getType() != Ty)
3533 return false;
3534 for (unsigned Elt = 0; Elt < NumElts; ++Elt) {
3535 auto [V, Lane] = Item[Slice * NumElts + Elt];
3536 if (Lane != static_cast<int>(Elt) || SliceV != V)
3537 return false;
3538 }
3539 }
3540 return true;
3541}
3542
3543static Value *
3545 const DenseSet<std::pair<Value *, Use *>> &IdentityLeafs,
3546 const DenseSet<std::pair<Value *, Use *>> &SplatLeafs,
3547 const DenseSet<std::pair<Value *, Use *>> &ConcatLeafs,
3548 IRBuilderBase &Builder, const TargetTransformInfo *TTI) {
3549 auto [FrontV, FrontLane] = Item.front();
3550
3551 if (IdentityLeafs.contains(std::make_pair(FrontV, From))) {
3552 return FrontV;
3553 }
3554 if (SplatLeafs.contains(std::make_pair(FrontV, From))) {
3555 SmallVector<int, 16> Mask(Ty->getNumElements(), FrontLane);
3556 return Builder.CreateShuffleVector(FrontV, Mask);
3557 }
3558 if (ConcatLeafs.contains(std::make_pair(FrontV, From))) {
3559 unsigned NumElts =
3560 cast<FixedVectorType>(FrontV->getType())->getNumElements();
3561 SmallVector<Value *> Values(Item.size() / NumElts, nullptr);
3562 for (unsigned S = 0; S < Values.size(); ++S)
3563 Values[S] = Item[S * NumElts].first;
3564
3565 while (Values.size() > 1) {
3566 NumElts *= 2;
3567 SmallVector<int, 16> Mask(NumElts, 0);
3568 std::iota(Mask.begin(), Mask.end(), 0);
3569 SmallVector<Value *> NewValues(Values.size() / 2, nullptr);
3570 for (unsigned S = 0; S < NewValues.size(); ++S)
3571 NewValues[S] =
3572 Builder.CreateShuffleVector(Values[S * 2], Values[S * 2 + 1], Mask);
3573 Values = NewValues;
3574 }
3575 return Values[0];
3576 }
3577
3578 auto *I = cast<Instruction>(FrontV);
3579 auto *II = dyn_cast<IntrinsicInst>(I);
3580 unsigned NumOps = I->getNumOperands() - (II ? 1 : 0);
3582 for (unsigned Idx = 0; Idx < NumOps; Idx++) {
3583 if (II &&
3584 isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Idx, TTI)) {
3585 Ops[Idx] = II->getOperand(Idx);
3586 continue;
3587 }
3589 &I->getOperandUse(Idx), Ty, IdentityLeafs,
3590 SplatLeafs, ConcatLeafs, Builder, TTI);
3591 }
3592
3593 SmallVector<Value *, 8> ValueList;
3594 for (const auto &Lane : Item)
3595 if (Lane.first)
3596 ValueList.push_back(Lane.first);
3597
3598 Type *DstTy =
3599 FixedVectorType::get(I->getType()->getScalarType(), Ty->getNumElements());
3600 if (auto *BI = dyn_cast<BinaryOperator>(I)) {
3601 auto *Value = Builder.CreateBinOp((Instruction::BinaryOps)BI->getOpcode(),
3602 Ops[0], Ops[1]);
3603 propagateIRFlags(Value, ValueList);
3604 return Value;
3605 }
3606 if (auto *CI = dyn_cast<CmpInst>(I)) {
3607 auto *Value = Builder.CreateCmp(CI->getPredicate(), Ops[0], Ops[1]);
3608 propagateIRFlags(Value, ValueList);
3609 return Value;
3610 }
3611 if (auto *SI = dyn_cast<SelectInst>(I)) {
3612 auto *Value = Builder.CreateSelect(Ops[0], Ops[1], Ops[2], "", SI);
3613 propagateIRFlags(Value, ValueList);
3614 return Value;
3615 }
3616 if (auto *CI = dyn_cast<CastInst>(I)) {
3617 auto *Value = Builder.CreateCast(CI->getOpcode(), Ops[0], DstTy);
3618 propagateIRFlags(Value, ValueList);
3619 return Value;
3620 }
3621 if (II) {
3622 auto *Value = Builder.CreateIntrinsic(DstTy, II->getIntrinsicID(), Ops);
3623 propagateIRFlags(Value, ValueList);
3624 return Value;
3625 }
3626 assert(isa<UnaryInstruction>(I) && "Unexpected instruction type in Generate");
3627 auto *Value =
3628 Builder.CreateUnOp((Instruction::UnaryOps)I->getOpcode(), Ops[0]);
3629 propagateIRFlags(Value, ValueList);
3630 return Value;
3631}
3632
3633// Starting from a shuffle, look up through operands tracking the shuffled index
3634// of each lane. If we can simplify away the shuffles to identities then
3635// do so.
3636bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
3637 auto *Ty = dyn_cast<FixedVectorType>(I.getType());
3638 if (!Ty || I.use_empty())
3639 return false;
3640
3641 SmallVector<InstLane> Start(Ty->getNumElements());
3642 for (unsigned M = 0, E = Ty->getNumElements(); M < E; ++M)
3643 Start[M] = lookThroughShuffles(&I, M);
3644
3646 Worklist.push_back(std::make_pair(Start, &*I.use_begin()));
3647 DenseSet<std::pair<Value *, Use *>> IdentityLeafs, SplatLeafs, ConcatLeafs;
3648 unsigned NumVisited = 0;
3649
3650 while (!Worklist.empty()) {
3651 if (++NumVisited > MaxInstrsToScan)
3652 return false;
3653
3654 auto ItemFrom = Worklist.pop_back_val();
3655 auto Item = ItemFrom.first;
3656 auto From = ItemFrom.second;
3657 auto [FrontV, FrontLane] = Item.front();
3658
3659 // If we found an undef first lane then bail out to keep things simple.
3660 if (!FrontV)
3661 return false;
3662
3663 // Helper to peek through bitcasts to the same value.
3664 auto IsEquiv = [&](Value *X, Value *Y) {
3665 return X->getType() == Y->getType() &&
3667 };
3668
3669 // Look for an identity value.
3670 if (FrontLane == 0 &&
3671 cast<FixedVectorType>(FrontV->getType())->getNumElements() ==
3672 Ty->getNumElements() &&
3673 all_of(drop_begin(enumerate(Item)), [IsEquiv, Item](const auto &E) {
3674 Value *FrontV = Item.front().first;
3675 return !E.value().first || (IsEquiv(E.value().first, FrontV) &&
3676 E.value().second == (int)E.index());
3677 })) {
3678 IdentityLeafs.insert(std::make_pair(FrontV, From));
3679 continue;
3680 }
3681 // Look for constants, for the moment only supporting constant splats.
3682 if (auto *C = dyn_cast<Constant>(FrontV);
3683 C && C->getSplatValue() &&
3684 all_of(drop_begin(Item), [Item](InstLane &IL) {
3685 Value *FrontV = Item.front().first;
3686 Value *V = IL.first;
3687 return !V || (isa<Constant>(V) &&
3688 cast<Constant>(V)->getSplatValue() ==
3689 cast<Constant>(FrontV)->getSplatValue());
3690 })) {
3691 SplatLeafs.insert(std::make_pair(FrontV, From));
3692 continue;
3693 }
3694 // Look for a splat value.
3695 if (all_of(drop_begin(Item), [Item](InstLane &IL) {
3696 auto [FrontV, FrontLane] = Item.front();
3697 auto [V, Lane] = IL;
3698 return !V || (V == FrontV && Lane == FrontLane);
3699 })) {
3700 SplatLeafs.insert(std::make_pair(FrontV, From));
3701 continue;
3702 }
3703
3704 // We need each element to be the same type of value, and check that each
3705 // element has a single use.
3706 auto CheckLaneIsEquivalentToFirst = [Item](InstLane IL) {
3707 Value *FrontV = Item.front().first;
3708 if (!IL.first)
3709 return true;
3710 Value *V = IL.first;
3711 if (auto *I = dyn_cast<Instruction>(V); I && !I->hasOneUser())
3712 return false;
3713 if (V->getValueID() != FrontV->getValueID())
3714 return false;
3715 if (auto *CI = dyn_cast<CmpInst>(V))
3716 if (CI->getPredicate() != cast<CmpInst>(FrontV)->getPredicate())
3717 return false;
3718 if (auto *CI = dyn_cast<CastInst>(V))
3719 if (CI->getSrcTy()->getScalarType() !=
3720 cast<CastInst>(FrontV)->getSrcTy()->getScalarType())
3721 return false;
3722 if (auto *SI = dyn_cast<SelectInst>(V))
3723 if (!isa<VectorType>(SI->getOperand(0)->getType()) ||
3724 SI->getOperand(0)->getType() !=
3725 cast<SelectInst>(FrontV)->getOperand(0)->getType())
3726 return false;
3727 if (isa<CallInst>(V) && !isa<IntrinsicInst>(V))
3728 return false;
3729 auto *II = dyn_cast<IntrinsicInst>(V);
3730 return !II || (isa<IntrinsicInst>(FrontV) &&
3731 II->getIntrinsicID() ==
3732 cast<IntrinsicInst>(FrontV)->getIntrinsicID() &&
3733 !II->hasOperandBundles());
3734 };
3735 if (all_of(drop_begin(Item), CheckLaneIsEquivalentToFirst)) {
3736 // Check the operator is one that we support.
3737 if (isa<BinaryOperator, CmpInst>(FrontV)) {
3738 // We exclude div/rem in case they hit UB from poison lanes.
3739 if (auto *BO = dyn_cast<BinaryOperator>(FrontV);
3740 BO && BO->isIntDivRem())
3741 return false;
3743 &cast<Instruction>(FrontV)->getOperandUse(0));
3745 &cast<Instruction>(FrontV)->getOperandUse(1));
3746 continue;
3747 } else if (isa<UnaryOperator, TruncInst, ZExtInst, SExtInst, FPToSIInst,
3748 FPToUIInst, SIToFPInst, UIToFPInst>(FrontV)) {
3750 &cast<Instruction>(FrontV)->getOperandUse(0));
3751 continue;
3752 } else if (auto *BitCast = dyn_cast<BitCastInst>(FrontV)) {
3753 // TODO: Handle vector widening/narrowing bitcasts.
3754 auto *DstTy = dyn_cast<FixedVectorType>(BitCast->getDestTy());
3755 auto *SrcTy = dyn_cast<FixedVectorType>(BitCast->getSrcTy());
3756 if (DstTy && SrcTy &&
3757 SrcTy->getNumElements() == DstTy->getNumElements()) {
3759 &BitCast->getOperandUse(0));
3760 continue;
3761 }
3762 } else if (auto *Sel = dyn_cast<SelectInst>(FrontV)) {
3764 &Sel->getOperandUse(0));
3766 &Sel->getOperandUse(1));
3768 &Sel->getOperandUse(2));
3769 continue;
3770 } else if (auto *II = dyn_cast<IntrinsicInst>(FrontV);
3771 II && isTriviallyVectorizable(II->getIntrinsicID()) &&
3772 !II->hasOperandBundles()) {
3773 for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) {
3774 if (isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Op,
3775 &TTI)) {
3776 if (!all_of(drop_begin(Item), [Item, Op](InstLane &IL) {
3777 Value *FrontV = Item.front().first;
3778 Value *V = IL.first;
3779 return !V || (cast<Instruction>(V)->getOperand(Op) ==
3780 cast<Instruction>(FrontV)->getOperand(Op));
3781 }))
3782 return false;
3783 continue;
3784 }
3786 &cast<Instruction>(FrontV)->getOperandUse(Op));
3787 }
3788 continue;
3789 }
3790 }
3791
3792 if (isFreeConcat(Item, CostKind, TTI)) {
3793 ConcatLeafs.insert(std::make_pair(FrontV, From));
3794 continue;
3795 }
3796
3797 return false;
3798 }
3799
3800 if (NumVisited <= 1)
3801 return false;
3802
3803 LLVM_DEBUG(dbgs() << "Found a superfluous identity shuffle: " << I << "\n");
3804
3805 // If we got this far, we know the shuffles are superfluous and can be
3806 // removed. Scan through again and generate the new tree of instructions.
3807 Builder.SetInsertPoint(&I);
3808 Value *V = generateNewInstTree(Start, &*I.use_begin(), Ty, IdentityLeafs,
3809 SplatLeafs, ConcatLeafs, Builder, &TTI);
3810 replaceValue(I, *V);
3811 return true;
3812}
3813
3814/// Given a commutative reduction, the order of the input lanes does not alter
3815/// the results. We can use this to remove certain shuffles feeding the
3816/// reduction, removing the need to shuffle at all.
3817bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
3818 auto *II = dyn_cast<IntrinsicInst>(&I);
3819 if (!II)
3820 return false;
3821 switch (II->getIntrinsicID()) {
3822 case Intrinsic::vector_reduce_add:
3823 case Intrinsic::vector_reduce_mul:
3824 case Intrinsic::vector_reduce_and:
3825 case Intrinsic::vector_reduce_or:
3826 case Intrinsic::vector_reduce_xor:
3827 case Intrinsic::vector_reduce_smin:
3828 case Intrinsic::vector_reduce_smax:
3829 case Intrinsic::vector_reduce_umin:
3830 case Intrinsic::vector_reduce_umax:
3831 break;
3832 default:
3833 return false;
3834 }
3835
3836 // Find all the inputs when looking through operations that do not alter the
3837 // lane order (binops, for example). Currently we look for a single shuffle,
3838 // and can ignore splat values.
3839 std::queue<Value *> Worklist;
3840 SmallPtrSet<Value *, 4> Visited;
3841 ShuffleVectorInst *Shuffle = nullptr;
3842 if (auto *Op = dyn_cast<Instruction>(I.getOperand(0)))
3843 Worklist.push(Op);
3844
3845 while (!Worklist.empty()) {
3846 Value *CV = Worklist.front();
3847 Worklist.pop();
3848 if (Visited.contains(CV))
3849 continue;
3850
3851 // Splats don't change the order, so can be safely ignored.
3852 if (isSplatValue(CV))
3853 continue;
3854
3855 Visited.insert(CV);
3856
3857 if (auto *CI = dyn_cast<Instruction>(CV)) {
3858 if (CI->isBinaryOp()) {
3859 for (auto *Op : CI->operand_values())
3860 Worklist.push(Op);
3861 continue;
3862 } else if (auto *SV = dyn_cast<ShuffleVectorInst>(CI)) {
3863 if (Shuffle && Shuffle != SV)
3864 return false;
3865 Shuffle = SV;
3866 continue;
3867 }
3868 }
3869
3870 // Anything else is currently an unknown node.
3871 return false;
3872 }
3873
3874 if (!Shuffle)
3875 return false;
3876
3877 // Check all uses of the binary ops and shuffles are also included in the
3878 // lane-invariant operations (Visited should be the list of lanewise
3879 // instructions, including the shuffle that we found).
3880 for (auto *V : Visited)
3881 for (auto *U : V->users())
3882 if (!Visited.contains(U) && U != &I)
3883 return false;
3884
3885 FixedVectorType *VecType =
3886 dyn_cast<FixedVectorType>(II->getOperand(0)->getType());
3887 if (!VecType)
3888 return false;
3889 FixedVectorType *ShuffleInputType =
3891 if (!ShuffleInputType)
3892 return false;
3893 unsigned NumInputElts = ShuffleInputType->getNumElements();
3894
3895 // Find the mask from sorting the lanes into order. This is most likely to
3896 // become a identity or concat mask. Undef elements are pushed to the end.
3897 SmallVector<int> ConcatMask;
3898 Shuffle->getShuffleMask(ConcatMask);
3899 sort(ConcatMask, [](int X, int Y) { return (unsigned)X < (unsigned)Y; });
3900 bool UsesSecondVec =
3901 any_of(ConcatMask, [&](int M) { return M >= (int)NumInputElts; });
3902
3904 UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, VecType,
3905 ShuffleInputType, Shuffle->getShuffleMask(), CostKind);
3907 UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, VecType,
3908 ShuffleInputType, ConcatMask, CostKind);
3909
3910 LLVM_DEBUG(dbgs() << "Found a reduction feeding from a shuffle: " << *Shuffle
3911 << "\n");
3912 LLVM_DEBUG(dbgs() << " OldCost: " << OldCost << " vs NewCost: " << NewCost
3913 << "\n");
3914 bool MadeChanges = false;
3915 if (NewCost < OldCost) {
3916 Builder.SetInsertPoint(Shuffle);
3917 Value *NewShuffle = Builder.CreateShuffleVector(
3918 Shuffle->getOperand(0), Shuffle->getOperand(1), ConcatMask);
3919 LLVM_DEBUG(dbgs() << "Created new shuffle: " << *NewShuffle << "\n");
3920 replaceValue(*Shuffle, *NewShuffle);
3921 return true;
3922 }
3923
3924 // See if we can re-use foldSelectShuffle, getting it to reduce the size of
3925 // the shuffle into a nicer order, as it can ignore the order of the shuffles.
3926 MadeChanges |= foldSelectShuffle(*Shuffle, true);
3927 return MadeChanges;
3928}
3929
3930/// Try to fold a chain of shuffles and ops feeding extractelement(..., 0)
3931/// into llvm.vector.reduce.*, by tracking which lanes contribute to the
3932/// extracted lane and reducing the widest vector whose lanes each contribute
3933/// once.
3934///
3935/// For example:
3936///
3937/// %lo = shufflevector <4 x i32> %a, poison, <2 x i32> <i32 0, i32 1>
3938/// %hi = shufflevector <4 x i32> %a, poison, <2 x i32> <i32 2, i32 3>
3939/// %s = add <2 x i32> %lo, %hi
3940/// %sh = shufflevector <2 x i32> %s, poison, <2 x i32> <i32 1, i32 poison>
3941/// %r = add <2 x i32> %s, %sh
3942/// %e = extractelement <2 x i32> %r, i64 0
3943///
3944/// transforms to:
3945///
3946/// %e = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %a)
3947bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
3948 Value *VecOpEE;
3949 if (!match(&I, m_ExtractElt(m_Value(VecOpEE), m_Zero())))
3950 return false;
3951
3952 auto *FVT = dyn_cast<FixedVectorType>(VecOpEE->getType());
3953 if (!FVT)
3954 return false;
3955
3956 if (FVT->getNumElements() < 2)
3957 return false;
3958
3959 std::optional<Instruction::BinaryOps> CommonBinOp;
3960 std::optional<Intrinsic::ID> CommonCallOp;
3961
3962 if (auto *BO = dyn_cast<BinaryOperator>(VecOpEE)) {
3963 if (!getReductionForBinop(BO->getOpcode()))
3964 return false;
3965 CommonBinOp = BO->getOpcode();
3966 } else if (auto *MMI = dyn_cast<MinMaxIntrinsic>(VecOpEE)) {
3967 CommonCallOp = MMI->getIntrinsicID();
3968 } else {
3969 return false;
3970 }
3971
3972 // For floating-point reductions, track FMF intersection across all binops.
3973 FastMathFlags CommonFMF;
3974 bool IsFloatReduction = false;
3975
3976 // A chain node is one we walk through, either a matching-opcode binop/min-max
3977 // or a single-source shuffle. Anything else is a leaf source.
3978 auto IsChainNode = [&](Value *V) {
3979 if (auto *BO = dyn_cast<BinaryOperator>(V))
3980 return CommonBinOp && BO->getOpcode() == *CommonBinOp;
3981 if (auto *MMI = dyn_cast<MinMaxIntrinsic>(V))
3982 return CommonCallOp && MMI->getIntrinsicID() == *CommonCallOp;
3983 if (auto *SVI = dyn_cast<ShuffleVectorInst>(V))
3984 return isa<PoisonValue>(SVI->getOperand(1));
3985 return false;
3986 };
3987
3988 // Collect the chain, building Nodes in postorder. Bail if the chain is empty
3989 // or exceeds MaxChainNodes.
3990 constexpr unsigned MaxChainNodes = 32;
3991 SmallSetVector<Value *, 16> Nodes;
3992 SmallSetVector<Value *, 4> Sources;
3993 unsigned NumVisited = 0;
3994 auto AddSource = [&](Value *V) {
3995 if (!isa<FixedVectorType>(V->getType()))
3996 return false;
3997 Sources.insert(V);
3998 return true;
3999 };
4000 auto Walk = [&](Value *V, auto &&Walk) -> bool {
4001 if (Nodes.contains(V) || Sources.contains(V))
4002 return true;
4003 if (++NumVisited > MaxChainNodes)
4004 return false;
4005 if (!IsChainNode(V))
4006 return AddSource(V);
4007 // Chain shuffles always have poison as op1, so only op0 matters.
4008 auto *U = cast<Instruction>(V);
4009 unsigned NumOps = isa<ShuffleVectorInst>(U) ? 1 : 2;
4010 for (unsigned I = 0; I != NumOps; ++I)
4011 if (!Walk(U->getOperand(I), Walk))
4012 return false;
4013 if (isa<ShuffleVectorInst>(U) || Nodes.contains(U->getOperand(0)) ||
4014 Nodes.contains(U->getOperand(1))) {
4015 Nodes.insert(V);
4016 return true;
4017 }
4018 // Both operands are leaves so treat this binop as a source rather than
4019 // walking into it.
4020 return AddSource(V);
4021 };
4022 if (!Walk(VecOpEE, Walk) || Nodes.empty())
4023 return false;
4024
4025 bool IsIdempotent =
4026 CommonCallOp || (CommonBinOp && Instruction::isIdempotent(*CommonBinOp));
4027
4028 // For FP reductions, require reassoc on every binop and collect FMF.
4029 for (Value *V : Nodes) {
4030 auto *BinOp = dyn_cast<BinaryOperator>(V);
4031 if (!BinOp || !BinOp->getType()->isFPOrFPVectorTy())
4032 continue;
4033 if (!BinOp->hasAllowReassoc())
4034 return false;
4035 if (!IsFloatReduction) {
4036 CommonFMF = BinOp->getFastMathFlags();
4037 IsFloatReduction = true;
4038 } else {
4039 CommonFMF &= BinOp->getFastMathFlags();
4040 }
4041 }
4042
4043 // Top-down demanded elements. For each chain value, track which lanes feed
4044 // the extracted lane 0 and which feed it more than once. Reverse postorder
4045 // visits every use before its value. A binop forwards its demand to both
4046 // operands and a shuffle follows its mask back to the source lane.
4047 struct Demand {
4048 APInt Lanes;
4049 APInt Duplicates;
4050 };
4051 DenseMap<Value *, Demand> Demands;
4052 auto DemandOf = [&](Value *V) -> Demand & {
4053 unsigned N = cast<FixedVectorType>(V->getType())->getNumElements();
4054 Demand &D = Demands[V];
4055 if (D.Lanes.getBitWidth() != N)
4056 D.Lanes = D.Duplicates = APInt::getZero(N);
4057 return D;
4058 };
4059 DemandOf(VecOpEE).Lanes.setBit(0);
4060 for (Value *V : reverse(Nodes)) {
4061 Demand DV = Demands.lookup(V);
4062 if (DV.Lanes.isZero())
4063 continue;
4064 if (auto *SVI = dyn_cast<ShuffleVectorInst>(V)) {
4065 ArrayRef<int> Mask = SVI->getShuffleMask();
4066 Demand &DS = DemandOf(SVI->getOperand(0));
4067 for (unsigned I = 0, E = Mask.size(); I != E; ++I) {
4068 // Skip lanes that are undemanded or map to poison.
4069 if (!DV.Lanes[I] || Mask[I] < 0 ||
4070 (unsigned)Mask[I] >= DS.Lanes.getBitWidth())
4071 continue;
4072 if (DS.Lanes[Mask[I]] || DV.Duplicates[I])
4073 DS.Duplicates.setBit(Mask[I]);
4074 DS.Lanes.setBit(Mask[I]);
4075 }
4076 } else {
4077 auto *U = cast<User>(V);
4078 for (Value *Op : {U->getOperand(0), U->getOperand(1)}) {
4079 Demand &DOp = DemandOf(Op);
4080 // Lanes demanded through more than one path accumulate in Duplicates.
4081 DOp.Duplicates |= DV.Duplicates | (DOp.Lanes & DV.Lanes);
4082 DOp.Lanes |= DV.Lanes;
4083 }
4084 }
4085 }
4086
4087 // Reducing V replaces the entire chain, so every contribution to the result
4088 // must flow through V. Reject if anything above V reads outside the chain.
4089 auto CoversChain = [&](Value *V) {
4090 SmallVector<Value *, 8> Worklist(1, VecOpEE);
4091 SmallPtrSet<Value *, 8> Seen;
4092 Seen.insert(VecOpEE);
4093 while (!Worklist.empty()) {
4094 auto *U = cast<Instruction>(Worklist.pop_back_val());
4095 unsigned NumOps = isa<ShuffleVectorInst>(U) ? 1 : 2;
4096 for (unsigned I = 0; I != NumOps; ++I) {
4097 Value *Op = U->getOperand(I);
4098 if (Op == V || !Seen.insert(Op).second)
4099 continue;
4100 if (!Nodes.contains(Op))
4101 return false;
4102 Worklist.push_back(Op);
4103 }
4104 }
4105 return true;
4106 };
4107
4108 // Reduce a single cleanly demanded source if there is one, otherwise the
4109 // deepest intermediate that covers the chain.
4110 struct ReductionCut {
4111 Value *Src;
4112 APInt Elts;
4113 };
4114 std::optional<ReductionCut> Cut;
4115 for (Value *S : Sources) {
4116 auto It = Demands.find(S);
4117 if (It == Demands.end() || It->second.Lanes.isZero())
4118 continue;
4119 if (Cut || (!IsIdempotent && !It->second.Duplicates.isZero())) {
4120 Cut.reset();
4121 break;
4122 }
4123 Cut = ReductionCut{S, It->second.Lanes};
4124 }
4125 if (!Cut) {
4126 for (Value *V : Nodes) {
4128 continue;
4129 auto It = Demands.find(V);
4130 if (It == Demands.end() || !It->second.Lanes.isAllOnes())
4131 continue;
4132 if (!IsIdempotent && !It->second.Duplicates.isZero())
4133 continue;
4134 if (!CoversChain(V))
4135 continue;
4136 Cut = ReductionCut{V, It->second.Lanes};
4137 break;
4138 }
4139 }
4140 // Reducing one lane is just an extract and can refold forever.
4141 if (!Cut || Cut->Elts.popcount() < 2)
4142 return false;
4143
4144 Intrinsic::ID ReducedOp =
4145 (CommonCallOp ? getMinMaxReductionIntrinsicID(*CommonCallOp)
4146 : getReductionForBinop(*CommonBinOp));
4147 if (!ReducedOp)
4148 return false;
4149
4150 InstructionCost OrigCost = 0;
4151 for (Value *V : Nodes)
4153
4154 auto *SrcVT = cast<FixedVectorType>(Cut->Src->getType());
4155 bool IsPartialReduction = !Cut->Elts.isAllOnes();
4156 FixedVectorType *ReduceVecTy =
4157 IsPartialReduction
4158 ? FixedVectorType::get(FVT->getElementType(), Cut->Elts.popcount())
4159 : SrcVT;
4160
4161 SmallVector<int> ExtractMask;
4162 InstructionCost NewCost = 0;
4163 if (IsPartialReduction) {
4164 for (unsigned I = 0, E = Cut->Elts.getBitWidth(); I != E; ++I)
4165 if (Cut->Elts[I])
4166 ExtractMask.push_back(I);
4167 unsigned SubIdx = 0, SubLen;
4168 auto SK = Cut->Elts.isShiftedMask(SubIdx, SubLen)
4171 NewCost += TTI.getShuffleCost(SK, ReduceVecTy, SrcVT, ExtractMask, CostKind,
4172 SubIdx, ReduceVecTy);
4173 }
4174
4175 IntrinsicCostAttributes ICA(
4176 ReducedOp, ReduceVecTy->getElementType(),
4177 IsFloatReduction
4178 ? SmallVector<Type *, 2>{ReduceVecTy->getElementType(), ReduceVecTy}
4179 : SmallVector<Type *, 2>{ReduceVecTy},
4180 IsFloatReduction ? CommonFMF : FastMathFlags());
4181 NewCost += TTI.getIntrinsicInstrCost(ICA, CostKind);
4182
4183 LLVM_DEBUG(dbgs() << "Found reduction shuffle chain: " << I << "\n OldCost : "
4184 << OrigCost << " vs NewCost: " << NewCost << "\n");
4185
4186 if (!OrigCost.isValid() || !NewCost.isValid())
4187 return false;
4188
4189 if (VecOpEE->hasOneUse() ? (NewCost > OrigCost) : (NewCost >= OrigCost))
4190 return false;
4191
4192 Value *ReduceInput = Cut->Src;
4193 if (IsPartialReduction)
4194 ReduceInput = Builder.CreateShuffleVector(Cut->Src, ExtractMask);
4195
4196 Value *ReducedResult;
4197 if (IsFloatReduction) {
4199 *CommonBinOp, ReduceVecTy->getElementType(), /*AllowRHSConstant=*/false,
4200 CommonFMF.noSignedZeros());
4201 ReducedResult = Builder.CreateIntrinsic(ReducedOp, {ReduceVecTy},
4202 {Identity, ReduceInput}, CommonFMF);
4203 } else {
4204 ReducedResult =
4205 Builder.CreateIntrinsic(ReducedOp, {ReduceVecTy}, {ReduceInput});
4206 }
4207 replaceValue(I, *ReducedResult);
4208
4209 return true;
4210}
4211
4212/// Determine if its more efficient to fold:
4213/// reduce(trunc(x)) -> trunc(reduce(x)).
4214/// reduce(sext(x)) -> sext(reduce(x)).
4215/// reduce(zext(x)) -> zext(reduce(x)).
4216bool VectorCombine::foldCastFromReductions(Instruction &I) {
4217 auto *II = dyn_cast<IntrinsicInst>(&I);
4218 if (!II)
4219 return false;
4220
4221 bool TruncOnly = false;
4222 Intrinsic::ID IID = II->getIntrinsicID();
4223 switch (IID) {
4224 case Intrinsic::vector_reduce_add:
4225 case Intrinsic::vector_reduce_mul:
4226 TruncOnly = true;
4227 break;
4228 case Intrinsic::vector_reduce_and:
4229 case Intrinsic::vector_reduce_or:
4230 case Intrinsic::vector_reduce_xor:
4231 break;
4232 default:
4233 return false;
4234 }
4235
4236 unsigned ReductionOpc = getArithmeticReductionInstruction(IID);
4237 Value *ReductionSrc = I.getOperand(0);
4238
4239 Value *Src;
4240 if (!match(ReductionSrc, m_OneUse(m_Trunc(m_Value(Src)))) &&
4241 (TruncOnly || !match(ReductionSrc, m_OneUse(m_ZExtOrSExt(m_Value(Src))))))
4242 return false;
4243
4244 auto CastOpc =
4245 (Instruction::CastOps)cast<Instruction>(ReductionSrc)->getOpcode();
4246
4247 auto *SrcTy = cast<VectorType>(Src->getType());
4248 auto *ReductionSrcTy = cast<VectorType>(ReductionSrc->getType());
4249 Type *ResultTy = I.getType();
4250
4252 ReductionOpc, ReductionSrcTy, std::nullopt, CostKind);
4253 OldCost += TTI.getCastInstrCost(CastOpc, ReductionSrcTy, SrcTy,
4255 cast<CastInst>(ReductionSrc));
4256 InstructionCost NewCost =
4257 TTI.getArithmeticReductionCost(ReductionOpc, SrcTy, std::nullopt,
4258 CostKind) +
4259 TTI.getCastInstrCost(CastOpc, ResultTy, ReductionSrcTy->getScalarType(),
4261
4262 if (OldCost <= NewCost || !NewCost.isValid())
4263 return false;
4264
4265 Value *NewReduction = Builder.CreateIntrinsic(SrcTy->getScalarType(),
4266 II->getIntrinsicID(), {Src});
4267 Value *NewCast = Builder.CreateCast(CastOpc, NewReduction, ResultTy);
4268 replaceValue(I, *NewCast);
4269 return true;
4270}
4271
4272/// Fold:
4273/// icmp pred (reduce.{add,or,and,umax,umin}(signbit_extract(x))), C
4274/// into:
4275/// icmp sgt/slt (reduce.{or,umax,and,umin}(x)), -1/0
4276///
4277/// Sign-bit reductions produce values with known semantics:
4278/// - reduce.{or,umax}: 0 if no element is negative, 1 if any is
4279/// - reduce.{and,umin}: 1 if all elements are negative, 0 if any isn't
4280/// - reduce.add: count of negative elements (0 to NumElts)
4281///
4282/// Both lshr and ashr are supported:
4283/// - lshr produces 0 or 1, so reduce.add range is [0, N]
4284/// - ashr produces 0 or -1, so reduce.add range is [-N, 0]
4285///
4286/// The fold generalizes to multiple source vectors combined with the same
4287/// operation as the reduction. For example:
4288/// reduce.or(or(shr A, shr B)) conceptually extends the vector
4289/// For reduce.add, this changes the count to M*N where M is the number of
4290/// source vectors.
4291///
4292/// We transform to a direct sign check on the original vector using
4293/// reduce.{or,umax} or reduce.{and,umin}.
4294///
4295/// In spirit, it's similar to foldSignBitCheck in InstCombine.
4296bool VectorCombine::foldSignBitReductionCmp(Instruction &I) {
4297 CmpPredicate Pred;
4298 IntrinsicInst *ReduceOp;
4299 const APInt *CmpVal;
4300 if (!match(&I,
4301 m_ICmp(Pred, m_OneUse(m_AnyIntrinsic(ReduceOp)), m_APInt(CmpVal))))
4302 return false;
4303
4304 Intrinsic::ID OrigIID = ReduceOp->getIntrinsicID();
4305 switch (OrigIID) {
4306 case Intrinsic::vector_reduce_or:
4307 case Intrinsic::vector_reduce_umax:
4308 case Intrinsic::vector_reduce_and:
4309 case Intrinsic::vector_reduce_umin:
4310 case Intrinsic::vector_reduce_add:
4311 break;
4312 default:
4313 return false;
4314 }
4315
4316 Value *ReductionSrc = ReduceOp->getArgOperand(0);
4317 auto *VecTy = dyn_cast<FixedVectorType>(ReductionSrc->getType());
4318 if (!VecTy)
4319 return false;
4320
4321 unsigned BitWidth = VecTy->getScalarSizeInBits();
4322 if (BitWidth == 1)
4323 return false;
4324
4325 unsigned NumElts = VecTy->getNumElements();
4326
4327 // Determine the expected tree opcode for multi-vector patterns.
4328 // The tree opcode must match the reduction's underlying operation.
4329 //
4330 // TODO: for pairs of equivalent operators, we should match both,
4331 // not only the most common.
4332 Instruction::BinaryOps TreeOpcode;
4333 switch (OrigIID) {
4334 case Intrinsic::vector_reduce_or:
4335 case Intrinsic::vector_reduce_umax:
4336 TreeOpcode = Instruction::Or;
4337 break;
4338 case Intrinsic::vector_reduce_and:
4339 case Intrinsic::vector_reduce_umin:
4340 TreeOpcode = Instruction::And;
4341 break;
4342 case Intrinsic::vector_reduce_add:
4343 TreeOpcode = Instruction::Add;
4344 break;
4345 default:
4346 llvm_unreachable("Unexpected intrinsic");
4347 }
4348
4349 // Collect sign-bit extraction leaves from an associative tree of TreeOpcode.
4350 // The tree conceptually extends the vector being reduced.
4351 SmallVector<Value *, 8> Worklist;
4352 SmallVector<Value *, 8> Sources; // Original vectors (X in shr X, BW-1)
4353 Worklist.push_back(ReductionSrc);
4354 std::optional<bool> IsAShr;
4355 constexpr unsigned MaxSources = 8;
4356
4357 // Calculate old cost: all shifts + tree ops + reduction
4358 InstructionCost OldCost = TTI.getInstructionCost(ReduceOp, CostKind);
4359
4360 while (!Worklist.empty() && Worklist.size() <= MaxSources &&
4361 Sources.size() <= MaxSources) {
4362 Value *V = Worklist.pop_back_val();
4363
4364 // Try to match sign-bit extraction: shr X, (bitwidth-1)
4365 Value *X;
4366 if (match(V, m_OneUse(m_Shr(m_Value(X), m_SpecificInt(BitWidth - 1))))) {
4367 auto *Shr = cast<Instruction>(V);
4368
4369 // All shifts must be the same type (all lshr or all ashr)
4370 bool ThisIsAShr = Shr->getOpcode() == Instruction::AShr;
4371 if (!IsAShr)
4372 IsAShr = ThisIsAShr;
4373 else if (*IsAShr != ThisIsAShr)
4374 return false;
4375
4376 Sources.push_back(X);
4377
4378 // As part of the fold, we remove all of the shifts, so we need to keep
4379 // track of their costs.
4380 OldCost += TTI.getInstructionCost(Shr, CostKind);
4381
4382 continue;
4383 }
4384
4385 // Try to extend through a tree node of the expected opcode
4386 Value *A, *B;
4387 if (!match(V, m_OneUse(m_BinOp(TreeOpcode, m_Value(A), m_Value(B)))))
4388 return false;
4389
4390 // We are potentially replacing these operations as well, so we add them
4391 // to the costs.
4393
4394 Worklist.push_back(A);
4395 Worklist.push_back(B);
4396 }
4397
4398 // Must have at least one source and not exceed limit
4399 if (Sources.empty() || Sources.size() > MaxSources ||
4400 Worklist.size() > MaxSources || !IsAShr)
4401 return false;
4402
4403 unsigned NumSources = Sources.size();
4404
4405 // For reduce.add, the total count must fit as a signed integer.
4406 // Range is [0, M*N] for lshr or [-M*N, 0] for ashr.
4407 if (OrigIID == Intrinsic::vector_reduce_add &&
4408 !isIntN(BitWidth, NumSources * NumElts))
4409 return false;
4410
4411 // Compute the boundary value when all elements are negative:
4412 // - Per-element contribution: 1 for lshr, -1 for ashr
4413 // - For add: M*N (total elements across all sources); for others: just 1
4414 unsigned Count =
4415 (OrigIID == Intrinsic::vector_reduce_add) ? NumSources * NumElts : 1;
4416 APInt NegativeVal(CmpVal->getBitWidth(), Count);
4417 if (*IsAShr)
4418 NegativeVal.negate();
4419
4420 // Range is [min(0, AllNegVal), max(0, AllNegVal)]
4421 APInt Zero = APInt::getZero(CmpVal->getBitWidth());
4422 APInt RangeLow = APIntOps::smin(Zero, NegativeVal);
4423 APInt RangeHigh = APIntOps::smax(Zero, NegativeVal);
4424
4425 // Determine comparison semantics:
4426 // - IsEq: true for equality test, false for inequality
4427 // - TestsNegative: true if testing against AllNegVal, false for zero
4428 //
4429 // In addition to EQ/NE against 0 or AllNegVal, we support inequalities
4430 // that fold to boundary tests given the narrow value range:
4431 // < RangeHigh -> != RangeHigh
4432 // > RangeHigh-1 -> == RangeHigh
4433 // > RangeLow -> != RangeLow
4434 // < RangeLow+1 -> == RangeLow
4435 //
4436 // For inequalities, we work with signed predicates only. Unsigned predicates
4437 // are canonicalized to signed when the range is non-negative (where they are
4438 // equivalent). When the range includes negative values, unsigned predicates
4439 // would have different semantics due to wrap-around, so we reject them.
4440 if (!ICmpInst::isEquality(Pred) && !ICmpInst::isSigned(Pred)) {
4441 if (RangeLow.isNegative())
4442 return false;
4443 Pred = ICmpInst::getSignedPredicate(Pred);
4444 }
4445
4446 bool IsEq;
4447 bool TestsNegative;
4448 if (ICmpInst::isEquality(Pred)) {
4449 if (CmpVal->isZero()) {
4450 TestsNegative = false;
4451 } else if (*CmpVal == NegativeVal) {
4452 TestsNegative = true;
4453 } else {
4454 return false;
4455 }
4456 IsEq = Pred == ICmpInst::ICMP_EQ;
4457 } else if (Pred == ICmpInst::ICMP_SLT && *CmpVal == RangeHigh) {
4458 IsEq = false;
4459 TestsNegative = (RangeHigh == NegativeVal);
4460 } else if (Pred == ICmpInst::ICMP_SGT && *CmpVal == RangeHigh - 1) {
4461 IsEq = true;
4462 TestsNegative = (RangeHigh == NegativeVal);
4463 } else if (Pred == ICmpInst::ICMP_SGT && *CmpVal == RangeLow) {
4464 IsEq = false;
4465 TestsNegative = (RangeLow == NegativeVal);
4466 } else if (Pred == ICmpInst::ICMP_SLT && *CmpVal == RangeLow + 1) {
4467 IsEq = true;
4468 TestsNegative = (RangeLow == NegativeVal);
4469 } else {
4470 return false;
4471 }
4472
4473 // For this fold we support four types of checks:
4474 //
4475 // 1. All lanes are negative - AllNeg
4476 // 2. All lanes are non-negative - AllNonNeg
4477 // 3. At least one negative lane - AnyNeg
4478 // 4. At least one non-negative lane - AnyNonNeg
4479 //
4480 // For each case, we can generate the following code:
4481 //
4482 // 1. AllNeg - reduce.and/umin(X) < 0
4483 // 2. AllNonNeg - reduce.or/umax(X) > -1
4484 // 3. AnyNeg - reduce.or/umax(X) < 0
4485 // 4. AnyNonNeg - reduce.and/umin(X) > -1
4486 //
4487 // The table below shows the aggregation of all supported cases
4488 // using these four cases.
4489 //
4490 // Reduction | == 0 | != 0 | == MAX | != MAX
4491 // ------------+-----------+-----------+-----------+-----------
4492 // or/umax | AllNonNeg | AnyNeg | AnyNeg | AllNonNeg
4493 // and/umin | AnyNonNeg | AllNeg | AllNeg | AnyNonNeg
4494 // add | AllNonNeg | AnyNeg | AllNeg | AnyNonNeg
4495 //
4496 // NOTE: MAX = 1 for or/and/umax/umin, and the vector size N for add
4497 //
4498 // For easier codegen and check inversion, we use the following encoding:
4499 //
4500 // 1. Bit-3 === requires or/umax (1) or and/umin (0) check
4501 // 2. Bit-2 === checks < 0 (1) or > -1 (0)
4502 // 3. Bit-1 === universal (1) or existential (0) check
4503 //
4504 // AnyNeg = 0b110: uses or/umax, checks negative, any-check
4505 // AllNonNeg = 0b101: uses or/umax, checks non-neg, all-check
4506 // AnyNonNeg = 0b000: uses and/umin, checks non-neg, any-check
4507 // AllNeg = 0b011: uses and/umin, checks negative, all-check
4508 //
4509 // XOR with 0b011 inverts the check (swaps all/any and neg/non-neg).
4510 //
4511 enum CheckKind : unsigned {
4512 AnyNonNeg = 0b000,
4513 AllNeg = 0b011,
4514 AllNonNeg = 0b101,
4515 AnyNeg = 0b110,
4516 };
4517 // Return true if we fold this check into or/umax and false for and/umin
4518 auto RequiresOr = [](CheckKind C) -> bool { return C & 0b100; };
4519 // Return true if we should check if result is negative and false otherwise
4520 auto IsNegativeCheck = [](CheckKind C) -> bool { return C & 0b010; };
4521 // Logically invert the check
4522 auto Invert = [](CheckKind C) { return CheckKind(C ^ 0b011); };
4523
4524 CheckKind Base;
4525 switch (OrigIID) {
4526 case Intrinsic::vector_reduce_or:
4527 case Intrinsic::vector_reduce_umax:
4528 Base = TestsNegative ? AnyNeg : AllNonNeg;
4529 break;
4530 case Intrinsic::vector_reduce_and:
4531 case Intrinsic::vector_reduce_umin:
4532 Base = TestsNegative ? AllNeg : AnyNonNeg;
4533 break;
4534 case Intrinsic::vector_reduce_add:
4535 Base = TestsNegative ? AllNeg : AllNonNeg;
4536 break;
4537 default:
4538 llvm_unreachable("Unexpected intrinsic");
4539 }
4540
4541 CheckKind Check = IsEq ? Base : Invert(Base);
4542
4543 auto PickCheaper = [&](Intrinsic::ID Arith, Intrinsic::ID MinMax) {
4544 InstructionCost ArithCost =
4546 VecTy, std::nullopt, CostKind);
4547 InstructionCost MinMaxCost =
4549 FastMathFlags(), CostKind);
4550 return ArithCost <= MinMaxCost ? std::make_pair(Arith, ArithCost)
4551 : std::make_pair(MinMax, MinMaxCost);
4552 };
4553
4554 // Choose output reduction based on encoding's MSB
4555 auto [NewIID, NewCost] = RequiresOr(Check)
4556 ? PickCheaper(Intrinsic::vector_reduce_or,
4557 Intrinsic::vector_reduce_umax)
4558 : PickCheaper(Intrinsic::vector_reduce_and,
4559 Intrinsic::vector_reduce_umin);
4560
4561 // Add cost of combining multiple sources with or/and
4562 if (NumSources > 1) {
4563 unsigned CombineOpc =
4564 RequiresOr(Check) ? Instruction::Or : Instruction::And;
4565 NewCost += TTI.getArithmeticInstrCost(CombineOpc, VecTy, CostKind) *
4566 (NumSources - 1);
4567 }
4568
4569 LLVM_DEBUG(dbgs() << "Found sign-bit reduction cmp: " << I << "\n OldCost: "
4570 << OldCost << " vs NewCost: " << NewCost << "\n");
4571
4572 if (NewCost > OldCost)
4573 return false;
4574
4575 // Generate the combined input and reduction
4576 Builder.SetInsertPoint(&I);
4577 Type *ScalarTy = VecTy->getScalarType();
4578
4579 Value *Input;
4580 if (NumSources == 1) {
4581 Input = Sources[0];
4582 } else {
4583 // Combine sources with or/and based on check type
4584 Input = RequiresOr(Check) ? Builder.CreateOr(Sources)
4585 : Builder.CreateAnd(Sources);
4586 }
4587
4588 Value *NewReduce = Builder.CreateIntrinsic(ScalarTy, NewIID, {Input});
4589 Value *NewCmp = IsNegativeCheck(Check) ? Builder.CreateIsNeg(NewReduce)
4590 : Builder.CreateIsNotNeg(NewReduce);
4591 replaceValue(I, *NewCmp);
4592 return true;
4593}
4594
4595/// vector.reduce.OP f(X_i) == 0 -> vector.reduce.OP X_i == 0
4596///
4597/// We can prove it for cases when:
4598///
4599/// 1. OP X_i == 0 <=> \forall i \in [1, N] X_i == 0
4600/// 1'. OP X_i == 0 <=> \exists j \in [1, N] X_j == 0
4601/// 2. f(x) == 0 <=> x == 0
4602///
4603/// From 1 and 2 (or 1' and 2), we can infer that
4604///
4605/// OP f(X_i) == 0 <=> OP X_i == 0.
4606///
4607/// (1)
4608/// OP f(X_i) == 0 <=> \forall i \in [1, N] f(X_i) == 0
4609/// (2)
4610/// <=> \forall i \in [1, N] X_i == 0
4611/// (1)
4612/// <=> OP(X_i) == 0
4613///
4614/// For some of the OP's and f's, we need to have domain constraints on X
4615/// to ensure properties 1 (or 1') and 2.
4616bool VectorCombine::foldICmpEqZeroVectorReduce(Instruction &I) {
4617 CmpPredicate Pred;
4618 Value *Op;
4619 if (!match(&I, m_ICmp(Pred, m_Value(Op), m_Zero())) ||
4620 !ICmpInst::isEquality(Pred))
4621 return false;
4622
4623 auto *II = dyn_cast<IntrinsicInst>(Op);
4624 if (!II)
4625 return false;
4626
4627 switch (II->getIntrinsicID()) {
4628 case Intrinsic::vector_reduce_add:
4629 case Intrinsic::vector_reduce_or:
4630 case Intrinsic::vector_reduce_umin:
4631 case Intrinsic::vector_reduce_umax:
4632 case Intrinsic::vector_reduce_smin:
4633 case Intrinsic::vector_reduce_smax:
4634 break;
4635 default:
4636 return false;
4637 }
4638
4639 Value *InnerOp = II->getArgOperand(0);
4640
4641 // TODO: fixed vector type might be too restrictive
4642 if (!II->hasOneUse() || !isa<FixedVectorType>(InnerOp->getType()))
4643 return false;
4644
4645 Value *X = nullptr;
4646
4647 // Check for zero-preserving operations where f(x) = 0 <=> x = 0
4648 //
4649 // 1. f(x) = shl nuw x, y for arbitrary y
4650 // 2. f(x) = mul nuw x, c for defined c != 0
4651 // 3. f(x) = zext x
4652 // 4. f(x) = sext x
4653 // 5. f(x) = neg x
4654 //
4655 if (!(match(InnerOp, m_NUWShl(m_Value(X), m_Value())) || // Case 1
4656 match(InnerOp, m_NUWMul(m_Value(X), m_NonZeroInt())) || // Case 2
4657 match(InnerOp, m_ZExt(m_Value(X))) || // Case 3
4658 match(InnerOp, m_SExt(m_Value(X))) || // Case 4
4659 match(InnerOp, m_Neg(m_Value(X))) // Case 5
4660 ))
4661 return false;
4662
4663 SimplifyQuery S = SQ.getWithInstruction(&I);
4664 auto *XTy = cast<FixedVectorType>(X->getType());
4665
4666 // Check for domain constraints for all supported reductions.
4667 //
4668 // a. OR X_i - has property 1 for every X
4669 // b. UMAX X_i - has property 1 for every X
4670 // c. UMIN X_i - has property 1' for every X
4671 // d. SMAX X_i - has property 1 for X >= 0
4672 // e. SMIN X_i - has property 1' for X >= 0
4673 // f. ADD X_i - has property 1 for X >= 0 && ADD X_i doesn't sign wrap
4674 //
4675 // In order for the proof to work, we need 1 (or 1') to be true for both
4676 // OP f(X_i) and OP X_i and that's why below we check constraints twice.
4677 //
4678 // NOTE: ADD X_i holds property 1 for a mirror case as well, i.e. when
4679 // X <= 0 && ADD X_i doesn't sign wrap. However, due to the nature
4680 // of known bits, we can't reasonably hold knowledge of "either 0
4681 // or negative".
4682 switch (II->getIntrinsicID()) {
4683 case Intrinsic::vector_reduce_add: {
4684 // We need to check that both X_i and f(X_i) have enough leading
4685 // zeros to not overflow.
4686 KnownBits KnownX = computeKnownBits(X, S);
4687 KnownBits KnownFX = computeKnownBits(InnerOp, S);
4688 unsigned NumElems = XTy->getNumElements();
4689 // Adding N elements loses at most ceil(log2(N)) leading bits.
4690 unsigned LostBits = Log2_32_Ceil(NumElems);
4691 unsigned LeadingZerosX = KnownX.countMinLeadingZeros();
4692 unsigned LeadingZerosFX = KnownFX.countMinLeadingZeros();
4693 // Need at least one leading zero left after summation to ensure no overflow
4694 if (LeadingZerosX <= LostBits || LeadingZerosFX <= LostBits)
4695 return false;
4696
4697 // We are not checking whether X or f(X) are positive explicitly because
4698 // we implicitly checked for it when we checked if both cases have enough
4699 // leading zeros to not wrap addition.
4700 break;
4701 }
4702 case Intrinsic::vector_reduce_smin:
4703 case Intrinsic::vector_reduce_smax:
4704 // Check whether X >= 0 and f(X) >= 0
4705 if (!isKnownNonNegative(InnerOp, S) || !isKnownNonNegative(X, S))
4706 return false;
4707
4708 break;
4709 default:
4710 break;
4711 };
4712
4713 LLVM_DEBUG(dbgs() << "Found a reduction to 0 comparison with removable op: "
4714 << *II << "\n");
4715
4716 // For zext/sext, check if the transform is profitable using cost model.
4717 // For other operations (shl, mul, neg), we're removing an instruction
4718 // while keeping the same reduction type, so it's always profitable.
4719 if (isa<ZExtInst>(InnerOp) || isa<SExtInst>(InnerOp)) {
4720 auto *FXTy = cast<FixedVectorType>(InnerOp->getType());
4721 Intrinsic::ID IID = II->getIntrinsicID();
4722
4724 cast<CastInst>(InnerOp)->getOpcode(), FXTy, XTy,
4726
4727 InstructionCost OldReduceCost, NewReduceCost;
4728 switch (IID) {
4729 case Intrinsic::vector_reduce_add:
4730 case Intrinsic::vector_reduce_or:
4731 OldReduceCost = TTI.getArithmeticReductionCost(
4732 getArithmeticReductionInstruction(IID), FXTy, std::nullopt, CostKind);
4733 NewReduceCost = TTI.getArithmeticReductionCost(
4734 getArithmeticReductionInstruction(IID), XTy, std::nullopt, CostKind);
4735 break;
4736 case Intrinsic::vector_reduce_umin:
4737 case Intrinsic::vector_reduce_umax:
4738 case Intrinsic::vector_reduce_smin:
4739 case Intrinsic::vector_reduce_smax:
4740 OldReduceCost = TTI.getMinMaxReductionCost(
4741 getMinMaxReductionIntrinsicOp(IID), FXTy, FastMathFlags(), CostKind);
4742 NewReduceCost = TTI.getMinMaxReductionCost(
4743 getMinMaxReductionIntrinsicOp(IID), XTy, FastMathFlags(), CostKind);
4744 break;
4745 default:
4746 llvm_unreachable("Unexpected reduction");
4747 }
4748
4749 InstructionCost OldCost = OldReduceCost + ExtCost;
4750 InstructionCost NewCost =
4751 NewReduceCost + (InnerOp->hasOneUse() ? 0 : ExtCost);
4752
4753 LLVM_DEBUG(dbgs() << "Found a removable extension before reduction: "
4754 << *InnerOp << "\n OldCost: " << OldCost
4755 << " vs NewCost: " << NewCost << "\n");
4756
4757 // We consider transformation to still be potentially beneficial even
4758 // when the costs are the same because we might remove a use from f(X)
4759 // and unlock other optimizations. Equal costs would just mean that we
4760 // didn't make it worse in the worst case.
4761 if (NewCost > OldCost)
4762 return false;
4763 }
4764
4765 // Since we support zext and sext as f, we might change the scalar type
4766 // of the intrinsic.
4767 Type *Ty = XTy->getScalarType();
4768 Value *NewReduce = Builder.CreateIntrinsic(Ty, II->getIntrinsicID(), {X});
4769 Value *NewCmp =
4770 Builder.CreateICmp(Pred, NewReduce, ConstantInt::getNullValue(Ty));
4771 replaceValue(I, *NewCmp);
4772 return true;
4773}
4774
4775/// Fold comparisons of reduce.or/reduce.and with reduce.umax/reduce.umin
4776/// based on cost, preserving the comparison semantics.
4777///
4778/// We use two fundamental properties for each pair:
4779///
4780/// 1. or(X) == 0 <=> umax(X) == 0
4781/// 2. or(X) == 1 <=> umax(X) == 1
4782/// 3. sign(or(X)) == sign(umax(X))
4783///
4784/// 1. and(X) == -1 <=> umin(X) == -1
4785/// 2. and(X) == -2 <=> umin(X) == -2
4786/// 3. sign(and(X)) == sign(umin(X))
4787///
4788/// From these we can infer the following transformations:
4789/// a. or(X) ==/!= 0 <-> umax(X) ==/!= 0
4790/// b. or(X) s< 0 <-> umax(X) s< 0
4791/// c. or(X) s> -1 <-> umax(X) s> -1
4792/// d. or(X) s< 1 <-> umax(X) s< 1
4793/// e. or(X) ==/!= 1 <-> umax(X) ==/!= 1
4794/// f. or(X) s< 2 <-> umax(X) s< 2
4795/// g. and(X) ==/!= -1 <-> umin(X) ==/!= -1
4796/// h. and(X) s< 0 <-> umin(X) s< 0
4797/// i. and(X) s> -1 <-> umin(X) s> -1
4798/// j. and(X) s> -2 <-> umin(X) s> -2
4799/// k. and(X) ==/!= -2 <-> umin(X) ==/!= -2
4800/// l. and(X) s> -3 <-> umin(X) s> -3
4801///
4802bool VectorCombine::foldEquivalentReductionCmp(Instruction &I) {
4803 CmpPredicate Pred;
4804 Value *ReduceOp;
4805 const APInt *CmpVal;
4806 if (!match(&I, m_ICmp(Pred, m_Value(ReduceOp), m_APInt(CmpVal))))
4807 return false;
4808
4809 auto *II = dyn_cast<IntrinsicInst>(ReduceOp);
4810 if (!II || !II->hasOneUse())
4811 return false;
4812
4813 const auto IsValidOrUmaxCmp = [&]() {
4814 // or === umax for i1
4815 if (CmpVal->getBitWidth() == 1)
4816 return true;
4817
4818 // Cases a and e
4819 bool IsEquality =
4820 (CmpVal->isZero() || CmpVal->isOne()) && ICmpInst::isEquality(Pred);
4821 // Case c
4822 bool IsPositive = CmpVal->isAllOnes() && Pred == ICmpInst::ICMP_SGT;
4823 // Cases b, d, and f
4824 bool IsNegative = (CmpVal->isZero() || CmpVal->isOne() || *CmpVal == 2) &&
4825 Pred == ICmpInst::ICMP_SLT;
4826 return IsEquality || IsPositive || IsNegative;
4827 };
4828
4829 const auto IsValidAndUminCmp = [&]() {
4830 // and === umin for i1
4831 if (CmpVal->getBitWidth() == 1)
4832 return true;
4833
4834 const auto LeadingOnes = CmpVal->countl_one();
4835
4836 // Cases g and k
4837 bool IsEquality =
4838 (CmpVal->isAllOnes() || LeadingOnes + 1 == CmpVal->getBitWidth()) &&
4840 // Case h
4841 bool IsNegative = CmpVal->isZero() && Pred == ICmpInst::ICMP_SLT;
4842 // Cases i, j, and l
4843 bool IsPositive =
4844 // if the number has at least N - 2 leading ones
4845 // and the two LSBs are:
4846 // - 1 x 1 -> -1
4847 // - 1 x 0 -> -2
4848 // - 0 x 1 -> -3
4849 LeadingOnes + 2 >= CmpVal->getBitWidth() &&
4850 ((*CmpVal)[0] || (*CmpVal)[1]) && Pred == ICmpInst::ICMP_SGT;
4851 return IsEquality || IsNegative || IsPositive;
4852 };
4853
4854 Intrinsic::ID OriginalIID = II->getIntrinsicID();
4855 Intrinsic::ID AlternativeIID;
4856
4857 // Check if this is a valid comparison pattern and determine the alternate
4858 // reduction intrinsic.
4859 switch (OriginalIID) {
4860 case Intrinsic::vector_reduce_or:
4861 if (!IsValidOrUmaxCmp())
4862 return false;
4863 AlternativeIID = Intrinsic::vector_reduce_umax;
4864 break;
4865 case Intrinsic::vector_reduce_umax:
4866 if (!IsValidOrUmaxCmp())
4867 return false;
4868 AlternativeIID = Intrinsic::vector_reduce_or;
4869 break;
4870 case Intrinsic::vector_reduce_and:
4871 if (!IsValidAndUminCmp())
4872 return false;
4873 AlternativeIID = Intrinsic::vector_reduce_umin;
4874 break;
4875 case Intrinsic::vector_reduce_umin:
4876 if (!IsValidAndUminCmp())
4877 return false;
4878 AlternativeIID = Intrinsic::vector_reduce_and;
4879 break;
4880 default:
4881 return false;
4882 }
4883
4884 Value *X = II->getArgOperand(0);
4885 auto *VecTy = dyn_cast<FixedVectorType>(X->getType());
4886 if (!VecTy)
4887 return false;
4888
4889 const auto GetReductionCost = [&](Intrinsic::ID IID) -> InstructionCost {
4890 unsigned ReductionOpc = getArithmeticReductionInstruction(IID);
4891 if (ReductionOpc != Instruction::ICmp)
4892 return TTI.getArithmeticReductionCost(ReductionOpc, VecTy, std::nullopt,
4893 CostKind);
4895 FastMathFlags(), CostKind);
4896 };
4897
4898 InstructionCost OrigCost = GetReductionCost(OriginalIID);
4899 InstructionCost AltCost = GetReductionCost(AlternativeIID);
4900
4901 LLVM_DEBUG(dbgs() << "Found equivalent reduction cmp: " << I
4902 << "\n OrigCost: " << OrigCost
4903 << " vs AltCost: " << AltCost << "\n");
4904
4905 if (AltCost >= OrigCost)
4906 return false;
4907
4908 Builder.SetInsertPoint(&I);
4909 Type *ScalarTy = VecTy->getScalarType();
4910 Value *NewReduce = Builder.CreateIntrinsic(ScalarTy, AlternativeIID, {X});
4911 Value *NewCmp =
4912 Builder.CreateICmp(Pred, NewReduce, ConstantInt::get(ScalarTy, *CmpVal));
4913
4914 replaceValue(I, *NewCmp);
4915 return true;
4916}
4917
4918/// Used by foldReduceAddCmpZero to check if we can prove that a value is
4919/// non-positive.
4920/// KnownBits cannot see sext <? x i1> as non-positive: each top bit equals a
4921/// single unknown input bit, which a per-bit lattice cannot track. The fold's
4922/// target shape is popcount-style sums of <N x i1> valid/invalid masks (e.g.
4923/// ray-intersection hits) tested for any-hit.
4924/// Previous attempts to approximate the known bits of such expressions were
4925/// using a fully recursive value tracking approach to infer a constant range
4926/// but ultimately turned to be too expensive in compile time.
4927static bool isKnownNonPositive(const Value *V, const SimplifyQuery &SQ,
4928 unsigned Depth = 0) {
4929 constexpr unsigned MaxLocalDepth = 2;
4930 if (Depth > MaxLocalDepth)
4931 return false;
4932
4933 auto NumSignBits = [&](const Value *X) {
4934 return ComputeNumSignBits(X, SQ.DL, SQ.AC, SQ.CxtI, SQ.DT);
4935 };
4936 if (NumSignBits(V) == V->getType()->getScalarSizeInBits())
4937 return true;
4938
4939 Value *A, *B;
4940 if (match(V, m_Add(m_Value(A), m_Value(B))))
4941 return NumSignBits(A) >= 2 && NumSignBits(B) >= 2 &&
4942 isKnownNonPositive(A, SQ, Depth + 1) &&
4943 isKnownNonPositive(B, SQ, Depth + 1);
4944
4945 return computeKnownBits(V, SQ).isNonPositive();
4946}
4947
4948/// Fold (icmp pred (reduce.add X), 0) to (icmp pred' (reduce.or X), 0) when X
4949/// has lanes known to all be non-negative or all non-positive, so that
4950/// sum == 0 iff every lane is 0. Falls back to reduce.umax if reduce.or is
4951/// more expensive on the target.
4952bool VectorCombine::foldReduceAddCmpZero(Instruction &I) {
4953 CmpPredicate Pred;
4954 Value *Vec;
4955 if (!match(&I, m_ICmp(Pred,
4957 m_Value(Vec))),
4958 m_Zero())))
4959 return false;
4960
4961 auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType());
4962 if (!VecTy || VecTy->getNumElements() < 2)
4963 return false;
4964
4965 SimplifyQuery Q = SQ.getWithInstruction(&I);
4966 bool IsNonNegative = isKnownNonNegative(Vec, Q);
4967 bool IsNonPositive = !IsNonNegative && isKnownNonPositive(Vec, Q);
4968 if (!IsNonNegative && !IsNonPositive)
4969 return false;
4970
4971 // Summing NumElts lanes can consume up to log2(NumElts) sign bits. Require
4972 // strictly more headroom than that so the sum cannot wrap to zero.
4973 unsigned NumElts = VecTy->getNumElements();
4974 unsigned NumSignBits = ComputeNumSignBits(Vec, *DL, SQ.AC, &I, &DT);
4975 if (Log2_32(NumElts) >= NumSignBits)
4976 return false;
4977
4978 ICmpInst::Predicate NewPred;
4979 switch (Pred) {
4980 case ICmpInst::ICMP_EQ:
4981 case ICmpInst::ICMP_ULE:
4982 case ICmpInst::ICMP_SLE:
4983 case ICmpInst::ICMP_SGE:
4984 NewPred = ICmpInst::ICMP_EQ;
4985 break;
4986 case ICmpInst::ICMP_NE:
4987 case ICmpInst::ICMP_UGT:
4988 case ICmpInst::ICMP_SGT:
4989 case ICmpInst::ICMP_SLT:
4990 NewPred = ICmpInst::ICMP_NE;
4991 break;
4992 default:
4993 return false;
4994 }
4995
4996 // SGT and SLE on a non-positive tree, and SLT and SGE on a non-negative
4997 // tree, are tautologies (always true or always false). Leave those to
4998 // InstCombine rather than mapping them here. Remaining signed inequalities
4999 // also need one extra sign bit so the sum cannot flip sign.
5000 if (!IsNonNegative &&
5001 (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLE))
5002 return false;
5003 if (!IsNonPositive &&
5004 (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE))
5005 return false;
5006 if ((Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLE ||
5007 Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) &&
5008 Log2_32(NumElts) >= NumSignBits - 1)
5009 return false;
5010
5012 Instruction::Add, VecTy, std::nullopt, CostKind);
5014 Instruction::Or, VecTy, std::nullopt, CostKind);
5016 Intrinsic::umax, VecTy, FastMathFlags(), CostKind);
5017 if (!OrCost.isValid() && !UmaxCost.isValid())
5018 return false;
5019 bool UseOr = OrCost.isValid() && (!UmaxCost.isValid() || OrCost <= UmaxCost);
5020 InstructionCost AltCost = UseOr ? OrCost : UmaxCost;
5021 if (AltCost > OrigCost)
5022 return false;
5023
5024 Builder.SetInsertPoint(&I);
5025 Value *NewReduce = UseOr ? Builder.CreateOrReduce(Vec)
5026 : Builder.CreateIntrinsic(
5027 Intrinsic::vector_reduce_umax, {VecTy}, {Vec});
5028 Worklist.pushValue(NewReduce);
5029 Value *NewCmp = Builder.CreateICmp(
5030 NewPred, NewReduce, ConstantInt::getNullValue(VecTy->getScalarType()));
5031 replaceValue(I, *NewCmp);
5032 return true;
5033}
5034
5035/// Returns true if this ShuffleVectorInst eventually feeds into a
5036/// vector reduction intrinsic (e.g., vector_reduce_add) by only following
5037/// chains of shuffles and binary operators (in any combination/order).
5038/// The search does not go deeper than the given Depth.
5040 constexpr unsigned MaxVisited = 32;
5043 bool FoundReduction = false;
5044
5045 WorkList.push_back(SVI);
5046 while (!WorkList.empty()) {
5047 Instruction *I = WorkList.pop_back_val();
5048 for (User *U : I->users()) {
5049 auto *UI = cast<Instruction>(U);
5050 if (!UI || !Visited.insert(UI).second)
5051 continue;
5052 if (Visited.size() > MaxVisited)
5053 return false;
5054 if (auto *II = dyn_cast<IntrinsicInst>(UI)) {
5055 // More than one reduction reached
5056 if (FoundReduction)
5057 return false;
5058 switch (II->getIntrinsicID()) {
5059 case Intrinsic::vector_reduce_add:
5060 case Intrinsic::vector_reduce_mul:
5061 case Intrinsic::vector_reduce_and:
5062 case Intrinsic::vector_reduce_or:
5063 case Intrinsic::vector_reduce_xor:
5064 case Intrinsic::vector_reduce_smin:
5065 case Intrinsic::vector_reduce_smax:
5066 case Intrinsic::vector_reduce_umin:
5067 case Intrinsic::vector_reduce_umax:
5068 FoundReduction = true;
5069 continue;
5070 default:
5071 return false;
5072 }
5073 }
5074
5076 return false;
5077
5078 WorkList.emplace_back(UI);
5079 }
5080 }
5081 return FoundReduction;
5082}
5083
5084/// This method looks for groups of shuffles acting on binops, of the form:
5085/// %x = shuffle ...
5086/// %y = shuffle ...
5087/// %a = binop %x, %y
5088/// %b = binop %x, %y
5089/// shuffle %a, %b, selectmask
5090/// We may, especially if the shuffle is wider than legal, be able to convert
5091/// the shuffle to a form where only parts of a and b need to be computed. On
5092/// architectures with no obvious "select" shuffle, this can reduce the total
5093/// number of operations if the target reports them as cheaper.
5094bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
5095 auto *SVI = cast<ShuffleVectorInst>(&I);
5096 auto *VT = cast<FixedVectorType>(I.getType());
5097 auto *Op0 = dyn_cast<Instruction>(SVI->getOperand(0));
5098 auto *Op1 = dyn_cast<Instruction>(SVI->getOperand(1));
5099 if (!Op0 || !Op1 || Op0 == Op1 || !Op0->isBinaryOp() || !Op1->isBinaryOp() ||
5100 VT != Op0->getType())
5101 return false;
5102
5103 auto *SVI0A = dyn_cast<Instruction>(Op0->getOperand(0));
5104 auto *SVI0B = dyn_cast<Instruction>(Op0->getOperand(1));
5105 auto *SVI1A = dyn_cast<Instruction>(Op1->getOperand(0));
5106 auto *SVI1B = dyn_cast<Instruction>(Op1->getOperand(1));
5107 SmallPtrSet<Instruction *, 4> InputShuffles({SVI0A, SVI0B, SVI1A, SVI1B});
5108 auto checkSVNonOpUses = [&](Instruction *I) {
5109 if (!I || I->getOperand(0)->getType() != VT)
5110 return true;
5111 return any_of(I->users(), [&](User *U) {
5112 return U != Op0 && U != Op1 &&
5113 !(isa<ShuffleVectorInst>(U) &&
5114 (InputShuffles.contains(cast<Instruction>(U)) ||
5115 isInstructionTriviallyDead(cast<Instruction>(U))));
5116 });
5117 };
5118 if (checkSVNonOpUses(SVI0A) || checkSVNonOpUses(SVI0B) ||
5119 checkSVNonOpUses(SVI1A) || checkSVNonOpUses(SVI1B))
5120 return false;
5121
5122 // Collect all the uses that are shuffles that we can transform together. We
5123 // may not have a single shuffle, but a group that can all be transformed
5124 // together profitably.
5126 auto collectShuffles = [&](Instruction *I) {
5127 for (auto *U : I->users()) {
5128 auto *SV = dyn_cast<ShuffleVectorInst>(U);
5129 if (!SV || SV->getType() != VT)
5130 return false;
5131 if ((SV->getOperand(0) != Op0 && SV->getOperand(0) != Op1) ||
5132 (SV->getOperand(1) != Op0 && SV->getOperand(1) != Op1))
5133 return false;
5134 if (!llvm::is_contained(Shuffles, SV))
5135 Shuffles.push_back(SV);
5136 }
5137 return true;
5138 };
5139 if (!collectShuffles(Op0) || !collectShuffles(Op1))
5140 return false;
5141 // From a reduction, we need to be processing a single shuffle, otherwise the
5142 // other uses will not be lane-invariant.
5143 if (FromReduction && Shuffles.size() > 1)
5144 return false;
5145
5146 // Add any shuffle uses for the shuffles we have found, to include them in our
5147 // cost calculations.
5148 if (!FromReduction) {
5149 for (size_t Idx = 0, E = Shuffles.size(); Idx != E; ++Idx) {
5150 for (auto *U : Shuffles[Idx]->users()) {
5151 ShuffleVectorInst *SSV = dyn_cast<ShuffleVectorInst>(U);
5152 if (SSV && isa<UndefValue>(SSV->getOperand(1)) && SSV->getType() == VT)
5153 Shuffles.push_back(SSV);
5154 }
5155 }
5156 }
5157
5158 // For each of the output shuffles, we try to sort all the first vector
5159 // elements to the beginning, followed by the second array elements at the
5160 // end. If the binops are legalized to smaller vectors, this may reduce total
5161 // number of binops. We compute the ReconstructMask mask needed to convert
5162 // back to the original lane order.
5164 SmallVector<SmallVector<int>> OrigReconstructMasks;
5165 int MaxV1Elt = 0, MaxV2Elt = 0;
5166 unsigned NumElts = VT->getNumElements();
5167 for (ShuffleVectorInst *SVN : Shuffles) {
5168 SmallVector<int> Mask;
5169 SVN->getShuffleMask(Mask);
5170
5171 // Check the operands are the same as the original, or reversed (in which
5172 // case we need to commute the mask).
5173 Value *SVOp0 = SVN->getOperand(0);
5174 Value *SVOp1 = SVN->getOperand(1);
5175 if (isa<UndefValue>(SVOp1)) {
5176 auto *SSV = cast<ShuffleVectorInst>(SVOp0);
5177 SVOp0 = SSV->getOperand(0);
5178 SVOp1 = SSV->getOperand(1);
5179 for (int &Elem : Mask) {
5180 if (Elem >= static_cast<int>(SSV->getShuffleMask().size()))
5181 return false;
5182 Elem = Elem < 0 ? Elem : SSV->getMaskValue(Elem);
5183 }
5184 }
5185 if (SVOp0 == Op1 && SVOp1 == Op0) {
5186 std::swap(SVOp0, SVOp1);
5188 }
5189 if (SVOp0 != Op0 || SVOp1 != Op1)
5190 return false;
5191
5192 // Calculate the reconstruction mask for this shuffle, as the mask needed to
5193 // take the packed values from Op0/Op1 and reconstructing to the original
5194 // order.
5195 SmallVector<int> ReconstructMask;
5196 for (unsigned I = 0; I < Mask.size(); I++) {
5197 if (Mask[I] < 0) {
5198 ReconstructMask.push_back(-1);
5199 } else if (Mask[I] < static_cast<int>(NumElts)) {
5200 MaxV1Elt = std::max(MaxV1Elt, Mask[I]);
5201 auto It = find_if(V1, [&](const std::pair<int, int> &A) {
5202 return Mask[I] == A.first;
5203 });
5204 if (It != V1.end())
5205 ReconstructMask.push_back(It - V1.begin());
5206 else {
5207 ReconstructMask.push_back(V1.size());
5208 V1.emplace_back(Mask[I], V1.size());
5209 }
5210 } else {
5211 MaxV2Elt = std::max<int>(MaxV2Elt, Mask[I] - NumElts);
5212 auto It = find_if(V2, [&](const std::pair<int, int> &A) {
5213 return Mask[I] - static_cast<int>(NumElts) == A.first;
5214 });
5215 if (It != V2.end())
5216 ReconstructMask.push_back(NumElts + It - V2.begin());
5217 else {
5218 ReconstructMask.push_back(NumElts + V2.size());
5219 V2.emplace_back(Mask[I] - NumElts, NumElts + V2.size());
5220 }
5221 }
5222 }
5223
5224 // For reductions, we know that the lane ordering out doesn't alter the
5225 // result. In-order can help simplify the shuffle away.
5226 if (FromReduction)
5227 sort(ReconstructMask);
5228 OrigReconstructMasks.push_back(std::move(ReconstructMask));
5229 }
5230
5231 // If the Maximum element used from V1 and V2 are not larger than the new
5232 // vectors, the vectors are already packes and performing the optimization
5233 // again will likely not help any further. This also prevents us from getting
5234 // stuck in a cycle in case the costs do not also rule it out.
5235 if (V1.empty() || V2.empty() ||
5236 (MaxV1Elt == static_cast<int>(V1.size()) - 1 &&
5237 MaxV2Elt == static_cast<int>(V2.size()) - 1))
5238 return false;
5239
5240 // GetBaseMaskValue takes one of the inputs, which may either be a shuffle, a
5241 // shuffle of another shuffle, or not a shuffle (that is treated like a
5242 // identity shuffle).
5243 auto GetBaseMaskValue = [&](Instruction *I, int M) {
5244 auto *SV = dyn_cast<ShuffleVectorInst>(I);
5245 if (!SV)
5246 return M;
5247 if (isa<UndefValue>(SV->getOperand(1)))
5248 if (auto *SSV = dyn_cast<ShuffleVectorInst>(SV->getOperand(0)))
5249 if (InputShuffles.contains(SSV))
5250 return SSV->getMaskValue(SV->getMaskValue(M));
5251 return SV->getMaskValue(M);
5252 };
5253
5254 // Attempt to sort the inputs my ascending mask values to make simpler input
5255 // shuffles and push complex shuffles down to the uses. We sort on the first
5256 // of the two input shuffle orders, to try and get at least one input into a
5257 // nice order.
5258 auto SortBase = [&](Instruction *A, std::pair<int, int> X,
5259 std::pair<int, int> Y) {
5260 int MXA = GetBaseMaskValue(A, X.first);
5261 int MYA = GetBaseMaskValue(A, Y.first);
5262 return MXA < MYA;
5263 };
5264 stable_sort(V1, [&](std::pair<int, int> A, std::pair<int, int> B) {
5265 return SortBase(SVI0A, A, B);
5266 });
5267 stable_sort(V2, [&](std::pair<int, int> A, std::pair<int, int> B) {
5268 return SortBase(SVI1A, A, B);
5269 });
5270 // Calculate our ReconstructMasks from the OrigReconstructMasks and the
5271 // modified order of the input shuffles.
5272 SmallVector<SmallVector<int>> ReconstructMasks;
5273 for (const auto &Mask : OrigReconstructMasks) {
5274 SmallVector<int> ReconstructMask;
5275 for (int M : Mask) {
5276 auto FindIndex = [](const SmallVector<std::pair<int, int>> &V, int M) {
5277 auto It = find_if(V, [M](auto A) { return A.second == M; });
5278 assert(It != V.end() && "Expected all entries in Mask");
5279 return std::distance(V.begin(), It);
5280 };
5281 if (M < 0)
5282 ReconstructMask.push_back(-1);
5283 else if (M < static_cast<int>(NumElts)) {
5284 ReconstructMask.push_back(FindIndex(V1, M));
5285 } else {
5286 ReconstructMask.push_back(NumElts + FindIndex(V2, M));
5287 }
5288 }
5289 ReconstructMasks.push_back(std::move(ReconstructMask));
5290 }
5291
5292 // Calculate the masks needed for the new input shuffles, which get padded
5293 // with undef
5294 SmallVector<int> V1A, V1B, V2A, V2B;
5295 for (unsigned I = 0; I < V1.size(); I++) {
5296 V1A.push_back(GetBaseMaskValue(SVI0A, V1[I].first));
5297 V1B.push_back(GetBaseMaskValue(SVI0B, V1[I].first));
5298 }
5299 for (unsigned I = 0; I < V2.size(); I++) {
5300 V2A.push_back(GetBaseMaskValue(SVI1A, V2[I].first));
5301 V2B.push_back(GetBaseMaskValue(SVI1B, V2[I].first));
5302 }
5303 while (V1A.size() < NumElts) {
5306 }
5307 while (V2A.size() < NumElts) {
5310 }
5311
5312 auto AddShuffleCost = [&](InstructionCost C, Instruction *I) {
5313 auto *SV = dyn_cast<ShuffleVectorInst>(I);
5314 if (!SV)
5315 return C;
5316 return C + TTI.getShuffleCost(isa<UndefValue>(SV->getOperand(1))
5319 VT, VT, SV->getShuffleMask(), CostKind);
5320 };
5321 auto AddShuffleMaskCost = [&](InstructionCost C, ArrayRef<int> Mask) {
5322 return C +
5324 };
5325
5326 unsigned ElementSize = VT->getElementType()->getPrimitiveSizeInBits();
5327 unsigned MaxVectorSize =
5329 unsigned MaxElementsInVector = MaxVectorSize / ElementSize;
5330 if (MaxElementsInVector == 0)
5331 return false;
5332 // When there are multiple shufflevector operations on the same input,
5333 // especially when the vector length is larger than the register size,
5334 // identical shuffle patterns may occur across different groups of elements.
5335 // To avoid overestimating the cost by counting these repeated shuffles more
5336 // than once, we only account for unique shuffle patterns. This adjustment
5337 // prevents inflated costs in the cost model for wide vectors split into
5338 // several register-sized groups.
5339 std::set<SmallVector<int, 4>> UniqueShuffles;
5340 auto AddShuffleMaskAdjustedCost = [&](InstructionCost C, ArrayRef<int> Mask) {
5341 // Compute the cost for performing the shuffle over the full vector.
5342 auto ShuffleCost =
5344 unsigned NumFullVectors = Mask.size() / MaxElementsInVector;
5345 if (NumFullVectors < 2)
5346 return C + ShuffleCost;
5347 SmallVector<int, 4> SubShuffle(MaxElementsInVector);
5348 unsigned NumUniqueGroups = 0;
5349 unsigned NumGroups = Mask.size() / MaxElementsInVector;
5350 // For each group of MaxElementsInVector contiguous elements,
5351 // collect their shuffle pattern and insert into the set of unique patterns.
5352 for (unsigned I = 0; I < NumFullVectors; ++I) {
5353 for (unsigned J = 0; J < MaxElementsInVector; ++J)
5354 SubShuffle[J] = Mask[MaxElementsInVector * I + J];
5355 if (UniqueShuffles.insert(SubShuffle).second)
5356 NumUniqueGroups += 1;
5357 }
5358 return C + ShuffleCost * NumUniqueGroups / NumGroups;
5359 };
5360 auto AddShuffleAdjustedCost = [&](InstructionCost C, Instruction *I) {
5361 auto *SV = dyn_cast<ShuffleVectorInst>(I);
5362 if (!SV)
5363 return C;
5364 SmallVector<int, 16> Mask;
5365 SV->getShuffleMask(Mask);
5366 return AddShuffleMaskAdjustedCost(C, Mask);
5367 };
5368 // Check that input consists of ShuffleVectors applied to the same input
5369 auto AllShufflesHaveSameOperands =
5370 [](SmallPtrSetImpl<Instruction *> &InputShuffles) {
5371 if (InputShuffles.size() < 2)
5372 return false;
5373 ShuffleVectorInst *FirstSV =
5374 dyn_cast<ShuffleVectorInst>(*InputShuffles.begin());
5375 if (!FirstSV)
5376 return false;
5377
5378 Value *In0 = FirstSV->getOperand(0), *In1 = FirstSV->getOperand(1);
5379 return std::all_of(
5380 std::next(InputShuffles.begin()), InputShuffles.end(),
5381 [&](Instruction *I) {
5382 ShuffleVectorInst *SV = dyn_cast<ShuffleVectorInst>(I);
5383 return SV && SV->getOperand(0) == In0 && SV->getOperand(1) == In1;
5384 });
5385 };
5386
5387 // Get the costs of the shuffles + binops before and after with the new
5388 // shuffle masks.
5389 InstructionCost CostBefore =
5390 TTI.getArithmeticInstrCost(Op0->getOpcode(), VT, CostKind) +
5391 TTI.getArithmeticInstrCost(Op1->getOpcode(), VT, CostKind);
5392 CostBefore += std::accumulate(Shuffles.begin(), Shuffles.end(),
5393 InstructionCost(0), AddShuffleCost);
5394 if (AllShufflesHaveSameOperands(InputShuffles)) {
5395 UniqueShuffles.clear();
5396 CostBefore += std::accumulate(InputShuffles.begin(), InputShuffles.end(),
5397 InstructionCost(0), AddShuffleAdjustedCost);
5398 } else {
5399 CostBefore += std::accumulate(InputShuffles.begin(), InputShuffles.end(),
5400 InstructionCost(0), AddShuffleCost);
5401 }
5402
5403 // The new binops will be unused for lanes past the used shuffle lengths.
5404 // These types attempt to get the correct cost for that from the target.
5405 FixedVectorType *Op0SmallVT =
5406 FixedVectorType::get(VT->getScalarType(), V1.size());
5407 FixedVectorType *Op1SmallVT =
5408 FixedVectorType::get(VT->getScalarType(), V2.size());
5409 InstructionCost CostAfter =
5410 TTI.getArithmeticInstrCost(Op0->getOpcode(), Op0SmallVT, CostKind) +
5411 TTI.getArithmeticInstrCost(Op1->getOpcode(), Op1SmallVT, CostKind);
5412 UniqueShuffles.clear();
5413 CostAfter += std::accumulate(ReconstructMasks.begin(), ReconstructMasks.end(),
5414 InstructionCost(0), AddShuffleMaskAdjustedCost);
5415 std::set<SmallVector<int>> OutputShuffleMasks({V1A, V1B, V2A, V2B});
5416 CostAfter +=
5417 std::accumulate(OutputShuffleMasks.begin(), OutputShuffleMasks.end(),
5418 InstructionCost(0), AddShuffleMaskCost);
5419
5420 LLVM_DEBUG(dbgs() << "Found a binop select shuffle pattern: " << I << "\n");
5421 LLVM_DEBUG(dbgs() << " CostBefore: " << CostBefore
5422 << " vs CostAfter: " << CostAfter << "\n");
5423 if (CostBefore < CostAfter ||
5424 (CostBefore == CostAfter && !feedsIntoVectorReduction(SVI)))
5425 return false;
5426
5427 // The cost model has passed, create the new instructions.
5428 auto GetShuffleOperand = [&](Instruction *I, unsigned Op) -> Value * {
5429 auto *SV = dyn_cast<ShuffleVectorInst>(I);
5430 if (!SV)
5431 return I;
5432 if (isa<UndefValue>(SV->getOperand(1)))
5433 if (auto *SSV = dyn_cast<ShuffleVectorInst>(SV->getOperand(0)))
5434 if (InputShuffles.contains(SSV))
5435 return SSV->getOperand(Op);
5436 return SV->getOperand(Op);
5437 };
5438 Builder.SetInsertPoint(*SVI0A->getInsertionPointAfterDef());
5439 Value *NSV0A = Builder.CreateShuffleVector(GetShuffleOperand(SVI0A, 0),
5440 GetShuffleOperand(SVI0A, 1), V1A);
5441 Builder.SetInsertPoint(*SVI0B->getInsertionPointAfterDef());
5442 Value *NSV0B = Builder.CreateShuffleVector(GetShuffleOperand(SVI0B, 0),
5443 GetShuffleOperand(SVI0B, 1), V1B);
5444 Builder.SetInsertPoint(*SVI1A->getInsertionPointAfterDef());
5445 Value *NSV1A = Builder.CreateShuffleVector(GetShuffleOperand(SVI1A, 0),
5446 GetShuffleOperand(SVI1A, 1), V2A);
5447 Builder.SetInsertPoint(*SVI1B->getInsertionPointAfterDef());
5448 Value *NSV1B = Builder.CreateShuffleVector(GetShuffleOperand(SVI1B, 0),
5449 GetShuffleOperand(SVI1B, 1), V2B);
5450 Builder.SetInsertPoint(Op0);
5451 Value *NOp0 = Builder.CreateBinOp((Instruction::BinaryOps)Op0->getOpcode(),
5452 NSV0A, NSV0B);
5453 if (auto *I = dyn_cast<Instruction>(NOp0))
5454 I->copyIRFlags(Op0, true);
5455 Builder.SetInsertPoint(Op1);
5456 Value *NOp1 = Builder.CreateBinOp((Instruction::BinaryOps)Op1->getOpcode(),
5457 NSV1A, NSV1B);
5458 if (auto *I = dyn_cast<Instruction>(NOp1))
5459 I->copyIRFlags(Op1, true);
5460
5461 for (int S = 0, E = ReconstructMasks.size(); S != E; S++) {
5462 Builder.SetInsertPoint(Shuffles[S]);
5463 Value *NSV = Builder.CreateShuffleVector(NOp0, NOp1, ReconstructMasks[S]);
5464 replaceValue(*Shuffles[S], *NSV, false);
5465 }
5466
5467 Worklist.pushValue(NSV0A);
5468 Worklist.pushValue(NSV0B);
5469 Worklist.pushValue(NSV1A);
5470 Worklist.pushValue(NSV1B);
5471 return true;
5472}
5473
5474/// Check if instruction depends on ZExt and this ZExt can be moved after the
5475/// instruction. Move ZExt if it is profitable. For example:
5476/// logic(zext(x),y) -> zext(logic(x,trunc(y)))
5477/// lshr((zext(x),y) -> zext(lshr(x,trunc(y)))
5478/// Cost model calculations takes into account if zext(x) has other users and
5479/// whether it can be propagated through them too.
5480bool VectorCombine::shrinkType(Instruction &I) {
5481 Value *ZExted, *OtherOperand;
5482 if (!match(&I, m_c_BitwiseLogic(m_ZExt(m_Value(ZExted)),
5483 m_Value(OtherOperand))) &&
5484 !match(&I, m_LShr(m_ZExt(m_Value(ZExted)), m_Value(OtherOperand))))
5485 return false;
5486
5487 Value *ZExtOperand = I.getOperand(I.getOperand(0) == OtherOperand ? 1 : 0);
5488
5489 auto *BigTy = cast<FixedVectorType>(I.getType());
5490 auto *SmallTy = cast<FixedVectorType>(ZExted->getType());
5491 unsigned BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
5492
5493 if (I.getOpcode() == Instruction::LShr) {
5494 // Check that the shift amount is less than the number of bits in the
5495 // smaller type. Otherwise, the smaller lshr will return a poison value.
5496 KnownBits ShAmtKB = computeKnownBits(I.getOperand(1), *DL);
5497 if (ShAmtKB.getMaxValue().uge(BW))
5498 return false;
5499 } else {
5500 // Check that the expression overall uses at most the same number of bits as
5501 // ZExted
5502 KnownBits KB = computeKnownBits(&I, *DL);
5503 if (KB.countMaxActiveBits() > BW)
5504 return false;
5505 }
5506
5507 // Calculate costs of leaving current IR as it is and moving ZExt operation
5508 // later, along with adding truncates if needed
5510 Instruction::ZExt, BigTy, SmallTy,
5511 TargetTransformInfo::CastContextHint::None, CostKind);
5512 InstructionCost CurrentCost = ZExtCost;
5513 InstructionCost ShrinkCost = 0;
5514
5515 // Calculate total cost and check that we can propagate through all ZExt users
5516 for (User *U : ZExtOperand->users()) {
5517 auto *UI = cast<Instruction>(U);
5518 if (UI == &I) {
5519 CurrentCost +=
5520 TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind);
5521 ShrinkCost +=
5522 TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind);
5523 ShrinkCost += ZExtCost;
5524 continue;
5525 }
5526
5527 if (!Instruction::isBinaryOp(UI->getOpcode()))
5528 return false;
5529
5530 // Check if we can propagate ZExt through its other users
5531 KnownBits KB = computeKnownBits(UI, *DL);
5532 if (KB.countMaxActiveBits() > BW)
5533 return false;
5534
5535 CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind);
5536 ShrinkCost +=
5537 TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind);
5538 ShrinkCost += ZExtCost;
5539 }
5540
5541 // If the other instruction operand is not a constant, we'll need to
5542 // generate a truncate instruction. So we have to adjust cost
5543 if (!isa<Constant>(OtherOperand))
5544 ShrinkCost += TTI.getCastInstrCost(
5545 Instruction::Trunc, SmallTy, BigTy,
5546 TargetTransformInfo::CastContextHint::None, CostKind);
5547
5548 // If the cost of shrinking types and leaving the IR is the same, we'll lean
5549 // towards modifying the IR because shrinking opens opportunities for other
5550 // shrinking optimisations.
5551 if (ShrinkCost > CurrentCost)
5552 return false;
5553
5554 Builder.SetInsertPoint(&I);
5555 Value *Op0 = ZExted;
5556 Value *Op1 = Builder.CreateTrunc(OtherOperand, SmallTy);
5557 // Keep the order of operands the same
5558 if (I.getOperand(0) == OtherOperand)
5559 std::swap(Op0, Op1);
5560 Value *NewBinOp =
5561 Builder.CreateBinOp((Instruction::BinaryOps)I.getOpcode(), Op0, Op1);
5562 cast<Instruction>(NewBinOp)->copyIRFlags(&I);
5563 cast<Instruction>(NewBinOp)->copyMetadata(I);
5564 Value *NewZExtr = Builder.CreateZExt(NewBinOp, BigTy);
5565 replaceValue(I, *NewZExtr);
5566 return true;
5567}
5568
5569/// insert (DstVec, (extract SrcVec, ExtIdx), InsIdx) -->
5570/// shuffle (DstVec, SrcVec, Mask)
5571bool VectorCombine::foldInsExtVectorToShuffle(Instruction &I) {
5572 Value *DstVec, *SrcVec;
5573 uint64_t ExtIdx, InsIdx;
5574 if (!match(&I,
5575 m_InsertElt(m_Value(DstVec),
5576 m_ExtractElt(m_Value(SrcVec), m_ConstantInt(ExtIdx)),
5577 m_ConstantInt(InsIdx))))
5578 return false;
5579
5580 auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType());
5581 auto *SrcVecTy = dyn_cast<FixedVectorType>(SrcVec->getType());
5582 // We can try combining vectors with different element sizes.
5583 if (!DstVecTy || !SrcVecTy ||
5584 SrcVecTy->getElementType() != DstVecTy->getElementType())
5585 return false;
5586
5587 unsigned NumDstElts = DstVecTy->getNumElements();
5588 unsigned NumSrcElts = SrcVecTy->getNumElements();
5589 if (InsIdx >= NumDstElts || ExtIdx >= NumSrcElts || NumDstElts == 1)
5590 return false;
5591
5592 // Insertion into poison is a cheaper single operand shuffle.
5594 SmallVector<int> Mask(NumDstElts, PoisonMaskElem);
5595
5596 bool NeedExpOrNarrow = NumSrcElts != NumDstElts;
5597 bool NeedDstSrcSwap = isa<PoisonValue>(DstVec) && !isa<UndefValue>(SrcVec);
5598 if (NeedDstSrcSwap) {
5600 Mask[InsIdx] = ExtIdx % NumDstElts;
5601 std::swap(DstVec, SrcVec);
5602 } else {
5604 std::iota(Mask.begin(), Mask.end(), 0);
5605 Mask[InsIdx] = (ExtIdx % NumDstElts) + NumDstElts;
5606 }
5607
5608 // Cost
5609 auto *Ins = cast<InsertElementInst>(&I);
5610 auto *Ext = cast<ExtractElementInst>(I.getOperand(1));
5611 InstructionCost InsCost =
5612 TTI.getVectorInstrCost(*Ins, DstVecTy, CostKind, InsIdx);
5613 InstructionCost ExtCost =
5614 TTI.getVectorInstrCost(*Ext, DstVecTy, CostKind, ExtIdx);
5615 InstructionCost OldCost = ExtCost + InsCost;
5616
5617 InstructionCost NewCost = 0;
5618 SmallVector<int> ExtToVecMask;
5619 if (!NeedExpOrNarrow) {
5620 // Ignore 'free' identity insertion shuffle.
5621 // TODO: getShuffleCost should return TCC_Free for Identity shuffles.
5622 if (!ShuffleVectorInst::isIdentityMask(Mask, NumSrcElts))
5623 NewCost += TTI.getShuffleCost(SK, DstVecTy, DstVecTy, Mask, CostKind, 0,
5624 nullptr, {DstVec, SrcVec});
5625 } else {
5626 // When creating a length-changing-vector, always try to keep the relevant
5627 // element in an equivalent position, so that bulk shuffles are more likely
5628 // to be useful.
5629 ExtToVecMask.assign(NumDstElts, PoisonMaskElem);
5630 ExtToVecMask[ExtIdx % NumDstElts] = ExtIdx;
5631 // Add cost for expanding or narrowing
5633 DstVecTy, SrcVecTy, ExtToVecMask, CostKind);
5634 NewCost += TTI.getShuffleCost(SK, DstVecTy, DstVecTy, Mask, CostKind);
5635 }
5636
5637 if (!Ext->hasOneUse())
5638 NewCost += ExtCost;
5639
5640 LLVM_DEBUG(dbgs() << "Found a insert/extract shuffle-like pair: " << I
5641 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
5642 << "\n");
5643
5644 if (OldCost < NewCost)
5645 return false;
5646
5647 if (NeedExpOrNarrow) {
5648 if (!NeedDstSrcSwap)
5649 SrcVec = Builder.CreateShuffleVector(SrcVec, ExtToVecMask);
5650 else
5651 DstVec = Builder.CreateShuffleVector(DstVec, ExtToVecMask);
5652 }
5653
5654 // Canonicalize undef param to RHS to help further folds.
5655 if (isa<UndefValue>(DstVec) && !isa<UndefValue>(SrcVec)) {
5656 ShuffleVectorInst::commuteShuffleMask(Mask, NumDstElts);
5657 std::swap(DstVec, SrcVec);
5658 }
5659
5660 Value *Shuf = Builder.CreateShuffleVector(DstVec, SrcVec, Mask);
5661 replaceValue(I, *Shuf);
5662
5663 return true;
5664}
5665
5666/// If we're interleaving 2 constant splats, for instance `<vscale x 8 x i32>
5667/// <splat of 666>` and `<vscale x 8 x i32> <splat of 777>`, we can create a
5668/// larger splat `<vscale x 8 x i64> <splat of ((777 << 32) | 666)>` first
5669/// before casting it back into `<vscale x 16 x i32>`.
5670bool VectorCombine::foldInterleaveIntrinsics(Instruction &I) {
5671 const APInt *SplatVal0, *SplatVal1;
5673 m_APInt(SplatVal0), m_APInt(SplatVal1))))
5674 return false;
5675
5676 LLVM_DEBUG(dbgs() << "VC: Folding interleave2 with two splats: " << I
5677 << "\n");
5678
5679 auto *VTy =
5680 cast<VectorType>(cast<IntrinsicInst>(I).getArgOperand(0)->getType());
5681 auto *ExtVTy = VectorType::getExtendedElementVectorType(VTy);
5682 unsigned Width = VTy->getElementType()->getIntegerBitWidth();
5683
5684 // Just in case the cost of interleave2 intrinsic and bitcast are both
5685 // invalid, in which case we want to bail out, we use <= rather
5686 // than < here. Even they both have valid and equal costs, it's probably
5687 // not a good idea to emit a high-cost constant splat.
5689 TTI.getCastInstrCost(Instruction::BitCast, I.getType(), ExtVTy,
5691 LLVM_DEBUG(dbgs() << "VC: The cost to cast from " << *ExtVTy << " to "
5692 << *I.getType() << " is too high.\n");
5693 return false;
5694 }
5695
5696 APInt NewSplatVal = SplatVal1->zext(Width * 2);
5697 NewSplatVal <<= Width;
5698 NewSplatVal |= SplatVal0->zext(Width * 2);
5699 auto *NewSplat = ConstantVector::getSplat(
5700 ExtVTy->getElementCount(), ConstantInt::get(F.getContext(), NewSplatVal));
5701
5702 IRBuilder<> Builder(&I);
5703 replaceValue(I, *Builder.CreateBitCast(NewSplat, I.getType()));
5704 return true;
5705}
5706
5707/// Given this sequence:
5708/// ```
5709/// %d = llvm.vector.deinterleave2 <vscale x 16 x i32> %v
5710/// %f0 = extractvalue { <vscale x 8 x i32>, <vscale x 8 x i32> } %d, 0
5711/// %f1 = extractvalue { <vscale x 8 x i32>, <vscale x 8 x i32> } %d, 1
5712///
5713/// %low0 = and <vscale x 8 x i32> %f0, splat (i32 65535)
5714/// %low1 = shl <vscale x 8 x i32> %f1, splat (i32 16)
5715/// %merge0 = or disjoint <vscale x 8 x i32> %low0, %low1
5716///
5717/// %high0 = and <vscale x 8 x i32> %f1, splat (i32 -65536)
5718/// %high1 = lshr <vscale x 8 x i32> %f0, splat (i32 16)
5719/// %merge1 = or disjoint <vscale x 8 x i32> %high0, %high1
5720/// ```
5721/// It is actually just de-interleaving a 16-bit vector with double the
5722/// vector length. More generally speaking, it's de-interleaving on a vector
5723/// with half the element width as the original vector.
5724///
5725/// Therefore, we can turn it into:
5726/// ```
5727/// %narrow.v = bitcast <vscale x 16 x i32> %v to <vscale x 32 x i16>
5728/// %d = llvm.vector.deinterleave2 <vscale x 32 x i16> %narrow.v
5729/// %f0 = extractvalue { <vscale x 16 x i16>, <vscale x 16 x i16> } %d, 0
5730/// %f1 = extractvalue { <vscale x 16 x i16>, <vscale x 16 x i16> } %d, 1
5731///
5732/// %merge0 = bitcast <vscale x 16 x i16> %f0 to <vscale x 8 x i32>
5733/// %merge1 = bitcast <vscale x 16 x i16> %f1 to <vscale x 8 x i32>
5734/// ```
5735bool VectorCombine::foldDeinterleaveIntrinsics(Instruction &I) {
5736 // This pattern involves bitcast that is not compatible with big endian.
5737 if (DL->isBigEndian())
5738 return false;
5739
5740 using namespace PatternMatch;
5741 Value *DeinterleavedVal;
5742 if (!match(&I, m_Deinterleave2(m_Value(DeinterleavedVal))))
5743 return false;
5744
5745 VectorType *VecTy = cast<VectorType>(DeinterleavedVal->getType());
5746 IntegerType *ElementTy = dyn_cast<IntegerType>(VecTy->getElementType());
5747 if (!ElementTy)
5748 return false;
5749 unsigned ElementWidth = ElementTy->getBitWidth();
5750 if (ElementWidth < 2 || !isPowerOf2_32(ElementWidth))
5751 return false;
5752 unsigned HalfElementWidth = ElementWidth / 2;
5753
5754 if (!I.hasNUses(2))
5755 return false;
5756 std::array<ExtractValueInst *, 2> OrigFields{};
5757 for (User *Usr : I.users()) {
5758 auto *E = dyn_cast<ExtractValueInst>(Usr);
5759 // The deinterleave result can only be used by extractions.
5760 if (!E || E->getNumIndices() != 1)
5761 return false;
5762 unsigned Idx = *E->idx_begin();
5763 // A single field cannot be extracted more than once.
5764 if (Idx >= 2 || OrigFields[Idx] || !E->hasNUses(2))
5765 return false;
5766 OrigFields[Idx] = E;
5767 }
5768
5769 // Find the merge instruction (i.e. OR) first.
5770 SmallVector<Instruction *, 2> MergeInsts;
5771 for (auto *FieldUsr : OrigFields[0]->users()) {
5772 if (!FieldUsr->hasOneUse() || !isa<Instruction>(FieldUsr->user_back()))
5773 return false;
5774 MergeInsts.push_back(cast<Instruction>(FieldUsr->user_back()));
5775 }
5776 assert(MergeInsts.size() == 2);
5777
5778 // Pattern match bottom-up from the merge instructions.
5779 auto MatchMerge = [&](void) -> bool {
5780 APInt LoMask = APInt::getLowBitsSet(ElementWidth, HalfElementWidth);
5781 APInt HiMask = APInt::getHighBitsSet(ElementWidth, HalfElementWidth);
5782 return match(MergeInsts[0],
5783 m_c_Or(m_And(m_Specific(OrigFields[0]), m_SpecificInt(LoMask)),
5784 m_Shl(m_Specific(OrigFields[1]),
5785 m_SpecificInt(HalfElementWidth)))) &&
5786 match(MergeInsts[1],
5787 m_c_Or(m_And(m_Specific(OrigFields[1]), m_SpecificInt(HiMask)),
5788 m_LShr(m_Specific(OrigFields[0]),
5789 m_SpecificInt(HalfElementWidth))));
5790 };
5791 if (!MatchMerge()) {
5792 std::swap(MergeInsts[0], MergeInsts[1]);
5793 if (!MatchMerge())
5794 return false;
5795 }
5796
5797 // Profitability check.
5798 InstructionCost OldCost =
5799 TTI.getInstructionCost(MergeInsts[0], CostKind) +
5800 TTI.getInstructionCost(cast<Instruction>(MergeInsts[0]->getOperand(0)),
5801 CostKind) +
5802 TTI.getInstructionCost(cast<Instruction>(MergeInsts[0]->getOperand(1)),
5803 CostKind);
5804 // There are two fields (assuming SHL has the same cost as LSHR).
5805 OldCost *= 2;
5806
5807 auto *NewFieldTy = VecTy->getWithNewBitWidth(HalfElementWidth);
5808 auto *NewVecTy =
5809 VectorType::getDoubleElementsVectorType(cast<VectorType>(NewFieldTy));
5810 InstructionCost NewCost =
5811 TTI.getCastInstrCost(Instruction::BitCast, VecTy, NewVecTy,
5813 TTI.getCastInstrCost(Instruction::BitCast, NewFieldTy,
5814 MergeInsts[0]->getType(), TTI::CastContextHint::None,
5815 CostKind) *
5816 2;
5817 if (OldCost <= NewCost || !NewCost.isValid()) {
5818 LLVM_DEBUG(
5819 dbgs() << "VC: New deinterleave2 sequence cost (" << NewCost << ")"
5820 << " is higher than that of the old one (" << OldCost << ")\n");
5821 return false;
5822 }
5823
5824 // Do the replacement.
5825 IRBuilder<> Builder(&I);
5826 Value *NewVecCast = Builder.CreateBitCast(DeinterleavedVal, NewVecTy);
5827 Value *NewDeinterleave = Builder.CreateIntrinsic(
5828 Intrinsic::vector_deinterleave2, {NewVecTy}, {NewVecCast});
5829 for (auto [Idx, MergeInst] : enumerate(MergeInsts)) {
5830 Value *NewField = Builder.CreateExtractValue(NewDeinterleave, Idx);
5831 NewField = Builder.CreateBitCast(NewField, MergeInst->getType());
5832 replaceValue(*MergeInst, *NewField);
5833 }
5834
5835 return true;
5836}
5837
5838bool VectorCombine::foldBitcastOfVPLoad(Instruction &I) {
5839 const DataLayout &DL = I.getDataLayout();
5840 auto *Cast = dyn_cast<CastInst>(&I);
5841 if (!Cast || !Cast->isNoopCast(DL) || !isa<VectorType>(Cast->getDestTy()))
5842 return false;
5843
5844 // Fold away bit casts of the loaded value by loading the desired type,
5845 // if the mask is all-ones.
5846 Value *EVL;
5847 auto *II = dyn_cast<VPIntrinsic>(I.getOperand(0));
5849 m_Value(), m_AllOnes(), m_Value(EVL)))))
5850 return false;
5851
5852 VectorType *OrigVecTy = cast<VectorType>(II->getType());
5853 Align OrigAlign =
5854 DL.getValueOrABITypeAlignment(II->getPointerAlignment(), OrigVecTy);
5855 ElementCount OrigVecCnt = OrigVecTy->getElementCount();
5856 VectorType *NewVecTy = cast<VectorType>(Cast->getDestTy());
5857 ElementCount NewVecCnt = NewVecTy->getElementCount();
5858
5859 // Right now we only support cases where the NewVec is longer, because for
5860 // cases where it's shorter, we have to be sure that EVL can be exactly
5861 // divided, otherwise it might yield incorrect results or even page faults
5862 // (if we round-up during the division).
5863 if (!(OrigVecCnt.isScalable() == NewVecCnt.isScalable() &&
5864 NewVecCnt.hasKnownScalarFactor(OrigVecCnt)))
5865 return false;
5866
5867 InstructionCost OldCost =
5868 TTI.getMemIntrinsicInstrCost({Intrinsic::vp_load, OrigVecTy,
5869 II->getMemoryPointerParam(), false,
5870 OrigAlign},
5871 CostKind) +
5872 TTI.getCastInstrCost(Instruction::BitCast, Cast->getType(), OrigVecTy,
5875 {Intrinsic::vp_load, NewVecTy, II->getMemoryPointerParam(), false,
5876 OrigAlign},
5877 CostKind);
5878 LLVM_DEBUG(dbgs() << "foldBitcastOfVPLoad: OldCost=" << OldCost
5879 << " NewCost=" << NewCost << "\n");
5880 if (NewCost > OldCost || !NewCost.isValid())
5881 return false;
5882
5883 unsigned Factor = NewVecCnt.getKnownScalarFactor(OrigVecCnt);
5884 Value *NewEVL = Builder.CreateNUWMul(EVL, Builder.getInt32(Factor));
5885 Value *NewMask = Builder.CreateVectorSplat(NewVecCnt, Builder.getTrue());
5886 CallInst *NewVP = Builder.CreateIntrinsicWithoutFolding(
5887 NewVecTy, Intrinsic::vp_load,
5888 {II->getMemoryPointerParam(), NewMask, NewEVL});
5889 // Preserve the original alignment.
5890 NewVP->addParamAttrs(
5891 0, AttrBuilder(II->getContext()).addAlignmentAttr(OrigAlign));
5892 replaceValue(*Cast, *NewVP);
5893 return true;
5894}
5895
5896// Attempt to shrink loads that are only used by shufflevector instructions.
5897bool VectorCombine::shrinkLoadForShuffles(Instruction &I) {
5898 auto *OldLoad = dyn_cast<LoadInst>(&I);
5899 if (!OldLoad || !OldLoad->isSimple())
5900 return false;
5901
5902 auto *OldLoadTy = dyn_cast<FixedVectorType>(OldLoad->getType());
5903 if (!OldLoadTy)
5904 return false;
5905
5906 unsigned const OldNumElements = OldLoadTy->getNumElements();
5907
5908 // Search all uses of load. If all uses are shufflevector instructions, and
5909 // the second operands are all poison values, find the minimum and maximum
5910 // indices of the vector elements referenced by all shuffle masks.
5911 // Otherwise return `std::nullopt`.
5912 using IndexRange = std::pair<int, int>;
5913 auto GetIndexRangeInShuffles = [&]() -> std::optional<IndexRange> {
5914 IndexRange OutputRange = IndexRange(OldNumElements, -1);
5915 for (llvm::Use &Use : I.uses()) {
5916 // Ensure all uses match the required pattern.
5917 User *Shuffle = Use.getUser();
5918 ArrayRef<int> Mask;
5919
5920 if (!match(Shuffle,
5921 m_Shuffle(m_Specific(OldLoad), m_Undef(), m_Mask(Mask))))
5922 return std::nullopt;
5923
5924 // Ignore shufflevector instructions that have no uses.
5925 if (Shuffle->use_empty())
5926 continue;
5927
5928 // Find the min and max indices used by the shufflevector instruction.
5929 for (int Index : Mask) {
5930 if (Index >= 0 && Index < static_cast<int>(OldNumElements)) {
5931 OutputRange.first = std::min(Index, OutputRange.first);
5932 OutputRange.second = std::max(Index, OutputRange.second);
5933 }
5934 }
5935 }
5936
5937 if (OutputRange.second < OutputRange.first)
5938 return std::nullopt;
5939
5940 return OutputRange;
5941 };
5942
5943 // Get the range of vector elements used by shufflevector instructions.
5944 if (std::optional<IndexRange> Indices = GetIndexRangeInShuffles()) {
5945 unsigned const NewNumElements = Indices->second + 1u;
5946
5947 // If the range of vector elements is smaller than the full load, attempt
5948 // to create a smaller load.
5949 if (NewNumElements < OldNumElements) {
5950 IRBuilder Builder(&I);
5951 Builder.SetCurrentDebugLocation(I.getDebugLoc());
5952
5953 // Calculate costs of old and new ops.
5954 Type *ElemTy = OldLoadTy->getElementType();
5955 FixedVectorType *NewLoadTy = FixedVectorType::get(ElemTy, NewNumElements);
5956 Value *PtrOp = OldLoad->getPointerOperand();
5957
5959 Instruction::Load, OldLoad->getType(), OldLoad->getAlign(),
5960 OldLoad->getPointerAddressSpace(), CostKind);
5961 InstructionCost NewCost =
5962 TTI.getMemoryOpCost(Instruction::Load, NewLoadTy, OldLoad->getAlign(),
5963 OldLoad->getPointerAddressSpace(), CostKind);
5964
5965 using UseEntry = std::pair<ShuffleVectorInst *, std::vector<int>>;
5967 unsigned const MaxIndex = NewNumElements * 2u;
5968
5969 for (llvm::Use &Use : I.uses()) {
5970 auto *Shuffle = cast<ShuffleVectorInst>(Use.getUser());
5971
5972 // Ignore shufflevector instructions that have no uses.
5973 if (Shuffle->use_empty())
5974 continue;
5975
5976 ArrayRef<int> OldMask = Shuffle->getShuffleMask();
5977
5978 // Create entry for new use.
5979 NewUses.push_back({Shuffle, OldMask});
5980
5981 // Validate mask indices.
5982 for (int Index : OldMask) {
5983 if (Index >= static_cast<int>(MaxIndex))
5984 return false;
5985 }
5986
5987 // Update costs.
5988 OldCost +=
5990 OldLoadTy, OldMask, CostKind);
5991 NewCost +=
5993 NewLoadTy, OldMask, CostKind);
5994 }
5995
5996 LLVM_DEBUG(
5997 dbgs() << "Found a load used only by shufflevector instructions: "
5998 << I << "\n OldCost: " << OldCost
5999 << " vs NewCost: " << NewCost << "\n");
6000
6001 if (OldCost < NewCost || !NewCost.isValid())
6002 return false;
6003
6004 // Create new load of smaller vector.
6005 auto *NewLoad = cast<LoadInst>(
6006 Builder.CreateAlignedLoad(NewLoadTy, PtrOp, OldLoad->getAlign()));
6007 NewLoad->copyMetadata(I);
6008
6009 // Replace all uses.
6010 for (UseEntry &Use : NewUses) {
6011 ShuffleVectorInst *Shuffle = Use.first;
6012 std::vector<int> &NewMask = Use.second;
6013
6014 Builder.SetInsertPoint(Shuffle);
6015 Builder.SetCurrentDebugLocation(Shuffle->getDebugLoc());
6016 Value *NewShuffle = Builder.CreateShuffleVector(
6017 NewLoad, PoisonValue::get(NewLoadTy), NewMask);
6018
6019 replaceValue(*Shuffle, *NewShuffle, false);
6020 }
6021
6022 return true;
6023 }
6024 }
6025 return false;
6026}
6027
6028// Attempt to narrow a phi of shufflevector instructions where the two incoming
6029// values have the same operands but different masks. If the two shuffle masks
6030// are offsets of one another we can use one branch to rotate the incoming
6031// vector and perform one larger shuffle after the phi.
6032bool VectorCombine::shrinkPhiOfShuffles(Instruction &I) {
6033 auto *Phi = dyn_cast<PHINode>(&I);
6034 if (!Phi || Phi->getNumIncomingValues() != 2u)
6035 return false;
6036
6037 Value *Op = nullptr;
6038 ArrayRef<int> Mask0;
6039 ArrayRef<int> Mask1;
6040
6041 if (!match(Phi->getOperand(0u),
6042 m_OneUse(m_Shuffle(m_Value(Op), m_Poison(), m_Mask(Mask0)))) ||
6043 !match(Phi->getOperand(1u),
6044 m_OneUse(m_Shuffle(m_Specific(Op), m_Poison(), m_Mask(Mask1)))))
6045 return false;
6046
6047 auto *Shuf = cast<ShuffleVectorInst>(Phi->getOperand(0u));
6048
6049 // Ensure result vectors are wider than the argument vector.
6050 auto *InputVT = cast<FixedVectorType>(Op->getType());
6051 auto *ResultVT = cast<FixedVectorType>(Shuf->getType());
6052 auto const InputNumElements = InputVT->getNumElements();
6053
6054 if (InputNumElements >= ResultVT->getNumElements())
6055 return false;
6056
6057 // Take the difference of the two shuffle masks at each index. Ignore poison
6058 // values at the same index in both masks.
6059 SmallVector<int, 16> NewMask;
6060 NewMask.reserve(Mask0.size());
6061
6062 for (auto [M0, M1] : zip(Mask0, Mask1)) {
6063 if (M0 >= 0 && M1 >= 0)
6064 NewMask.push_back(M0 - M1);
6065 else if (M0 == -1 && M1 == -1)
6066 continue;
6067 else
6068 return false;
6069 }
6070
6071 // Ensure all elements of the new mask are equal. If the difference between
6072 // the incoming mask elements is the same, the two must be constant offsets
6073 // of one another.
6074 if (NewMask.empty() || !all_equal(NewMask))
6075 return false;
6076
6077 // Create new mask using difference of the two incoming masks.
6078 int MaskOffset = NewMask[0u];
6079 unsigned Index = (InputNumElements + MaskOffset) % InputNumElements;
6080 NewMask.clear();
6081
6082 for (unsigned I = 0u; I < InputNumElements; ++I) {
6083 NewMask.push_back(Index);
6084 Index = (Index + 1u) % InputNumElements;
6085 }
6086
6087 // Calculate costs for worst cases and compare.
6088 auto const Kind = TTI::SK_PermuteSingleSrc;
6089 auto OldCost =
6090 std::max(TTI.getShuffleCost(Kind, ResultVT, InputVT, Mask0, CostKind),
6091 TTI.getShuffleCost(Kind, ResultVT, InputVT, Mask1, CostKind));
6092 auto NewCost = TTI.getShuffleCost(Kind, InputVT, InputVT, NewMask, CostKind) +
6093 TTI.getShuffleCost(Kind, ResultVT, InputVT, Mask1, CostKind);
6094
6095 LLVM_DEBUG(dbgs() << "Found a phi of mergeable shuffles: " << I
6096 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
6097 << "\n");
6098
6099 if (NewCost > OldCost)
6100 return false;
6101
6102 // Create new shuffles and narrowed phi.
6103 auto Builder = IRBuilder(Shuf);
6104 Builder.SetCurrentDebugLocation(Shuf->getDebugLoc());
6105 auto *PoisonVal = PoisonValue::get(InputVT);
6106 auto *NewShuf0 = Builder.CreateShuffleVector(Op, PoisonVal, NewMask);
6107 Worklist.push(cast<Instruction>(NewShuf0));
6108
6109 Builder.SetInsertPoint(Phi);
6110 Builder.SetCurrentDebugLocation(Phi->getDebugLoc());
6111 auto *NewPhi = Builder.CreatePHI(NewShuf0->getType(), 2u);
6112 NewPhi->addIncoming(NewShuf0, Phi->getIncomingBlock(0u));
6113 NewPhi->addIncoming(Op, Phi->getIncomingBlock(1u));
6114
6115 Builder.SetInsertPoint(*NewPhi->getInsertionPointAfterDef());
6116 PoisonVal = PoisonValue::get(NewPhi->getType());
6117 auto *NewShuf1 = Builder.CreateShuffleVector(NewPhi, PoisonVal, Mask1);
6118
6119 replaceValue(*Phi, *NewShuf1);
6120 return true;
6121}
6122
6123/// This is the entry point for all transforms. Pass manager differences are
6124/// handled in the callers of this function.
6125bool VectorCombine::run() {
6127 return false;
6128
6129 // Don't attempt vectorization if the target does not support vectors.
6130 if (!TTI.getNumberOfRegisters(TTI.getRegisterClassForType(/*Vector*/ true)))
6131 return false;
6132
6133 LLVM_DEBUG(dbgs() << "\n\nVECTORCOMBINE on " << F.getName() << "\n");
6134
6135 auto FoldInst = [this](Instruction &I) {
6136 Builder.SetInsertPoint(&I);
6137 bool IsVectorType = isa<VectorType>(I.getType());
6138 bool IsFixedVectorType = isa<FixedVectorType>(I.getType());
6139 auto Opcode = I.getOpcode();
6140
6141 LLVM_DEBUG(dbgs() << "VC: Visiting: " << I << '\n');
6142
6143 // These folds should be beneficial regardless of when this pass is run
6144 // in the optimization pipeline.
6145 // The type checking is for run-time efficiency. We can avoid wasting time
6146 // dispatching to folding functions if there's no chance of matching.
6147 if (IsFixedVectorType) {
6148 switch (Opcode) {
6149 case Instruction::InsertElement:
6150 if (vectorizeLoadInsert(I))
6151 return true;
6152 break;
6153 case Instruction::ShuffleVector:
6154 if (widenSubvectorLoad(I))
6155 return true;
6156 break;
6157 default:
6158 break;
6159 }
6160 }
6161
6162 // This transform works with scalable and fixed vectors
6163 // TODO: Identify and allow other scalable transforms
6164 if (IsVectorType) {
6165 if (scalarizeOpOrCmp(I))
6166 return true;
6167 if (scalarizeLoad(I))
6168 return true;
6169 if (scalarizeExtExtract(I))
6170 return true;
6171 if (scalarizeVPIntrinsic(I))
6172 return true;
6173 if (foldInterleaveIntrinsics(I))
6174 return true;
6175 if (foldBitcastOfVPLoad(I))
6176 return true;
6177 }
6178
6179 if (foldDeinterleaveIntrinsics(I))
6180 return true;
6181
6182 if (Opcode == Instruction::Store)
6183 if (foldSingleElementStore(I))
6184 return true;
6185
6186 // If this is an early pipeline invocation of this pass, we are done.
6187 if (TryEarlyFoldsOnly)
6188 return false;
6189
6190 // Otherwise, try folds that improve codegen but may interfere with
6191 // early IR canonicalizations.
6192 // The type checking is for run-time efficiency. We can avoid wasting time
6193 // dispatching to folding functions if there's no chance of matching.
6194 if (IsFixedVectorType) {
6195 switch (Opcode) {
6196 case Instruction::InsertElement:
6197 if (foldInsExtFNeg(I))
6198 return true;
6199 if (foldInsExtBinop(I))
6200 return true;
6201 if (foldInsExtVectorToShuffle(I))
6202 return true;
6203 break;
6204 case Instruction::ShuffleVector:
6205 if (foldPermuteOfBinops(I))
6206 return true;
6207 if (foldShuffleOfBinops(I))
6208 return true;
6209 if (foldShuffleOfSelects(I))
6210 return true;
6211 if (foldShuffleOfCastops(I))
6212 return true;
6213 if (foldShuffleOfShuffles(I))
6214 return true;
6215 if (foldPermuteOfIntrinsic(I))
6216 return true;
6217 if (foldShufflesOfLengthChangingShuffles(I))
6218 return true;
6219 if (foldShuffleOfIntrinsics(I))
6220 return true;
6221 if (foldSelectShuffle(I))
6222 return true;
6223 if (foldShuffleToIdentity(I))
6224 return true;
6225 break;
6226 case Instruction::Load:
6227 if (shrinkLoadForShuffles(I))
6228 return true;
6229 break;
6230 case Instruction::BitCast:
6231 if (foldBitcastShuffle(I))
6232 return true;
6233 if (foldSelectsFromBitcast(I))
6234 return true;
6235 break;
6236 case Instruction::And:
6237 case Instruction::Or:
6238 case Instruction::Xor:
6239 if (foldBitOpOfCastops(I))
6240 return true;
6241 if (foldBitOpOfCastConstant(I))
6242 return true;
6243 break;
6244 case Instruction::PHI:
6245 if (shrinkPhiOfShuffles(I))
6246 return true;
6247 break;
6248 default:
6249 if (shrinkType(I))
6250 return true;
6251 break;
6252 }
6253 } else {
6254 switch (Opcode) {
6255 case Instruction::Call:
6256 if (foldShuffleFromReductions(I))
6257 return true;
6258 if (foldCastFromReductions(I))
6259 return true;
6260 break;
6261 case Instruction::ExtractElement:
6262 if (foldShuffleChainsToReduce(I))
6263 return true;
6264 break;
6265 case Instruction::ICmp:
6266 if (foldSignBitReductionCmp(I))
6267 return true;
6268 if (foldICmpEqZeroVectorReduce(I))
6269 return true;
6270 if (foldEquivalentReductionCmp(I))
6271 return true;
6272 if (foldReduceAddCmpZero(I))
6273 return true;
6274 [[fallthrough]];
6275 case Instruction::FCmp:
6276 if (foldExtractExtract(I))
6277 return true;
6278 break;
6279 case Instruction::Or:
6280 if (foldConcatOfBoolMasks(I))
6281 return true;
6282 [[fallthrough]];
6283 default:
6284 if (Instruction::isBinaryOp(Opcode)) {
6285 if (foldExtractExtract(I))
6286 return true;
6287 if (foldExtractedCmps(I))
6288 return true;
6289 if (foldBinopOfReductions(I))
6290 return true;
6291 }
6292 break;
6293 }
6294 }
6295 return false;
6296 };
6297
6298 bool MadeChange = false;
6299 for (BasicBlock &BB : F) {
6300 // Ignore unreachable basic blocks.
6301 if (!DT.isReachableFromEntry(&BB))
6302 continue;
6303 // Use early increment range so that we can erase instructions in loop.
6304 // make_early_inc_range is not applicable here, as the next iterator may
6305 // be invalidated by RecursivelyDeleteTriviallyDeadInstructions.
6306 // We manually maintain the next instruction and update it when it is about
6307 // to be deleted.
6308 Instruction *I = &BB.front();
6309 while (I) {
6310 NextInst = I->getNextNode();
6311 if (!I->isDebugOrPseudoInst())
6312 MadeChange |= FoldInst(*I);
6313 I = NextInst;
6314 }
6315 }
6316
6317 NextInst = nullptr;
6318
6319 while (!Worklist.isEmpty()) {
6320 Instruction *I = Worklist.removeOne();
6321 if (!I)
6322 continue;
6323
6326 continue;
6327 }
6328
6329 MadeChange |= FoldInst(*I);
6330 }
6331
6332 return MadeChange;
6333}
6334
6337 auto &AC = FAM.getResult<AssumptionAnalysis>(F);
6339 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
6340 AAResults &AA = FAM.getResult<AAManager>(F);
6341 const DataLayout *DL = &F.getDataLayout();
6344 VectorCombine Combiner(F, TTI, DT, AA, AC, DL, CostKind, TryEarlyFoldsOnly);
6345 if (!Combiner.run())
6346 return PreservedAnalyses::all();
6349 return PA;
6350}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static cl::opt< unsigned > MaxInstrsToScan("aggressive-instcombine-max-scan-instrs", cl::init(64), cl::Hidden, cl::desc("Max number of instructions to scan for aggressive instcombine."))
This is the interface for LLVM's primary stateless and local alias analysis.
#define X(NUM, ENUM, NAME)
Definition ELF.h:853
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static cl::opt< OutputCostKind > CostKind("cost-kind", cl::desc("Target cost kind"), cl::init(OutputCostKind::RecipThroughput), cl::values(clEnumValN(OutputCostKind::RecipThroughput, "throughput", "Reciprocal throughput"), clEnumValN(OutputCostKind::Latency, "latency", "Instruction latency"), clEnumValN(OutputCostKind::CodeSize, "code-size", "Code size"), clEnumValN(OutputCostKind::SizeAndLatency, "size-latency", "Code size and latency"), clEnumValN(OutputCostKind::All, "all", "Print all cost kinds")))
static cl::opt< IntrinsicCostStrategy > IntrinsicCost("intrinsic-cost-strategy", cl::desc("Costing strategy for intrinsic instructions"), cl::init(IntrinsicCostStrategy::InstructionCost), cl::values(clEnumValN(IntrinsicCostStrategy::InstructionCost, "instruction-cost", "Use TargetTransformInfo::getInstructionCost"), clEnumValN(IntrinsicCostStrategy::IntrinsicCost, "intrinsic-cost", "Use TargetTransformInfo::getIntrinsicInstrCost"), clEnumValN(IntrinsicCostStrategy::TypeBasedIntrinsicCost, "type-based-intrinsic-cost", "Calculate the intrinsic cost based only on argument types")))
This file defines the DenseMap class.
#define Check(C,...)
This is the interface for a simple mod/ref and alias analysis over globals.
Hexagon Common GEP
iv users
Definition IVUsers.cpp:48
static Value * getOpcode(Value &V, Type &Ty, InstrumentationConfig &IConf, InstrumentorIRBuilderTy &IIRB)
const size_t AbstractManglingParser< Derived, Alloc >::NumOps
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
static void eraseInstruction(Instruction &I, ICFLoopSafetyInfo &SafetyInfo, MemorySSAUpdater &MSSAU)
Definition LICM.cpp:1457
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
#define T1
MachineInstr unsigned OpIdx
uint64_t IntrinsicInst * II
FunctionAnalysisManager FAM
if(PassOpts->AAPipeline)
const SmallVectorImpl< MachineOperand > & Cond
unsigned OpIndex
This file contains some templates that are useful if you are working with the STL at all.
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
#define LLVM_DEBUG(...)
Definition Debug.h:119
static TableGen::Emitter::Opt Y("gen-skeleton-entry", EmitSkeleton, "Generate example skeleton entry")
static SymbolRef::Type getType(const Symbol *Sym)
Definition TapiFile.cpp:39
This pass exposes codegen information to IR-level passes.
static bool isFreeConcat(ArrayRef< InstLane > Item, TTI::TargetCostKind CostKind, const TargetTransformInfo &TTI)
Detect concat of multiple values into a vector.
static void analyzeCostOfVecReduction(const IntrinsicInst &II, TTI::TargetCostKind CostKind, const TargetTransformInfo &TTI, InstructionCost &CostBeforeReduction, InstructionCost &CostAfterReduction)
static SmallVector< InstLane > generateInstLaneVectorFromOperand(ArrayRef< InstLane > Item, int Op)
static Value * createShiftShuffle(Value *Vec, unsigned OldIndex, unsigned NewIndex, IRBuilderBase &Builder)
Create a shuffle that translates (shifts) 1 element from the input vector to a new element location.
static Value * generateNewInstTree(ArrayRef< InstLane > Item, Use *From, FixedVectorType *Ty, const DenseSet< std::pair< Value *, Use * > > &IdentityLeafs, const DenseSet< std::pair< Value *, Use * > > &SplatLeafs, const DenseSet< std::pair< Value *, Use * > > &ConcatLeafs, IRBuilderBase &Builder, const TargetTransformInfo *TTI)
std::pair< Value *, int > InstLane
static bool isKnownNonPositive(const Value *V, const SimplifyQuery &SQ, unsigned Depth=0)
Used by foldReduceAddCmpZero to check if we can prove that a value is non-positive.
static Align computeAlignmentAfterScalarization(Align VectorAlignment, Type *ScalarType, Value *Idx, const DataLayout &DL)
The memory operation on a vector of ScalarType had alignment of VectorAlignment.
static bool feedsIntoVectorReduction(ShuffleVectorInst *SVI)
Returns true if this ShuffleVectorInst eventually feeds into a vector reduction intrinsic (e....
static cl::opt< bool > DisableVectorCombine("disable-vector-combine", cl::init(false), cl::Hidden, cl::desc("Disable all vector combine transforms"))
static bool canWidenLoad(LoadInst *Load, const TargetTransformInfo &TTI)
static const unsigned InvalidIndex
static Value * translateExtract(ExtractElementInst *ExtElt, unsigned NewIndex, IRBuilderBase &Builder)
Given an extract element instruction with constant index operand, shuffle the source vector (shift th...
static ScalarizationResult canScalarizeAccess(VectorType *VecTy, Value *Idx, const SimplifyQuery &SQ)
Check if it is legal to scalarize a memory access to VecTy at index Idx.
static cl::opt< unsigned > MaxInstrsToScan("vector-combine-max-scan-instrs", cl::init(30), cl::Hidden, cl::desc("Max number of instructions to scan for vector combining."))
static cl::opt< bool > DisableBinopExtractShuffle("disable-binop-extract-shuffle", cl::init(false), cl::Hidden, cl::desc("Disable binop extract to shuffle transforms"))
static InstLane lookThroughShuffles(Value *V, int Lane)
static bool isMemModifiedBetween(BasicBlock::iterator Begin, BasicBlock::iterator End, const MemoryLocation &Loc, AAResults &AA)
static constexpr int Concat[]
Value * RHS
Value * LHS
A manager for alias analyses.
Class for arbitrary precision integers.
Definition APInt.h:78
LLVM_ABI APInt zext(unsigned width) const
Zero extend to a new width.
Definition APInt.cpp:1055
bool isAllOnes() const
Determine if all bits are set. This is true for zero-width values.
Definition APInt.h:372
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition APInt.h:381
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition APInt.h:1511
bool isNegative() const
Determine sign of this APInt.
Definition APInt.h:330
unsigned countl_one() const
Count the number of leading one bits.
Definition APInt.h:1638
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition APInt.h:307
static APInt getHighBitsSet(unsigned numBits, unsigned hiBitsSet)
Constructs an APInt value that has the top hiBitsSet bits set.
Definition APInt.h:297
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition APInt.h:201
bool isOne() const
Determine if this is a value of 1.
Definition APInt.h:390
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition APInt.h:240
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition APInt.h:1228
Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
const T & front() const
Get the first element.
Definition ArrayRef.h:144
size_t size() const
Get the array size.
Definition ArrayRef.h:141
A function analysis which provides an AssumptionCache.
A cache of @llvm.assume calls within a function.
LLVM_ABI bool hasAttribute(Attribute::AttrKind Kind) const
Return true if the attribute exists in this set.
InstListType::iterator iterator
Instruction iterators...
Definition BasicBlock.h:170
BinaryOps getOpcode() const
Definition InstrTypes.h:409
Represents analyses that only rely on functions' control flow.
Definition Analysis.h:73
Value * getArgOperand(unsigned i) const
iterator_range< User::op_iterator > args()
Iteration adapter for range-for loops.
void addParamAttrs(unsigned ArgNo, const AttrBuilder &B)
Adds attributes to the indicated argument.
static LLVM_ABI CastInst * Create(Instruction::CastOps, Value *S, Type *Ty, const Twine &Name="", InsertPosition InsertBefore=nullptr)
Provides a way to construct any of the CastInst subclasses using an opcode instead of the subclass's ...
static Type * makeCmpResultType(Type *opnd_type)
Create a result type for fcmp/icmp.
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:740
bool isFPPredicate() const
Definition InstrTypes.h:845
static LLVM_ABI std::optional< CmpPredicate > getMatching(CmpPredicate A, CmpPredicate B)
Compares two CmpPredicates taking samesign into account and returns the canonicalized CmpPredicate if...
Combiner implementation.
Definition Combiner.h:33
static LLVM_ABI Constant * getExtractElement(Constant *Vec, Constant *Idx, Type *OnlyIfReducedTy=nullptr)
static LLVM_ABI Constant * getBinOpIdentity(unsigned Opcode, Type *Ty, bool AllowRHSConstant=false, bool NSZ=false)
Return the identity constant for a binary opcode.
This is the shared class of boolean and integer constants.
Definition Constants.h:87
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition Constants.h:159
This class represents a range of values.
LLVM_ABI ConstantRange urem(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an unsigned remainder operation of...
LLVM_ABI ConstantRange binaryAnd(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a binary-and of a value in this ra...
LLVM_ABI bool contains(const APInt &Val) const
Return true if the specified value is in the set.
static LLVM_ABI Constant * getSplat(ElementCount EC, Constant *Elt)
Return a ConstantVector with the specified constant in each element.
static LLVM_ABI Constant * get(ArrayRef< Constant * > V)
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:64
ValueT lookup(const_arg_type_t< KeyT > Val) const
Return the entry for the specified key, or a default constructed value if no such entry exists.
Definition DenseMap.h:252
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:225
bool empty() const
Definition DenseMap.h:173
iterator end()
Definition DenseMap.h:143
Implements a dense probed hash-table based set.
Definition DenseSet.h:289
Analysis pass which computes a DominatorTree.
Definition Dominators.h:270
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:151
LLVM_ABI bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
This instruction extracts a single (scalar) element from a VectorType value.
Convenience struct for specifying and reasoning about fast-math flags.
Definition FMF.h:23
bool noSignedZeros() const
Definition FMF.h:67
Class to represent fixed width SIMD vectors.
unsigned getNumElements() const
static FixedVectorType * getDoubleElementsVectorType(FixedVectorType *VTy)
static LLVM_ABI FixedVectorType * get(Type *ElementType, unsigned NumElts)
Definition Type.cpp:869
Predicate getSignedPredicate() const
For example, EQ->EQ, SLE->SLE, UGT->SGT, etc.
bool isEquality() const
Return true if this predicate is either EQ or NE.
Common base class shared among various IRBuilders.
Definition IRBuilder.h:114
LLVM_ABI CallInst * CreateIntrinsicWithoutFolding(Intrinsic::ID ID, ArrayRef< Type * > OverloadTypes, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="", ArrayRef< OperandBundleDef > OpBundles={})
Create a call to intrinsic ID with Args, mangled using OverloadTypes.
Value * CreateNUWMul(Value *LHS, Value *RHS, const Twine &Name="")
Definition IRBuilder.h:1521
Value * CreateInsertElement(Type *VecTy, Value *NewElt, Value *Idx, const Twine &Name="")
Definition IRBuilder.h:2669
Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")
Definition IRBuilder.h:2657
LoadInst * CreateAlignedLoad(Type *Ty, Value *Ptr, MaybeAlign Align, const char *Name)
Definition IRBuilder.h:1975
LLVM_ABI Value * CreateSelectFMF(Value *C, Value *True, Value *False, FMFSource FMFSource, const Twine &Name="", Instruction *MDFrom=nullptr)
LLVM_ABI Value * CreateVectorSplat(unsigned NumElts, Value *V, const Twine &Name="")
Return a vector value that contains.
Value * CreateExtractValue(Value *Agg, ArrayRef< unsigned > Idxs, const Twine &Name="")
Definition IRBuilder.h:2716
ConstantInt * getTrue()
Get the constant value for i1 true.
Definition IRBuilder.h:509
LLVM_ABI Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
Value * CreateFreeze(Value *V, const Twine &Name="")
Definition IRBuilder.h:2735
void SetCurrentDebugLocation(const DebugLoc &L)
Set location information used by debugging information.
Definition IRBuilder.h:247
Value * CreateLShr(Value *LHS, Value *RHS, const Twine &Name="", bool isExact=false)
Definition IRBuilder.h:1584
Value * CreateCast(Instruction::CastOps Op, Value *V, Type *DestTy, const Twine &Name="", MDNode *FPMathTag=nullptr, FMFSource FMFSource={})
Definition IRBuilder.h:2318
Value * CreateIsNotNeg(Value *Arg, const Twine &Name="")
Return a boolean value testing if Arg > -1.
Definition IRBuilder.h:2759
Value * CreateInBoundsGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="")
Definition IRBuilder.h:2060
Value * CreatePointerBitCastOrAddrSpaceCast(Value *V, Type *DestTy, const Twine &Name="")
Definition IRBuilder.h:2343
ConstantInt * getInt64(uint64_t C)
Get a constant 64-bit value.
Definition IRBuilder.h:534
LLVM_ABI Value * CreateOrReduce(Value *Src)
Create a vector int OR reduction intrinsic of the source vector.
ConstantInt * getInt32(uint32_t C)
Get a constant 32-bit value.
Definition IRBuilder.h:529
Value * CreateCmp(CmpInst::Predicate Pred, Value *LHS, Value *RHS, const Twine &Name="", MDNode *FPMathTag=nullptr)
Definition IRBuilder.h:2550
PHINode * CreatePHI(Type *Ty, unsigned NumReservedValues, const Twine &Name="")
Definition IRBuilder.h:2581
InstTy * Insert(InstTy *I, const Twine &Name="") const
Insert and return the specified instruction.
Definition IRBuilder.h:172
Value * CreateIsNeg(Value *Arg, const Twine &Name="")
Return a boolean value testing if Arg < 0.
Definition IRBuilder.h:2754
Value * CreateBitCast(Value *V, Type *DestTy, const Twine &Name="")
Definition IRBuilder.h:2284
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
Definition IRBuilder.h:1958
Value * CreateShl(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition IRBuilder.h:1563
LLVM_ABI Value * CreateNAryOp(unsigned Opc, ArrayRef< Value * > Ops, const Twine &Name="", MDNode *FPMathTag=nullptr)
Create either a UnaryOperator or BinaryOperator depending on Opc.
Value * CreateZExt(Value *V, Type *DestTy, const Twine &Name="", bool IsNonNeg=false)
Definition IRBuilder.h:2162
Value * CreateShuffleVector(Value *V1, Value *V2, Value *Mask, const Twine &Name="")
Definition IRBuilder.h:2691
Value * CreateAnd(Value *LHS, Value *RHS, const Twine &Name="")
Definition IRBuilder.h:1622
LLVM_ABI Value * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > OverloadTypes, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="", ArrayRef< OperandBundleDef > OpBundles={}, function_ref< void(CallInst *)> SetFn=[](CallInst *) {})
Variant to create a possibly constant-folded intrinsic.
StoreInst * CreateStore(Value *Val, Value *Ptr, bool isVolatile=false)
Definition IRBuilder.h:1971
Value * CreateTrunc(Value *V, Type *DestTy, const Twine &Name="", bool IsNUW=false, bool IsNSW=false)
Definition IRBuilder.h:2148
PointerType * getPtrTy(unsigned AddrSpace=0)
Fetch the type representing a pointer.
Definition IRBuilder.h:629
Value * CreateBinOp(Instruction::BinaryOps Opc, Value *LHS, Value *RHS, const Twine &Name="", MDNode *FPMathTag=nullptr)
Definition IRBuilder.h:1783
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition IRBuilder.h:207
Value * CreateFNegFMF(Value *V, FMFSource FMFSource, const Twine &Name="", MDNode *FPMathTag=nullptr)
Definition IRBuilder.h:1896
Value * CreateICmp(CmpInst::Predicate P, Value *LHS, Value *RHS, const Twine &Name="")
Definition IRBuilder.h:2526
Value * CreateOr(Value *LHS, Value *RHS, const Twine &Name="", bool IsDisjoint=false)
Definition IRBuilder.h:1644
InstSimplifyFolder - Use InstructionSimplify to fold operations to existing values.
CostType getValue() const
This function is intended to be used as sparingly as possible, since the class provides the full rang...
void push(Instruction *I)
Push the instruction onto the worklist stack.
LLVM_ABI void setHasNoUnsignedWrap(bool b=true)
Set or clear the nuw flag on this instruction, which must be an operator which supports this flag.
LLVM_ABI void copyIRFlags(const Value *V, bool IncludeWrapFlags=true)
Convenience method to copy supported exact, fast-math, and (optionally) wrapping flags from V to this...
LLVM_ABI void setHasNoSignedWrap(bool b=true)
Set or clear the nsw flag on this instruction, which must be an operator which supports this flag.
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
LLVM_ABI void andIRFlags(const Value *V)
Logical 'and' of any supported wrapping, exact, and fast-math flags of V and this instruction.
bool isBinaryOp() const
LLVM_ABI void setNonNeg(bool b=true)
Set or clear the nneg flag on this instruction, which must be a zext instruction.
LLVM_ABI bool comesBefore(const Instruction *Other) const
Given an instruction Other in the same basic block as this instruction, return true if this instructi...
LLVM_ABI FastMathFlags getFastMathFlags() const LLVM_READONLY
Convenience function for getting all the fast-math flags, which must be an operator which supports th...
LLVM_ABI AAMDNodes getAAMetadata() const
Returns the AA metadata for this instruction.
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
bool isIdempotent() const
Return true if the instruction is idempotent:
LLVM_ABI void copyMetadata(const Instruction &SrcInst, ArrayRef< unsigned > WL=ArrayRef< unsigned >())
Copy metadata from SrcInst to this instruction.
LLVM_ABI bool hasAllowReassoc() const LLVM_READONLY
Determine whether the allow-reassociation flag is set.
bool isIntDivRem() const
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:350
unsigned getBitWidth() const
Get the number of bits in this IntegerType.
A wrapper class for inspecting calls to intrinsic functions.
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
An instruction for reading from memory.
unsigned getPointerAddressSpace() const
Returns the address space of the pointer operand.
void setAlignment(Align Align)
Type * getPointerOperandType() const
Align getAlign() const
Return the alignment of the access that is being performed.
Representation for a specific memory location.
static LLVM_ABI MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalyses & preserveSet()
Mark an analysis set as preserved.
Definition Analysis.h:151
const SDValue & getOperand(unsigned Num) const
bool contains(const_arg_type key) const
Check if the SetVector contains the given key.
Definition SetVector.h:252
bool empty() const
Determine if the SetVector is empty or not.
Definition SetVector.h:100
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition SetVector.h:151
This instruction constructs a fixed permutation of two input vectors.
int getMaskValue(unsigned Elt) const
Return the shuffle mask value of this instruction for the given element index.
VectorType * getType() const
Overload to return most specific vector type.
static LLVM_ABI void getShuffleMask(const Constant *Mask, SmallVectorImpl< int > &Result)
Convert the input shuffle mask operand to a vector of integers.
static LLVM_ABI bool isIdentityMask(ArrayRef< int > Mask, int NumSrcElts)
Return true if this shuffle mask chooses elements from exactly one source vector without lane crossin...
static void commuteShuffleMask(MutableArrayRef< int > Mask, unsigned InVecNumElts)
Change values in a shuffle permute mask assuming the two vector operands of length InVecNumElts have ...
size_type size() const
Definition SmallPtrSet.h:99
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
void assign(size_type NumElts, ValueParamT Elt)
reference emplace_back(ArgTypes &&... Args)
void reserve(size_type N)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
void setAlignment(Align Align)
Analysis pass providing the TargetTransformInfo.
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
static LLVM_ABI CastContextHint getCastContextHint(const Instruction *I)
Calculates a CastContextHint from I.
@ None
The insert/extract is not used with a load/store.
LLVM_ABI InstructionCost getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, CmpInst::Predicate VecPred, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, OperandValueInfo Op1Info={OK_AnyValue, OP_None}, OperandValueInfo Op2Info={OK_AnyValue, OP_None}, const Instruction *I=nullptr) const
LLVM_ABI TypeSize getRegisterBitWidth(RegisterKind K) const
static LLVM_ABI OperandValueInfo commonOperandInfo(const Value *X, const Value *Y)
Collect common data between two OperandValueInfo inputs.
LLVM_ABI InstructionCost getMemoryOpCost(unsigned Opcode, Type *Src, Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, OperandValueInfo OpdInfo={OK_AnyValue, OP_None}, const Instruction *I=nullptr) const
LLVM_ABI bool allowVectorElementIndexingUsingGEP() const
Returns true if GEP should not be used to index into vectors for this target.
LLVM_ABI InstructionCost getShuffleCost(ShuffleKind Kind, VectorType *DstTy, VectorType *SrcTy, ArrayRef< int > Mask={}, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, int Index=0, VectorType *SubTp=nullptr, ArrayRef< const Value * > Args={}, const Instruction *CxtI=nullptr) const
LLVM_ABI InstructionCost getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA, TTI::TargetCostKind CostKind) const
LLVM_ABI InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty, std::optional< FastMathFlags > FMF, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput) const
Calculate the cost of vector reduction intrinsics.
LLVM_ABI InstructionCost getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, TTI::CastContextHint CCH, TTI::TargetCostKind CostKind=TTI::TCK_SizeAndLatency, const Instruction *I=nullptr) const
LLVM_ABI InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index=-1, const Value *Op0=nullptr, const Value *Op1=nullptr, TTI::VectorInstrContext VIC=TTI::VectorInstrContext::None) const
LLVM_ABI unsigned getRegisterClassForType(bool Vector, Type *Ty=nullptr) const
LLVM_ABI InstructionCost getMinMaxReductionCost(Intrinsic::ID IID, VectorType *Ty, FastMathFlags FMF=FastMathFlags(), TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput) const
TargetCostKind
The kind of cost model.
@ TCK_RecipThroughput
Reciprocal throughput.
@ TCK_CodeSize
Instruction code size.
LLVM_ABI InstructionCost getArithmeticInstrCost(unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind=TTI::TCK_RecipThroughput, TTI::OperandValueInfo Opd1Info={TTI::OK_AnyValue, TTI::OP_None}, TTI::OperandValueInfo Opd2Info={TTI::OK_AnyValue, TTI::OP_None}, ArrayRef< const Value * > Args={}, const Instruction *CxtI=nullptr, const TargetLibraryInfo *TLibInfo=nullptr) const
This is an approximation of reciprocal throughput of a math/logic op.
LLVM_ABI InstructionCost getMemIntrinsicInstrCost(const MemIntrinsicCostAttributes &MICA, TTI::TargetCostKind CostKind) const
LLVM_ABI unsigned getMinVectorRegisterBitWidth() const
LLVM_ABI InstructionCost getAddressComputationCost(Type *PtrTy, ScalarEvolution *SE, const SCEV *Ptr, TTI::TargetCostKind CostKind) const
LLVM_ABI unsigned getNumberOfRegisters(unsigned ClassID) const
LLVM_ABI InstructionCost getInstructionCost(const User *U, ArrayRef< const Value * > Operands, TargetCostKind CostKind) const
Estimate the cost of a given IR user when lowered.
LLVM_ABI InstructionCost getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract, TTI::TargetCostKind CostKind, bool ForPoisonSrc=true, ArrayRef< Value * > VL={}, TTI::VectorInstrContext VIC=TTI::VectorInstrContext::None) const
Estimate the overhead of scalarizing an instruction.
ShuffleKind
The various kinds of shuffle patterns for vector queries.
@ SK_PermuteSingleSrc
Shuffle elements of single source vector with any shuffle mask.
@ SK_Broadcast
Broadcast element 0 to all other elements.
@ SK_PermuteTwoSrc
Merge elements from two source vectors into one with any shuffle mask.
@ SK_ExtractSubvector
ExtractSubvector Index indicates start offset.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:46
bool isPointerTy() const
True if this is an instance of PointerType.
Definition Type.h:282
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition Type.h:368
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:197
LLVMContext & getContext() const
Return the LLVMContext in which this type was uniqued.
Definition Type.h:130
LLVM_ABI unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
Definition Type.cpp:232
bool isFloatingPointTy() const
Return true if this is one of the floating-point types.
Definition Type.h:186
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition Type.h:257
bool isFPOrFPVectorTy() const
Return true if this is a FP type or a vector of FP.
Definition Type.h:227
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
op_range operands()
Definition User.h:267
Value * getOperand(unsigned i) const
Definition User.h:207
static LLVM_ABI bool isVPBinOp(Intrinsic::ID ID)
std::optional< unsigned > getFunctionalIntrinsicID() const
std::optional< unsigned > getFunctionalOpcode() const
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:255
const Value * stripAndAccumulateInBoundsConstantOffsets(const DataLayout &DL, APInt &Offset) const
This is a wrapper around stripAndAccumulateConstantOffsets with the in-bounds requirement set to fals...
Definition Value.h:727
LLVM_ABI bool hasOneUser() const
Return true if there is exactly one user of this value.
Definition Value.cpp:163
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition Value.h:439
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition Value.cpp:553
iterator_range< user_iterator > users()
Definition Value.h:426
LLVM_ABI Align getPointerAlignment(const DataLayout &DL) const
Returns an alignment of the pointer value.
Definition Value.cpp:993
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition Value.h:543
LLVM_ABI bool hasNUses(unsigned N) const
Return true if this Value has exactly N uses.
Definition Value.cpp:147
bool use_empty() const
Definition Value.h:346
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:319
bool user_empty() const
Definition Value.h:389
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &)
static LLVM_ABI VectorType * get(Type *ElementType, ElementCount EC)
This static method is the primary way to construct an VectorType.
Type * getElementType() const
std::pair< iterator, bool > insert(const ValueT &V)
Definition DenseSet.h:212
size_type size() const
Definition DenseSet.h:87
constexpr bool hasKnownScalarFactor(const FixedOrScalableQuantity &RHS) const
Returns true if there exists a value X where RHS.multiplyCoefficientBy(X) will result in a value whos...
Definition TypeSize.h:269
constexpr ScalarTy getKnownScalarFactor(const FixedOrScalableQuantity &RHS) const
Returns a value X where RHS.multiplyCoefficientBy(X) will result in a value whose quantity matches ou...
Definition TypeSize.h:277
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition TypeSize.h:168
const ParentTy * getParent() const
Definition ilist_node.h:34
self_iterator getIterator()
Definition ilist_node.h:123
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition ilist_node.h:348
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
Abstract Attribute helper functions.
Definition Attributor.h:165
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
constexpr char Attrs[]
Key for Kernel::Metadata::mAttrs.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition APInt.h:2277
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition APInt.h:2282
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI AttributeSet getFnAttributes(LLVMContext &C, ID id)
Return the function attributes for an intrinsic.
SpecificConstantMatch m_ZeroInt()
Convenience matchers for specific integer values.
BinaryOp_match< SpecificConstantMatch, SrcTy, TargetOpcode::G_SUB > m_Neg(const SrcTy &&Src)
Matches a register negated by a G_SUB.
OneUse_match< SubPat > m_OneUse(const SubPat &SP)
match_combine_and< Ty... > m_CombineAnd(const Ty &...Ps)
Combine pattern matchers matching all of Ps patterns.
cst_pred_ty< is_all_ones > m_AllOnes()
Match an integer or vector with all bits set.
BinaryOp_match< LHS, RHS, Instruction::And > m_And(const LHS &L, const RHS &R)
auto m_Cmp()
Matches any compare instruction and ignore it.
BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::URem > m_URem(const LHS &L, const RHS &R)
auto m_Poison()
Match an arbitrary poison constant.
ap_match< APInt > m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
CastInst_match< OpTy, TruncInst > m_Trunc(const OpTy &Op)
Matches Trunc.
specific_intval< false > m_SpecificInt(const APInt &V)
Match a specific integer value or vector with all elements equal to the value.
bool match(Val *V, const Pattern &P)
match_bind< Instruction > m_Instruction(Instruction *&I)
Match an instruction, capturing it if we match.
specificval_ty m_Specific(const Value *V)
Match if we have a specific specified value.
DisjointOr_match< LHS, RHS > m_DisjointOr(const LHS &L, const RHS &R)
BinOpPred_match< LHS, RHS, is_right_shift_op > m_Shr(const LHS &L, const RHS &R)
Matches logical shift operations.
TwoOps_match< Val_t, Idx_t, Instruction::ExtractElement > m_ExtractElt(const Val_t &Val, const Idx_t &Idx)
Matches ExtractElementInst.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
auto m_BinOp()
Match an arbitrary binary operation and ignore it.
auto m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
auto m_Constant()
Match an arbitrary Constant and ignore it.
TwoOps_match< V1_t, V2_t, Instruction::ShuffleVector > m_Shuffle(const V1_t &v1, const V2_t &v2)
Matches ShuffleVectorInst independently of mask value.
cst_pred_ty< is_non_zero_int > m_NonZeroInt()
Match a non-zero integer or a vector with all non-zero elements.
OneOps_match< OpTy, Instruction::Load > m_Load(const OpTy &Op)
Matches LoadInst.
CastInst_match< OpTy, ZExtInst > m_ZExt(const OpTy &Op)
Matches ZExt.
OverflowingBinaryOp_match< LHS, RHS, Instruction::Shl, OverflowingBinaryOperator::NoUnsignedWrap > m_NUWShl(const LHS &L, const RHS &R)
auto m_AnyIntrinsic()
Matches any intrinsic call and ignore it.
OverflowingBinaryOp_match< LHS, RHS, Instruction::Mul, OverflowingBinaryOperator::NoUnsignedWrap > m_NUWMul(const LHS &L, const RHS &R)
BinOpPred_match< LHS, RHS, is_bitwiselogic_op, true > m_c_BitwiseLogic(const LHS &L, const RHS &R)
Matches bitwise logic operations in either order.
CastOperator_match< OpTy, Instruction::BitCast > m_BitCast(const OpTy &Op)
Matches BitCast.
match_combine_or< CastInst_match< OpTy, SExtInst >, NNegZExt_match< OpTy > > m_SExtLike(const OpTy &Op)
Match either "sext" or "zext nneg".
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
CmpClass_match< LHS, RHS, ICmpInst > m_ICmp(CmpPredicate &Pred, const LHS &L, const RHS &R)
match_combine_or< CastInst_match< OpTy, ZExtInst >, CastInst_match< OpTy, SExtInst > > m_ZExtOrSExt(const OpTy &Op)
FNeg_match< OpTy > m_FNeg(const OpTy &X)
Match 'fneg X' as 'fsub -0.0, X'.
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_Undef()
Match an arbitrary undef constant.
CastInst_match< OpTy, SExtInst > m_SExt(const OpTy &Op)
Matches SExt.
is_zero m_Zero()
Match any null constant or a vector with all elements equal to 0.
BinaryOp_match< LHS, RHS, Instruction::Or, true > m_c_Or(const LHS &L, const RHS &R)
Matches an Or with LHS and RHS in either order.
ThreeOps_match< Val_t, Elt_t, Idx_t, Instruction::InsertElement > m_InsertElt(const Val_t &Val, const Elt_t &Elt, const Idx_t &Idx)
Matches InsertElementInst.
m_Intrinsic_Ty< Opnd >::Ty m_Deinterleave2(const Opnd &Op)
auto m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
@ Valid
The data is already valid.
initializer< Ty > init(const Ty &Val)
DXILDebugInfoMap run(Module &M)
@ User
could "use" a pointer
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
NodeAddr< UseNode * > Use
Definition RDFGraph.h:385
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:315
unsigned Log2_32_Ceil(uint32_t Value)
Return the ceil log base 2 of the specified value, 32 if the value is zero.
Definition MathExtras.h:344
@ Offset
Definition DWP.cpp:558
detail::zippy< detail::zip_shortest, T, U, Args... > zip(T &&t, U &&u, Args &&...args)
zip iterator for two or more iteratable types.
Definition STLExtras.h:830
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
void stable_sort(R &&Range)
Definition STLExtras.h:2115
UnaryFunction for_each(R &&Range, UnaryFunction F)
Provide wrappers to std::for_each which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1731
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1738
LLVM_ABI Intrinsic::ID getMinMaxReductionIntrinsicOp(Intrinsic::ID RdxID)
Returns the min/max intrinsic used when expanding a min/max reduction.
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
Definition Local.cpp:535
LLVM_ABI SDValue peekThroughBitcasts(SDValue V)
Return the non-bitcasted source operand of V if it exists.
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are tuples (A, B,...
Definition STLExtras.h:2553
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
LLVM_ABI Value * simplifyUnOp(unsigned Opcode, Value *Op, const SimplifyQuery &Q)
Given operand for a UnaryOperator, fold the result or return null.
scope_exit(Callable) -> scope_exit< Callable >
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
LLVM_ABI unsigned getArithmeticReductionInstruction(Intrinsic::ID RdxID)
Returns the arithmetic instruction opcode used when expanding a reduction.
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition STLExtras.h:2207
constexpr bool isUIntN(unsigned N, uint64_t x)
Checks if an unsigned integer fits into the given (dynamic) bit width.
Definition MathExtras.h:243
LLVM_ABI Value * simplifyCall(CallBase *Call, Value *Callee, ArrayRef< Value * > Args, const SimplifyQuery &Q)
Given a callsite, callee, and arguments, fold the result or return null.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition STLExtras.h:633
LLVM_ABI bool mustSuppressSpeculation(const LoadInst &LI)
Return true if speculation of the given load must be suppressed to avoid ordering or interfering with...
Definition Loads.cpp:445
LLVM_ABI bool widenShuffleMaskElts(int Scale, ArrayRef< int > Mask, SmallVectorImpl< int > &ScaledMask)
Try to transform a shuffle mask by replacing elements with the scaled index for an equivalent mask of...
LLVM_ABI bool isSafeToSpeculativelyExecute(const Instruction *I, const Instruction *CtxI=nullptr, AssumptionCache *AC=nullptr, const DominatorTree *DT=nullptr, const TargetLibraryInfo *TLI=nullptr, bool UseVariableInfo=true, bool IgnoreUBImplyingAttrs=true)
Return true if the instruction does not have any effects besides calculating the result and does not ...
LLVM_ABI Value * getSplatValue(const Value *V)
Get splat value if the input is a splat vector or return nullptr.
unsigned M1(unsigned Val)
Definition VE.h:377
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1745
LLVM_ABI bool isInstructionTriviallyDead(Instruction *I, const TargetLibraryInfo *TLI=nullptr)
Return true if the result produced by the instruction is not used, and the instruction will return.
Definition Local.cpp:403
LLVM_ABI bool isSplatValue(const Value *V, int Index=-1, unsigned Depth=0)
Return true if each element of the vector value V is poisoned or equal to every other non-poisoned el...
unsigned Log2_32(uint32_t Value)
Return the floor log base 2 of the specified value, -1 if the value is zero.
Definition MathExtras.h:331
auto reverse(ContainerTy &&C)
Definition STLExtras.h:407
constexpr bool isPowerOf2_32(uint32_t Value)
Return true if the argument is a power of two > 0.
Definition MathExtras.h:279
bool isModSet(const ModRefInfo MRI)
Definition ModRef.h:49
void sort(IteratorTy Start, IteratorTy End)
Definition STLExtras.h:1635
LLVM_ABI void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
LLVM_ABI bool programUndefinedIfPoison(const Instruction *Inst)
LLVM_ABI bool isSafeToLoadUnconditionally(Value *V, Align Alignment, const APInt &Size, const DataLayout &DL, Instruction *ScanFrom, AssumptionCache *AC=nullptr, const DominatorTree *DT=nullptr, const TargetLibraryInfo *TLI=nullptr)
Return true if we know that executing a load from this value cannot trap.
Definition Loads.cpp:449
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:209
FunctionAddr VTableAddr Count
Definition InstrProf.h:139
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
LLVM_ABI void propagateIRFlags(Value *I, ArrayRef< Value * > VL, Value *OpValue=nullptr, bool IncludeWrapFlags=true)
Get the intersection (logical and) of all of the potential IR flags of each scalar operation (VL) tha...
LLVM_ABI bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
MutableArrayRef(T &OneElt) -> MutableArrayRef< T >
constexpr int PoisonMaskElem
LLVM_ABI bool isSafeToSpeculativelyExecuteWithOpcode(unsigned Opcode, const Instruction *Inst, const Instruction *CtxI=nullptr, AssumptionCache *AC=nullptr, const DominatorTree *DT=nullptr, const TargetLibraryInfo *TLI=nullptr, bool UseVariableInfo=true, bool IgnoreUBImplyingAttrs=true)
This returns the same result as isSafeToSpeculativelyExecute if Opcode is the actual opcode of Inst.
@ Other
Any other memory.
Definition ModRef.h:68
TargetTransformInfo TTI
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
LLVM_ABI Value * simplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, const SimplifyQuery &Q)
Given operands for a BinaryOperator, fold the result or return null.
LLVM_ABI void narrowShuffleMaskElts(int Scale, ArrayRef< int > Mask, SmallVectorImpl< int > &ScaledMask)
Replace each shuffle mask index with the scaled sequential indices for an equivalent mask of narrowed...
LLVM_ABI Intrinsic::ID getReductionForBinop(Instruction::BinaryOps Opc)
Returns the reduction intrinsic id corresponding to the binary operation.
@ And
Bitwise or logical AND of integers.
LLVM_ABI bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, unsigned ScalarOpdIdx, const TargetTransformInfo *TTI)
Identifies if the vector form of the intrinsic has a scalar operand.
DWARFExpression::Operation Op
unsigned M0(unsigned Val)
Definition VE.h:376
LLVM_ABI unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Return the number of times the sign bit of the register is replicated into the other bits.
constexpr unsigned BitWidth
LLVM_ABI bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
LLVM_ABI Constant * getLosslessInvCast(Constant *C, Type *InvCastTo, unsigned CastOp, const DataLayout &DL, PreservedCastFlags *Flags=nullptr)
Try to cast C to InvC losslessly, satisfying CastOp(InvC) equals C, or CastOp(InvC) is a refined valu...
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
auto find_if(R &&Range, UnaryPredicate P)
Provide wrappers to std::find_if which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1771
constexpr bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition MathExtras.h:248
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1946
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Definition Alignment.h:201
bool all_equal(std::initializer_list< T > Values)
Returns true if all Values in the initializer lists are equal or the list.
Definition STLExtras.h:2165
LLVM_ABI Value * simplifyCmpInst(CmpPredicate Predicate, Value *LHS, Value *RHS, const SimplifyQuery &Q)
Given operands for a CmpInst, fold the result or return null.
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
LLVM_ABI bool isKnownNonNegative(const Value *V, const SimplifyQuery &SQ, unsigned Depth=0)
Returns true if the give value is known to be non-negative.
LLVM_ABI bool isTriviallyVectorizable(Intrinsic::ID ID)
Identify if the intrinsic is trivially vectorizable.
LLVM_ABI Intrinsic::ID getMinMaxReductionIntrinsicID(Intrinsic::ID IID)
Returns the llvm.vector.reduce min/max intrinsic that corresponds to the intrinsic op.
LLVM_ABI ConstantRange computeConstantRange(const Value *V, bool ForSigned, const SimplifyQuery &SQ, unsigned Depth=0)
Determine the possible constant range of an integer or vector of integer value.
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:862
#define N
LLVM_ABI AAMDNodes adjustForAccess(unsigned AccessSize)
Create a new AAMDNode for accessing AccessSize bytes of this AAMDNode.
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition Alignment.h:39
unsigned countMaxActiveBits() const
Returns the maximum number of bits needed to represent all possible unsigned values with these known ...
Definition KnownBits.h:310
unsigned countMinLeadingZeros() const
Returns the minimum number of leading zero bits.
Definition KnownBits.h:262
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition KnownBits.h:146
const DataLayout & DL
const Instruction * CxtI
const DominatorTree * DT
SimplifyQuery getWithInstruction(const Instruction *I) const
AssumptionCache * AC