LLVM 22.0.0git
ScalarizeMaskedMemIntrin.cpp
Go to the documentation of this file.
1//===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2// intrinsics
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9//
10// This pass replaces masked memory intrinsics - when unsupported by the target
11// - with a chain of basic blocks, that deal with the elements one-by-one if the
12// appropriate mask bit is set.
13//
14//===----------------------------------------------------------------------===//
15
17#include "llvm/ADT/Twine.h"
21#include "llvm/IR/BasicBlock.h"
22#include "llvm/IR/Constant.h"
23#include "llvm/IR/Constants.h"
25#include "llvm/IR/Dominators.h"
26#include "llvm/IR/Function.h"
27#include "llvm/IR/IRBuilder.h"
28#include "llvm/IR/Instruction.h"
31#include "llvm/IR/Type.h"
32#include "llvm/IR/Value.h"
34#include "llvm/Pass.h"
38#include <cassert>
39#include <optional>
40
41using namespace llvm;
42
43#define DEBUG_TYPE "scalarize-masked-mem-intrin"
44
45namespace {
46
47class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass {
48public:
49 static char ID; // Pass identification, replacement for typeid
50
51 explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID) {
54 }
55
56 bool runOnFunction(Function &F) override;
57
58 StringRef getPassName() const override {
59 return "Scalarize Masked Memory Intrinsics";
60 }
61
62 void getAnalysisUsage(AnalysisUsage &AU) const override {
65 }
66};
67
68} // end anonymous namespace
69
70static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
71 const TargetTransformInfo &TTI, const DataLayout &DL,
72 bool HasBranchDivergence, DomTreeUpdater *DTU);
73static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
75 const DataLayout &DL, bool HasBranchDivergence,
76 DomTreeUpdater *DTU);
77
78char ScalarizeMaskedMemIntrinLegacyPass::ID = 0;
79
80INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
81 "Scalarize unsupported masked memory intrinsics", false,
82 false)
85INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
86 "Scalarize unsupported masked memory intrinsics", false,
87 false)
88
90 return new ScalarizeMaskedMemIntrinLegacyPass();
91}
92
93static bool isConstantIntVector(Value *Mask) {
95 if (!C)
96 return false;
97
98 unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements();
99 for (unsigned i = 0; i != NumElts; ++i) {
100 Constant *CElt = C->getAggregateElement(i);
101 if (!CElt || !isa<ConstantInt>(CElt))
102 return false;
103 }
104
105 return true;
106}
107
108static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth,
109 unsigned Idx) {
110 return DL.isBigEndian() ? VectorWidth - 1 - Idx : Idx;
111}
112
113// Translate a masked load intrinsic like
114// <16 x i32 > @llvm.masked.load( <16 x i32>* %addr,
115// <16 x i1> %mask, <16 x i32> %passthru)
116// to a chain of basic blocks, with loading element one-by-one if
117// the appropriate mask bit is set
118//
119// %1 = bitcast i8* %addr to i32*
120// %2 = extractelement <16 x i1> %mask, i32 0
121// br i1 %2, label %cond.load, label %else
122//
123// cond.load: ; preds = %0
124// %3 = getelementptr i32* %1, i32 0
125// %4 = load i32* %3
126// %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
127// br label %else
128//
129// else: ; preds = %0, %cond.load
130// %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ poison, %0 ]
131// %6 = extractelement <16 x i1> %mask, i32 1
132// br i1 %6, label %cond.load1, label %else2
133//
134// cond.load1: ; preds = %else
135// %7 = getelementptr i32* %1, i32 1
136// %8 = load i32* %7
137// %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
138// br label %else2
139//
140// else2: ; preds = %else, %cond.load1
141// %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
142// %10 = extractelement <16 x i1> %mask, i32 2
143// br i1 %10, label %cond.load4, label %else5
144//
145static void scalarizeMaskedLoad(const DataLayout &DL, bool HasBranchDivergence,
146 CallInst *CI, DomTreeUpdater *DTU,
147 bool &ModifiedDT) {
148 Value *Ptr = CI->getArgOperand(0);
149 Value *Mask = CI->getArgOperand(1);
150 Value *Src0 = CI->getArgOperand(2);
151
152 const Align AlignVal = CI->getParamAlign(0).valueOrOne();
153 VectorType *VecType = cast<FixedVectorType>(CI->getType());
154
155 Type *EltTy = VecType->getElementType();
156
157 IRBuilder<> Builder(CI->getContext());
158 Instruction *InsertPt = CI;
159 BasicBlock *IfBlock = CI->getParent();
160
161 Builder.SetInsertPoint(InsertPt);
162 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
163
164 // Short-cut if the mask is all-true.
165 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
166 LoadInst *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
167 NewI->copyMetadata(*CI);
168 NewI->takeName(CI);
169 CI->replaceAllUsesWith(NewI);
170 CI->eraseFromParent();
171 return;
172 }
173
174 // Adjust alignment for the scalar instruction.
175 const Align AdjustedAlignVal =
176 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
177 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
178
179 // The result vector
180 Value *VResult = Src0;
181
182 if (isConstantIntVector(Mask)) {
183 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
184 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
185 continue;
186 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx);
187 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
188 VResult = Builder.CreateInsertElement(VResult, Load, Idx);
189 }
190 CI->replaceAllUsesWith(VResult);
191 CI->eraseFromParent();
192 return;
193 }
194
195 // Optimize the case where the "masked load" is a predicated load - that is,
196 // where the mask is the splat of a non-constant scalar boolean. In that case,
197 // use that splated value as the guard on a conditional vector load.
198 if (isSplatValue(Mask, /*Index=*/0)) {
199 Value *Predicate = Builder.CreateExtractElement(Mask, uint64_t(0ull),
200 Mask->getName() + ".first");
201 Instruction *ThenTerm =
202 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
203 /*BranchWeights=*/nullptr, DTU);
204
205 BasicBlock *CondBlock = ThenTerm->getParent();
206 CondBlock->setName("cond.load");
207 Builder.SetInsertPoint(CondBlock->getTerminator());
208 LoadInst *Load = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal,
209 CI->getName() + ".cond.load");
210 Load->copyMetadata(*CI);
211
212 BasicBlock *PostLoad = ThenTerm->getSuccessor(0);
213 Builder.SetInsertPoint(PostLoad, PostLoad->begin());
214 PHINode *Phi = Builder.CreatePHI(VecType, /*NumReservedValues=*/2);
215 Phi->addIncoming(Load, CondBlock);
216 Phi->addIncoming(Src0, IfBlock);
217 Phi->takeName(CI);
218
219 CI->replaceAllUsesWith(Phi);
220 CI->eraseFromParent();
221 ModifiedDT = true;
222 return;
223 }
224 // If the mask is not v1i1, use scalar bit test operations. This generates
225 // better results on X86 at least. However, don't do this on GPUs and other
226 // machines with divergence, as there each i1 needs a vector register.
227 Value *SclrMask = nullptr;
228 if (VectorWidth != 1 && !HasBranchDivergence) {
229 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
230 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
231 }
232
233 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
234 // Fill the "else" block, created in the previous iteration
235 //
236 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else,
237 // %else ] %mask_1 = and i16 %scalar_mask, i32 1 << Idx %cond = icmp ne i16
238 // %mask_1, 0 br i1 %mask_1, label %cond.load, label %else
239 //
240 // On GPUs, use
241 // %cond = extrectelement %mask, Idx
242 // instead
244 if (SclrMask != nullptr) {
245 Value *Mask = Builder.getInt(APInt::getOneBitSet(
246 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
247 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
248 Builder.getIntN(VectorWidth, 0));
249 } else {
250 Predicate = Builder.CreateExtractElement(Mask, Idx);
251 }
252
253 // Create "cond" block
254 //
255 // %EltAddr = getelementptr i32* %1, i32 0
256 // %Elt = load i32* %EltAddr
257 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
258 //
259 Instruction *ThenTerm =
260 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
261 /*BranchWeights=*/nullptr, DTU);
262
263 BasicBlock *CondBlock = ThenTerm->getParent();
264 CondBlock->setName("cond.load");
265
266 Builder.SetInsertPoint(CondBlock->getTerminator());
267 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx);
268 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
269 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
270
271 // Create "else" block, fill it in the next iteration
272 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
273 NewIfBlock->setName("else");
274 BasicBlock *PrevIfBlock = IfBlock;
275 IfBlock = NewIfBlock;
276
277 // Create the phi to join the new and previous value.
278 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
279 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
280 Phi->addIncoming(NewVResult, CondBlock);
281 Phi->addIncoming(VResult, PrevIfBlock);
282 VResult = Phi;
283 }
284
285 CI->replaceAllUsesWith(VResult);
286 CI->eraseFromParent();
287
288 ModifiedDT = true;
289}
290
291// Translate a masked store intrinsic, like
292// void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr,
293// <16 x i1> %mask)
294// to a chain of basic blocks, that stores element one-by-one if
295// the appropriate mask bit is set
296//
297// %1 = bitcast i8* %addr to i32*
298// %2 = extractelement <16 x i1> %mask, i32 0
299// br i1 %2, label %cond.store, label %else
300//
301// cond.store: ; preds = %0
302// %3 = extractelement <16 x i32> %val, i32 0
303// %4 = getelementptr i32* %1, i32 0
304// store i32 %3, i32* %4
305// br label %else
306//
307// else: ; preds = %0, %cond.store
308// %5 = extractelement <16 x i1> %mask, i32 1
309// br i1 %5, label %cond.store1, label %else2
310//
311// cond.store1: ; preds = %else
312// %6 = extractelement <16 x i32> %val, i32 1
313// %7 = getelementptr i32* %1, i32 1
314// store i32 %6, i32* %7
315// br label %else2
316// . . .
317static void scalarizeMaskedStore(const DataLayout &DL, bool HasBranchDivergence,
318 CallInst *CI, DomTreeUpdater *DTU,
319 bool &ModifiedDT) {
320 Value *Src = CI->getArgOperand(0);
321 Value *Ptr = CI->getArgOperand(1);
322 Value *Mask = CI->getArgOperand(2);
323
324 const Align AlignVal = CI->getParamAlign(1).valueOrOne();
325 auto *VecType = cast<VectorType>(Src->getType());
326
327 Type *EltTy = VecType->getElementType();
328
329 IRBuilder<> Builder(CI->getContext());
330 Instruction *InsertPt = CI;
331 Builder.SetInsertPoint(InsertPt);
332 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
333
334 // Short-cut if the mask is all-true.
335 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
336 StoreInst *Store = Builder.CreateAlignedStore(Src, Ptr, AlignVal);
337 Store->takeName(CI);
338 Store->copyMetadata(*CI);
339 CI->eraseFromParent();
340 return;
341 }
342
343 // Adjust alignment for the scalar instruction.
344 const Align AdjustedAlignVal =
345 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
346 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
347
348 if (isConstantIntVector(Mask)) {
349 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
350 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
351 continue;
352 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
353 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx);
354 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
355 }
356 CI->eraseFromParent();
357 return;
358 }
359
360 // Optimize the case where the "masked store" is a predicated store - that is,
361 // when the mask is the splat of a non-constant scalar boolean. In that case,
362 // optimize to a conditional store.
363 if (isSplatValue(Mask, /*Index=*/0)) {
364 Value *Predicate = Builder.CreateExtractElement(Mask, uint64_t(0ull),
365 Mask->getName() + ".first");
366 Instruction *ThenTerm =
367 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
368 /*BranchWeights=*/nullptr, DTU);
369 BasicBlock *CondBlock = ThenTerm->getParent();
370 CondBlock->setName("cond.store");
371 Builder.SetInsertPoint(CondBlock->getTerminator());
372
373 StoreInst *Store = Builder.CreateAlignedStore(Src, Ptr, AlignVal);
374 Store->takeName(CI);
375 Store->copyMetadata(*CI);
376
377 CI->eraseFromParent();
378 ModifiedDT = true;
379 return;
380 }
381
382 // If the mask is not v1i1, use scalar bit test operations. This generates
383 // better results on X86 at least. However, don't do this on GPUs or other
384 // machines with branch divergence, as there each i1 takes up a register.
385 Value *SclrMask = nullptr;
386 if (VectorWidth != 1 && !HasBranchDivergence) {
387 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
388 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
389 }
390
391 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
392 // Fill the "else" block, created in the previous iteration
393 //
394 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
395 // %cond = icmp ne i16 %mask_1, 0
396 // br i1 %mask_1, label %cond.store, label %else
397 //
398 // On GPUs, use
399 // %cond = extrectelement %mask, Idx
400 // instead
402 if (SclrMask != nullptr) {
403 Value *Mask = Builder.getInt(APInt::getOneBitSet(
404 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
405 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
406 Builder.getIntN(VectorWidth, 0));
407 } else {
408 Predicate = Builder.CreateExtractElement(Mask, Idx);
409 }
410
411 // Create "cond" block
412 //
413 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
414 // %EltAddr = getelementptr i32* %1, i32 0
415 // %store i32 %OneElt, i32* %EltAddr
416 //
417 Instruction *ThenTerm =
418 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
419 /*BranchWeights=*/nullptr, DTU);
420
421 BasicBlock *CondBlock = ThenTerm->getParent();
422 CondBlock->setName("cond.store");
423
424 Builder.SetInsertPoint(CondBlock->getTerminator());
425 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
426 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx);
427 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
428
429 // Create "else" block, fill it in the next iteration
430 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
431 NewIfBlock->setName("else");
432
433 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
434 }
435 CI->eraseFromParent();
436
437 ModifiedDT = true;
438}
439
440// Translate a masked gather intrinsic like
441// <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
442// <16 x i1> %Mask, <16 x i32> %Src)
443// to a chain of basic blocks, with loading element one-by-one if
444// the appropriate mask bit is set
445//
446// %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
447// %Mask0 = extractelement <16 x i1> %Mask, i32 0
448// br i1 %Mask0, label %cond.load, label %else
449//
450// cond.load:
451// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
452// %Load0 = load i32, i32* %Ptr0, align 4
453// %Res0 = insertelement <16 x i32> poison, i32 %Load0, i32 0
454// br label %else
455//
456// else:
457// %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [poison, %0]
458// %Mask1 = extractelement <16 x i1> %Mask, i32 1
459// br i1 %Mask1, label %cond.load1, label %else2
460//
461// cond.load1:
462// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
463// %Load1 = load i32, i32* %Ptr1, align 4
464// %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
465// br label %else2
466// . . .
467// %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
468// ret <16 x i32> %Result
470 bool HasBranchDivergence, CallInst *CI,
471 DomTreeUpdater *DTU, bool &ModifiedDT) {
472 Value *Ptrs = CI->getArgOperand(0);
473 Value *Mask = CI->getArgOperand(1);
474 Value *Src0 = CI->getArgOperand(2);
475
476 auto *VecType = cast<FixedVectorType>(CI->getType());
477 Type *EltTy = VecType->getElementType();
478
479 IRBuilder<> Builder(CI->getContext());
480 Instruction *InsertPt = CI;
481 BasicBlock *IfBlock = CI->getParent();
482 Builder.SetInsertPoint(InsertPt);
483 Align AlignVal = CI->getParamAlign(0).valueOrOne();
484
485 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
486
487 // The result vector
488 Value *VResult = Src0;
489 unsigned VectorWidth = VecType->getNumElements();
490
491 // Shorten the way if the mask is a vector of constants.
492 if (isConstantIntVector(Mask)) {
493 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
494 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
495 continue;
496 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
497 LoadInst *Load =
498 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
499 VResult =
500 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
501 }
502 CI->replaceAllUsesWith(VResult);
503 CI->eraseFromParent();
504 return;
505 }
506
507 // If the mask is not v1i1, use scalar bit test operations. This generates
508 // better results on X86 at least. However, don't do this on GPUs or other
509 // machines with branch divergence, as there, each i1 takes up a register.
510 Value *SclrMask = nullptr;
511 if (VectorWidth != 1 && !HasBranchDivergence) {
512 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
513 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
514 }
515
516 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
517 // Fill the "else" block, created in the previous iteration
518 //
519 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
520 // %cond = icmp ne i16 %mask_1, 0
521 // br i1 %Mask1, label %cond.load, label %else
522 //
523 // On GPUs, use
524 // %cond = extrectelement %mask, Idx
525 // instead
526
528 if (SclrMask != nullptr) {
529 Value *Mask = Builder.getInt(APInt::getOneBitSet(
530 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
531 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
532 Builder.getIntN(VectorWidth, 0));
533 } else {
534 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
535 }
536
537 // Create "cond" block
538 //
539 // %EltAddr = getelementptr i32* %1, i32 0
540 // %Elt = load i32* %EltAddr
541 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
542 //
543 Instruction *ThenTerm =
544 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
545 /*BranchWeights=*/nullptr, DTU);
546
547 BasicBlock *CondBlock = ThenTerm->getParent();
548 CondBlock->setName("cond.load");
549
550 Builder.SetInsertPoint(CondBlock->getTerminator());
551 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
552 LoadInst *Load =
553 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
554 Value *NewVResult =
555 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
556
557 // Create "else" block, fill it in the next iteration
558 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
559 NewIfBlock->setName("else");
560 BasicBlock *PrevIfBlock = IfBlock;
561 IfBlock = NewIfBlock;
562
563 // Create the phi to join the new and previous value.
564 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
565 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
566 Phi->addIncoming(NewVResult, CondBlock);
567 Phi->addIncoming(VResult, PrevIfBlock);
568 VResult = Phi;
569 }
570
571 CI->replaceAllUsesWith(VResult);
572 CI->eraseFromParent();
573
574 ModifiedDT = true;
575}
576
577// Translate a masked scatter intrinsic, like
578// void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
579// <16 x i1> %Mask)
580// to a chain of basic blocks, that stores element one-by-one if
581// the appropriate mask bit is set.
582//
583// %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
584// %Mask0 = extractelement <16 x i1> %Mask, i32 0
585// br i1 %Mask0, label %cond.store, label %else
586//
587// cond.store:
588// %Elt0 = extractelement <16 x i32> %Src, i32 0
589// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
590// store i32 %Elt0, i32* %Ptr0, align 4
591// br label %else
592//
593// else:
594// %Mask1 = extractelement <16 x i1> %Mask, i32 1
595// br i1 %Mask1, label %cond.store1, label %else2
596//
597// cond.store1:
598// %Elt1 = extractelement <16 x i32> %Src, i32 1
599// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
600// store i32 %Elt1, i32* %Ptr1, align 4
601// br label %else2
602// . . .
604 bool HasBranchDivergence, CallInst *CI,
605 DomTreeUpdater *DTU, bool &ModifiedDT) {
606 Value *Src = CI->getArgOperand(0);
607 Value *Ptrs = CI->getArgOperand(1);
608 Value *Mask = CI->getArgOperand(2);
609
610 auto *SrcFVTy = cast<FixedVectorType>(Src->getType());
611
612 assert(
613 isa<VectorType>(Ptrs->getType()) &&
614 isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) &&
615 "Vector of pointers is expected in masked scatter intrinsic");
616
617 IRBuilder<> Builder(CI->getContext());
618 Instruction *InsertPt = CI;
619 Builder.SetInsertPoint(InsertPt);
620 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
621
622 Align AlignVal = CI->getParamAlign(1).valueOrOne();
623 unsigned VectorWidth = SrcFVTy->getNumElements();
624
625 // Shorten the way if the mask is a vector of constants.
626 if (isConstantIntVector(Mask)) {
627 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
628 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
629 continue;
630 Value *OneElt =
631 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
632 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
633 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
634 }
635 CI->eraseFromParent();
636 return;
637 }
638
639 // If the mask is not v1i1, use scalar bit test operations. This generates
640 // better results on X86 at least.
641 Value *SclrMask = nullptr;
642 if (VectorWidth != 1 && !HasBranchDivergence) {
643 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
644 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
645 }
646
647 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
648 // Fill the "else" block, created in the previous iteration
649 //
650 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
651 // %cond = icmp ne i16 %mask_1, 0
652 // br i1 %Mask1, label %cond.store, label %else
653 //
654 // On GPUs, use
655 // %cond = extrectelement %mask, Idx
656 // instead
658 if (SclrMask != nullptr) {
659 Value *Mask = Builder.getInt(APInt::getOneBitSet(
660 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
661 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
662 Builder.getIntN(VectorWidth, 0));
663 } else {
664 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
665 }
666
667 // Create "cond" block
668 //
669 // %Elt1 = extractelement <16 x i32> %Src, i32 1
670 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
671 // %store i32 %Elt1, i32* %Ptr1
672 //
673 Instruction *ThenTerm =
674 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
675 /*BranchWeights=*/nullptr, DTU);
676
677 BasicBlock *CondBlock = ThenTerm->getParent();
678 CondBlock->setName("cond.store");
679
680 Builder.SetInsertPoint(CondBlock->getTerminator());
681 Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
682 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
683 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
684
685 // Create "else" block, fill it in the next iteration
686 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
687 NewIfBlock->setName("else");
688
689 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
690 }
691 CI->eraseFromParent();
692
693 ModifiedDT = true;
694}
695
697 bool HasBranchDivergence, CallInst *CI,
698 DomTreeUpdater *DTU, bool &ModifiedDT) {
699 Value *Ptr = CI->getArgOperand(0);
700 Value *Mask = CI->getArgOperand(1);
701 Value *PassThru = CI->getArgOperand(2);
702 Align Alignment = CI->getParamAlign(0).valueOrOne();
703
704 auto *VecType = cast<FixedVectorType>(CI->getType());
705
706 Type *EltTy = VecType->getElementType();
707
708 IRBuilder<> Builder(CI->getContext());
709 Instruction *InsertPt = CI;
710 BasicBlock *IfBlock = CI->getParent();
711
712 Builder.SetInsertPoint(InsertPt);
713 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
714
715 unsigned VectorWidth = VecType->getNumElements();
716
717 // The result vector
718 Value *VResult = PassThru;
719
720 // Adjust alignment for the scalar instruction.
721 const Align AdjustedAlignment =
722 commonAlignment(Alignment, EltTy->getPrimitiveSizeInBits() / 8);
723
724 // Shorten the way if the mask is a vector of constants.
725 // Create a build_vector pattern, with loads/poisons as necessary and then
726 // shuffle blend with the pass through value.
727 if (isConstantIntVector(Mask)) {
728 unsigned MemIndex = 0;
729 VResult = PoisonValue::get(VecType);
730 SmallVector<int, 16> ShuffleMask(VectorWidth, PoisonMaskElem);
731 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
732 Value *InsertElt;
733 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) {
734 InsertElt = PoisonValue::get(EltTy);
735 ShuffleMask[Idx] = Idx + VectorWidth;
736 } else {
737 Value *NewPtr =
738 Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
739 InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, AdjustedAlignment,
740 "Load" + Twine(Idx));
741 ShuffleMask[Idx] = Idx;
742 ++MemIndex;
743 }
744 VResult = Builder.CreateInsertElement(VResult, InsertElt, Idx,
745 "Res" + Twine(Idx));
746 }
747 VResult = Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask);
748 CI->replaceAllUsesWith(VResult);
749 CI->eraseFromParent();
750 return;
751 }
752
753 // If the mask is not v1i1, use scalar bit test operations. This generates
754 // better results on X86 at least. However, don't do this on GPUs or other
755 // machines with branch divergence, as there, each i1 takes up a register.
756 Value *SclrMask = nullptr;
757 if (VectorWidth != 1 && !HasBranchDivergence) {
758 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
759 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
760 }
761
762 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
763 // Fill the "else" block, created in the previous iteration
764 //
765 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else,
766 // %else ] %mask_1 = extractelement <16 x i1> %mask, i32 Idx br i1 %mask_1,
767 // label %cond.load, label %else
768 //
769 // On GPUs, use
770 // %cond = extrectelement %mask, Idx
771 // instead
772
774 if (SclrMask != nullptr) {
775 Value *Mask = Builder.getInt(APInt::getOneBitSet(
776 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
777 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
778 Builder.getIntN(VectorWidth, 0));
779 } else {
780 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
781 }
782
783 // Create "cond" block
784 //
785 // %EltAddr = getelementptr i32* %1, i32 0
786 // %Elt = load i32* %EltAddr
787 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
788 //
789 Instruction *ThenTerm =
790 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
791 /*BranchWeights=*/nullptr, DTU);
792
793 BasicBlock *CondBlock = ThenTerm->getParent();
794 CondBlock->setName("cond.load");
795
796 Builder.SetInsertPoint(CondBlock->getTerminator());
797 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, AdjustedAlignment);
798 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
799
800 // Move the pointer if there are more blocks to come.
801 Value *NewPtr;
802 if ((Idx + 1) != VectorWidth)
803 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
804
805 // Create "else" block, fill it in the next iteration
806 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
807 NewIfBlock->setName("else");
808 BasicBlock *PrevIfBlock = IfBlock;
809 IfBlock = NewIfBlock;
810
811 // Create the phi to join the new and previous value.
812 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
813 PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
814 ResultPhi->addIncoming(NewVResult, CondBlock);
815 ResultPhi->addIncoming(VResult, PrevIfBlock);
816 VResult = ResultPhi;
817
818 // Add a PHI for the pointer if this isn't the last iteration.
819 if ((Idx + 1) != VectorWidth) {
820 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
821 PtrPhi->addIncoming(NewPtr, CondBlock);
822 PtrPhi->addIncoming(Ptr, PrevIfBlock);
823 Ptr = PtrPhi;
824 }
825 }
826
827 CI->replaceAllUsesWith(VResult);
828 CI->eraseFromParent();
829
830 ModifiedDT = true;
831}
832
834 bool HasBranchDivergence, CallInst *CI,
835 DomTreeUpdater *DTU,
836 bool &ModifiedDT) {
837 Value *Src = CI->getArgOperand(0);
838 Value *Ptr = CI->getArgOperand(1);
839 Value *Mask = CI->getArgOperand(2);
840 Align Alignment = CI->getParamAlign(1).valueOrOne();
841
842 auto *VecType = cast<FixedVectorType>(Src->getType());
843
844 IRBuilder<> Builder(CI->getContext());
845 Instruction *InsertPt = CI;
846 BasicBlock *IfBlock = CI->getParent();
847
848 Builder.SetInsertPoint(InsertPt);
849 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
850
851 Type *EltTy = VecType->getElementType();
852
853 // Adjust alignment for the scalar instruction.
854 const Align AdjustedAlignment =
855 commonAlignment(Alignment, EltTy->getPrimitiveSizeInBits() / 8);
856
857 unsigned VectorWidth = VecType->getNumElements();
858
859 // Shorten the way if the mask is a vector of constants.
860 if (isConstantIntVector(Mask)) {
861 unsigned MemIndex = 0;
862 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
863 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
864 continue;
865 Value *OneElt =
866 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
867 Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
868 Builder.CreateAlignedStore(OneElt, NewPtr, AdjustedAlignment);
869 ++MemIndex;
870 }
871 CI->eraseFromParent();
872 return;
873 }
874
875 // If the mask is not v1i1, use scalar bit test operations. This generates
876 // better results on X86 at least. However, don't do this on GPUs or other
877 // machines with branch divergence, as there, each i1 takes up a register.
878 Value *SclrMask = nullptr;
879 if (VectorWidth != 1 && !HasBranchDivergence) {
880 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
881 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
882 }
883
884 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
885 // Fill the "else" block, created in the previous iteration
886 //
887 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
888 // br i1 %mask_1, label %cond.store, label %else
889 //
890 // On GPUs, use
891 // %cond = extrectelement %mask, Idx
892 // instead
894 if (SclrMask != nullptr) {
895 Value *Mask = Builder.getInt(APInt::getOneBitSet(
896 VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
897 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
898 Builder.getIntN(VectorWidth, 0));
899 } else {
900 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
901 }
902
903 // Create "cond" block
904 //
905 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
906 // %EltAddr = getelementptr i32* %1, i32 0
907 // %store i32 %OneElt, i32* %EltAddr
908 //
909 Instruction *ThenTerm =
910 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
911 /*BranchWeights=*/nullptr, DTU);
912
913 BasicBlock *CondBlock = ThenTerm->getParent();
914 CondBlock->setName("cond.store");
915
916 Builder.SetInsertPoint(CondBlock->getTerminator());
917 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
918 Builder.CreateAlignedStore(OneElt, Ptr, AdjustedAlignment);
919
920 // Move the pointer if there are more blocks to come.
921 Value *NewPtr;
922 if ((Idx + 1) != VectorWidth)
923 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
924
925 // Create "else" block, fill it in the next iteration
926 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
927 NewIfBlock->setName("else");
928 BasicBlock *PrevIfBlock = IfBlock;
929 IfBlock = NewIfBlock;
930
931 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
932
933 // Add a PHI for the pointer if this isn't the last iteration.
934 if ((Idx + 1) != VectorWidth) {
935 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
936 PtrPhi->addIncoming(NewPtr, CondBlock);
937 PtrPhi->addIncoming(Ptr, PrevIfBlock);
938 Ptr = PtrPhi;
939 }
940 }
941 CI->eraseFromParent();
942
943 ModifiedDT = true;
944}
945
947 DomTreeUpdater *DTU,
948 bool &ModifiedDT) {
949 // If we extend histogram to return a result someday (like the updated vector)
950 // then we'll need to support it here.
951 assert(CI->getType()->isVoidTy() && "Histogram with non-void return.");
952 Value *Ptrs = CI->getArgOperand(0);
953 Value *Inc = CI->getArgOperand(1);
954 Value *Mask = CI->getArgOperand(2);
955
956 auto *AddrType = cast<FixedVectorType>(Ptrs->getType());
957 Type *EltTy = Inc->getType();
958
959 IRBuilder<> Builder(CI->getContext());
960 Instruction *InsertPt = CI;
961 Builder.SetInsertPoint(InsertPt);
962
963 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
964
965 // FIXME: Do we need to add an alignment parameter to the intrinsic?
966 unsigned VectorWidth = AddrType->getNumElements();
967 auto CreateHistogramUpdateValue = [&](IntrinsicInst *CI, Value *Load,
968 Value *Inc) -> Value * {
969 Value *UpdateOp;
970 switch (CI->getIntrinsicID()) {
971 case Intrinsic::experimental_vector_histogram_add:
972 UpdateOp = Builder.CreateAdd(Load, Inc);
973 break;
974 case Intrinsic::experimental_vector_histogram_uadd_sat:
975 UpdateOp =
976 Builder.CreateIntrinsic(Intrinsic::uadd_sat, {EltTy}, {Load, Inc});
977 break;
978 case Intrinsic::experimental_vector_histogram_umin:
979 UpdateOp = Builder.CreateIntrinsic(Intrinsic::umin, {EltTy}, {Load, Inc});
980 break;
981 case Intrinsic::experimental_vector_histogram_umax:
982 UpdateOp = Builder.CreateIntrinsic(Intrinsic::umax, {EltTy}, {Load, Inc});
983 break;
984
985 default:
986 llvm_unreachable("Unexpected histogram intrinsic");
987 }
988 return UpdateOp;
989 };
990
991 // Shorten the way if the mask is a vector of constants.
992 if (isConstantIntVector(Mask)) {
993 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
994 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
995 continue;
996 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
997 LoadInst *Load = Builder.CreateLoad(EltTy, Ptr, "Load" + Twine(Idx));
998 Value *Update =
999 CreateHistogramUpdateValue(cast<IntrinsicInst>(CI), Load, Inc);
1000 Builder.CreateStore(Update, Ptr);
1001 }
1002 CI->eraseFromParent();
1003 return;
1004 }
1005
1006 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
1007 Value *Predicate =
1008 Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
1009
1010 Instruction *ThenTerm =
1011 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
1012 /*BranchWeights=*/nullptr, DTU);
1013
1014 BasicBlock *CondBlock = ThenTerm->getParent();
1015 CondBlock->setName("cond.histogram.update");
1016
1017 Builder.SetInsertPoint(CondBlock->getTerminator());
1018 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
1019 LoadInst *Load = Builder.CreateLoad(EltTy, Ptr, "Load" + Twine(Idx));
1020 Value *UpdateOp =
1021 CreateHistogramUpdateValue(cast<IntrinsicInst>(CI), Load, Inc);
1022 Builder.CreateStore(UpdateOp, Ptr);
1023
1024 // Create "else" block, fill it in the next iteration
1025 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
1026 NewIfBlock->setName("else");
1027 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
1028 }
1029
1030 CI->eraseFromParent();
1031 ModifiedDT = true;
1032}
1033
1035 DominatorTree *DT) {
1036 std::optional<DomTreeUpdater> DTU;
1037 if (DT)
1038 DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy);
1039
1040 bool EverMadeChange = false;
1041 bool MadeChange = true;
1042 auto &DL = F.getDataLayout();
1043 bool HasBranchDivergence = TTI.hasBranchDivergence(&F);
1044 while (MadeChange) {
1045 MadeChange = false;
1047 bool ModifiedDTOnIteration = false;
1048 MadeChange |= optimizeBlock(BB, ModifiedDTOnIteration, TTI, DL,
1049 HasBranchDivergence, DTU ? &*DTU : nullptr);
1050
1051 // Restart BB iteration if the dominator tree of the Function was changed
1052 if (ModifiedDTOnIteration)
1053 break;
1054 }
1055
1056 EverMadeChange |= MadeChange;
1057 }
1058 return EverMadeChange;
1059}
1060
1061bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function &F) {
1062 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
1063 DominatorTree *DT = nullptr;
1064 if (auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>())
1065 DT = &DTWP->getDomTree();
1066 return runImpl(F, TTI, DT);
1067}
1068
1069PreservedAnalyses
1080
1081static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
1082 const TargetTransformInfo &TTI, const DataLayout &DL,
1083 bool HasBranchDivergence, DomTreeUpdater *DTU) {
1084 bool MadeChange = false;
1085
1086 BasicBlock::iterator CurInstIterator = BB.begin();
1087 while (CurInstIterator != BB.end()) {
1088 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
1089 MadeChange |=
1090 optimizeCallInst(CI, ModifiedDT, TTI, DL, HasBranchDivergence, DTU);
1091 if (ModifiedDT)
1092 return true;
1093 }
1094
1095 return MadeChange;
1096}
1097
1098static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
1099 const TargetTransformInfo &TTI,
1100 const DataLayout &DL, bool HasBranchDivergence,
1101 DomTreeUpdater *DTU) {
1103 if (II) {
1104 // The scalarization code below does not work for scalable vectors.
1105 if (isa<ScalableVectorType>(II->getType()) ||
1106 any_of(II->args(),
1107 [](Value *V) { return isa<ScalableVectorType>(V->getType()); }))
1108 return false;
1109 switch (II->getIntrinsicID()) {
1110 default:
1111 break;
1112 case Intrinsic::experimental_vector_histogram_add:
1113 case Intrinsic::experimental_vector_histogram_uadd_sat:
1114 case Intrinsic::experimental_vector_histogram_umin:
1115 case Intrinsic::experimental_vector_histogram_umax:
1116 if (TTI.isLegalMaskedVectorHistogram(CI->getArgOperand(0)->getType(),
1117 CI->getArgOperand(1)->getType()))
1118 return false;
1119 scalarizeMaskedVectorHistogram(DL, CI, DTU, ModifiedDT);
1120 return true;
1121 case Intrinsic::masked_load:
1122 // Scalarize unsupported vector masked load
1123 if (TTI.isLegalMaskedLoad(
1124 CI->getType(), CI->getParamAlign(0).valueOrOne(),
1126 ->getAddressSpace()))
1127 return false;
1128 scalarizeMaskedLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1129 return true;
1130 case Intrinsic::masked_store:
1131 if (TTI.isLegalMaskedStore(
1132 CI->getArgOperand(0)->getType(),
1133 CI->getParamAlign(1).valueOrOne(),
1135 ->getAddressSpace()))
1136 return false;
1137 scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1138 return true;
1139 case Intrinsic::masked_gather: {
1140 Align Alignment = CI->getParamAlign(0).valueOrOne();
1141 Type *LoadTy = CI->getType();
1142 if (TTI.isLegalMaskedGather(LoadTy, Alignment) &&
1143 !TTI.forceScalarizeMaskedGather(cast<VectorType>(LoadTy), Alignment))
1144 return false;
1145 scalarizeMaskedGather(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1146 return true;
1147 }
1148 case Intrinsic::masked_scatter: {
1149 Align Alignment = CI->getParamAlign(1).valueOrOne();
1150 Type *StoreTy = CI->getArgOperand(0)->getType();
1151 if (TTI.isLegalMaskedScatter(StoreTy, Alignment) &&
1152 !TTI.forceScalarizeMaskedScatter(cast<VectorType>(StoreTy),
1153 Alignment))
1154 return false;
1155 scalarizeMaskedScatter(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1156 return true;
1157 }
1158 case Intrinsic::masked_expandload:
1159 if (TTI.isLegalMaskedExpandLoad(
1160 CI->getType(),
1161 CI->getAttributes().getParamAttrs(0).getAlignment().valueOrOne()))
1162 return false;
1163 scalarizeMaskedExpandLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1164 return true;
1165 case Intrinsic::masked_compressstore:
1166 if (TTI.isLegalMaskedCompressStore(
1167 CI->getArgOperand(0)->getType(),
1168 CI->getAttributes().getParamAttrs(1).getAlignment().valueOrOne()))
1169 return false;
1170 scalarizeMaskedCompressStore(DL, HasBranchDivergence, CI, DTU,
1171 ModifiedDT);
1172 return true;
1173 }
1174 }
1175
1176 return false;
1177}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
This file contains the declarations for the subclasses of Constant, which represent the different fla...
static bool runOnFunction(Function &F, bool PostInlining)
static bool runImpl(Function &F, const TargetLowering &TLI, AssumptionCache *AC)
Definition ExpandFp.cpp:993
#define DEBUG_TYPE
#define F(x, y, z)
Definition MD5.cpp:55
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
static void scalarizeMaskedExpandLoad(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static void scalarizeMaskedVectorHistogram(const DataLayout &DL, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, const TargetTransformInfo &TTI, const DataLayout &DL, bool HasBranchDivergence, DomTreeUpdater *DTU)
static void scalarizeMaskedScatter(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth, unsigned Idx)
static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, const TargetTransformInfo &TTI, const DataLayout &DL, bool HasBranchDivergence, DomTreeUpdater *DTU)
static void scalarizeMaskedStore(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static void scalarizeMaskedCompressStore(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static void scalarizeMaskedGather(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
static bool runImpl(Function &F, const TargetTransformInfo &TTI, DominatorTree *DT)
static bool isConstantIntVector(Value *Mask)
static void scalarizeMaskedLoad(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT)
This pass exposes codegen information to IR-level passes.
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition APInt.h:239
PassT::Result * getCachedResult(IRUnitT &IR) const
Get the cached result of an analysis pass for a given IR unit.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
LLVM Basic Block Representation.
Definition BasicBlock.h:62
iterator end()
Definition BasicBlock.h:472
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:459
InstListType::iterator iterator
Instruction iterators...
Definition BasicBlock.h:170
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition BasicBlock.h:233
MaybeAlign getParamAlign(unsigned ArgNo) const
Extract the alignment for a call or parameter (0=unknown).
Value * getArgOperand(unsigned i) const
LLVM_ABI Intrinsic::ID getIntrinsicID() const
Returns the intrinsic ID of the intrinsic called or Intrinsic::not_intrinsic if the called function i...
AttributeList getAttributes() const
Return the attributes for this call.
This class represents a function call, abstracting a target machine's calling convention.
This is an important base class in LLVM.
Definition Constant.h:43
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:63
Analysis pass which computes a DominatorTree.
Definition Dominators.h:284
Legacy analysis pass which computes a DominatorTree.
Definition Dominators.h:322
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:165
FunctionPass class - This class is used to implement most global optimizations.
Definition Pass.h:314
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition IRBuilder.h:2788
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
LLVM_ABI BasicBlock * getSuccessor(unsigned Idx) const LLVM_READONLY
Return the specified successor. This instruction must be a terminator.
LLVM_ABI void copyMetadata(const Instruction &SrcInst, ArrayRef< unsigned > WL=ArrayRef< unsigned >())
Copy metadata from SrcInst to this instruction.
A wrapper class for inspecting calls to intrinsic functions.
An instruction for reading from memory.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static LLVM_ABI PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalyses & preserve()
Mark an analysis as preserved.
Definition Analysis.h:132
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
Analysis pass providing the TargetTransformInfo.
Wrapper pass for TargetTransformInfo.
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition Twine.h:82
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:198
static LLVM_ABI IntegerType * getIntNTy(LLVMContext &C, unsigned N)
Definition Type.cpp:301
bool isVoidTy() const
Return true if this is 'void'.
Definition Type.h:139
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
LLVM_ABI void setName(const Twine &Name)
Change the name of the value.
Definition Value.cpp:390
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition Value.cpp:546
LLVM_ABI LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.cpp:1099
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:322
LLVM_ABI void takeName(Value *V)
Transfer the name from V to this value.
Definition Value.cpp:396
const ParentTy * getParent() const
Definition ilist_node.h:34
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
This is an optimization pass for GlobalISel generic memory operations.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition STLExtras.h:632
LLVM_ABI FunctionPass * createScalarizeMaskedMemIntrinLegacyPass()
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1732
LLVM_ABI bool isSplatValue(const Value *V, int Index=-1, unsigned Depth=0)
Return true if each element of the vector value V is poisoned or equal to every other non-poisoned el...
LLVM_ABI void initializeScalarizeMaskedMemIntrinLegacyPassPass(PassRegistry &)
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
constexpr int PoisonMaskElem
TargetTransformInfo TTI
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Definition Alignment.h:201
LLVM_ABI Instruction * SplitBlockAndInsertIfThen(Value *Cond, BasicBlock::iterator SplitBefore, bool Unreachable, MDNode *BranchWeights=nullptr, DomTreeUpdater *DTU=nullptr, LoopInfo *LI=nullptr, BasicBlock *ThenBlock=nullptr)
Split the containing block at the specified instruction - everything before SplitBefore stays in the ...
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition Alignment.h:39
Align valueOrOne() const
For convenience, returns a valid alignment or 1 if undefined.
Definition Alignment.h:130
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)