LLVM  9.0.0svn
ScalarizeMaskedMemIntrin.cpp
Go to the documentation of this file.
1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2 // instrinsics
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass replaces masked memory intrinsics - when unsupported by the target
11 // - with a chain of basic blocks, that deal with the elements one-by-one if the
12 // appropriate mask bit is set.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "llvm/ADT/Twine.h"
19 #include "llvm/IR/BasicBlock.h"
20 #include "llvm/IR/Constant.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/DerivedTypes.h"
23 #include "llvm/IR/Function.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/InstrTypes.h"
26 #include "llvm/IR/Instruction.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/IR/Value.h"
32 #include "llvm/Pass.h"
33 #include "llvm/Support/Casting.h"
34 #include <algorithm>
35 #include <cassert>
36 
37 using namespace llvm;
38 
39 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
40 
41 namespace {
42 
43 class ScalarizeMaskedMemIntrin : public FunctionPass {
44  const TargetTransformInfo *TTI = nullptr;
45 
46 public:
47  static char ID; // Pass identification, replacement for typeid
48 
49  explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
51  }
52 
53  bool runOnFunction(Function &F) override;
54 
55  StringRef getPassName() const override {
56  return "Scalarize Masked Memory Intrinsics";
57  }
58 
59  void getAnalysisUsage(AnalysisUsage &AU) const override {
61  }
62 
63 private:
64  bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
65  bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
66 };
67 
68 } // end anonymous namespace
69 
71 
72 INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
73  "Scalarize unsupported masked memory intrinsics", false, false)
74 
76  return new ScalarizeMaskedMemIntrin();
77 }
78 
81  if (!C)
82  return false;
83 
84  unsigned NumElts = Mask->getType()->getVectorNumElements();
85  for (unsigned i = 0; i != NumElts; ++i) {
86  Constant *CElt = C->getAggregateElement(i);
87  if (!CElt || !isa<ConstantInt>(CElt))
88  return false;
89  }
90 
91  return true;
92 }
93 
94 // Translate a masked load intrinsic like
95 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
96 // <16 x i1> %mask, <16 x i32> %passthru)
97 // to a chain of basic blocks, with loading element one-by-one if
98 // the appropriate mask bit is set
99 //
100 // %1 = bitcast i8* %addr to i32*
101 // %2 = extractelement <16 x i1> %mask, i32 0
102 // br i1 %2, label %cond.load, label %else
103 //
104 // cond.load: ; preds = %0
105 // %3 = getelementptr i32* %1, i32 0
106 // %4 = load i32* %3
107 // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
108 // br label %else
109 //
110 // else: ; preds = %0, %cond.load
111 // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
112 // %6 = extractelement <16 x i1> %mask, i32 1
113 // br i1 %6, label %cond.load1, label %else2
114 //
115 // cond.load1: ; preds = %else
116 // %7 = getelementptr i32* %1, i32 1
117 // %8 = load i32* %7
118 // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
119 // br label %else2
120 //
121 // else2: ; preds = %else, %cond.load1
122 // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
123 // %10 = extractelement <16 x i1> %mask, i32 2
124 // br i1 %10, label %cond.load4, label %else5
125 //
126 static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) {
127  Value *Ptr = CI->getArgOperand(0);
128  Value *Alignment = CI->getArgOperand(1);
129  Value *Mask = CI->getArgOperand(2);
130  Value *Src0 = CI->getArgOperand(3);
131 
132  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
133  VectorType *VecType = cast<VectorType>(CI->getType());
134 
135  Type *EltTy = VecType->getElementType();
136 
137  IRBuilder<> Builder(CI->getContext());
138  Instruction *InsertPt = CI;
139  BasicBlock *IfBlock = CI->getParent();
140 
141  Builder.SetInsertPoint(InsertPt);
142  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
143 
144  // Short-cut if the mask is all-true.
145  if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
146  Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
147  CI->replaceAllUsesWith(NewI);
148  CI->eraseFromParent();
149  return;
150  }
151 
152  // Adjust alignment for the scalar instruction.
153  AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
154  // Bitcast %addr from i8* to EltTy*
155  Type *NewPtrType =
156  EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
157  Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
158  unsigned VectorWidth = VecType->getNumElements();
159 
160  // The result vector
161  Value *VResult = Src0;
162 
163  if (isConstantIntVector(Mask)) {
164  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
165  if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
166  continue;
167  Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
168  LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
169  VResult = Builder.CreateInsertElement(VResult, Load, Idx);
170  }
171  CI->replaceAllUsesWith(VResult);
172  CI->eraseFromParent();
173  return;
174  }
175 
176  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
177  // Fill the "else" block, created in the previous iteration
178  //
179  // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
180  // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
181  // br i1 %mask_1, label %cond.load, label %else
182  //
183 
184  Value *Predicate = Builder.CreateExtractElement(Mask, Idx);
185 
186  // Create "cond" block
187  //
188  // %EltAddr = getelementptr i32* %1, i32 0
189  // %Elt = load i32* %EltAddr
190  // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
191  //
192  BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
193  "cond.load");
194  Builder.SetInsertPoint(InsertPt);
195 
196  Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
197  LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
198  Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
199 
200  // Create "else" block, fill it in the next iteration
201  BasicBlock *NewIfBlock =
202  CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
203  Builder.SetInsertPoint(InsertPt);
204  Instruction *OldBr = IfBlock->getTerminator();
205  BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
206  OldBr->eraseFromParent();
207  BasicBlock *PrevIfBlock = IfBlock;
208  IfBlock = NewIfBlock;
209 
210  // Create the phi to join the new and previous value.
211  PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
212  Phi->addIncoming(NewVResult, CondBlock);
213  Phi->addIncoming(VResult, PrevIfBlock);
214  VResult = Phi;
215  }
216 
217  CI->replaceAllUsesWith(VResult);
218  CI->eraseFromParent();
219 
220  ModifiedDT = true;
221 }
222 
223 // Translate a masked store intrinsic, like
224 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
225 // <16 x i1> %mask)
226 // to a chain of basic blocks, that stores element one-by-one if
227 // the appropriate mask bit is set
228 //
229 // %1 = bitcast i8* %addr to i32*
230 // %2 = extractelement <16 x i1> %mask, i32 0
231 // br i1 %2, label %cond.store, label %else
232 //
233 // cond.store: ; preds = %0
234 // %3 = extractelement <16 x i32> %val, i32 0
235 // %4 = getelementptr i32* %1, i32 0
236 // store i32 %3, i32* %4
237 // br label %else
238 //
239 // else: ; preds = %0, %cond.store
240 // %5 = extractelement <16 x i1> %mask, i32 1
241 // br i1 %5, label %cond.store1, label %else2
242 //
243 // cond.store1: ; preds = %else
244 // %6 = extractelement <16 x i32> %val, i32 1
245 // %7 = getelementptr i32* %1, i32 1
246 // store i32 %6, i32* %7
247 // br label %else2
248 // . . .
249 static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) {
250  Value *Src = CI->getArgOperand(0);
251  Value *Ptr = CI->getArgOperand(1);
252  Value *Alignment = CI->getArgOperand(2);
253  Value *Mask = CI->getArgOperand(3);
254 
255  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
256  VectorType *VecType = cast<VectorType>(Src->getType());
257 
258  Type *EltTy = VecType->getElementType();
259 
260  IRBuilder<> Builder(CI->getContext());
261  Instruction *InsertPt = CI;
262  BasicBlock *IfBlock = CI->getParent();
263  Builder.SetInsertPoint(InsertPt);
264  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
265 
266  // Short-cut if the mask is all-true.
267  if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
268  Builder.CreateAlignedStore(Src, Ptr, AlignVal);
269  CI->eraseFromParent();
270  return;
271  }
272 
273  // Adjust alignment for the scalar instruction.
274  AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
275  // Bitcast %addr from i8* to EltTy*
276  Type *NewPtrType =
277  EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
278  Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
279  unsigned VectorWidth = VecType->getNumElements();
280 
281  if (isConstantIntVector(Mask)) {
282  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
283  if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
284  continue;
285  Value *OneElt = Builder.CreateExtractElement(Src, Idx);
286  Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
287  Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
288  }
289  CI->eraseFromParent();
290  return;
291  }
292 
293  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
294  // Fill the "else" block, created in the previous iteration
295  //
296  // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
297  // br i1 %mask_1, label %cond.store, label %else
298  //
299  Value *Predicate = Builder.CreateExtractElement(Mask, Idx);
300 
301  // Create "cond" block
302  //
303  // %OneElt = extractelement <16 x i32> %Src, i32 Idx
304  // %EltAddr = getelementptr i32* %1, i32 0
305  // %store i32 %OneElt, i32* %EltAddr
306  //
307  BasicBlock *CondBlock =
308  IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
309  Builder.SetInsertPoint(InsertPt);
310 
311  Value *OneElt = Builder.CreateExtractElement(Src, Idx);
312  Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
313  Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
314 
315  // Create "else" block, fill it in the next iteration
316  BasicBlock *NewIfBlock =
317  CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
318  Builder.SetInsertPoint(InsertPt);
319  Instruction *OldBr = IfBlock->getTerminator();
320  BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
321  OldBr->eraseFromParent();
322  IfBlock = NewIfBlock;
323  }
324  CI->eraseFromParent();
325 
326  ModifiedDT = true;
327 }
328 
329 // Translate a masked gather intrinsic like
330 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
331 // <16 x i1> %Mask, <16 x i32> %Src)
332 // to a chain of basic blocks, with loading element one-by-one if
333 // the appropriate mask bit is set
334 //
335 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
336 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
337 // br i1 %Mask0, label %cond.load, label %else
338 //
339 // cond.load:
340 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
341 // %Load0 = load i32, i32* %Ptr0, align 4
342 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
343 // br label %else
344 //
345 // else:
346 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
347 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
348 // br i1 %Mask1, label %cond.load1, label %else2
349 //
350 // cond.load1:
351 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
352 // %Load1 = load i32, i32* %Ptr1, align 4
353 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
354 // br label %else2
355 // . . .
356 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
357 // ret <16 x i32> %Result
358 static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) {
359  Value *Ptrs = CI->getArgOperand(0);
360  Value *Alignment = CI->getArgOperand(1);
361  Value *Mask = CI->getArgOperand(2);
362  Value *Src0 = CI->getArgOperand(3);
363 
364  VectorType *VecType = cast<VectorType>(CI->getType());
365  Type *EltTy = VecType->getElementType();
366 
367  IRBuilder<> Builder(CI->getContext());
368  Instruction *InsertPt = CI;
369  BasicBlock *IfBlock = CI->getParent();
370  Builder.SetInsertPoint(InsertPt);
371  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
372 
373  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
374 
375  // The result vector
376  Value *VResult = Src0;
377  unsigned VectorWidth = VecType->getNumElements();
378 
379  // Shorten the way if the mask is a vector of constants.
380  if (isConstantIntVector(Mask)) {
381  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
382  if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
383  continue;
384  Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
385  LoadInst *Load =
386  Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
387  VResult =
388  Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
389  }
390  CI->replaceAllUsesWith(VResult);
391  CI->eraseFromParent();
392  return;
393  }
394 
395  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
396  // Fill the "else" block, created in the previous iteration
397  //
398  // %Mask1 = extractelement <16 x i1> %Mask, i32 1
399  // br i1 %Mask1, label %cond.load, label %else
400  //
401 
402  Value *Predicate =
403  Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
404 
405  // Create "cond" block
406  //
407  // %EltAddr = getelementptr i32* %1, i32 0
408  // %Elt = load i32* %EltAddr
409  // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
410  //
411  BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
412  Builder.SetInsertPoint(InsertPt);
413 
414  Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
415  LoadInst *Load =
416  Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
417  Value *NewVResult =
418  Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
419 
420  // Create "else" block, fill it in the next iteration
421  BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
422  Builder.SetInsertPoint(InsertPt);
423  Instruction *OldBr = IfBlock->getTerminator();
424  BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
425  OldBr->eraseFromParent();
426  BasicBlock *PrevIfBlock = IfBlock;
427  IfBlock = NewIfBlock;
428 
429  PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
430  Phi->addIncoming(NewVResult, CondBlock);
431  Phi->addIncoming(VResult, PrevIfBlock);
432  VResult = Phi;
433  }
434 
435  CI->replaceAllUsesWith(VResult);
436  CI->eraseFromParent();
437 
438  ModifiedDT = true;
439 }
440 
441 // Translate a masked scatter intrinsic, like
442 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
443 // <16 x i1> %Mask)
444 // to a chain of basic blocks, that stores element one-by-one if
445 // the appropriate mask bit is set.
446 //
447 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
448 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
449 // br i1 %Mask0, label %cond.store, label %else
450 //
451 // cond.store:
452 // %Elt0 = extractelement <16 x i32> %Src, i32 0
453 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
454 // store i32 %Elt0, i32* %Ptr0, align 4
455 // br label %else
456 //
457 // else:
458 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
459 // br i1 %Mask1, label %cond.store1, label %else2
460 //
461 // cond.store1:
462 // %Elt1 = extractelement <16 x i32> %Src, i32 1
463 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
464 // store i32 %Elt1, i32* %Ptr1, align 4
465 // br label %else2
466 // . . .
467 static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) {
468  Value *Src = CI->getArgOperand(0);
469  Value *Ptrs = CI->getArgOperand(1);
470  Value *Alignment = CI->getArgOperand(2);
471  Value *Mask = CI->getArgOperand(3);
472 
473  assert(isa<VectorType>(Src->getType()) &&
474  "Unexpected data type in masked scatter intrinsic");
475  assert(isa<VectorType>(Ptrs->getType()) &&
476  isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
477  "Vector of pointers is expected in masked scatter intrinsic");
478 
479  IRBuilder<> Builder(CI->getContext());
480  Instruction *InsertPt = CI;
481  BasicBlock *IfBlock = CI->getParent();
482  Builder.SetInsertPoint(InsertPt);
483  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
484 
485  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
486  unsigned VectorWidth = Src->getType()->getVectorNumElements();
487 
488  // Shorten the way if the mask is a vector of constants.
489  if (isConstantIntVector(Mask)) {
490  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
491  if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
492  continue;
493  Value *OneElt =
494  Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
495  Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
496  Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
497  }
498  CI->eraseFromParent();
499  return;
500  }
501 
502  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
503  // Fill the "else" block, created in the previous iteration
504  //
505  // %Mask1 = extractelement <16 x i1> %Mask, i32 Idx
506  // br i1 %Mask1, label %cond.store, label %else
507  //
508  Value *Predicate =
509  Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
510 
511  // Create "cond" block
512  //
513  // %Elt1 = extractelement <16 x i32> %Src, i32 1
514  // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
515  // %store i32 %Elt1, i32* %Ptr1
516  //
517  BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
518  Builder.SetInsertPoint(InsertPt);
519 
520  Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
521  Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
522  Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
523 
524  // Create "else" block, fill it in the next iteration
525  BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
526  Builder.SetInsertPoint(InsertPt);
527  Instruction *OldBr = IfBlock->getTerminator();
528  BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
529  OldBr->eraseFromParent();
530  IfBlock = NewIfBlock;
531  }
532  CI->eraseFromParent();
533 
534  ModifiedDT = true;
535 }
536 
537 static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) {
538  Value *Ptr = CI->getArgOperand(0);
539  Value *Mask = CI->getArgOperand(1);
540  Value *PassThru = CI->getArgOperand(2);
541 
542  VectorType *VecType = cast<VectorType>(CI->getType());
543 
544  Type *EltTy = VecType->getElementType();
545 
546  IRBuilder<> Builder(CI->getContext());
547  Instruction *InsertPt = CI;
548  BasicBlock *IfBlock = CI->getParent();
549 
550  Builder.SetInsertPoint(InsertPt);
551  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
552 
553  unsigned VectorWidth = VecType->getNumElements();
554 
555  // The result vector
556  Value *VResult = PassThru;
557 
558  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
559  // Fill the "else" block, created in the previous iteration
560  //
561  // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
562  // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
563  // br i1 %mask_1, label %cond.load, label %else
564  //
565 
566  Value *Predicate =
567  Builder.CreateExtractElement(Mask, Idx);
568 
569  // Create "cond" block
570  //
571  // %EltAddr = getelementptr i32* %1, i32 0
572  // %Elt = load i32* %EltAddr
573  // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
574  //
575  BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
576  "cond.load");
577  Builder.SetInsertPoint(InsertPt);
578 
579  LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, 1);
580  Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
581 
582  // Move the pointer if there are more blocks to come.
583  Value *NewPtr;
584  if ((Idx + 1) != VectorWidth)
585  NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
586 
587  // Create "else" block, fill it in the next iteration
588  BasicBlock *NewIfBlock =
589  CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
590  Builder.SetInsertPoint(InsertPt);
591  Instruction *OldBr = IfBlock->getTerminator();
592  BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
593  OldBr->eraseFromParent();
594  BasicBlock *PrevIfBlock = IfBlock;
595  IfBlock = NewIfBlock;
596 
597  // Create the phi to join the new and previous value.
598  PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
599  ResultPhi->addIncoming(NewVResult, CondBlock);
600  ResultPhi->addIncoming(VResult, PrevIfBlock);
601  VResult = ResultPhi;
602 
603  // Add a PHI for the pointer if this isn't the last iteration.
604  if ((Idx + 1) != VectorWidth) {
605  PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
606  PtrPhi->addIncoming(NewPtr, CondBlock);
607  PtrPhi->addIncoming(Ptr, PrevIfBlock);
608  Ptr = PtrPhi;
609  }
610  }
611 
612  CI->replaceAllUsesWith(VResult);
613  CI->eraseFromParent();
614 
615  ModifiedDT = true;
616 }
617 
618 static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) {
619  Value *Src = CI->getArgOperand(0);
620  Value *Ptr = CI->getArgOperand(1);
621  Value *Mask = CI->getArgOperand(2);
622 
623  VectorType *VecType = cast<VectorType>(Src->getType());
624 
625  IRBuilder<> Builder(CI->getContext());
626  Instruction *InsertPt = CI;
627  BasicBlock *IfBlock = CI->getParent();
628 
629  Builder.SetInsertPoint(InsertPt);
630  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
631 
632  Type *EltTy = VecType->getVectorElementType();
633 
634  unsigned VectorWidth = VecType->getNumElements();
635 
636  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
637  // Fill the "else" block, created in the previous iteration
638  //
639  // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
640  // br i1 %mask_1, label %cond.store, label %else
641  //
642  Value *Predicate = Builder.CreateExtractElement(Mask, Idx);
643 
644  // Create "cond" block
645  //
646  // %OneElt = extractelement <16 x i32> %Src, i32 Idx
647  // %EltAddr = getelementptr i32* %1, i32 0
648  // %store i32 %OneElt, i32* %EltAddr
649  //
650  BasicBlock *CondBlock =
651  IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
652  Builder.SetInsertPoint(InsertPt);
653 
654  Value *OneElt = Builder.CreateExtractElement(Src, Idx);
655  Builder.CreateAlignedStore(OneElt, Ptr, 1);
656 
657  // Move the pointer if there are more blocks to come.
658  Value *NewPtr;
659  if ((Idx + 1) != VectorWidth)
660  NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
661 
662  // Create "else" block, fill it in the next iteration
663  BasicBlock *NewIfBlock =
664  CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
665  Builder.SetInsertPoint(InsertPt);
666  Instruction *OldBr = IfBlock->getTerminator();
667  BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
668  OldBr->eraseFromParent();
669  BasicBlock *PrevIfBlock = IfBlock;
670  IfBlock = NewIfBlock;
671 
672  // Add a PHI for the pointer if this isn't the last iteration.
673  if ((Idx + 1) != VectorWidth) {
674  PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
675  PtrPhi->addIncoming(NewPtr, CondBlock);
676  PtrPhi->addIncoming(Ptr, PrevIfBlock);
677  Ptr = PtrPhi;
678  }
679  }
680  CI->eraseFromParent();
681 
682  ModifiedDT = true;
683 }
684 
686  bool EverMadeChange = false;
687 
688  TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
689 
690  bool MadeChange = true;
691  while (MadeChange) {
692  MadeChange = false;
693  for (Function::iterator I = F.begin(); I != F.end();) {
694  BasicBlock *BB = &*I++;
695  bool ModifiedDTOnIteration = false;
696  MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
697 
698  // Restart BB iteration if the dominator tree of the Function was changed
699  if (ModifiedDTOnIteration)
700  break;
701  }
702 
703  EverMadeChange |= MadeChange;
704  }
705 
706  return EverMadeChange;
707 }
708 
709 bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
710  bool MadeChange = false;
711 
712  BasicBlock::iterator CurInstIterator = BB.begin();
713  while (CurInstIterator != BB.end()) {
714  if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
715  MadeChange |= optimizeCallInst(CI, ModifiedDT);
716  if (ModifiedDT)
717  return true;
718  }
719 
720  return MadeChange;
721 }
722 
723 bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
724  bool &ModifiedDT) {
726  if (II) {
727  switch (II->getIntrinsicID()) {
728  default:
729  break;
730  case Intrinsic::masked_load:
731  // Scalarize unsupported vector masked load
732  if (TTI->isLegalMaskedLoad(CI->getType()))
733  return false;
734  scalarizeMaskedLoad(CI, ModifiedDT);
735  return true;
736  case Intrinsic::masked_store:
737  if (TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType()))
738  return false;
739  scalarizeMaskedStore(CI, ModifiedDT);
740  return true;
741  case Intrinsic::masked_gather:
742  if (TTI->isLegalMaskedGather(CI->getType()))
743  return false;
744  scalarizeMaskedGather(CI, ModifiedDT);
745  return true;
746  case Intrinsic::masked_scatter:
747  if (TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType()))
748  return false;
749  scalarizeMaskedScatter(CI, ModifiedDT);
750  return true;
751  case Intrinsic::masked_expandload:
752  if (TTI->isLegalMaskedExpandLoad(CI->getType()))
753  return false;
754  scalarizeMaskedExpandLoad(CI, ModifiedDT);
755  return true;
756  case Intrinsic::masked_compressstore:
757  if (TTI->isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
758  return false;
759  scalarizeMaskedCompressStore(CI, ModifiedDT);
760  return true;
761  }
762  }
763 
764  return false;
765 }
Type * getVectorElementType() const
Definition: Type.h:370
uint64_t CallInst * C
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks &#39;this&#39; from the containing basic block and deletes it.
Definition: Instruction.cpp:67
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
This class represents lattice values for constants.
Definition: AllocatorList.h:23
static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT)
iterator end()
Definition: Function.h:674
This class represents a function call, abstracting a target machine&#39;s calling convention.
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:720
F(f)
unsigned getPointerAddressSpace() const
Get the address space of this pointer or pointer vector type.
Definition: DerivedTypes.h:534
An instruction for reading from memory.
Definition: Instructions.h:167
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.cpp:137
void initializeScalarizeMaskedMemIntrinPass(PassRegistry &)
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:268
FunctionPass * createScalarizeMaskedMemIntrinPass()
createScalarizeMaskedMemIntrinPass - Replace masked load, store, gather and scatter intrinsics with s...
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1241
AnalysisUsage & addRequired()
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:80
static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT)
PointerType * getPointerTo(unsigned AddrSpace=0) const
Return a pointer to the current type.
Definition: Type.cpp:651
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:742
uint64_t getNumElements() const
Definition: DerivedTypes.h:390
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:244
#define DEBUG_TYPE
static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT)
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:429
iterator begin()
Definition: Function.h:672
static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT)
Constant * getAggregateElement(unsigned Elt) const
For aggregates (struct/array/vector) return the constant that corresponds to the specified element if...
Definition: Constants.cpp:344
constexpr uint64_t MinAlign(uint64_t A, uint64_t B)
A and B are either alignments or offsets.
Definition: MathExtras.h:609
static bool runOnFunction(Function &F, bool PostInlining)
Wrapper pass for TargetTransformInfo.
LLVM Basic Block Representation.
Definition: BasicBlock.h:57
The instances of the Type class are immutable: once they are created, they are never changed...
Definition: Type.h:45
This is an important base class in LLVM.
Definition: Constant.h:41
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Represent the analysis usage information of a pass.
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:284
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
Definition: IntrinsicInst.h:50
Iterator for intrusive lists based on ilist_node.
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE, "Scalarize unsupported masked memory intrinsics", false, false) FunctionPass *llvm
iterator end()
Definition: BasicBlock.h:270
static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT)
static BranchInst * Create(BasicBlock *IfTrue, Instruction *InsertBefore=nullptr)
unsigned getVectorNumElements() const
Definition: DerivedTypes.h:493
Class to represent vector types.
Definition: DerivedTypes.h:424
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
Definition: Instruction.h:321
#define I(x, y, z)
Definition: MD5.cpp:58
LLVM_NODISCARD std::enable_if<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
Definition: Casting.h:332
static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT)
BasicBlock * splitBasicBlock(iterator I, const Twine &BBName="")
Split the basic block into two basic blocks at the specified instruction.
Definition: BasicBlock.cpp:407
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
LLVM Value Representation.
Definition: Value.h:72
std::underlying_type< E >::type Mask()
Get a bitmask with 1s in all places up to the high-order bit of E&#39;s largest value.
Definition: BitmaskEnum.h:80
Type * getElementType() const
Definition: DerivedTypes.h:391
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:48
This pass exposes codegen information to IR-level passes.
static bool isConstantIntVector(Value *Mask)
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:43
const BasicBlock * getParent() const
Definition: Instruction.h:66