LLVM  12.0.0git
ScalarizeMaskedMemIntrin.cpp
Go to the documentation of this file.
1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2 // instrinsics
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass replaces masked memory intrinsics - when unsupported by the target
11 // - with a chain of basic blocks, that deal with the elements one-by-one if the
12 // appropriate mask bit is set.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "llvm/ADT/Twine.h"
19 #include "llvm/IR/BasicBlock.h"
20 #include "llvm/IR/Constant.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/DerivedTypes.h"
23 #include "llvm/IR/Function.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/InstrTypes.h"
26 #include "llvm/IR/Instruction.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/IR/Value.h"
32 #include "llvm/InitializePasses.h"
33 #include "llvm/Pass.h"
34 #include "llvm/Support/Casting.h"
35 #include <algorithm>
36 #include <cassert>
37 
38 using namespace llvm;
39 
40 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
41 
42 namespace {
43 
44 class ScalarizeMaskedMemIntrin : public FunctionPass {
45  const TargetTransformInfo *TTI = nullptr;
46  const DataLayout *DL = nullptr;
47 
48 public:
49  static char ID; // Pass identification, replacement for typeid
50 
51  explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
53  }
54 
55  bool runOnFunction(Function &F) override;
56 
57  StringRef getPassName() const override {
58  return "Scalarize Masked Memory Intrinsics";
59  }
60 
61  void getAnalysisUsage(AnalysisUsage &AU) const override {
63  }
64 
65 private:
66  bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
67  bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
68 };
69 
70 } // end anonymous namespace
71 
73 
74 INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
75  "Scalarize unsupported masked memory intrinsics", false, false)
76 
78  return new ScalarizeMaskedMemIntrin();
79 }
80 
83  if (!C)
84  return false;
85 
86  unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements();
87  for (unsigned i = 0; i != NumElts; ++i) {
88  Constant *CElt = C->getAggregateElement(i);
89  if (!CElt || !isa<ConstantInt>(CElt))
90  return false;
91  }
92 
93  return true;
94 }
95 
96 // Translate a masked load intrinsic like
97 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
98 // <16 x i1> %mask, <16 x i32> %passthru)
99 // to a chain of basic blocks, with loading element one-by-one if
100 // the appropriate mask bit is set
101 //
102 // %1 = bitcast i8* %addr to i32*
103 // %2 = extractelement <16 x i1> %mask, i32 0
104 // br i1 %2, label %cond.load, label %else
105 //
106 // cond.load: ; preds = %0
107 // %3 = getelementptr i32* %1, i32 0
108 // %4 = load i32* %3
109 // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
110 // br label %else
111 //
112 // else: ; preds = %0, %cond.load
113 // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
114 // %6 = extractelement <16 x i1> %mask, i32 1
115 // br i1 %6, label %cond.load1, label %else2
116 //
117 // cond.load1: ; preds = %else
118 // %7 = getelementptr i32* %1, i32 1
119 // %8 = load i32* %7
120 // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
121 // br label %else2
122 //
123 // else2: ; preds = %else, %cond.load1
124 // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
125 // %10 = extractelement <16 x i1> %mask, i32 2
126 // br i1 %10, label %cond.load4, label %else5
127 //
128 static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) {
129  Value *Ptr = CI->getArgOperand(0);
130  Value *Alignment = CI->getArgOperand(1);
131  Value *Mask = CI->getArgOperand(2);
132  Value *Src0 = CI->getArgOperand(3);
133 
134  const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
135  VectorType *VecType = cast<FixedVectorType>(CI->getType());
136 
137  Type *EltTy = VecType->getElementType();
138 
140  Instruction *InsertPt = CI;
141  BasicBlock *IfBlock = CI->getParent();
142 
143  Builder.SetInsertPoint(InsertPt);
144  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
145 
146  // Short-cut if the mask is all-true.
147  if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
148  Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
149  CI->replaceAllUsesWith(NewI);
150  CI->eraseFromParent();
151  return;
152  }
153 
154  // Adjust alignment for the scalar instruction.
155  const Align AdjustedAlignVal =
156  commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
157  // Bitcast %addr from i8* to EltTy*
158  Type *NewPtrType =
159  EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
160  Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
161  unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
162 
163  // The result vector
164  Value *VResult = Src0;
165 
166  if (isConstantIntVector(Mask)) {
167  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
168  if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
169  continue;
170  Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
171  LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
172  VResult = Builder.CreateInsertElement(VResult, Load, Idx);
173  }
174  CI->replaceAllUsesWith(VResult);
175  CI->eraseFromParent();
176  return;
177  }
178 
179  // If the mask is not v1i1, use scalar bit test operations. This generates
180  // better results on X86 at least.
181  Value *SclrMask;
182  if (VectorWidth != 1) {
183  Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
184  SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
185  }
186 
187  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
188  // Fill the "else" block, created in the previous iteration
189  //
190  // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
191  // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
192  // %cond = icmp ne i16 %mask_1, 0
193  // br i1 %mask_1, label %cond.load, label %else
194  //
195  Value *Predicate;
196  if (VectorWidth != 1) {
197  Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
198  Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
199  Builder.getIntN(VectorWidth, 0));
200  } else {
201  Predicate = Builder.CreateExtractElement(Mask, Idx);
202  }
203 
204  // Create "cond" block
205  //
206  // %EltAddr = getelementptr i32* %1, i32 0
207  // %Elt = load i32* %EltAddr
208  // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
209  //
210  BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
211  "cond.load");
212  Builder.SetInsertPoint(InsertPt);
213 
214  Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
215  LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
216  Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
217 
218  // Create "else" block, fill it in the next iteration
219  BasicBlock *NewIfBlock =
220  CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
221  Builder.SetInsertPoint(InsertPt);
222  Instruction *OldBr = IfBlock->getTerminator();
223  BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
224  OldBr->eraseFromParent();
225  BasicBlock *PrevIfBlock = IfBlock;
226  IfBlock = NewIfBlock;
227 
228  // Create the phi to join the new and previous value.
229  PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
230  Phi->addIncoming(NewVResult, CondBlock);
231  Phi->addIncoming(VResult, PrevIfBlock);
232  VResult = Phi;
233  }
234 
235  CI->replaceAllUsesWith(VResult);
236  CI->eraseFromParent();
237 
238  ModifiedDT = true;
239 }
240 
241 // Translate a masked store intrinsic, like
242 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
243 // <16 x i1> %mask)
244 // to a chain of basic blocks, that stores element one-by-one if
245 // the appropriate mask bit is set
246 //
247 // %1 = bitcast i8* %addr to i32*
248 // %2 = extractelement <16 x i1> %mask, i32 0
249 // br i1 %2, label %cond.store, label %else
250 //
251 // cond.store: ; preds = %0
252 // %3 = extractelement <16 x i32> %val, i32 0
253 // %4 = getelementptr i32* %1, i32 0
254 // store i32 %3, i32* %4
255 // br label %else
256 //
257 // else: ; preds = %0, %cond.store
258 // %5 = extractelement <16 x i1> %mask, i32 1
259 // br i1 %5, label %cond.store1, label %else2
260 //
261 // cond.store1: ; preds = %else
262 // %6 = extractelement <16 x i32> %val, i32 1
263 // %7 = getelementptr i32* %1, i32 1
264 // store i32 %6, i32* %7
265 // br label %else2
266 // . . .
267 static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) {
268  Value *Src = CI->getArgOperand(0);
269  Value *Ptr = CI->getArgOperand(1);
270  Value *Alignment = CI->getArgOperand(2);
271  Value *Mask = CI->getArgOperand(3);
272 
273  const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
274  auto *VecType = cast<VectorType>(Src->getType());
275 
276  Type *EltTy = VecType->getElementType();
277 
279  Instruction *InsertPt = CI;
280  BasicBlock *IfBlock = CI->getParent();
281  Builder.SetInsertPoint(InsertPt);
282  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
283 
284  // Short-cut if the mask is all-true.
285  if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
286  Builder.CreateAlignedStore(Src, Ptr, AlignVal);
287  CI->eraseFromParent();
288  return;
289  }
290 
291  // Adjust alignment for the scalar instruction.
292  const Align AdjustedAlignVal =
293  commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
294  // Bitcast %addr from i8* to EltTy*
295  Type *NewPtrType =
296  EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
297  Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
298  unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
299 
300  if (isConstantIntVector(Mask)) {
301  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
302  if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
303  continue;
304  Value *OneElt = Builder.CreateExtractElement(Src, Idx);
305  Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
306  Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
307  }
308  CI->eraseFromParent();
309  return;
310  }
311 
312  // If the mask is not v1i1, use scalar bit test operations. This generates
313  // better results on X86 at least.
314  Value *SclrMask;
315  if (VectorWidth != 1) {
316  Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
317  SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
318  }
319 
320  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
321  // Fill the "else" block, created in the previous iteration
322  //
323  // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
324  // %cond = icmp ne i16 %mask_1, 0
325  // br i1 %mask_1, label %cond.store, label %else
326  //
327  Value *Predicate;
328  if (VectorWidth != 1) {
329  Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
330  Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
331  Builder.getIntN(VectorWidth, 0));
332  } else {
333  Predicate = Builder.CreateExtractElement(Mask, Idx);
334  }
335 
336  // Create "cond" block
337  //
338  // %OneElt = extractelement <16 x i32> %Src, i32 Idx
339  // %EltAddr = getelementptr i32* %1, i32 0
340  // %store i32 %OneElt, i32* %EltAddr
341  //
342  BasicBlock *CondBlock =
343  IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
344  Builder.SetInsertPoint(InsertPt);
345 
346  Value *OneElt = Builder.CreateExtractElement(Src, Idx);
347  Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
348  Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
349 
350  // Create "else" block, fill it in the next iteration
351  BasicBlock *NewIfBlock =
352  CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
353  Builder.SetInsertPoint(InsertPt);
354  Instruction *OldBr = IfBlock->getTerminator();
355  BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
356  OldBr->eraseFromParent();
357  IfBlock = NewIfBlock;
358  }
359  CI->eraseFromParent();
360 
361  ModifiedDT = true;
362 }
363 
364 // Translate a masked gather intrinsic like
365 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
366 // <16 x i1> %Mask, <16 x i32> %Src)
367 // to a chain of basic blocks, with loading element one-by-one if
368 // the appropriate mask bit is set
369 //
370 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
371 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
372 // br i1 %Mask0, label %cond.load, label %else
373 //
374 // cond.load:
375 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
376 // %Load0 = load i32, i32* %Ptr0, align 4
377 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
378 // br label %else
379 //
380 // else:
381 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
382 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
383 // br i1 %Mask1, label %cond.load1, label %else2
384 //
385 // cond.load1:
386 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
387 // %Load1 = load i32, i32* %Ptr1, align 4
388 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
389 // br label %else2
390 // . . .
391 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
392 // ret <16 x i32> %Result
393 static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) {
394  Value *Ptrs = CI->getArgOperand(0);
395  Value *Alignment = CI->getArgOperand(1);
396  Value *Mask = CI->getArgOperand(2);
397  Value *Src0 = CI->getArgOperand(3);
398 
399  auto *VecType = cast<FixedVectorType>(CI->getType());
400  Type *EltTy = VecType->getElementType();
401 
403  Instruction *InsertPt = CI;
404  BasicBlock *IfBlock = CI->getParent();
405  Builder.SetInsertPoint(InsertPt);
406  MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
407 
408  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
409 
410  // The result vector
411  Value *VResult = Src0;
412  unsigned VectorWidth = VecType->getNumElements();
413 
414  // Shorten the way if the mask is a vector of constants.
415  if (isConstantIntVector(Mask)) {
416  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
417  if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
418  continue;
419  Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
420  LoadInst *Load =
421  Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
422  VResult =
423  Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
424  }
425  CI->replaceAllUsesWith(VResult);
426  CI->eraseFromParent();
427  return;
428  }
429 
430  // If the mask is not v1i1, use scalar bit test operations. This generates
431  // better results on X86 at least.
432  Value *SclrMask;
433  if (VectorWidth != 1) {
434  Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
435  SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
436  }
437 
438  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
439  // Fill the "else" block, created in the previous iteration
440  //
441  // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
442  // %cond = icmp ne i16 %mask_1, 0
443  // br i1 %Mask1, label %cond.load, label %else
444  //
445 
446  Value *Predicate;
447  if (VectorWidth != 1) {
448  Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
449  Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
450  Builder.getIntN(VectorWidth, 0));
451  } else {
452  Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
453  }
454 
455  // Create "cond" block
456  //
457  // %EltAddr = getelementptr i32* %1, i32 0
458  // %Elt = load i32* %EltAddr
459  // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
460  //
461  BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
462  Builder.SetInsertPoint(InsertPt);
463 
464  Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
465  LoadInst *Load =
466  Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
467  Value *NewVResult =
468  Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
469 
470  // Create "else" block, fill it in the next iteration
471  BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
472  Builder.SetInsertPoint(InsertPt);
473  Instruction *OldBr = IfBlock->getTerminator();
474  BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
475  OldBr->eraseFromParent();
476  BasicBlock *PrevIfBlock = IfBlock;
477  IfBlock = NewIfBlock;
478 
479  PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
480  Phi->addIncoming(NewVResult, CondBlock);
481  Phi->addIncoming(VResult, PrevIfBlock);
482  VResult = Phi;
483  }
484 
485  CI->replaceAllUsesWith(VResult);
486  CI->eraseFromParent();
487 
488  ModifiedDT = true;
489 }
490 
491 // Translate a masked scatter intrinsic, like
492 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
493 // <16 x i1> %Mask)
494 // to a chain of basic blocks, that stores element one-by-one if
495 // the appropriate mask bit is set.
496 //
497 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
498 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
499 // br i1 %Mask0, label %cond.store, label %else
500 //
501 // cond.store:
502 // %Elt0 = extractelement <16 x i32> %Src, i32 0
503 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
504 // store i32 %Elt0, i32* %Ptr0, align 4
505 // br label %else
506 //
507 // else:
508 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
509 // br i1 %Mask1, label %cond.store1, label %else2
510 //
511 // cond.store1:
512 // %Elt1 = extractelement <16 x i32> %Src, i32 1
513 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
514 // store i32 %Elt1, i32* %Ptr1, align 4
515 // br label %else2
516 // . . .
517 static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) {
518  Value *Src = CI->getArgOperand(0);
519  Value *Ptrs = CI->getArgOperand(1);
520  Value *Alignment = CI->getArgOperand(2);
521  Value *Mask = CI->getArgOperand(3);
522 
523  auto *SrcFVTy = cast<FixedVectorType>(Src->getType());
524 
525  assert(
526  isa<VectorType>(Ptrs->getType()) &&
527  isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) &&
528  "Vector of pointers is expected in masked scatter intrinsic");
529 
531  Instruction *InsertPt = CI;
532  BasicBlock *IfBlock = CI->getParent();
533  Builder.SetInsertPoint(InsertPt);
534  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
535 
536  MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
537  unsigned VectorWidth = SrcFVTy->getNumElements();
538 
539  // Shorten the way if the mask is a vector of constants.
540  if (isConstantIntVector(Mask)) {
541  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
542  if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
543  continue;
544  Value *OneElt =
545  Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
546  Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
547  Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
548  }
549  CI->eraseFromParent();
550  return;
551  }
552 
553  // If the mask is not v1i1, use scalar bit test operations. This generates
554  // better results on X86 at least.
555  Value *SclrMask;
556  if (VectorWidth != 1) {
557  Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
558  SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
559  }
560 
561  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
562  // Fill the "else" block, created in the previous iteration
563  //
564  // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
565  // %cond = icmp ne i16 %mask_1, 0
566  // br i1 %Mask1, label %cond.store, label %else
567  //
568  Value *Predicate;
569  if (VectorWidth != 1) {
570  Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
571  Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
572  Builder.getIntN(VectorWidth, 0));
573  } else {
574  Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
575  }
576 
577  // Create "cond" block
578  //
579  // %Elt1 = extractelement <16 x i32> %Src, i32 1
580  // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
581  // %store i32 %Elt1, i32* %Ptr1
582  //
583  BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
584  Builder.SetInsertPoint(InsertPt);
585 
586  Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
587  Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
588  Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
589 
590  // Create "else" block, fill it in the next iteration
591  BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
592  Builder.SetInsertPoint(InsertPt);
593  Instruction *OldBr = IfBlock->getTerminator();
594  BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
595  OldBr->eraseFromParent();
596  IfBlock = NewIfBlock;
597  }
598  CI->eraseFromParent();
599 
600  ModifiedDT = true;
601 }
602 
603 static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) {
604  Value *Ptr = CI->getArgOperand(0);
605  Value *Mask = CI->getArgOperand(1);
606  Value *PassThru = CI->getArgOperand(2);
607 
608  auto *VecType = cast<FixedVectorType>(CI->getType());
609 
610  Type *EltTy = VecType->getElementType();
611 
613  Instruction *InsertPt = CI;
614  BasicBlock *IfBlock = CI->getParent();
615 
616  Builder.SetInsertPoint(InsertPt);
617  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
618 
619  unsigned VectorWidth = VecType->getNumElements();
620 
621  // The result vector
622  Value *VResult = PassThru;
623 
624  // Shorten the way if the mask is a vector of constants.
625  if (isConstantIntVector(Mask)) {
626  unsigned MemIndex = 0;
627  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
628  if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
629  continue;
630  Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
631  LoadInst *Load = Builder.CreateAlignedLoad(EltTy, NewPtr, Align(1),
632  "Load" + Twine(Idx));
633  VResult =
634  Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
635  ++MemIndex;
636  }
637  CI->replaceAllUsesWith(VResult);
638  CI->eraseFromParent();
639  return;
640  }
641 
642  // If the mask is not v1i1, use scalar bit test operations. This generates
643  // better results on X86 at least.
644  Value *SclrMask;
645  if (VectorWidth != 1) {
646  Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
647  SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
648  }
649 
650  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
651  // Fill the "else" block, created in the previous iteration
652  //
653  // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
654  // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
655  // br i1 %mask_1, label %cond.load, label %else
656  //
657 
658  Value *Predicate;
659  if (VectorWidth != 1) {
660  Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
661  Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
662  Builder.getIntN(VectorWidth, 0));
663  } else {
664  Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
665  }
666 
667  // Create "cond" block
668  //
669  // %EltAddr = getelementptr i32* %1, i32 0
670  // %Elt = load i32* %EltAddr
671  // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
672  //
673  BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
674  "cond.load");
675  Builder.SetInsertPoint(InsertPt);
676 
677  LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, Align(1));
678  Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
679 
680  // Move the pointer if there are more blocks to come.
681  Value *NewPtr;
682  if ((Idx + 1) != VectorWidth)
683  NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
684 
685  // Create "else" block, fill it in the next iteration
686  BasicBlock *NewIfBlock =
687  CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
688  Builder.SetInsertPoint(InsertPt);
689  Instruction *OldBr = IfBlock->getTerminator();
690  BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
691  OldBr->eraseFromParent();
692  BasicBlock *PrevIfBlock = IfBlock;
693  IfBlock = NewIfBlock;
694 
695  // Create the phi to join the new and previous value.
696  PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
697  ResultPhi->addIncoming(NewVResult, CondBlock);
698  ResultPhi->addIncoming(VResult, PrevIfBlock);
699  VResult = ResultPhi;
700 
701  // Add a PHI for the pointer if this isn't the last iteration.
702  if ((Idx + 1) != VectorWidth) {
703  PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
704  PtrPhi->addIncoming(NewPtr, CondBlock);
705  PtrPhi->addIncoming(Ptr, PrevIfBlock);
706  Ptr = PtrPhi;
707  }
708  }
709 
710  CI->replaceAllUsesWith(VResult);
711  CI->eraseFromParent();
712 
713  ModifiedDT = true;
714 }
715 
716 static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) {
717  Value *Src = CI->getArgOperand(0);
718  Value *Ptr = CI->getArgOperand(1);
719  Value *Mask = CI->getArgOperand(2);
720 
721  auto *VecType = cast<FixedVectorType>(Src->getType());
722 
724  Instruction *InsertPt = CI;
725  BasicBlock *IfBlock = CI->getParent();
726 
727  Builder.SetInsertPoint(InsertPt);
728  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
729 
730  Type *EltTy = VecType->getElementType();
731 
732  unsigned VectorWidth = VecType->getNumElements();
733 
734  // Shorten the way if the mask is a vector of constants.
735  if (isConstantIntVector(Mask)) {
736  unsigned MemIndex = 0;
737  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
738  if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
739  continue;
740  Value *OneElt =
741  Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
742  Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
743  Builder.CreateAlignedStore(OneElt, NewPtr, Align(1));
744  ++MemIndex;
745  }
746  CI->eraseFromParent();
747  return;
748  }
749 
750  // If the mask is not v1i1, use scalar bit test operations. This generates
751  // better results on X86 at least.
752  Value *SclrMask;
753  if (VectorWidth != 1) {
754  Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
755  SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
756  }
757 
758  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
759  // Fill the "else" block, created in the previous iteration
760  //
761  // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
762  // br i1 %mask_1, label %cond.store, label %else
763  //
764  Value *Predicate;
765  if (VectorWidth != 1) {
766  Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
767  Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
768  Builder.getIntN(VectorWidth, 0));
769  } else {
770  Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
771  }
772 
773  // Create "cond" block
774  //
775  // %OneElt = extractelement <16 x i32> %Src, i32 Idx
776  // %EltAddr = getelementptr i32* %1, i32 0
777  // %store i32 %OneElt, i32* %EltAddr
778  //
779  BasicBlock *CondBlock =
780  IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
781  Builder.SetInsertPoint(InsertPt);
782 
783  Value *OneElt = Builder.CreateExtractElement(Src, Idx);
784  Builder.CreateAlignedStore(OneElt, Ptr, Align(1));
785 
786  // Move the pointer if there are more blocks to come.
787  Value *NewPtr;
788  if ((Idx + 1) != VectorWidth)
789  NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
790 
791  // Create "else" block, fill it in the next iteration
792  BasicBlock *NewIfBlock =
793  CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
794  Builder.SetInsertPoint(InsertPt);
795  Instruction *OldBr = IfBlock->getTerminator();
796  BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
797  OldBr->eraseFromParent();
798  BasicBlock *PrevIfBlock = IfBlock;
799  IfBlock = NewIfBlock;
800 
801  // Add a PHI for the pointer if this isn't the last iteration.
802  if ((Idx + 1) != VectorWidth) {
803  PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
804  PtrPhi->addIncoming(NewPtr, CondBlock);
805  PtrPhi->addIncoming(Ptr, PrevIfBlock);
806  Ptr = PtrPhi;
807  }
808  }
809  CI->eraseFromParent();
810 
811  ModifiedDT = true;
812 }
813 
815  bool EverMadeChange = false;
816 
817  TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
818  DL = &F.getParent()->getDataLayout();
819 
820  bool MadeChange = true;
821  while (MadeChange) {
822  MadeChange = false;
823  for (Function::iterator I = F.begin(); I != F.end();) {
824  BasicBlock *BB = &*I++;
825  bool ModifiedDTOnIteration = false;
826  MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
827 
828  // Restart BB iteration if the dominator tree of the Function was changed
829  if (ModifiedDTOnIteration)
830  break;
831  }
832 
833  EverMadeChange |= MadeChange;
834  }
835 
836  return EverMadeChange;
837 }
838 
839 bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
840  bool MadeChange = false;
841 
842  BasicBlock::iterator CurInstIterator = BB.begin();
843  while (CurInstIterator != BB.end()) {
844  if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
845  MadeChange |= optimizeCallInst(CI, ModifiedDT);
846  if (ModifiedDT)
847  return true;
848  }
849 
850  return MadeChange;
851 }
852 
853 bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
854  bool &ModifiedDT) {
856  if (II) {
857  switch (II->getIntrinsicID()) {
858  default:
859  break;
860  case Intrinsic::masked_load:
861  // Scalarize unsupported vector masked load
862  if (TTI->isLegalMaskedLoad(
863  CI->getType(),
864  cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue()))
865  return false;
866  scalarizeMaskedLoad(CI, ModifiedDT);
867  return true;
868  case Intrinsic::masked_store:
869  if (TTI->isLegalMaskedStore(
870  CI->getArgOperand(0)->getType(),
871  cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue()))
872  return false;
873  scalarizeMaskedStore(CI, ModifiedDT);
874  return true;
875  case Intrinsic::masked_gather: {
876  unsigned AlignmentInt =
877  cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue();
878  Type *LoadTy = CI->getType();
879  Align Alignment =
880  DL->getValueOrABITypeAlignment(MaybeAlign(AlignmentInt), LoadTy);
881  if (TTI->isLegalMaskedGather(LoadTy, Alignment))
882  return false;
883  scalarizeMaskedGather(CI, ModifiedDT);
884  return true;
885  }
886  case Intrinsic::masked_scatter: {
887  unsigned AlignmentInt =
888  cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
889  Type *StoreTy = CI->getArgOperand(0)->getType();
890  Align Alignment =
891  DL->getValueOrABITypeAlignment(MaybeAlign(AlignmentInt), StoreTy);
892  if (TTI->isLegalMaskedScatter(StoreTy, Alignment))
893  return false;
894  scalarizeMaskedScatter(CI, ModifiedDT);
895  return true;
896  }
897  case Intrinsic::masked_expandload:
898  if (TTI->isLegalMaskedExpandLoad(CI->getType()))
899  return false;
900  scalarizeMaskedExpandLoad(CI, ModifiedDT);
901  return true;
902  case Intrinsic::masked_compressstore:
904  return false;
905  scalarizeMaskedCompressStore(CI, ModifiedDT);
906  return true;
907  }
908  }
909 
910  return false;
911 }
uint64_t CallInst * C
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks &#39;this&#39; from the containing basic block and deletes it.
Definition: Instruction.cpp:80
A parsed version of the target data layout string in and methods for querying it. ...
Definition: DataLayout.h:111
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
LLVM_NODISCARD std::enable_if_t< !is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type > dyn_cast(const Y &Val)
Definition: Casting.h:334
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
This class represents lattice values for constants.
Definition: AllocatorList.h:23
static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT)
iterator end()
Definition: Function.h:707
This class represents a function call, abstracting a target machine&#39;s calling convention.
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:826
F(f)
unsigned getPointerAddressSpace() const
Get the address space of this pointer or pointer vector type.
Definition: DerivedTypes.h:718
An instruction for reading from memory.
Definition: Instructions.h:173
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.cpp:150
void initializeScalarizeMaskedMemIntrinPass(PassRegistry &)
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:289
FunctionPass * createScalarizeMaskedMemIntrinPass()
createScalarizeMaskedMemIntrinPass - Replace masked load, store, gather and scatter intrinsics with s...
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1259
AnalysisUsage & addRequired()
const DataLayout & getDataLayout() const
Get the data layout for the module&#39;s target platform.
Definition: Module.cpp:397
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:80
static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT)
PointerType * getPointerTo(unsigned AddrSpace=0) const
Return a pointer to the current type.
Definition: Type.cpp:681
std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E&#39;s largest value.
Definition: BitmaskEnum.h:80
Align commonAlignment(Align A, Align B)
Returns the alignment that satisfies both alignments.
Definition: Alignment.h:221
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:244
#define DEBUG_TYPE
static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT)
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:486
iterator begin()
Definition: Function.h:705
static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT)
Constant * getAggregateElement(unsigned Elt) const
For aggregates (struct/array/vector) return the constant that corresponds to the specified element if...
Definition: Constants.cpp:402
static bool runOnFunction(Function &F, bool PostInlining)
Wrapper pass for TargetTransformInfo.
LLVM Basic Block Representation.
Definition: BasicBlock.h:58
The instances of the Type class are immutable: once they are created, they are never changed...
Definition: Type.h:46
This is an important base class in LLVM.
Definition: Constant.h:41
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Type * getElementType() const
Definition: DerivedTypes.h:442
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition: APInt.h:592
Represent the analysis usage information of a pass.
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:284
assume Assume Builder
bool isLegalMaskedStore(Type *DataType, Align Alignment) const
Return true if the target supports masked store.
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
bool isLegalMaskedCompressStore(Type *DataType) const
Return true if the target supports masked compress store.
bool isLegalMaskedScatter(Type *DataType, Align Alignment) const
Return true if the target supports masked scatter.
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
Definition: IntrinsicInst.h:51
Iterator for intrusive lists based on ilist_node.
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE, "Scalarize unsupported masked memory intrinsics", false, false) FunctionPass *llvm
iterator end()
Definition: BasicBlock.h:291
This struct is a compact representation of a valid (power of two) or undefined (0) alignment...
Definition: Alignment.h:119
Predicate
Predicate - These are "(BI << 5) | BO" for various predicates.
Definition: PPCPredicates.h:26
static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT)
static BranchInst * Create(BasicBlock *IfTrue, Instruction *InsertBefore=nullptr)
Base class of all SIMD vector types.
Definition: DerivedTypes.h:390
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
Definition: Instruction.h:363
#define I(x, y, z)
Definition: MD5.cpp:59
static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT)
BasicBlock * splitBasicBlock(iterator I, const Twine &BBName="")
Split the basic block into two basic blocks at the specified instruction.
Definition: BasicBlock.cpp:392
static Value * getNumElements(BasicBlock *Preheader, Value *BTC)
bool isLegalMaskedExpandLoad(Type *DataType) const
Return true if the target supports masked expand load.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:572
LLVM Value Representation.
Definition: Value.h:74
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:57
This pass exposes codegen information to IR-level passes.
bool isLegalMaskedLoad(Type *DataType, Align Alignment) const
Return true if the target supports masked load.
static bool isConstantIntVector(Value *Mask)
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:44
const BasicBlock * getParent() const
Definition: Instruction.h:94
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
bool isLegalMaskedGather(Type *DataType, Align Alignment) const
Return true if the target supports masked gather.