LLVM  9.0.0svn
ScalarizeMaskedMemIntrin.cpp
Go to the documentation of this file.
1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2 // instrinsics
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass replaces masked memory intrinsics - when unsupported by the target
11 // - with a chain of basic blocks, that deal with the elements one-by-one if the
12 // appropriate mask bit is set.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "llvm/ADT/Twine.h"
19 #include "llvm/IR/BasicBlock.h"
20 #include "llvm/IR/Constant.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/DerivedTypes.h"
23 #include "llvm/IR/Function.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/InstrTypes.h"
26 #include "llvm/IR/Instruction.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/IR/Value.h"
32 #include "llvm/Pass.h"
33 #include "llvm/Support/Casting.h"
34 #include <algorithm>
35 #include <cassert>
36 
37 using namespace llvm;
38 
39 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
40 
41 namespace {
42 
43 class ScalarizeMaskedMemIntrin : public FunctionPass {
44  const TargetTransformInfo *TTI = nullptr;
45 
46 public:
47  static char ID; // Pass identification, replacement for typeid
48 
49  explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
51  }
52 
53  bool runOnFunction(Function &F) override;
54 
55  StringRef getPassName() const override {
56  return "Scalarize Masked Memory Intrinsics";
57  }
58 
59  void getAnalysisUsage(AnalysisUsage &AU) const override {
61  }
62 
63 private:
64  bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
65  bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
66 };
67 
68 } // end anonymous namespace
69 
71 
72 INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
73  "Scalarize unsupported masked memory intrinsics", false, false)
74 
76  return new ScalarizeMaskedMemIntrin();
77 }
78 
81  if (!C)
82  return false;
83 
84  unsigned NumElts = Mask->getType()->getVectorNumElements();
85  for (unsigned i = 0; i != NumElts; ++i) {
86  Constant *CElt = C->getAggregateElement(i);
87  if (!CElt || !isa<ConstantInt>(CElt))
88  return false;
89  }
90 
91  return true;
92 }
93 
94 // Translate a masked load intrinsic like
95 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
96 // <16 x i1> %mask, <16 x i32> %passthru)
97 // to a chain of basic blocks, with loading element one-by-one if
98 // the appropriate mask bit is set
99 //
100 // %1 = bitcast i8* %addr to i32*
101 // %2 = extractelement <16 x i1> %mask, i32 0
102 // br i1 %2, label %cond.load, label %else
103 //
104 // cond.load: ; preds = %0
105 // %3 = getelementptr i32* %1, i32 0
106 // %4 = load i32* %3
107 // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
108 // br label %else
109 //
110 // else: ; preds = %0, %cond.load
111 // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
112 // %6 = extractelement <16 x i1> %mask, i32 1
113 // br i1 %6, label %cond.load1, label %else2
114 //
115 // cond.load1: ; preds = %else
116 // %7 = getelementptr i32* %1, i32 1
117 // %8 = load i32* %7
118 // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
119 // br label %else2
120 //
121 // else2: ; preds = %else, %cond.load1
122 // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
123 // %10 = extractelement <16 x i1> %mask, i32 2
124 // br i1 %10, label %cond.load4, label %else5
125 //
126 static void scalarizeMaskedLoad(CallInst *CI) {
127  Value *Ptr = CI->getArgOperand(0);
128  Value *Alignment = CI->getArgOperand(1);
129  Value *Mask = CI->getArgOperand(2);
130  Value *Src0 = CI->getArgOperand(3);
131 
132  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
133  VectorType *VecType = cast<VectorType>(CI->getType());
134 
135  Type *EltTy = VecType->getElementType();
136 
137  IRBuilder<> Builder(CI->getContext());
138  Instruction *InsertPt = CI;
139  BasicBlock *IfBlock = CI->getParent();
140 
141  Builder.SetInsertPoint(InsertPt);
142  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
143 
144  // Short-cut if the mask is all-true.
145  if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
146  Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
147  CI->replaceAllUsesWith(NewI);
148  CI->eraseFromParent();
149  return;
150  }
151 
152  // Adjust alignment for the scalar instruction.
153  AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
154  // Bitcast %addr fron i8* to EltTy*
155  Type *NewPtrType =
156  EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
157  Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
158  unsigned VectorWidth = VecType->getNumElements();
159 
160  // The result vector
161  Value *VResult = Src0;
162 
163  if (isConstantIntVector(Mask)) {
164  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
165  if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
166  continue;
167  Value *Gep =
168  Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
169  LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
170  VResult =
171  Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
172  }
173  CI->replaceAllUsesWith(VResult);
174  CI->eraseFromParent();
175  return;
176  }
177 
178  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
179  // Fill the "else" block, created in the previous iteration
180  //
181  // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
182  // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
183  // br i1 %mask_1, label %cond.load, label %else
184  //
185 
186  Value *Predicate =
187  Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
188 
189  // Create "cond" block
190  //
191  // %EltAddr = getelementptr i32* %1, i32 0
192  // %Elt = load i32* %EltAddr
193  // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
194  //
195  BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
196  "cond.load");
197  Builder.SetInsertPoint(InsertPt);
198 
199  Value *Gep =
200  Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
201  LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
202  Value *NewVResult = Builder.CreateInsertElement(VResult, Load,
203  Builder.getInt32(Idx));
204 
205  // Create "else" block, fill it in the next iteration
206  BasicBlock *NewIfBlock =
207  CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
208  Builder.SetInsertPoint(InsertPt);
209  Instruction *OldBr = IfBlock->getTerminator();
210  BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
211  OldBr->eraseFromParent();
212  BasicBlock *PrevIfBlock = IfBlock;
213  IfBlock = NewIfBlock;
214 
215  // Create the phi to join the new and previous value.
216  PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
217  Phi->addIncoming(NewVResult, CondBlock);
218  Phi->addIncoming(VResult, PrevIfBlock);
219  VResult = Phi;
220  }
221 
222  CI->replaceAllUsesWith(VResult);
223  CI->eraseFromParent();
224 }
225 
226 // Translate a masked store intrinsic, like
227 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
228 // <16 x i1> %mask)
229 // to a chain of basic blocks, that stores element one-by-one if
230 // the appropriate mask bit is set
231 //
232 // %1 = bitcast i8* %addr to i32*
233 // %2 = extractelement <16 x i1> %mask, i32 0
234 // br i1 %2, label %cond.store, label %else
235 //
236 // cond.store: ; preds = %0
237 // %3 = extractelement <16 x i32> %val, i32 0
238 // %4 = getelementptr i32* %1, i32 0
239 // store i32 %3, i32* %4
240 // br label %else
241 //
242 // else: ; preds = %0, %cond.store
243 // %5 = extractelement <16 x i1> %mask, i32 1
244 // br i1 %5, label %cond.store1, label %else2
245 //
246 // cond.store1: ; preds = %else
247 // %6 = extractelement <16 x i32> %val, i32 1
248 // %7 = getelementptr i32* %1, i32 1
249 // store i32 %6, i32* %7
250 // br label %else2
251 // . . .
252 static void scalarizeMaskedStore(CallInst *CI) {
253  Value *Src = CI->getArgOperand(0);
254  Value *Ptr = CI->getArgOperand(1);
255  Value *Alignment = CI->getArgOperand(2);
256  Value *Mask = CI->getArgOperand(3);
257 
258  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
259  VectorType *VecType = cast<VectorType>(Src->getType());
260 
261  Type *EltTy = VecType->getElementType();
262 
263  IRBuilder<> Builder(CI->getContext());
264  Instruction *InsertPt = CI;
265  BasicBlock *IfBlock = CI->getParent();
266  Builder.SetInsertPoint(InsertPt);
267  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
268 
269  // Short-cut if the mask is all-true.
270  if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
271  Builder.CreateAlignedStore(Src, Ptr, AlignVal);
272  CI->eraseFromParent();
273  return;
274  }
275 
276  // Adjust alignment for the scalar instruction.
277  AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
278  // Bitcast %addr fron i8* to EltTy*
279  Type *NewPtrType =
280  EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
281  Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
282  unsigned VectorWidth = VecType->getNumElements();
283 
284  if (isConstantIntVector(Mask)) {
285  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
286  if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
287  continue;
288  Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
289  Value *Gep =
290  Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
291  Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
292  }
293  CI->eraseFromParent();
294  return;
295  }
296 
297  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
298  // Fill the "else" block, created in the previous iteration
299  //
300  // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
301  // br i1 %mask_1, label %cond.store, label %else
302  //
303  Value *Predicate =
304  Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
305 
306  // Create "cond" block
307  //
308  // %OneElt = extractelement <16 x i32> %Src, i32 Idx
309  // %EltAddr = getelementptr i32* %1, i32 0
310  // %store i32 %OneElt, i32* %EltAddr
311  //
312  BasicBlock *CondBlock =
313  IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
314  Builder.SetInsertPoint(InsertPt);
315 
316  Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
317  Value *Gep =
318  Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
319  Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
320 
321  // Create "else" block, fill it in the next iteration
322  BasicBlock *NewIfBlock =
323  CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
324  Builder.SetInsertPoint(InsertPt);
325  Instruction *OldBr = IfBlock->getTerminator();
326  BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
327  OldBr->eraseFromParent();
328  IfBlock = NewIfBlock;
329  }
330  CI->eraseFromParent();
331 }
332 
333 // Translate a masked gather intrinsic like
334 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
335 // <16 x i1> %Mask, <16 x i32> %Src)
336 // to a chain of basic blocks, with loading element one-by-one if
337 // the appropriate mask bit is set
338 //
339 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
340 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
341 // br i1 %Mask0, label %cond.load, label %else
342 //
343 // cond.load:
344 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
345 // %Load0 = load i32, i32* %Ptr0, align 4
346 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
347 // br label %else
348 //
349 // else:
350 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
351 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
352 // br i1 %Mask1, label %cond.load1, label %else2
353 //
354 // cond.load1:
355 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
356 // %Load1 = load i32, i32* %Ptr1, align 4
357 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
358 // br label %else2
359 // . . .
360 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
361 // ret <16 x i32> %Result
362 static void scalarizeMaskedGather(CallInst *CI) {
363  Value *Ptrs = CI->getArgOperand(0);
364  Value *Alignment = CI->getArgOperand(1);
365  Value *Mask = CI->getArgOperand(2);
366  Value *Src0 = CI->getArgOperand(3);
367 
368  VectorType *VecType = cast<VectorType>(CI->getType());
369  Type *EltTy = VecType->getElementType();
370 
371  IRBuilder<> Builder(CI->getContext());
372  Instruction *InsertPt = CI;
373  BasicBlock *IfBlock = CI->getParent();
374  Builder.SetInsertPoint(InsertPt);
375  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
376 
377  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
378 
379  // The result vector
380  Value *VResult = Src0;
381  unsigned VectorWidth = VecType->getNumElements();
382 
383  // Shorten the way if the mask is a vector of constants.
384  if (isConstantIntVector(Mask)) {
385  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
386  if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
387  continue;
388  Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
389  "Ptr" + Twine(Idx));
390  LoadInst *Load =
391  Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
392  VResult = Builder.CreateInsertElement(
393  VResult, Load, Builder.getInt32(Idx), "Res" + Twine(Idx));
394  }
395  CI->replaceAllUsesWith(VResult);
396  CI->eraseFromParent();
397  return;
398  }
399 
400  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
401  // Fill the "else" block, created in the previous iteration
402  //
403  // %Mask1 = extractelement <16 x i1> %Mask, i32 1
404  // br i1 %Mask1, label %cond.load, label %else
405  //
406 
407  Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
408  "Mask" + Twine(Idx));
409 
410  // Create "cond" block
411  //
412  // %EltAddr = getelementptr i32* %1, i32 0
413  // %Elt = load i32* %EltAddr
414  // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
415  //
416  BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
417  Builder.SetInsertPoint(InsertPt);
418 
419  Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
420  "Ptr" + Twine(Idx));
421  LoadInst *Load =
422  Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
423  Value *NewVResult = Builder.CreateInsertElement(VResult, Load,
424  Builder.getInt32(Idx),
425  "Res" + Twine(Idx));
426 
427  // Create "else" block, fill it in the next iteration
428  BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
429  Builder.SetInsertPoint(InsertPt);
430  Instruction *OldBr = IfBlock->getTerminator();
431  BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
432  OldBr->eraseFromParent();
433  BasicBlock *PrevIfBlock = IfBlock;
434  IfBlock = NewIfBlock;
435 
436  PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
437  Phi->addIncoming(NewVResult, CondBlock);
438  Phi->addIncoming(VResult, PrevIfBlock);
439  VResult = Phi;
440  }
441 
442  CI->replaceAllUsesWith(VResult);
443  CI->eraseFromParent();
444 }
445 
446 // Translate a masked scatter intrinsic, like
447 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
448 // <16 x i1> %Mask)
449 // to a chain of basic blocks, that stores element one-by-one if
450 // the appropriate mask bit is set.
451 //
452 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
453 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
454 // br i1 %Mask0, label %cond.store, label %else
455 //
456 // cond.store:
457 // %Elt0 = extractelement <16 x i32> %Src, i32 0
458 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
459 // store i32 %Elt0, i32* %Ptr0, align 4
460 // br label %else
461 //
462 // else:
463 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
464 // br i1 %Mask1, label %cond.store1, label %else2
465 //
466 // cond.store1:
467 // %Elt1 = extractelement <16 x i32> %Src, i32 1
468 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
469 // store i32 %Elt1, i32* %Ptr1, align 4
470 // br label %else2
471 // . . .
473  Value *Src = CI->getArgOperand(0);
474  Value *Ptrs = CI->getArgOperand(1);
475  Value *Alignment = CI->getArgOperand(2);
476  Value *Mask = CI->getArgOperand(3);
477 
478  assert(isa<VectorType>(Src->getType()) &&
479  "Unexpected data type in masked scatter intrinsic");
480  assert(isa<VectorType>(Ptrs->getType()) &&
481  isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
482  "Vector of pointers is expected in masked scatter intrinsic");
483 
484  IRBuilder<> Builder(CI->getContext());
485  Instruction *InsertPt = CI;
486  BasicBlock *IfBlock = CI->getParent();
487  Builder.SetInsertPoint(InsertPt);
488  Builder.SetCurrentDebugLocation(CI->getDebugLoc());
489 
490  unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
491  unsigned VectorWidth = Src->getType()->getVectorNumElements();
492 
493  // Shorten the way if the mask is a vector of constants.
494  if (isConstantIntVector(Mask)) {
495  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
496  if (cast<ConstantVector>(Mask)->getAggregateElement(Idx)->isNullValue())
497  continue;
498  Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
499  "Elt" + Twine(Idx));
500  Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
501  "Ptr" + Twine(Idx));
502  Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
503  }
504  CI->eraseFromParent();
505  return;
506  }
507 
508  for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
509  // Fill the "else" block, created in the previous iteration
510  //
511  // %Mask1 = extractelement <16 x i1> %Mask, i32 Idx
512  // br i1 %Mask1, label %cond.store, label %else
513  //
514  Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
515  "Mask" + Twine(Idx));
516 
517  // Create "cond" block
518  //
519  // %Elt1 = extractelement <16 x i32> %Src, i32 1
520  // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
521  // %store i32 %Elt1, i32* %Ptr1
522  //
523  BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
524  Builder.SetInsertPoint(InsertPt);
525 
526  Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
527  "Elt" + Twine(Idx));
528  Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
529  "Ptr" + Twine(Idx));
530  Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
531 
532  // Create "else" block, fill it in the next iteration
533  BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
534  Builder.SetInsertPoint(InsertPt);
535  Instruction *OldBr = IfBlock->getTerminator();
536  BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
537  OldBr->eraseFromParent();
538  IfBlock = NewIfBlock;
539  }
540  CI->eraseFromParent();
541 }
542 
544  bool EverMadeChange = false;
545 
546  TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
547 
548  bool MadeChange = true;
549  while (MadeChange) {
550  MadeChange = false;
551  for (Function::iterator I = F.begin(); I != F.end();) {
552  BasicBlock *BB = &*I++;
553  bool ModifiedDTOnIteration = false;
554  MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
555 
556  // Restart BB iteration if the dominator tree of the Function was changed
557  if (ModifiedDTOnIteration)
558  break;
559  }
560 
561  EverMadeChange |= MadeChange;
562  }
563 
564  return EverMadeChange;
565 }
566 
567 bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
568  bool MadeChange = false;
569 
570  BasicBlock::iterator CurInstIterator = BB.begin();
571  while (CurInstIterator != BB.end()) {
572  if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
573  MadeChange |= optimizeCallInst(CI, ModifiedDT);
574  if (ModifiedDT)
575  return true;
576  }
577 
578  return MadeChange;
579 }
580 
581 bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
582  bool &ModifiedDT) {
584  if (II) {
585  switch (II->getIntrinsicID()) {
586  default:
587  break;
588  case Intrinsic::masked_load:
589  // Scalarize unsupported vector masked load
590  if (!TTI->isLegalMaskedLoad(CI->getType())) {
592  ModifiedDT = true;
593  return true;
594  }
595  return false;
596  case Intrinsic::masked_store:
597  if (!TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType())) {
599  ModifiedDT = true;
600  return true;
601  }
602  return false;
603  case Intrinsic::masked_gather:
604  if (!TTI->isLegalMaskedGather(CI->getType())) {
606  ModifiedDT = true;
607  return true;
608  }
609  return false;
610  case Intrinsic::masked_scatter:
611  if (!TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType())) {
613  ModifiedDT = true;
614  return true;
615  }
616  return false;
617  }
618  }
619 
620  return false;
621 }
Type * getVectorElementType() const
Definition: Type.h:370
uint64_t CallInst * C
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks &#39;this&#39; from the containing basic block and deletes it.
Definition: Instruction.cpp:67
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
static void scalarizeMaskedScatter(CallInst *CI)
This class represents lattice values for constants.
Definition: AllocatorList.h:23
iterator end()
Definition: Function.h:657
This class represents a function call, abstracting a target machine&#39;s calling convention.
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:705
F(f)
An instruction for reading from memory.
Definition: Instructions.h:167
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.cpp:137
void initializeScalarizeMaskedMemIntrinPass(PassRegistry &)
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:268
FunctionPass * createScalarizeMaskedMemIntrinPass()
createScalarizeMaskedMemIntrinPass - Replace masked load, store, gather and scatter intrinsics with s...
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1155
AnalysisUsage & addRequired()
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:80
PointerType * getPointerTo(unsigned AddrSpace=0) const
Return a pointer to the current type.
Definition: Type.cpp:651
static void scalarizeMaskedGather(CallInst *CI)
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:742
uint64_t getNumElements() const
Definition: DerivedTypes.h:390
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:244
#define DEBUG_TYPE
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:429
iterator begin()
Definition: Function.h:655
Constant * getAggregateElement(unsigned Elt) const
For aggregates (struct/array/vector) return the constant that corresponds to the specified element if...
Definition: Constants.cpp:334
constexpr uint64_t MinAlign(uint64_t A, uint64_t B)
A and B are either alignments or offsets.
Definition: MathExtras.h:609
static bool runOnFunction(Function &F, bool PostInlining)
Wrapper pass for TargetTransformInfo.
LLVM Basic Block Representation.
Definition: BasicBlock.h:57
The instances of the Type class are immutable: once they are created, they are never changed...
Definition: Type.h:45
This is an important base class in LLVM.
Definition: Constant.h:41
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Represent the analysis usage information of a pass.
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:284
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
Definition: IntrinsicInst.h:50
Iterator for intrusive lists based on ilist_node.
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE, "Scalarize unsupported masked memory intrinsics", false, false) FunctionPass *llvm
iterator end()
Definition: BasicBlock.h:270
static BranchInst * Create(BasicBlock *IfTrue, Instruction *InsertBefore=nullptr)
unsigned getVectorNumElements() const
Definition: DerivedTypes.h:493
Class to represent vector types.
Definition: DerivedTypes.h:424
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
Definition: Instruction.h:324
static void scalarizeMaskedLoad(CallInst *CI)
#define I(x, y, z)
Definition: MD5.cpp:58
LLVM_NODISCARD std::enable_if<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
Definition: Casting.h:322
BasicBlock * splitBasicBlock(iterator I, const Twine &BBName="")
Split the basic block into two basic blocks at the specified instruction.
Definition: BasicBlock.cpp:407
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
LLVM Value Representation.
Definition: Value.h:72
std::underlying_type< E >::type Mask()
Get a bitmask with 1s in all places up to the high-order bit of E&#39;s largest value.
Definition: BitmaskEnum.h:80
Type * getElementType() const
Definition: DerivedTypes.h:391
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:48
This pass exposes codegen information to IR-level passes.
static void scalarizeMaskedStore(CallInst *CI)
static bool isConstantIntVector(Value *Mask)
A wrapper class for inspecting calls to intrinsic functions.
Definition: IntrinsicInst.h:43
const BasicBlock * getParent() const
Definition: Instruction.h:66