LLVM  10.0.0svn
ScalarizeMaskedMemIntrin.cpp
Go to the documentation of this file.
1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2 // instrinsics
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass replaces masked memory intrinsics - when unsupported by the target
11 // - with a chain of basic blocks, that deal with the elements one-by-one if the
12 // appropriate mask bit is set.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "llvm/ADT/Twine.h"
19 #include "llvm/IR/BasicBlock.h"
20 #include "llvm/IR/Constant.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/DerivedTypes.h"
23 #include "llvm/IR/Function.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/InstrTypes.h"
26 #include "llvm/IR/Instruction.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/IR/Value.h"
32 #include "llvm/Pass.h"
33 #include "llvm/Support/Casting.h"
34 #include <algorithm>
35 #include <cassert>
36 
37 using namespace llvm;
38 
39 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
40 
41 namespace {
42 
43 class ScalarizeMaskedMemIntrin : public FunctionPass {
44  const TargetTransformInfo *TTI = nullptr;
45 
46 public:
47  static char ID; // Pass identification, replacement for typeid
48 
49  explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
51  }
52 
53  bool runOnFunction(Function &F) override;
54 
55  StringRef getPassName() const override {
56  return "Scalarize Masked Memory Intrinsics";
57  }
58 
59  void getAnalysisUsage(AnalysisUsage &AU) const override {
61  }
62 
63 private:
64  bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
65  bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
66 };
67 
68 } // end anonymous namespace
69 
71 
72 INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
73  "Scalarize unsupported masked memory intrinsics", false, false)
74 
76  return new ScalarizeMaskedMemIntrin();
77 }
78 
81  if (!C)
82  return false;
83 
84  unsigned NumElts = Mask->getType()->getVectorNumElements();
85  for (unsigned i = 0; i != NumElts; ++i) {
86  Constant *CElt = C->getAggregateElement(i);
87  if (!CElt || !isa<ConstantInt>(CElt))
88  return false;
89  }
90 
91  return true;
92 }
93 
94 // Translate a masked load intrinsic like
95 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
96 // <16 x i1> %mask, <16 x i32> %passthru)
97 // to a chain of basic blocks, with loading element one-by-one if
98 // the appropriate mask bit is set
99 //
100 // %1 = bitcast i8* %addr to i32*
101 // %2 = extractelement <16 x i1> %mask, i32 0
102 // br i1 %2, label %cond.load, label %else
103 //
104 // cond.load: ; preds = %0
105 // %3 = getelementptr i32* %1, i32 0
106 // %4 = load i32* %3
107 // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
108 // br label %else
109 //
110 // else: ; preds = %0, %cond.load
111 // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
112 // %6 = extractelement <16 x i1> %mask, i32 1
113 // br i1 %6, label %cond.load1, label %else2
114 //
115 // cond.load1: ; preds = %else
116 // %7 = getelementptr i32* %1, i32 1
117 // %8 = load i32* %7
118 // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
119 // br label %else2
120 //
121 // else2: ; preds = %else, %cond.load1
122 // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
123 // %10 = extractelement <16 x i1> %mask, i32 2
124 // br i1 %10, label %cond.load4, label %else5
125 //
126 static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) {
127  Value *Ptr = CI->getArgOperand(0);
128  Value *Alignment = CI->getArgOperand(1);
129  Value *Mask = CI->getArgOperand(2);
130  Value *Src0 = CI->getArgOperand(3);
131 
132  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
133  VectorType *VecType = cast<VectorType>(CI->getType());
134 
135  Type *EltTy = VecType->getElementType();
136 
137  IRBuilder<> Builder(CI->getContext());
138  Instruction *InsertPt = CI;
139  BasicBlock *IfBlock = CI->getParent();
140 
141  Builder.SetInsertPoint(InsertPt);
142  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
143 
144  // Short-cut if the mask is all-true.
145  if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
146  Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
147  CI->replaceAllUsesWith(NewI);
148  CI->eraseFromParent();
149  return;
150  }
151 
152  // Adjust alignment for the scalar instruction.
153  AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
154  // Bitcast %addr from i8* to EltTy*
155  Type *NewPtrType =
156  EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
157  Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
158  unsigned VectorWidth = VecType->getNumElements();
159 
160  // The result vector
161  Value *VResult = Src0;
162 
163  if (isConstantIntVector(Mask)) {
164  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
165  if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
166  continue;
167  Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
168  LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
169  VResult = Builder.CreateInsertElement(VResult, Load, Idx);
170  }
171  CI->replaceAllUsesWith(VResult);
172  CI->eraseFromParent();
173  return;
174  }
175 
176  // If the mask is not v1i1, use scalar bit test operations. This generates
177  // better results on X86 at least.
178  Value *SclrMask;
179  if (VectorWidth != 1) {
180  Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
181  SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
182  }
183 
184  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
185  // Fill the "else" block, created in the previous iteration
186  //
187  // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
188  // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
189  // %cond = icmp ne i16 %mask_1, 0
190  // br i1 %mask_1, label %cond.load, label %else
191  //
192  Value *Predicate;
193  if (VectorWidth != 1) {
194  Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
195  Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
196  Builder.getIntN(VectorWidth, 0));
197  } else {
198  Predicate = Builder.CreateExtractElement(Mask, Idx);
199  }
200 
201  // Create "cond" block
202  //
203  // %EltAddr = getelementptr i32* %1, i32 0
204  // %Elt = load i32* %EltAddr
205  // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
206  //
207  BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
208  "cond.load");
209  Builder.SetInsertPoint(InsertPt);
210 
211  Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
212  LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
213  Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
214 
215  // Create "else" block, fill it in the next iteration
216  BasicBlock *NewIfBlock =
217  CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
218  Builder.SetInsertPoint(InsertPt);
219  Instruction *OldBr = IfBlock->getTerminator();
220  BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
221  OldBr->eraseFromParent();
222  BasicBlock *PrevIfBlock = IfBlock;
223  IfBlock = NewIfBlock;
224 
225  // Create the phi to join the new and previous value.
226  PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
227  Phi->addIncoming(NewVResult, CondBlock);
228  Phi->addIncoming(VResult, PrevIfBlock);
229  VResult = Phi;
230  }
231 
232  CI->replaceAllUsesWith(VResult);
233  CI->eraseFromParent();
234 
235  ModifiedDT = true;
236 }
237 
238 // Translate a masked store intrinsic, like
239 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
240 // <16 x i1> %mask)
241 // to a chain of basic blocks, that stores element one-by-one if
242 // the appropriate mask bit is set
243 //
244 // %1 = bitcast i8* %addr to i32*
245 // %2 = extractelement <16 x i1> %mask, i32 0
246 // br i1 %2, label %cond.store, label %else
247 //
248 // cond.store: ; preds = %0
249 // %3 = extractelement <16 x i32> %val, i32 0
250 // %4 = getelementptr i32* %1, i32 0
251 // store i32 %3, i32* %4
252 // br label %else
253 //
254 // else: ; preds = %0, %cond.store
255 // %5 = extractelement <16 x i1> %mask, i32 1
256 // br i1 %5, label %cond.store1, label %else2
257 //
258 // cond.store1: ; preds = %else
259 // %6 = extractelement <16 x i32> %val, i32 1
260 // %7 = getelementptr i32* %1, i32 1
261 // store i32 %6, i32* %7
262 // br label %else2
263 // . . .
264 static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) {
265  Value *Src = CI->getArgOperand(0);
266  Value *Ptr = CI->getArgOperand(1);
267  Value *Alignment = CI->getArgOperand(2);
268  Value *Mask = CI->getArgOperand(3);
269 
270  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
271  VectorType *VecType = cast<VectorType>(Src->getType());
272 
273  Type *EltTy = VecType->getElementType();
274 
275  IRBuilder<> Builder(CI->getContext());
276  Instruction *InsertPt = CI;
277  BasicBlock *IfBlock = CI->getParent();
278  Builder.SetInsertPoint(InsertPt);
279  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
280 
281  // Short-cut if the mask is all-true.
282  if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
283  Builder.CreateAlignedStore(Src, Ptr, AlignVal);
284  CI->eraseFromParent();
285  return;
286  }
287 
288  // Adjust alignment for the scalar instruction.
289  AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
290  // Bitcast %addr from i8* to EltTy*
291  Type *NewPtrType =
292  EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
293  Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
294  unsigned VectorWidth = VecType->getNumElements();
295 
296  if (isConstantIntVector(Mask)) {
297  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
298  if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
299  continue;
300  Value *OneElt = Builder.CreateExtractElement(Src, Idx);
301  Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
302  Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
303  }
304  CI->eraseFromParent();
305  return;
306  }
307 
308  // If the mask is not v1i1, use scalar bit test operations. This generates
309  // better results on X86 at least.
310  Value *SclrMask;
311  if (VectorWidth != 1) {
312  Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
313  SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
314  }
315 
316  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
317  // Fill the "else" block, created in the previous iteration
318  //
319  // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
320  // %cond = icmp ne i16 %mask_1, 0
321  // br i1 %mask_1, label %cond.store, label %else
322  //
323  Value *Predicate;
324  if (VectorWidth != 1) {
325  Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
326  Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
327  Builder.getIntN(VectorWidth, 0));
328  } else {
329  Predicate = Builder.CreateExtractElement(Mask, Idx);
330  }
331 
332  // Create "cond" block
333  //
334  // %OneElt = extractelement <16 x i32> %Src, i32 Idx
335  // %EltAddr = getelementptr i32* %1, i32 0
336  // %store i32 %OneElt, i32* %EltAddr
337  //
338  BasicBlock *CondBlock =
339  IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
340  Builder.SetInsertPoint(InsertPt);
341 
342  Value *OneElt = Builder.CreateExtractElement(Src, Idx);
343  Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
344  Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
345 
346  // Create "else" block, fill it in the next iteration
347  BasicBlock *NewIfBlock =
348  CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
349  Builder.SetInsertPoint(InsertPt);
350  Instruction *OldBr = IfBlock->getTerminator();
351  BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
352  OldBr->eraseFromParent();
353  IfBlock = NewIfBlock;
354  }
355  CI->eraseFromParent();
356 
357  ModifiedDT = true;
358 }
359 
360 // Translate a masked gather intrinsic like
361 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
362 // <16 x i1> %Mask, <16 x i32> %Src)
363 // to a chain of basic blocks, with loading element one-by-one if
364 // the appropriate mask bit is set
365 //
366 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
367 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
368 // br i1 %Mask0, label %cond.load, label %else
369 //
370 // cond.load:
371 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
372 // %Load0 = load i32, i32* %Ptr0, align 4
373 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
374 // br label %else
375 //
376 // else:
377 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
378 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
379 // br i1 %Mask1, label %cond.load1, label %else2
380 //
381 // cond.load1:
382 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
383 // %Load1 = load i32, i32* %Ptr1, align 4
384 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
385 // br label %else2
386 // . . .
387 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
388 // ret <16 x i32> %Result
389 static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) {
390  Value *Ptrs = CI->getArgOperand(0);
391  Value *Alignment = CI->getArgOperand(1);
392  Value *Mask = CI->getArgOperand(2);
393  Value *Src0 = CI->getArgOperand(3);
394 
395  VectorType *VecType = cast<VectorType>(CI->getType());
396  Type *EltTy = VecType->getElementType();
397 
398  IRBuilder<> Builder(CI->getContext());
399  Instruction *InsertPt = CI;
400  BasicBlock *IfBlock = CI->getParent();
401  Builder.SetInsertPoint(InsertPt);
402  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
403 
404  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
405 
406  // The result vector
407  Value *VResult = Src0;
408  unsigned VectorWidth = VecType->getNumElements();
409 
410  // Shorten the way if the mask is a vector of constants.
411  if (isConstantIntVector(Mask)) {
412  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
413  if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
414  continue;
415  Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
416  LoadInst *Load =
417  Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
418  VResult =
419  Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
420  }
421  CI->replaceAllUsesWith(VResult);
422  CI->eraseFromParent();
423  return;
424  }
425 
426  // If the mask is not v1i1, use scalar bit test operations. This generates
427  // better results on X86 at least.
428  Value *SclrMask;
429  if (VectorWidth != 1) {
430  Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
431  SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
432  }
433 
434  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
435  // Fill the "else" block, created in the previous iteration
436  //
437  // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
438  // %cond = icmp ne i16 %mask_1, 0
439  // br i1 %Mask1, label %cond.load, label %else
440  //
441 
442  Value *Predicate;
443  if (VectorWidth != 1) {
444  Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
445  Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
446  Builder.getIntN(VectorWidth, 0));
447  } else {
448  Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
449  }
450 
451  // Create "cond" block
452  //
453  // %EltAddr = getelementptr i32* %1, i32 0
454  // %Elt = load i32* %EltAddr
455  // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
456  //
457  BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
458  Builder.SetInsertPoint(InsertPt);
459 
460  Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
461  LoadInst *Load =
462  Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
463  Value *NewVResult =
464  Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
465 
466  // Create "else" block, fill it in the next iteration
467  BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
468  Builder.SetInsertPoint(InsertPt);
469  Instruction *OldBr = IfBlock->getTerminator();
470  BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
471  OldBr->eraseFromParent();
472  BasicBlock *PrevIfBlock = IfBlock;
473  IfBlock = NewIfBlock;
474 
475  PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
476  Phi->addIncoming(NewVResult, CondBlock);
477  Phi->addIncoming(VResult, PrevIfBlock);
478  VResult = Phi;
479  }
480 
481  CI->replaceAllUsesWith(VResult);
482  CI->eraseFromParent();
483 
484  ModifiedDT = true;
485 }
486 
487 // Translate a masked scatter intrinsic, like
488 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
489 // <16 x i1> %Mask)
490 // to a chain of basic blocks, that stores element one-by-one if
491 // the appropriate mask bit is set.
492 //
493 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
494 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
495 // br i1 %Mask0, label %cond.store, label %else
496 //
497 // cond.store:
498 // %Elt0 = extractelement <16 x i32> %Src, i32 0
499 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
500 // store i32 %Elt0, i32* %Ptr0, align 4
501 // br label %else
502 //
503 // else:
504 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
505 // br i1 %Mask1, label %cond.store1, label %else2
506 //
507 // cond.store1:
508 // %Elt1 = extractelement <16 x i32> %Src, i32 1
509 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
510 // store i32 %Elt1, i32* %Ptr1, align 4
511 // br label %else2
512 // . . .
513 static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) {
514  Value *Src = CI->getArgOperand(0);
515  Value *Ptrs = CI->getArgOperand(1);
516  Value *Alignment = CI->getArgOperand(2);
517  Value *Mask = CI->getArgOperand(3);
518 
519  assert(isa<VectorType>(Src->getType()) &&
520  "Unexpected data type in masked scatter intrinsic");
521  assert(isa<VectorType>(Ptrs->getType()) &&
522  isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
523  "Vector of pointers is expected in masked scatter intrinsic");
524 
525  IRBuilder<> Builder(CI->getContext());
526  Instruction *InsertPt = CI;
527  BasicBlock *IfBlock = CI->getParent();
528  Builder.SetInsertPoint(InsertPt);
529  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
530 
531  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
532  unsigned VectorWidth = Src->getType()->getVectorNumElements();
533 
534  // Shorten the way if the mask is a vector of constants.
535  if (isConstantIntVector(Mask)) {
536  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
537  if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
538  continue;
539  Value *OneElt =
540  Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
541  Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
542  Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
543  }
544  CI->eraseFromParent();
545  return;
546  }
547 
548  // If the mask is not v1i1, use scalar bit test operations. This generates
549  // better results on X86 at least.
550  Value *SclrMask;
551  if (VectorWidth != 1) {
552  Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
553  SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
554  }
555 
556  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
557  // Fill the "else" block, created in the previous iteration
558  //
559  // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
560  // %cond = icmp ne i16 %mask_1, 0
561  // br i1 %Mask1, label %cond.store, label %else
562  //
563  Value *Predicate;
564  if (VectorWidth != 1) {
565  Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
566  Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
567  Builder.getIntN(VectorWidth, 0));
568  } else {
569  Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
570  }
571 
572  // Create "cond" block
573  //
574  // %Elt1 = extractelement <16 x i32> %Src, i32 1
575  // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
576  // %store i32 %Elt1, i32* %Ptr1
577  //
578  BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
579  Builder.SetInsertPoint(InsertPt);
580 
581  Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
582  Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
583  Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
584 
585  // Create "else" block, fill it in the next iteration
586  BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
587  Builder.SetInsertPoint(InsertPt);
588  Instruction *OldBr = IfBlock->getTerminator();
589  BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
590  OldBr->eraseFromParent();
591  IfBlock = NewIfBlock;
592  }
593  CI->eraseFromParent();
594 
595  ModifiedDT = true;
596 }
597 
598 static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) {
599  Value *Ptr = CI->getArgOperand(0);
600  Value *Mask = CI->getArgOperand(1);
601  Value *PassThru = CI->getArgOperand(2);
602 
603  VectorType *VecType = cast<VectorType>(CI->getType());
604 
605  Type *EltTy = VecType->getElementType();
606 
607  IRBuilder<> Builder(CI->getContext());
608  Instruction *InsertPt = CI;
609  BasicBlock *IfBlock = CI->getParent();
610 
611  Builder.SetInsertPoint(InsertPt);
612  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
613 
614  unsigned VectorWidth = VecType->getNumElements();
615 
616  // The result vector
617  Value *VResult = PassThru;
618 
619  // Shorten the way if the mask is a vector of constants.
620  if (isConstantIntVector(Mask)) {
621  unsigned MemIndex = 0;
622  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
623  if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
624  continue;
625  Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
626  LoadInst *Load =
627  Builder.CreateAlignedLoad(EltTy, NewPtr, 1, "Load" + Twine(Idx));
628  VResult =
629  Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
630  ++MemIndex;
631  }
632  CI->replaceAllUsesWith(VResult);
633  CI->eraseFromParent();
634  return;
635  }
636 
637  // If the mask is not v1i1, use scalar bit test operations. This generates
638  // better results on X86 at least.
639  Value *SclrMask;
640  if (VectorWidth != 1) {
641  Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
642  SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
643  }
644 
645  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
646  // Fill the "else" block, created in the previous iteration
647  //
648  // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
649  // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
650  // br i1 %mask_1, label %cond.load, label %else
651  //
652 
653  Value *Predicate;
654  if (VectorWidth != 1) {
655  Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
656  Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
657  Builder.getIntN(VectorWidth, 0));
658  } else {
659  Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
660  }
661 
662  // Create "cond" block
663  //
664  // %EltAddr = getelementptr i32* %1, i32 0
665  // %Elt = load i32* %EltAddr
666  // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
667  //
668  BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
669  "cond.load");
670  Builder.SetInsertPoint(InsertPt);
671 
672  LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, 1);
673  Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
674 
675  // Move the pointer if there are more blocks to come.
676  Value *NewPtr;
677  if ((Idx + 1) != VectorWidth)
678  NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
679 
680  // Create "else" block, fill it in the next iteration
681  BasicBlock *NewIfBlock =
682  CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
683  Builder.SetInsertPoint(InsertPt);
684  Instruction *OldBr = IfBlock->getTerminator();
685  BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
686  OldBr->eraseFromParent();
687  BasicBlock *PrevIfBlock = IfBlock;
688  IfBlock = NewIfBlock;
689 
690  // Create the phi to join the new and previous value.
691  PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
692  ResultPhi->addIncoming(NewVResult, CondBlock);
693  ResultPhi->addIncoming(VResult, PrevIfBlock);
694  VResult = ResultPhi;
695 
696  // Add a PHI for the pointer if this isn't the last iteration.
697  if ((Idx + 1) != VectorWidth) {
698  PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
699  PtrPhi->addIncoming(NewPtr, CondBlock);
700  PtrPhi->addIncoming(Ptr, PrevIfBlock);
701  Ptr = PtrPhi;
702  }
703  }
704 
705  CI->replaceAllUsesWith(VResult);
706  CI->eraseFromParent();
707 
708  ModifiedDT = true;
709 }
710 
711 static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) {
712  Value *Src = CI->getArgOperand(0);
713  Value *Ptr = CI->getArgOperand(1);
714  Value *Mask = CI->getArgOperand(2);
715 
716  VectorType *VecType = cast<VectorType>(Src->getType());
717 
718  IRBuilder<> Builder(CI->getContext());
719  Instruction *InsertPt = CI;
720  BasicBlock *IfBlock = CI->getParent();
721 
722  Builder.SetInsertPoint(InsertPt);
723  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
724 
725  Type *EltTy = VecType->getVectorElementType();
726 
727  unsigned VectorWidth = VecType->getNumElements();
728 
729  // Shorten the way if the mask is a vector of constants.
730  if (isConstantIntVector(Mask)) {
731  unsigned MemIndex = 0;
732  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
733  if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
734  continue;
735  Value *OneElt =
736  Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
737  Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
738  Builder.CreateAlignedStore(OneElt, NewPtr, 1);
739  ++MemIndex;
740  }
741  CI->eraseFromParent();
742  return;
743  }
744 
745  // If the mask is not v1i1, use scalar bit test operations. This generates
746  // better results on X86 at least.
747  Value *SclrMask;
748  if (VectorWidth != 1) {
749  Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
750  SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
751  }
752 
753  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
754  // Fill the "else" block, created in the previous iteration
755  //
756  // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
757  // br i1 %mask_1, label %cond.store, label %else
758  //
759  Value *Predicate;
760  if (VectorWidth != 1) {
761  Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
762  Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
763  Builder.getIntN(VectorWidth, 0));
764  } else {
765  Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
766  }
767 
768  // Create "cond" block
769  //
770  // %OneElt = extractelement <16 x i32> %Src, i32 Idx
771  // %EltAddr = getelementptr i32* %1, i32 0
772  // %store i32 %OneElt, i32* %EltAddr
773  //
774  BasicBlock *CondBlock =
775  IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
776  Builder.SetInsertPoint(InsertPt);
777 
778  Value *OneElt = Builder.CreateExtractElement(Src, Idx);
779  Builder.CreateAlignedStore(OneElt, Ptr, 1);
780 
781  // Move the pointer if there are more blocks to come.
782  Value *NewPtr;
783  if ((Idx + 1) != VectorWidth)
784  NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
785 
786  // Create "else" block, fill it in the next iteration
787  BasicBlock *NewIfBlock =
788  CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
789  Builder.SetInsertPoint(InsertPt);
790  Instruction *OldBr = IfBlock->getTerminator();
791  BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
792  OldBr->eraseFromParent();
793  BasicBlock *PrevIfBlock = IfBlock;
794  IfBlock = NewIfBlock;
795 
796  // Add a PHI for the pointer if this isn't the last iteration.
797  if ((Idx + 1) != VectorWidth) {
798  PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
799  PtrPhi->addIncoming(NewPtr, CondBlock);
800  PtrPhi->addIncoming(Ptr, PrevIfBlock);
801  Ptr = PtrPhi;
802  }
803  }
804  CI->eraseFromParent();
805 
806  ModifiedDT = true;
807 }
808 
810  bool EverMadeChange = false;
811 
812  TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
813 
814  bool MadeChange = true;
815  while (MadeChange) {
816  MadeChange = false;
817  for (Function::iterator I = F.begin(); I != F.end();) {
818  BasicBlock *BB = &*I++;
819  bool ModifiedDTOnIteration = false;
820  MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
821 
822  // Restart BB iteration if the dominator tree of the Function was changed
823  if (ModifiedDTOnIteration)
824  break;
825  }
826 
827  EverMadeChange |= MadeChange;
828  }
829 
830  return EverMadeChange;
831 }
832 
833 bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
834  bool MadeChange = false;
835 
836  BasicBlock::iterator CurInstIterator = BB.begin();
837  while (CurInstIterator != BB.end()) {
838  if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
839  MadeChange |= optimizeCallInst(CI, ModifiedDT);
840  if (ModifiedDT)
841  return true;
842  }
843 
844  return MadeChange;
845 }
846 
847 bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
848  bool &ModifiedDT) {
850  if (II) {
851  switch (II->getIntrinsicID()) {
852  default:
853  break;
854  case Intrinsic::masked_load:
855  // Scalarize unsupported vector masked load
856  if (TTI->isLegalMaskedLoad(CI->getType()))
857  return false;
858  scalarizeMaskedLoad(CI, ModifiedDT);
859  return true;
860  case Intrinsic::masked_store:
861  if (TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType()))
862  return false;
863  scalarizeMaskedStore(CI, ModifiedDT);
864  return true;
865  case Intrinsic::masked_gather:
866  if (TTI->isLegalMaskedGather(CI->getType()))
867  return false;
868  scalarizeMaskedGather(CI, ModifiedDT);
869  return true;
870  case Intrinsic::masked_scatter:
871  if (TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType()))
872  return false;
873  scalarizeMaskedScatter(CI, ModifiedDT);
874  return true;
875  case Intrinsic::masked_expandload:
876  if (TTI->isLegalMaskedExpandLoad(CI->getType()))
877  return false;
878  scalarizeMaskedExpandLoad(CI, ModifiedDT);
879  return true;
880  case Intrinsic::masked_compressstore:
881  if (TTI->isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
882  return false;
883  scalarizeMaskedCompressStore(CI, ModifiedDT);
884  return true;
885  }
886  }
887 
888  return false;
889 }
Type * getVectorElementType() const
Definition: Type.h:371
uint64_t CallInst * C
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks &#39;this&#39; from the containing basic block and deletes it.
Definition: Instruction.cpp:67
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
This class represents lattice values for constants.
Definition: AllocatorList.h:23
static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT)
iterator end()
Definition: Function.h:682
This class represents a function call, abstracting a target machine&#39;s calling convention.
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:745
F(f)
unsigned getPointerAddressSpace() const
Get the address space of this pointer or pointer vector type.
Definition: DerivedTypes.h:580
An instruction for reading from memory.
Definition: Instructions.h:167
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.cpp:137
void initializeScalarizeMaskedMemIntrinPass(PassRegistry &)
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:268
FunctionPass * createScalarizeMaskedMemIntrinPass()
createScalarizeMaskedMemIntrinPass - Replace masked load, store, gather and scatter intrinsics with s...
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1241
AnalysisUsage & addRequired()
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:80
static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT)
PointerType * getPointerTo(unsigned AddrSpace=0) const
Return a pointer to the current type.
Definition: Type.cpp:654
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:779
uint64_t getNumElements() const
For scalable vectors, this will return the minimum number of elements in the vector.
Definition: DerivedTypes.h:393
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:245
#define DEBUG_TYPE
static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT)
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:429
iterator begin()
Definition: Function.h:680
static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT)
Constant * getAggregateElement(unsigned Elt) const
For aggregates (struct/array/vector) return the constant that corresponds to the specified element if...
Definition: Constants.cpp:344
constexpr uint64_t MinAlign(uint64_t A, uint64_t B)
A and B are either alignments or offsets.
Definition: MathExtras.h:614
static bool runOnFunction(Function &F, bool PostInlining)
Wrapper pass for TargetTransformInfo.
LLVM Basic Block Representation.
Definition: BasicBlock.h:57
The instances of the Type class are immutable: once they are created, they are never changed...
Definition: Type.h:45
This is an important base class in LLVM.
Definition: Constant.h:41
This file contains the declarations for the subclasses of Constant, which represent the different fla...
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition: APInt.h:587
Represent the analysis usage information of a pass.
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:284
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
Definition: IntrinsicInst.h:50
Iterator for intrusive lists based on ilist_node.
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE, "Scalarize unsupported masked memory intrinsics", false, false) FunctionPass *llvm
iterator end()
Definition: BasicBlock.h:270
Predicate
Predicate - These are "(BI << 5) | BO" for various predicates.
Definition: PPCPredicates.h:26
static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT)
static IntegerType * getIntNTy(LLVMContext &C, unsigned N)
Definition: Type.cpp:179
static BranchInst * Create(BasicBlock *IfTrue, Instruction *InsertBefore=nullptr)
unsigned getVectorNumElements() const
Definition: DerivedTypes.h:535
Class to represent vector types.
Definition: DerivedTypes.h:427
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
Definition: Instruction.h:321
#define I(x, y, z)
Definition: MD5.cpp:58
LLVM_NODISCARD std::enable_if<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
Definition: Casting.h:332
static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT)
BasicBlock * splitBasicBlock(iterator I, const Twine &BBName="")
Split the basic block into two basic blocks at the specified instruction.
Definition: BasicBlock.cpp:407
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
LLVM Value Representation.
Definition: Value.h:73
std::underlying_type< E >::type Mask()
Get a bitmask with 1s in all places up to the high-order bit of E&#39;s largest value.
Definition: BitmaskEnum.h:80
Type * getElementType() const
Definition: DerivedTypes.h:394
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:48
This pass exposes codegen information to IR-level passes.
static bool isConstantIntVector(Value *Mask)
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:43
const BasicBlock * getParent() const
Definition: Instruction.h:66