LLVM 20.0.0git
ScalarizeMaskedMemIntrin.cpp
Go to the documentation of this file.
1//===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2// intrinsics
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9//
10// This pass replaces masked memory intrinsics - when unsupported by the target
11// - with a chain of basic blocks, that deal with the elements one-by-one if the
12// appropriate mask bit is set.
13//
14//===----------------------------------------------------------------------===//
15
17#include "llvm/ADT/Twine.h"
21#include "llvm/IR/BasicBlock.h"
22#include "llvm/IR/Constant.h"
23#include "llvm/IR/Constants.h"
25#include "llvm/IR/Dominators.h"
26#include "llvm/IR/Function.h"
27#include "llvm/IR/IRBuilder.h"
28#include "llvm/IR/Instruction.h"
31#include "llvm/IR/Type.h"
32#include "llvm/IR/Value.h"
34#include "llvm/Pass.h"
38#include <cassert>
39#include <optional>
40
41using namespace llvm;
42
43#define DEBUG_TYPE "scalarize-masked-mem-intrin"
44
45namespace {
46
47class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass {
48public:
49 static char ID; // Pass identification, replacement for typeid
50
51 explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID) {
54 }
55
56 bool runOnFunction(Function &F) override;
57
58 StringRef getPassName() const override {
59 return "Scalarize Masked Memory Intrinsics";
60 }
61
62 void getAnalysisUsage(AnalysisUsage &AU) const override {
65 }
66};
67
68} // end anonymous namespace
69
70static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
71 const TargetTransformInfo &TTI, const DataLayout &DL,
72 bool HasBranchDivergence, DomTreeUpdater *DTU);
73static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
75 const DataLayout &DL, bool HasBranchDivergence,
76 DomTreeUpdater *DTU);
77
78char ScalarizeMaskedMemIntrinLegacyPass::ID = 0;
79
80INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
81 "Scalarize unsupported masked memory intrinsics", false,
82 false)
85INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
86 "Scalarize unsupported masked memory intrinsics", false,
87 false)
88
90 return new ScalarizeMaskedMemIntrinLegacyPass();
91}
92
93static bool isConstantIntVector(Value *Mask) {
94 Constant *C = dyn_cast<Constant>(Mask);
95 if (!C)
96 return false;
97
98 unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements();
99 for (unsigned i = 0; i != NumElts; ++i) {
100 Constant *CElt = C->getAggregateElement(i);
101 if (!CElt || !isa<ConstantInt>(CElt))
102 return false;
103 }
104
105 return true;
106}
107
108static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth,
109 unsigned Idx) {
110 return DL.isBigEndian() ? VectorWidth - 1 - Idx : Idx;
111}
112
113// Translate a masked load intrinsic like
114// <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
115// <16 x i1> %mask, <16 x i32> %passthru)
116// to a chain of basic blocks, with loading element one-by-one if
117// the appropriate mask bit is set
118//
119// %1 = bitcast i8* %addr to i32*
120// %2 = extractelement <16 x i1> %mask, i32 0
121// br i1 %2, label %cond.load, label %else
122//
123// cond.load: ; preds = %0
124// %3 = getelementptr i32* %1, i32 0
125// %4 = load i32* %3
126// %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
127// br label %else
128//
129// else: ; preds = %0, %cond.load
130// %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ poison, %0 ]
131// %6 = extractelement <16 x i1> %mask, i32 1
132// br i1 %6, label %cond.load1, label %else2
133//
134// cond.load1: ; preds = %else
135// %7 = getelementptr i32* %1, i32 1
136// %8 = load i32* %7
137// %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
138// br label %else2
139//
140// else2: ; preds = %else, %cond.load1
141// %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
142// %10 = extractelement <16 x i1> %mask, i32 2
143// br i1 %10, label %cond.load4, label %else5
144//
145static void scalarizeMaskedLoad(const DataLayout &DL, bool HasBranchDivergence,
146 CallInst *CI, DomTreeUpdater *DTU,
147 bool &ModifiedDT) {
148 Value *Ptr = CI->getArgOperand(0);
149 Value *Alignment = CI->getArgOperand(1);
150 Value *Mask = CI->getArgOperand(2);
151 Value *Src0 = CI->getArgOperand(3);
152
153 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
154 VectorType *VecType = cast<FixedVectorType>(CI->getType());
155
156 Type *EltTy = VecType->getElementType();
157
158 IRBuilder<> Builder(CI->getContext());
159 Instruction *InsertPt = CI;
160 BasicBlock *IfBlock = CI->getParent();
161
162 Builder.SetInsertPoint(InsertPt);
164
165 // Short-cut if the mask is all-true.
166 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
167 LoadInst *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
168 NewI->copyMetadata(*CI);
169 NewI->takeName(CI);
170 CI->replaceAllUsesWith(NewI);
171 CI->eraseFromParent();
172 return;
173 }
174
175 // Adjust alignment for the scalar instruction.
176 const Align AdjustedAlignVal =
177 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
178 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
179
180 // The result vector
181 Value *VResult = Src0;
182
183 if (isConstantIntVector(Mask)) {
184 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
185 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
186 continue;
187 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx);
188 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
189 VResult = Builder.CreateInsertElement(VResult, Load, Idx);
190 }
191 CI->replaceAllUsesWith(VResult);
192 CI->eraseFromParent();
193 return;
194 }
195
196 // Optimize the case where the "masked load" is a predicated load - that is,
197 // where the mask is the splat of a non-constant scalar boolean. In that case,
198 // use that splated value as the guard on a conditional vector load.
199 if (isSplatValue(Mask, /*Index=*/0)) {
200 Value *Predicate = Builder.CreateExtractElement(Mask, uint64_t(0ull),
201 Mask->getName() + ".first");
202 Instruction *ThenTerm =
203 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
204 /*BranchWeights=*/nullptr, DTU);
205
206 BasicBlock *CondBlock = ThenTerm->getParent();
207 CondBlock->setName("cond.load");
208 Builder.SetInsertPoint(CondBlock->getTerminator());
209 LoadInst *Load = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal,
210 CI->getName() + ".cond.load");
211 Load->copyMetadata(*CI);
212
213 BasicBlock *PostLoad = ThenTerm->getSuccessor(0);
214 Builder.SetInsertPoint(PostLoad, PostLoad->begin());
215 PHINode *Phi = Builder.CreatePHI(VecType, /*NumReservedValues=*/2);
216 Phi->addIncoming(Load, CondBlock);
217 Phi->addIncoming(Src0, IfBlock);
218 Phi->takeName(CI);
219
220 CI->replaceAllUsesWith(Phi);
221 CI->eraseFromParent();
222 ModifiedDT = true;
223 return;
224 }
225 // If the mask is not v1i1, use scalar bit test operations. This generates
226 // better results on X86 at least. However, don't do this on GPUs and other
227 // machines with divergence, as there each i1 needs a vector register.
228 Value *SclrMask = nullptr;
229 if (VectorWidth != 1 && !HasBranchDivergence) {
230 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
231 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
232 }
233
234 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
235 // Fill the "else" block, created in the previous iteration
236 //
237 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else,
238 // %else ] %mask_1 = and i16 %scalar_mask, i32 1 << Idx %cond = icmp ne i16
239 // %mask_1, 0 br i1 %mask_1, label %cond.load, label %else
240 //
241 // On GPUs, use
242 // %cond = extrectelement %mask, Idx
243 // instead
244 Value *Predicate;
245 if (SclrMask != nullptr) {
246 Value *Mask = Builder.getInt(APInt::getOneBitSet(
247 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
248 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
249 Builder.getIntN(VectorWidth, 0));
250 } else {
251 Predicate = Builder.CreateExtractElement(Mask, Idx);
252 }
253
254 // Create "cond" block
255 //
256 // %EltAddr = getelementptr i32* %1, i32 0
257 // %Elt = load i32* %EltAddr
258 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
259 //
260 Instruction *ThenTerm =
261 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
262 /*BranchWeights=*/nullptr, DTU);
263
264 BasicBlock *CondBlock = ThenTerm->getParent();
265 CondBlock->setName("cond.load");
266
267 Builder.SetInsertPoint(CondBlock->getTerminator());
268 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx);
269 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
270 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
271
272 // Create "else" block, fill it in the next iteration
273 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
274 NewIfBlock->setName("else");
275 BasicBlock *PrevIfBlock = IfBlock;
276 IfBlock = NewIfBlock;
277
278 // Create the phi to join the new and previous value.
279 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
280 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
281 Phi->addIncoming(NewVResult, CondBlock);
282 Phi->addIncoming(VResult, PrevIfBlock);
283 VResult = Phi;
284 }
285
286 CI->replaceAllUsesWith(VResult);
287 CI->eraseFromParent();
288
289 ModifiedDT = true;
290}
291
292// Translate a masked store intrinsic, like
293// void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
294// <16 x i1> %mask)
295// to a chain of basic blocks, that stores element one-by-one if
296// the appropriate mask bit is set
297//
298// %1 = bitcast i8* %addr to i32*
299// %2 = extractelement <16 x i1> %mask, i32 0
300// br i1 %2, label %cond.store, label %else
301//
302// cond.store: ; preds = %0
303// %3 = extractelement <16 x i32> %val, i32 0
304// %4 = getelementptr i32* %1, i32 0
305// store i32 %3, i32* %4
306// br label %else
307//
308// else: ; preds = %0, %cond.store
309// %5 = extractelement <16 x i1> %mask, i32 1
310// br i1 %5, label %cond.store1, label %else2
311//
312// cond.store1: ; preds = %else
313// %6 = extractelement <16 x i32> %val, i32 1
314// %7 = getelementptr i32* %1, i32 1
315// store i32 %6, i32* %7
316// br label %else2
317// . . .
318static void scalarizeMaskedStore(const DataLayout &DL, bool HasBranchDivergence,
319 CallInst *CI, DomTreeUpdater *DTU,
320 bool &ModifiedDT) {
321 Value *Src = CI->getArgOperand(0);
322 Value *Ptr = CI->getArgOperand(1);
323 Value *Alignment = CI->getArgOperand(2);
324 Value *Mask = CI->getArgOperand(3);
325
326 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
327 auto *VecType = cast<VectorType>(Src->getType());
328
329 Type *EltTy = VecType->getElementType();
330
331 IRBuilder<> Builder(CI->getContext());
332 Instruction *InsertPt = CI;
333 Builder.SetInsertPoint(InsertPt);
335
336 // Short-cut if the mask is all-true.
337 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
338 StoreInst *Store = Builder.CreateAlignedStore(Src, Ptr, AlignVal);
339 Store->takeName(CI);
340 Store->copyMetadata(*CI);
341 CI->eraseFromParent();
342 return;
343 }
344
345 // Adjust alignment for the scalar instruction.
346 const Align AdjustedAlignVal =
347 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
348 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
349
350 if (isConstantIntVector(Mask)) {
351 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
352 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
353 continue;
354 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
355 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx);
356 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
357 }
358 CI->eraseFromParent();
359 return;
360 }
361
362 // Optimize the case where the "masked store" is a predicated store - that is,
363 // when the mask is the splat of a non-constant scalar boolean. In that case,
364 // optimize to a conditional store.
365 if (isSplatValue(Mask, /*Index=*/0)) {
366 Value *Predicate = Builder.CreateExtractElement(Mask, uint64_t(0ull),
367 Mask->getName() + ".first");
368 Instruction *ThenTerm =
369 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
370 /*BranchWeights=*/nullptr, DTU);
371 BasicBlock *CondBlock = ThenTerm->getParent();
372 CondBlock->setName("cond.store");
373 Builder.SetInsertPoint(CondBlock->getTerminator());
374
375 StoreInst *Store = Builder.CreateAlignedStore(Src, Ptr, AlignVal);
376 Store->takeName(CI);
377 Store->copyMetadata(*CI);
378
379 CI->eraseFromParent();
380 ModifiedDT = true;
381 return;
382 }
383
384 // If the mask is not v1i1, use scalar bit test operations. This generates
385 // better results on X86 at least. However, don't do this on GPUs or other
386 // machines with branch divergence, as there each i1 takes up a register.
387 Value *SclrMask = nullptr;
388 if (VectorWidth != 1 && !HasBranchDivergence) {
389 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
390 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
391 }
392
393 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
394 // Fill the "else" block, created in the previous iteration
395 //
396 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
397 // %cond = icmp ne i16 %mask_1, 0
398 // br i1 %mask_1, label %cond.store, label %else
399 //
400 // On GPUs, use
401 // %cond = extrectelement %mask, Idx
402 // instead
403 Value *Predicate;
404 if (SclrMask != nullptr) {
405 Value *Mask = Builder.getInt(APInt::getOneBitSet(
406 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
407 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
408 Builder.getIntN(VectorWidth, 0));
409 } else {
410 Predicate = Builder.CreateExtractElement(Mask, Idx);
411 }
412
413 // Create "cond" block
414 //
415 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
416 // %EltAddr = getelementptr i32* %1, i32 0
417 // %store i32 %OneElt, i32* %EltAddr
418 //
419 Instruction *ThenTerm =
420 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
421 /*BranchWeights=*/nullptr, DTU);
422
423 BasicBlock *CondBlock = ThenTerm->getParent();
424 CondBlock->setName("cond.store");
425
426 Builder.SetInsertPoint(CondBlock->getTerminator());
427 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
428 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx);
429 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
430
431 // Create "else" block, fill it in the next iteration
432 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
433 NewIfBlock->setName("else");
434
435 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
436 }
437 CI->eraseFromParent();
438
439 ModifiedDT = true;
440}
441
442// Translate a masked gather intrinsic like
443// <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
444// <16 x i1> %Mask, <16 x i32> %Src)
445// to a chain of basic blocks, with loading element one-by-one if
446// the appropriate mask bit is set
447//
448// %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
449// %Mask0 = extractelement <16 x i1> %Mask, i32 0
450// br i1 %Mask0, label %cond.load, label %else
451//
452// cond.load:
453// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
454// %Load0 = load i32, i32* %Ptr0, align 4
455// %Res0 = insertelement <16 x i32> poison, i32 %Load0, i32 0
456// br label %else
457//
458// else:
459// %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [poison, %0]
460// %Mask1 = extractelement <16 x i1> %Mask, i32 1
461// br i1 %Mask1, label %cond.load1, label %else2
462//
463// cond.load1:
464// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
465// %Load1 = load i32, i32* %Ptr1, align 4
466// %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
467// br label %else2
468// . . .
469// %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
470// ret <16 x i32> %Result
472 bool HasBranchDivergence, CallInst *CI,
473 DomTreeUpdater *DTU, bool &ModifiedDT) {
474 Value *Ptrs = CI->getArgOperand(0);
475 Value *Alignment = CI->getArgOperand(1);
476 Value *Mask = CI->getArgOperand(2);
477 Value *Src0 = CI->getArgOperand(3);
478
479 auto *VecType = cast<FixedVectorType>(CI->getType());
480 Type *EltTy = VecType->getElementType();
481
482 IRBuilder<> Builder(CI->getContext());
483 Instruction *InsertPt = CI;
484 BasicBlock *IfBlock = CI->getParent();
485 Builder.SetInsertPoint(InsertPt);
486 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
487
489
490 // The result vector
491 Value *VResult = Src0;
492 unsigned VectorWidth = VecType->getNumElements();
493
494 // Shorten the way if the mask is a vector of constants.
495 if (isConstantIntVector(Mask)) {
496 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
497 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
498 continue;
499 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
500 LoadInst *Load =
501 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
502 VResult =
503 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
504 }
505 CI->replaceAllUsesWith(VResult);
506 CI->eraseFromParent();
507 return;
508 }
509
510 // If the mask is not v1i1, use scalar bit test operations. This generates
511 // better results on X86 at least. However, don't do this on GPUs or other
512 // machines with branch divergence, as there, each i1 takes up a register.
513 Value *SclrMask = nullptr;
514 if (VectorWidth != 1 && !HasBranchDivergence) {
515 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
516 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
517 }
518
519 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
520 // Fill the "else" block, created in the previous iteration
521 //
522 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
523 // %cond = icmp ne i16 %mask_1, 0
524 // br i1 %Mask1, label %cond.load, label %else
525 //
526 // On GPUs, use
527 // %cond = extrectelement %mask, Idx
528 // instead
529
530 Value *Predicate;
531 if (SclrMask != nullptr) {
532 Value *Mask = Builder.getInt(APInt::getOneBitSet(
533 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
534 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
535 Builder.getIntN(VectorWidth, 0));
536 } else {
537 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
538 }
539
540 // Create "cond" block
541 //
542 // %EltAddr = getelementptr i32* %1, i32 0
543 // %Elt = load i32* %EltAddr
544 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
545 //
546 Instruction *ThenTerm =
547 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
548 /*BranchWeights=*/nullptr, DTU);
549
550 BasicBlock *CondBlock = ThenTerm->getParent();
551 CondBlock->setName("cond.load");
552
553 Builder.SetInsertPoint(CondBlock->getTerminator());
554 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
555 LoadInst *Load =
556 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
557 Value *NewVResult =
558 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
559
560 // Create "else" block, fill it in the next iteration
561 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
562 NewIfBlock->setName("else");
563 BasicBlock *PrevIfBlock = IfBlock;
564 IfBlock = NewIfBlock;
565
566 // Create the phi to join the new and previous value.
567 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
568 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
569 Phi->addIncoming(NewVResult, CondBlock);
570 Phi->addIncoming(VResult, PrevIfBlock);
571 VResult = Phi;
572 }
573
574 CI->replaceAllUsesWith(VResult);
575 CI->eraseFromParent();
576
577 ModifiedDT = true;
578}
579
580// Translate a masked scatter intrinsic, like
581// void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
582// <16 x i1> %Mask)
583// to a chain of basic blocks, that stores element one-by-one if
584// the appropriate mask bit is set.
585//
586// %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
587// %Mask0 = extractelement <16 x i1> %Mask, i32 0
588// br i1 %Mask0, label %cond.store, label %else
589//
590// cond.store:
591// %Elt0 = extractelement <16 x i32> %Src, i32 0
592// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
593// store i32 %Elt0, i32* %Ptr0, align 4
594// br label %else
595//
596// else:
597// %Mask1 = extractelement <16 x i1> %Mask, i32 1
598// br i1 %Mask1, label %cond.store1, label %else2
599//
600// cond.store1:
601// %Elt1 = extractelement <16 x i32> %Src, i32 1
602// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
603// store i32 %Elt1, i32* %Ptr1, align 4
604// br label %else2
605// . . .
607 bool HasBranchDivergence, CallInst *CI,
608 DomTreeUpdater *DTU, bool &ModifiedDT) {
609 Value *Src = CI->getArgOperand(0);
610 Value *Ptrs = CI->getArgOperand(1);
611 Value *Alignment = CI->getArgOperand(2);
612 Value *Mask = CI->getArgOperand(3);
613
614 auto *SrcFVTy = cast<FixedVectorType>(Src->getType());
615
616 assert(
617 isa<VectorType>(Ptrs->getType()) &&
618 isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) &&
619 "Vector of pointers is expected in masked scatter intrinsic");
620
621 IRBuilder<> Builder(CI->getContext());
622 Instruction *InsertPt = CI;
623 Builder.SetInsertPoint(InsertPt);
625
626 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
627 unsigned VectorWidth = SrcFVTy->getNumElements();
628
629 // Shorten the way if the mask is a vector of constants.
630 if (isConstantIntVector(Mask)) {
631 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
632 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
633 continue;
634 Value *OneElt =
635 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
636 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
637 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
638 }
639 CI->eraseFromParent();
640 return;
641 }
642
643 // If the mask is not v1i1, use scalar bit test operations. This generates
644 // better results on X86 at least.
645 Value *SclrMask = nullptr;
646 if (VectorWidth != 1 && !HasBranchDivergence) {
647 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
648 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
649 }
650
651 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
652 // Fill the "else" block, created in the previous iteration
653 //
654 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
655 // %cond = icmp ne i16 %mask_1, 0
656 // br i1 %Mask1, label %cond.store, label %else
657 //
658 // On GPUs, use
659 // %cond = extrectelement %mask, Idx
660 // instead
661 Value *Predicate;
662 if (SclrMask != nullptr) {
663 Value *Mask = Builder.getInt(APInt::getOneBitSet(
664 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
665 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
666 Builder.getIntN(VectorWidth, 0));
667 } else {
668 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
669 }
670
671 // Create "cond" block
672 //
673 // %Elt1 = extractelement <16 x i32> %Src, i32 1
674 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
675 // %store i32 %Elt1, i32* %Ptr1
676 //
677 Instruction *ThenTerm =
678 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
679 /*BranchWeights=*/nullptr, DTU);
680
681 BasicBlock *CondBlock = ThenTerm->getParent();
682 CondBlock->setName("cond.store");
683
684 Builder.SetInsertPoint(CondBlock->getTerminator());
685 Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
686 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
687 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
688
689 // Create "else" block, fill it in the next iteration
690 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
691 NewIfBlock->setName("else");
692
693 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
694 }
695 CI->eraseFromParent();
696
697 ModifiedDT = true;
698}
699
701 bool HasBranchDivergence, CallInst *CI,
702 DomTreeUpdater *DTU, bool &ModifiedDT) {
703 Value *Ptr = CI->getArgOperand(0);
704 Value *Mask = CI->getArgOperand(1);
705 Value *PassThru = CI->getArgOperand(2);
706 Align Alignment = CI->getParamAlign(0).valueOrOne();
707
708 auto *VecType = cast<FixedVectorType>(CI->getType());
709
710 Type *EltTy = VecType->getElementType();
711
712 IRBuilder<> Builder(CI->getContext());
713 Instruction *InsertPt = CI;
714 BasicBlock *IfBlock = CI->getParent();
715
716 Builder.SetInsertPoint(InsertPt);
718
719 unsigned VectorWidth = VecType->getNumElements();
720
721 // The result vector
722 Value *VResult = PassThru;
723
724 // Adjust alignment for the scalar instruction.
725 const Align AdjustedAlignment =
726 commonAlignment(Alignment, EltTy->getPrimitiveSizeInBits() / 8);
727
728 // Shorten the way if the mask is a vector of constants.
729 // Create a build_vector pattern, with loads/poisons as necessary and then
730 // shuffle blend with the pass through value.
731 if (isConstantIntVector(Mask)) {
732 unsigned MemIndex = 0;
733 VResult = PoisonValue::get(VecType);
734 SmallVector<int, 16> ShuffleMask(VectorWidth, PoisonMaskElem);
735 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
736 Value *InsertElt;
737 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) {
738 InsertElt = PoisonValue::get(EltTy);
739 ShuffleMask[Idx] = Idx + VectorWidth;
740 } else {
741 Value *NewPtr =
742 Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
743 InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, AdjustedAlignment,
744 "Load" + Twine(Idx));
745 ShuffleMask[Idx] = Idx;
746 ++MemIndex;
747 }
748 VResult = Builder.CreateInsertElement(VResult, InsertElt, Idx,
749 "Res" + Twine(Idx));
750 }
751 VResult = Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask);
752 CI->replaceAllUsesWith(VResult);
753 CI->eraseFromParent();
754 return;
755 }
756
757 // If the mask is not v1i1, use scalar bit test operations. This generates
758 // better results on X86 at least. However, don't do this on GPUs or other
759 // machines with branch divergence, as there, each i1 takes up a register.
760 Value *SclrMask = nullptr;
761 if (VectorWidth != 1 && !HasBranchDivergence) {
762 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
763 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
764 }
765
766 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
767 // Fill the "else" block, created in the previous iteration
768 //
769 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else,
770 // %else ] %mask_1 = extractelement <16 x i1> %mask, i32 Idx br i1 %mask_1,
771 // label %cond.load, label %else
772 //
773 // On GPUs, use
774 // %cond = extrectelement %mask, Idx
775 // instead
776
777 Value *Predicate;
778 if (SclrMask != nullptr) {
779 Value *Mask = Builder.getInt(APInt::getOneBitSet(
780 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
781 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
782 Builder.getIntN(VectorWidth, 0));
783 } else {
784 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
785 }
786
787 // Create "cond" block
788 //
789 // %EltAddr = getelementptr i32* %1, i32 0
790 // %Elt = load i32* %EltAddr
791 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
792 //
793 Instruction *ThenTerm =
794 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
795 /*BranchWeights=*/nullptr, DTU);
796
797 BasicBlock *CondBlock = ThenTerm->getParent();
798 CondBlock->setName("cond.load");
799
800 Builder.SetInsertPoint(CondBlock->getTerminator());
801 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, AdjustedAlignment);
802 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
803
804 // Move the pointer if there are more blocks to come.
805 Value *NewPtr;
806 if ((Idx + 1) != VectorWidth)
807 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
808
809 // Create "else" block, fill it in the next iteration
810 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
811 NewIfBlock->setName("else");
812 BasicBlock *PrevIfBlock = IfBlock;
813 IfBlock = NewIfBlock;
814
815 // Create the phi to join the new and previous value.
816 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
817 PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
818 ResultPhi->addIncoming(NewVResult, CondBlock);
819 ResultPhi->addIncoming(VResult, PrevIfBlock);
820 VResult = ResultPhi;
821
822 // Add a PHI for the pointer if this isn't the last iteration.
823 if ((Idx + 1) != VectorWidth) {
824 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
825 PtrPhi->addIncoming(NewPtr, CondBlock);
826 PtrPhi->addIncoming(Ptr, PrevIfBlock);
827 Ptr = PtrPhi;
828 }
829 }
830
831 CI->replaceAllUsesWith(VResult);
832 CI->eraseFromParent();
833
834 ModifiedDT = true;
835}
836
838 bool HasBranchDivergence, CallInst *CI,
839 DomTreeUpdater *DTU,
840 bool &ModifiedDT) {
841 Value *Src = CI->getArgOperand(0);
842 Value *Ptr = CI->getArgOperand(1);
843 Value *Mask = CI->getArgOperand(2);
844 Align Alignment = CI->getParamAlign(1).valueOrOne();
845
846 auto *VecType = cast<FixedVectorType>(Src->getType());
847
848 IRBuilder<> Builder(CI->getContext());
849 Instruction *InsertPt = CI;
850 BasicBlock *IfBlock = CI->getParent();
851
852 Builder.SetInsertPoint(InsertPt);
854
855 Type *EltTy = VecType->getElementType();
856
857 // Adjust alignment for the scalar instruction.
858 const Align AdjustedAlignment =
859 commonAlignment(Alignment, EltTy->getPrimitiveSizeInBits() / 8);
860
861 unsigned VectorWidth = VecType->getNumElements();
862
863 // Shorten the way if the mask is a vector of constants.
864 if (isConstantIntVector(Mask)) {
865 unsigned MemIndex = 0;
866 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
867 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
868 continue;
869 Value *OneElt =
870 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
871 Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
872 Builder.CreateAlignedStore(OneElt, NewPtr, AdjustedAlignment);
873 ++MemIndex;
874 }
875 CI->eraseFromParent();
876 return;
877 }
878
879 // If the mask is not v1i1, use scalar bit test operations. This generates
880 // better results on X86 at least. However, don't do this on GPUs or other
881 // machines with branch divergence, as there, each i1 takes up a register.
882 Value *SclrMask = nullptr;
883 if (VectorWidth != 1 && !HasBranchDivergence) {
884 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
885 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
886 }
887
888 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
889 // Fill the "else" block, created in the previous iteration
890 //
891 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
892 // br i1 %mask_1, label %cond.store, label %else
893 //
894 // On GPUs, use
895 // %cond = extrectelement %mask, Idx
896 // instead
897 Value *Predicate;
898 if (SclrMask != nullptr) {
899 Value *Mask = Builder.getInt(APInt::getOneBitSet(
900 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
901 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
902 Builder.getIntN(VectorWidth, 0));
903 } else {
904 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
905 }
906
907 // Create "cond" block
908 //
909 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
910 // %EltAddr = getelementptr i32* %1, i32 0
911 // %store i32 %OneElt, i32* %EltAddr
912 //
913 Instruction *ThenTerm =
914 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
915 /*BranchWeights=*/nullptr, DTU);
916
917 BasicBlock *CondBlock = ThenTerm->getParent();
918 CondBlock->setName("cond.store");
919
920 Builder.SetInsertPoint(CondBlock->getTerminator());
921 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
922 Builder.CreateAlignedStore(OneElt, Ptr, AdjustedAlignment);
923
924 // Move the pointer if there are more blocks to come.
925 Value *NewPtr;
926 if ((Idx + 1) != VectorWidth)
927 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
928
929 // Create "else" block, fill it in the next iteration
930 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
931 NewIfBlock->setName("else");
932 BasicBlock *PrevIfBlock = IfBlock;
933 IfBlock = NewIfBlock;
934
935 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
936
937 // Add a PHI for the pointer if this isn't the last iteration.
938 if ((Idx + 1) != VectorWidth) {
939 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
940 PtrPhi->addIncoming(NewPtr, CondBlock);
941 PtrPhi->addIncoming(Ptr, PrevIfBlock);
942 Ptr = PtrPhi;
943 }
944 }
945 CI->eraseFromParent();
946
947 ModifiedDT = true;
948}
949
951 DomTreeUpdater *DTU,
952 bool &ModifiedDT) {
953 // If we extend histogram to return a result someday (like the updated vector)
954 // then we'll need to support it here.
955 assert(CI->getType()->isVoidTy() && "Histogram with non-void return.");
956 Value *Ptrs = CI->getArgOperand(0);
957 Value *Inc = CI->getArgOperand(1);
958 Value *Mask = CI->getArgOperand(2);
959
960 auto *AddrType = cast<FixedVectorType>(Ptrs->getType());
961 Type *EltTy = Inc->getType();
962
963 IRBuilder<> Builder(CI->getContext());
964 Instruction *InsertPt = CI;
965 Builder.SetInsertPoint(InsertPt);
966
968
969 // FIXME: Do we need to add an alignment parameter to the intrinsic?
970 unsigned VectorWidth = AddrType->getNumElements();
971
972 // Shorten the way if the mask is a vector of constants.
973 if (isConstantIntVector(Mask)) {
974 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
975 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
976 continue;
977 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
978 LoadInst *Load = Builder.CreateLoad(EltTy, Ptr, "Load" + Twine(Idx));
979 Value *Add = Builder.CreateAdd(Load, Inc);
980 Builder.CreateStore(Add, Ptr);
981 }
982 CI->eraseFromParent();
983 return;
984 }
985
986 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
987 Value *Predicate =
988 Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
989
990 Instruction *ThenTerm =
991 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
992 /*BranchWeights=*/nullptr, DTU);
993
994 BasicBlock *CondBlock = ThenTerm->getParent();
995 CondBlock->setName("cond.histogram.update");
996
997 Builder.SetInsertPoint(CondBlock->getTerminator());
998 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
999 LoadInst *Load = Builder.CreateLoad(EltTy, Ptr, "Load" + Twine(Idx));
1000 Value *Add = Builder.CreateAdd(Load, Inc);
1001 Builder.CreateStore(Add, Ptr);
1002
1003 // Create "else" block, fill it in the next iteration
1004 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
1005 NewIfBlock->setName("else");
1006 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
1007 }
1008
1009 CI->eraseFromParent();
1010 ModifiedDT = true;
1011}
1012
1014 DominatorTree *DT) {
1015 std::optional<DomTreeUpdater> DTU;
1016 if (DT)
1017 DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy);
1018
1019 bool EverMadeChange = false;
1020 bool MadeChange = true;
1021 auto &DL = F.getDataLayout();
1022 bool HasBranchDivergence = TTI.hasBranchDivergence(&F);
1023 while (MadeChange) {
1024 MadeChange = false;
1026 bool ModifiedDTOnIteration = false;
1027 MadeChange |= optimizeBlock(BB, ModifiedDTOnIteration, TTI, DL,
1028 HasBranchDivergence, DTU ? &*DTU : nullptr);
1029
1030 // Restart BB iteration if the dominator tree of the Function was changed
1031 if (ModifiedDTOnIteration)
1032 break;
1033 }
1034
1035 EverMadeChange |= MadeChange;
1036 }
1037 return EverMadeChange;
1038}
1039
1040bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function &F) {
1041 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
1042 DominatorTree *DT = nullptr;
1043 if (auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>())
1044 DT = &DTWP->getDomTree();
1045 return runImpl(F, TTI, DT);
1046}
1047
1050 auto &TTI = AM.getResult<TargetIRAnalysis>(F);
1051 auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
1052 if (!runImpl(F, TTI, DT))
1053 return PreservedAnalyses::all();
1057 return PA;
1058}
1059
1060static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
1061 const TargetTransformInfo &TTI, const DataLayout &DL,
1062 bool HasBranchDivergence, DomTreeUpdater *DTU) {
1063 bool MadeChange = false;
1064
1065 BasicBlock::iterator CurInstIterator = BB.begin();
1066 while (CurInstIterator != BB.end()) {
1067 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
1068 MadeChange |=
1069 optimizeCallInst(CI, ModifiedDT, TTI, DL, HasBranchDivergence, DTU);
1070 if (ModifiedDT)
1071 return true;
1072 }
1073
1074 return MadeChange;
1075}
1076
1077static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
1078 const TargetTransformInfo &TTI,
1079 const DataLayout &DL, bool HasBranchDivergence,
1080 DomTreeUpdater *DTU) {
1081 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
1082 if (II) {
1083 // The scalarization code below does not work for scalable vectors.
1084 if (isa<ScalableVectorType>(II->getType()) ||
1085 any_of(II->args(),
1086 [](Value *V) { return isa<ScalableVectorType>(V->getType()); }))
1087 return false;
1088 switch (II->getIntrinsicID()) {
1089 default:
1090 break;
1091 case Intrinsic::experimental_vector_histogram_add:
1093 CI->getArgOperand(1)->getType()))
1094 return false;
1095 scalarizeMaskedVectorHistogram(DL, CI, DTU, ModifiedDT);
1096 return true;
1097 case Intrinsic::masked_load:
1098 // Scalarize unsupported vector masked load
1100 CI->getType(),
1101 cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue()))
1102 return false;
1103 scalarizeMaskedLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1104 return true;
1105 case Intrinsic::masked_store:
1107 CI->getArgOperand(0)->getType(),
1108 cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue()))
1109 return false;
1110 scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1111 return true;
1112 case Intrinsic::masked_gather: {
1113 MaybeAlign MA =
1114 cast<ConstantInt>(CI->getArgOperand(1))->getMaybeAlignValue();
1115 Type *LoadTy = CI->getType();
1116 Align Alignment = DL.getValueOrABITypeAlignment(MA,
1117 LoadTy->getScalarType());
1118 if (TTI.isLegalMaskedGather(LoadTy, Alignment) &&
1119 !TTI.forceScalarizeMaskedGather(cast<VectorType>(LoadTy), Alignment))
1120 return false;
1121 scalarizeMaskedGather(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1122 return true;
1123 }
1124 case Intrinsic::masked_scatter: {
1125 MaybeAlign MA =
1126 cast<ConstantInt>(CI->getArgOperand(2))->getMaybeAlignValue();
1127 Type *StoreTy = CI->getArgOperand(0)->getType();
1128 Align Alignment = DL.getValueOrABITypeAlignment(MA,
1129 StoreTy->getScalarType());
1130 if (TTI.isLegalMaskedScatter(StoreTy, Alignment) &&
1131 !TTI.forceScalarizeMaskedScatter(cast<VectorType>(StoreTy),
1132 Alignment))
1133 return false;
1134 scalarizeMaskedScatter(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1135 return true;
1136 }
1137 case Intrinsic::masked_expandload:
1139 CI->getType(),
1141 return false;
1142 scalarizeMaskedExpandLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1143 return true;
1144 case Intrinsic::masked_compressstore:
1146 CI->getArgOperand(0)->getType(),
1148 return false;
1149 scalarizeMaskedCompressStore(DL, HasBranchDivergence, CI, DTU,
1150 ModifiedDT);
1151 return true;
1152 }
1153 }
1154
1155 return false;
1156}
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static Error unsupported(const char *Str, const Triple &T)
Definition: MachO.cpp:71
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
static bool runImpl(Function &F, const TargetLowering &TLI)
#define F(x, y, z)
Definition: MD5.cpp:55
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:55
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:57
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static void scalarizeMaskedExpandLoad(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static void scalarizeMaskedVectorHistogram(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, const TargetTransformInfo &TTI, const DataLayout &DL, bool HasBranchDivergence, DomTreeUpdater *DTU)
static void scalarizeMaskedScatter(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth, unsigned Idx)
static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, const TargetTransformInfo &TTI, const DataLayout &DL, bool HasBranchDivergence, DomTreeUpdater *DTU)
static void scalarizeMaskedStore(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static void scalarizeMaskedCompressStore(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static void scalarizeMaskedGather(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static bool runImpl(Function &F, const TargetTransformInfo &TTI, DominatorTree *DT)
static bool isConstantIntVector(Value *Mask)
#define DEBUG_TYPE
Scalarize unsupported masked memory intrinsics
static void scalarizeMaskedLoad(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
This pass exposes codegen information to IR-level passes.
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition: APInt.h:239
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
PassT::Result * getCachedResult(IRUnitT &IR) const
Get the cached result of an analysis pass for a given IR unit.
Definition: PassManager.h:429
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:410
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
AttributeSet getParamAttrs(unsigned ArgNo) const
The attributes for the argument or parameter at the given index are returned.
MaybeAlign getAlignment() const
LLVM Basic Block Representation.
Definition: BasicBlock.h:61
iterator end()
Definition: BasicBlock.h:461
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:448
InstListType::iterator iterator
Instruction iterators...
Definition: BasicBlock.h:177
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:239
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Definition: InstrTypes.h:1746
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1294
AttributeList getAttributes() const
Return the attributes for this call.
Definition: InstrTypes.h:1425
This class represents a function call, abstracting a target machine's calling convention.
This is an important base class in LLVM.
Definition: Constant.h:42
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:63
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:279
Legacy analysis pass which computes a DominatorTree.
Definition: Dominators.h:317
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:162
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:310
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
Value * CreateInsertElement(Type *VecTy, Value *NewElt, Value *Idx, const Twine &Name="")
Definition: IRBuilder.h:2503
Value * CreateExtractElement(Value *Vec, Value *Idx, const Twine &Name="")
Definition: IRBuilder.h:2491
IntegerType * getIntNTy(unsigned N)
Fetch the type representing an N-bit integer.
Definition: IRBuilder.h:536
LoadInst * CreateAlignedLoad(Type *Ty, Value *Ptr, MaybeAlign Align, const char *Name)
Definition: IRBuilder.h:1830
Value * CreateConstInBoundsGEP1_32(Type *Ty, Value *Ptr, unsigned Idx0, const Twine &Name="")
Definition: IRBuilder.h:1912
void SetCurrentDebugLocation(DebugLoc L)
Set location information used by debugging information.
Definition: IRBuilder.h:217
Value * CreateICmpNE(Value *LHS, Value *RHS, const Twine &Name="")
Definition: IRBuilder.h:2277
PHINode * CreatePHI(Type *Ty, unsigned NumReservedValues, const Twine &Name="")
Definition: IRBuilder.h:2429
Value * CreateBitCast(Value *V, Type *DestTy, const Twine &Name="")
Definition: IRBuilder.h:2155
ConstantInt * getIntN(unsigned N, uint64_t C)
Get a constant N-bit value, zero extended or truncated from a 64-bit value.
Definition: IRBuilder.h:494
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
Definition: IRBuilder.h:1813
Value * CreateShuffleVector(Value *V1, Value *V2, Value *Mask, const Twine &Name="")
Definition: IRBuilder.h:2525
Value * CreateAnd(Value *LHS, Value *RHS, const Twine &Name="")
Definition: IRBuilder.h:1498
StoreInst * CreateStore(Value *Val, Value *Ptr, bool isVolatile=false)
Definition: IRBuilder.h:1826
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition: IRBuilder.h:1350
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition: IRBuilder.h:177
StoreInst * CreateAlignedStore(Value *Val, Value *Ptr, MaybeAlign Align, bool isVolatile=false)
Definition: IRBuilder.h:1849
ConstantInt * getInt(const APInt &AI)
Get a constant integer value.
Definition: IRBuilder.h:499
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2697
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
Definition: Instruction.h:475
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:94
BasicBlock * getSuccessor(unsigned Idx) const LLVM_READONLY
Return the specified successor. This instruction must be a terminator.
void copyMetadata(const Instruction &SrcInst, ArrayRef< unsigned > WL=ArrayRef< unsigned >())
Copy metadata from SrcInst to this instruction.
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:48
An instruction for reading from memory.
Definition: Instructions.h:176
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:98
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
Definition: Pass.cpp:81
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1878
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:117
void preserve()
Mark an analysis as preserved.
Definition: Analysis.h:131
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1196
An instruction for storing to memory.
Definition: Instructions.h:292
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:51
Analysis pass providing the TargetTransformInfo.
Wrapper pass for TargetTransformInfo.
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
bool isLegalMaskedScatter(Type *DataType, Align Alignment) const
Return true if the target supports masked scatter.
bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment) const
Return true if the target supports masked expand load.
bool hasBranchDivergence(const Function *F=nullptr) const
Return true if branch divergence exists.
bool isLegalMaskedGather(Type *DataType, Align Alignment) const
Return true if the target supports masked gather.
bool forceScalarizeMaskedGather(VectorType *Type, Align Alignment) const
Return true if the target forces scalarizing of llvm.masked.gather intrinsics.
bool isLegalMaskedCompressStore(Type *DataType, Align Alignment) const
Return true if the target supports masked compress store.
bool isLegalMaskedStore(Type *DataType, Align Alignment) const
Return true if the target supports masked store.
bool isLegalMaskedVectorHistogram(Type *AddrType, Type *DataType) const
bool forceScalarizeMaskedScatter(VectorType *Type, Align Alignment) const
Return true if the target forces scalarizing of llvm.masked.scatter intrinsics.
bool isLegalMaskedLoad(Type *DataType, Align Alignment) const
Return true if the target supports masked load.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:81
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
bool isVoidTy() const
Return true if this is 'void'.
Definition: Type.h:139
Type * getScalarType() const
If this is a vector type, return the element type, otherwise return 'this'.
Definition: Type.h:355
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
void setName(const Twine &Name)
Change the name of the value.
Definition: Value.cpp:377
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:534
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:1075
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
void takeName(Value *V)
Transfer the name from V to this value.
Definition: Value.cpp:383
const ParentTy * getParent() const
Definition: ilist_node.h:32
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:657
FunctionPass * createScalarizeMaskedMemIntrinLegacyPass()
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1746
bool isSplatValue(const Value *V, int Index=-1, unsigned Depth=0)
Return true if each element of the vector value V is poisoned or equal to every other non-poisoned el...
void initializeScalarizeMaskedMemIntrinLegacyPassPass(PassRegistry &)
constexpr int PoisonMaskElem
@ Add
Sum of integers.
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Definition: Alignment.h:212
Instruction * SplitBlockAndInsertIfThen(Value *Cond, BasicBlock::iterator SplitBefore, bool Unreachable, MDNode *BranchWeights=nullptr, DomTreeUpdater *DTU=nullptr, LoopInfo *LI=nullptr, BasicBlock *ThenBlock=nullptr)
Split the containing block at the specified instruction - everything before SplitBefore stays in the ...
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Definition: Alignment.h:117
Align valueOrOne() const
For convenience, returns a valid alignment or 1 if undefined.
Definition: Alignment.h:141
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)