LCOV - code coverage report
Current view: top level - lib/CodeGen - ScalarizeMaskedMemIntrin.cpp (source / functions) Hit Total Coverage
Test: llvm-toolchain.info Lines: 240 267 89.9 %
Date: 2017-09-14 15:23:50 Functions: 14 15 93.3 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //=== ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem      ===//
       2             : //===                                instrinsics                           ===//
       3             : //
       4             : //                     The LLVM Compiler Infrastructure
       5             : //
       6             : // This file is distributed under the University of Illinois Open Source
       7             : // License. See LICENSE.TXT for details.
       8             : //
       9             : //===----------------------------------------------------------------------===//
      10             : //
      11             : // This pass replaces masked memory intrinsics - when unsupported by the target
      12             : // - with a chain of basic blocks, that deal with the elements one-by-one if the
      13             : // appropriate mask bit is set.
      14             : //
      15             : //===----------------------------------------------------------------------===//
      16             : 
      17             : #include "llvm/Analysis/TargetTransformInfo.h"
      18             : #include "llvm/IR/IRBuilder.h"
      19             : #include "llvm/IR/IntrinsicInst.h"
      20             : #include "llvm/Target/TargetSubtargetInfo.h"
      21             : 
      22             : using namespace llvm;
      23             : 
      24             : #define DEBUG_TYPE "scalarize-masked-mem-intrin"
      25             : 
      26             : namespace {
      27             : 
      28       33642 : class ScalarizeMaskedMemIntrin : public FunctionPass {
      29             :   const TargetTransformInfo *TTI;
      30             : 
      31             : public:
      32             :   static char ID; // Pass identification, replacement for typeid
      33       33848 :   explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID), TTI(nullptr) {
      34       16924 :     initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
      35       16924 :   }
      36             :   bool runOnFunction(Function &F) override;
      37             : 
      38          23 :   StringRef getPassName() const override {
      39          23 :     return "Scalarize Masked Memory Intrinsics";
      40             :   }
      41             : 
      42       16886 :   void getAnalysisUsage(AnalysisUsage &AU) const override {
      43       16886 :     AU.addRequired<TargetTransformInfoWrapperPass>();
      44       16886 :   }
      45             : 
      46             : private:
      47             :   bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
      48             :   bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
      49             : };
      50             : } // namespace
      51             : 
      52             : char ScalarizeMaskedMemIntrin::ID = 0;
      53      291852 : INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
      54             :                 "Scalarize unsupported masked memory intrinsics", false, false)
      55             : 
      56       16923 : FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
      57       16923 :   return new ScalarizeMaskedMemIntrin();
      58             : }
      59             : 
      60             : // Translate a masked load intrinsic like
      61             : // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
      62             : //                               <16 x i1> %mask, <16 x i32> %passthru)
      63             : // to a chain of basic blocks, with loading element one-by-one if
      64             : // the appropriate mask bit is set
      65             : //
      66             : //  %1 = bitcast i8* %addr to i32*
      67             : //  %2 = extractelement <16 x i1> %mask, i32 0
      68             : //  %3 = icmp eq i1 %2, true
      69             : //  br i1 %3, label %cond.load, label %else
      70             : //
      71             : // cond.load:                                        ; preds = %0
      72             : //  %4 = getelementptr i32* %1, i32 0
      73             : //  %5 = load i32* %4
      74             : //  %6 = insertelement <16 x i32> undef, i32 %5, i32 0
      75             : //  br label %else
      76             : //
      77             : // else:                                             ; preds = %0, %cond.load
      78             : //  %res.phi.else = phi <16 x i32> [ %6, %cond.load ], [ undef, %0 ]
      79             : //  %7 = extractelement <16 x i1> %mask, i32 1
      80             : //  %8 = icmp eq i1 %7, true
      81             : //  br i1 %8, label %cond.load1, label %else2
      82             : //
      83             : // cond.load1:                                       ; preds = %else
      84             : //  %9 = getelementptr i32* %1, i32 1
      85             : //  %10 = load i32* %9
      86             : //  %11 = insertelement <16 x i32> %res.phi.else, i32 %10, i32 1
      87             : //  br label %else2
      88             : //
      89             : // else2:                                          ; preds = %else, %cond.load1
      90             : //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
      91             : //  %12 = extractelement <16 x i1> %mask, i32 2
      92             : //  %13 = icmp eq i1 %12, true
      93             : //  br i1 %13, label %cond.load4, label %else5
      94             : //
      95           1 : static void scalarizeMaskedLoad(CallInst *CI) {
      96           1 :   Value *Ptr = CI->getArgOperand(0);
      97           1 :   Value *Alignment = CI->getArgOperand(1);
      98           1 :   Value *Mask = CI->getArgOperand(2);
      99           1 :   Value *Src0 = CI->getArgOperand(3);
     100             : 
     101           2 :   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
     102           2 :   VectorType *VecType = dyn_cast<VectorType>(CI->getType());
     103             :   assert(VecType && "Unexpected return type of masked load intrinsic");
     104             : 
     105           2 :   Type *EltTy = CI->getType()->getVectorElementType();
     106             : 
     107           3 :   IRBuilder<> Builder(CI->getContext());
     108           1 :   Instruction *InsertPt = CI;
     109           1 :   BasicBlock *IfBlock = CI->getParent();
     110           1 :   BasicBlock *CondBlock = nullptr;
     111           1 :   BasicBlock *PrevIfBlock = CI->getParent();
     112             : 
     113           1 :   Builder.SetInsertPoint(InsertPt);
     114           4 :   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
     115             : 
     116             :   // Short-cut if the mask is all-true.
     117             :   bool IsAllOnesMask =
     118           1 :       isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue();
     119             : 
     120             :   if (IsAllOnesMask) {
     121           0 :     Value *NewI = Builder.CreateAlignedLoad(Ptr, AlignVal);
     122           0 :     CI->replaceAllUsesWith(NewI);
     123           0 :     CI->eraseFromParent();
     124           0 :     return;
     125             :   }
     126             : 
     127             :   // Adjust alignment for the scalar instruction.
     128           2 :   AlignVal = std::min(AlignVal, VecType->getScalarSizeInBits() / 8);
     129             :   // Bitcast %addr fron i8* to EltTy*
     130             :   Type *NewPtrType =
     131           3 :       EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
     132           2 :   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
     133           1 :   unsigned VectorWidth = VecType->getNumElements();
     134             : 
     135           1 :   Value *UndefVal = UndefValue::get(VecType);
     136             : 
     137             :   // The result vector
     138           1 :   Value *VResult = UndefVal;
     139             : 
     140           2 :   if (isa<ConstantVector>(Mask)) {
     141           0 :     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     142           0 :       if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
     143           0 :         continue;
     144             :       Value *Gep =
     145           0 :           Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
     146           0 :       LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
     147           0 :       VResult =
     148           0 :           Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
     149             :     }
     150           0 :     Value *NewI = Builder.CreateSelect(Mask, VResult, Src0);
     151           0 :     CI->replaceAllUsesWith(NewI);
     152           0 :     CI->eraseFromParent();
     153           0 :     return;
     154             :   }
     155             : 
     156             :   PHINode *Phi = nullptr;
     157             :   Value *PrevPhi = UndefVal;
     158             : 
     159           9 :   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     160             : 
     161             :     // Fill the "else" block, created in the previous iteration
     162             :     //
     163             :     //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
     164             :     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
     165             :     //  %to_load = icmp eq i1 %mask_1, true
     166             :     //  br i1 %to_load, label %cond.load, label %else
     167             :     //
     168           4 :     if (Idx > 0) {
     169           3 :       Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
     170           3 :       Phi->addIncoming(VResult, CondBlock);
     171           3 :       Phi->addIncoming(PrevPhi, PrevIfBlock);
     172           3 :       PrevPhi = Phi;
     173           3 :       VResult = Phi;
     174             :     }
     175             : 
     176             :     Value *Predicate =
     177           4 :         Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
     178           8 :     Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
     179           4 :                                     ConstantInt::get(Predicate->getType(), 1));
     180             : 
     181             :     // Create "cond" block
     182             :     //
     183             :     //  %EltAddr = getelementptr i32* %1, i32 0
     184             :     //  %Elt = load i32* %EltAddr
     185             :     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
     186             :     //
     187           8 :     CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.load");
     188           4 :     Builder.SetInsertPoint(InsertPt);
     189             : 
     190             :     Value *Gep =
     191           4 :         Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
     192           8 :     LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
     193           4 :     VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
     194             : 
     195             :     // Create "else" block, fill it in the next iteration
     196             :     BasicBlock *NewIfBlock =
     197           8 :         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
     198           4 :     Builder.SetInsertPoint(InsertPt);
     199           4 :     Instruction *OldBr = IfBlock->getTerminator();
     200           4 :     BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
     201           4 :     OldBr->eraseFromParent();
     202           4 :     PrevIfBlock = IfBlock;
     203           4 :     IfBlock = NewIfBlock;
     204             :   }
     205             : 
     206           1 :   Phi = Builder.CreatePHI(VecType, 2, "res.phi.select");
     207           1 :   Phi->addIncoming(VResult, CondBlock);
     208           1 :   Phi->addIncoming(PrevPhi, PrevIfBlock);
     209           1 :   Value *NewI = Builder.CreateSelect(Mask, Phi, Src0);
     210           1 :   CI->replaceAllUsesWith(NewI);
     211           1 :   CI->eraseFromParent();
     212             : }
     213             : 
     214             : // Translate a masked store intrinsic, like
     215             : // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
     216             : //                               <16 x i1> %mask)
     217             : // to a chain of basic blocks, that stores element one-by-one if
     218             : // the appropriate mask bit is set
     219             : //
     220             : //   %1 = bitcast i8* %addr to i32*
     221             : //   %2 = extractelement <16 x i1> %mask, i32 0
     222             : //   %3 = icmp eq i1 %2, true
     223             : //   br i1 %3, label %cond.store, label %else
     224             : //
     225             : // cond.store:                                       ; preds = %0
     226             : //   %4 = extractelement <16 x i32> %val, i32 0
     227             : //   %5 = getelementptr i32* %1, i32 0
     228             : //   store i32 %4, i32* %5
     229             : //   br label %else
     230             : //
     231             : // else:                                             ; preds = %0, %cond.store
     232             : //   %6 = extractelement <16 x i1> %mask, i32 1
     233             : //   %7 = icmp eq i1 %6, true
     234             : //   br i1 %7, label %cond.store1, label %else2
     235             : //
     236             : // cond.store1:                                      ; preds = %else
     237             : //   %8 = extractelement <16 x i32> %val, i32 1
     238             : //   %9 = getelementptr i32* %1, i32 1
     239             : //   store i32 %8, i32* %9
     240             : //   br label %else2
     241             : //   . . .
     242           1 : static void scalarizeMaskedStore(CallInst *CI) {
     243           1 :   Value *Src = CI->getArgOperand(0);
     244           1 :   Value *Ptr = CI->getArgOperand(1);
     245           1 :   Value *Alignment = CI->getArgOperand(2);
     246           1 :   Value *Mask = CI->getArgOperand(3);
     247             : 
     248           2 :   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
     249           2 :   VectorType *VecType = dyn_cast<VectorType>(Src->getType());
     250             :   assert(VecType && "Unexpected data type in masked store intrinsic");
     251             : 
     252           1 :   Type *EltTy = VecType->getElementType();
     253             : 
     254           3 :   IRBuilder<> Builder(CI->getContext());
     255           1 :   Instruction *InsertPt = CI;
     256           1 :   BasicBlock *IfBlock = CI->getParent();
     257           1 :   Builder.SetInsertPoint(InsertPt);
     258           4 :   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
     259             : 
     260             :   // Short-cut if the mask is all-true.
     261             :   bool IsAllOnesMask =
     262           1 :       isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue();
     263             : 
     264             :   if (IsAllOnesMask) {
     265           0 :     Builder.CreateAlignedStore(Src, Ptr, AlignVal);
     266           0 :     CI->eraseFromParent();
     267           0 :     return;
     268             :   }
     269             : 
     270             :   // Adjust alignment for the scalar instruction.
     271           2 :   AlignVal = std::max(AlignVal, VecType->getScalarSizeInBits() / 8);
     272             :   // Bitcast %addr fron i8* to EltTy*
     273             :   Type *NewPtrType =
     274           3 :       EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
     275           2 :   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
     276           1 :   unsigned VectorWidth = VecType->getNumElements();
     277             : 
     278           2 :   if (isa<ConstantVector>(Mask)) {
     279           0 :     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     280           0 :       if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
     281           0 :         continue;
     282           0 :       Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
     283             :       Value *Gep =
     284           0 :           Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
     285           0 :       Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
     286             :     }
     287           0 :     CI->eraseFromParent();
     288           0 :     return;
     289             :   }
     290             : 
     291           9 :   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     292             : 
     293             :     // Fill the "else" block, created in the previous iteration
     294             :     //
     295             :     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
     296             :     //  %to_store = icmp eq i1 %mask_1, true
     297             :     //  br i1 %to_store, label %cond.store, label %else
     298             :     //
     299             :     Value *Predicate =
     300           4 :         Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
     301           8 :     Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
     302           4 :                                     ConstantInt::get(Predicate->getType(), 1));
     303             : 
     304             :     // Create "cond" block
     305             :     //
     306             :     //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
     307             :     //  %EltAddr = getelementptr i32* %1, i32 0
     308             :     //  %store i32 %OneElt, i32* %EltAddr
     309             :     //
     310             :     BasicBlock *CondBlock =
     311           8 :         IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
     312           4 :     Builder.SetInsertPoint(InsertPt);
     313             : 
     314           4 :     Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
     315             :     Value *Gep =
     316           4 :         Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
     317           8 :     Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
     318             : 
     319             :     // Create "else" block, fill it in the next iteration
     320             :     BasicBlock *NewIfBlock =
     321           8 :         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
     322           4 :     Builder.SetInsertPoint(InsertPt);
     323           4 :     Instruction *OldBr = IfBlock->getTerminator();
     324           4 :     BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
     325           4 :     OldBr->eraseFromParent();
     326           4 :     IfBlock = NewIfBlock;
     327             :   }
     328           1 :   CI->eraseFromParent();
     329             : }
     330             : 
     331             : // Translate a masked gather intrinsic like
     332             : // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
     333             : //                               <16 x i1> %Mask, <16 x i32> %Src)
     334             : // to a chain of basic blocks, with loading element one-by-one if
     335             : // the appropriate mask bit is set
     336             : //
     337             : // % Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
     338             : // % Mask0 = extractelement <16 x i1> %Mask, i32 0
     339             : // % ToLoad0 = icmp eq i1 % Mask0, true
     340             : // br i1 % ToLoad0, label %cond.load, label %else
     341             : //
     342             : // cond.load:
     343             : // % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
     344             : // % Load0 = load i32, i32* % Ptr0, align 4
     345             : // % Res0 = insertelement <16 x i32> undef, i32 % Load0, i32 0
     346             : // br label %else
     347             : //
     348             : // else:
     349             : // %res.phi.else = phi <16 x i32>[% Res0, %cond.load], [undef, % 0]
     350             : // % Mask1 = extractelement <16 x i1> %Mask, i32 1
     351             : // % ToLoad1 = icmp eq i1 % Mask1, true
     352             : // br i1 % ToLoad1, label %cond.load1, label %else2
     353             : //
     354             : // cond.load1:
     355             : // % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
     356             : // % Load1 = load i32, i32* % Ptr1, align 4
     357             : // % Res1 = insertelement <16 x i32> %res.phi.else, i32 % Load1, i32 1
     358             : // br label %else2
     359             : // . . .
     360             : // % Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
     361             : // ret <16 x i32> %Result
     362          42 : static void scalarizeMaskedGather(CallInst *CI) {
     363          42 :   Value *Ptrs = CI->getArgOperand(0);
     364          42 :   Value *Alignment = CI->getArgOperand(1);
     365          42 :   Value *Mask = CI->getArgOperand(2);
     366          42 :   Value *Src0 = CI->getArgOperand(3);
     367             : 
     368          84 :   VectorType *VecType = dyn_cast<VectorType>(CI->getType());
     369             : 
     370             :   assert(VecType && "Unexpected return type of masked load intrinsic");
     371             : 
     372         113 :   IRBuilder<> Builder(CI->getContext());
     373          42 :   Instruction *InsertPt = CI;
     374          42 :   BasicBlock *IfBlock = CI->getParent();
     375          42 :   BasicBlock *CondBlock = nullptr;
     376          42 :   BasicBlock *PrevIfBlock = CI->getParent();
     377          42 :   Builder.SetInsertPoint(InsertPt);
     378          84 :   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
     379             : 
     380         168 :   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
     381             : 
     382          42 :   Value *UndefVal = UndefValue::get(VecType);
     383             : 
     384             :   // The result vector
     385          42 :   Value *VResult = UndefVal;
     386          42 :   unsigned VectorWidth = VecType->getNumElements();
     387             : 
     388             :   // Shorten the way if the mask is a vector of constants.
     389          84 :   bool IsConstMask = isa<ConstantVector>(Mask);
     390             : 
     391          42 :   if (IsConstMask) {
     392         297 :     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     393         426 :       if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
     394          13 :         continue;
     395         129 :       Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
     396         516 :                                                 "Ptr" + Twine(Idx));
     397             :       LoadInst *Load =
     398         516 :           Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
     399         129 :       VResult = Builder.CreateInsertElement(
     400         516 :           VResult, Load, Builder.getInt32(Idx), "Res" + Twine(Idx));
     401             :     }
     402          13 :     Value *NewI = Builder.CreateSelect(Mask, VResult, Src0);
     403          13 :     CI->replaceAllUsesWith(NewI);
     404          13 :     CI->eraseFromParent();
     405          13 :     return;
     406             :   }
     407             : 
     408             :   PHINode *Phi = nullptr;
     409             :   Value *PrevPhi = UndefVal;
     410             : 
     411         485 :   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     412             : 
     413             :     // Fill the "else" block, created in the previous iteration
     414             :     //
     415             :     //  %Mask1 = extractelement <16 x i1> %Mask, i32 1
     416             :     //  %ToLoad1 = icmp eq i1 %Mask1, true
     417             :     //  br i1 %ToLoad1, label %cond.load, label %else
     418             :     //
     419         228 :     if (Idx > 0) {
     420         199 :       Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
     421         199 :       Phi->addIncoming(VResult, CondBlock);
     422         199 :       Phi->addIncoming(PrevPhi, PrevIfBlock);
     423         199 :       PrevPhi = Phi;
     424         199 :       VResult = Phi;
     425             :     }
     426             : 
     427         228 :     Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
     428         912 :                                                     "Mask" + Twine(Idx));
     429             :     Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
     430         228 :                                     ConstantInt::get(Predicate->getType(), 1),
     431         912 :                                     "ToLoad" + Twine(Idx));
     432             : 
     433             :     // Create "cond" block
     434             :     //
     435             :     //  %EltAddr = getelementptr i32* %1, i32 0
     436             :     //  %Elt = load i32* %EltAddr
     437             :     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
     438             :     //
     439         456 :     CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
     440         228 :     Builder.SetInsertPoint(InsertPt);
     441             : 
     442         228 :     Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
     443         912 :                                               "Ptr" + Twine(Idx));
     444             :     LoadInst *Load =
     445         912 :         Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
     446         228 :     VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx),
     447         912 :                                           "Res" + Twine(Idx));
     448             : 
     449             :     // Create "else" block, fill it in the next iteration
     450         456 :     BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
     451         228 :     Builder.SetInsertPoint(InsertPt);
     452         228 :     Instruction *OldBr = IfBlock->getTerminator();
     453         228 :     BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
     454         228 :     OldBr->eraseFromParent();
     455         228 :     PrevIfBlock = IfBlock;
     456         228 :     IfBlock = NewIfBlock;
     457             :   }
     458             : 
     459          29 :   Phi = Builder.CreatePHI(VecType, 2, "res.phi.select");
     460          29 :   Phi->addIncoming(VResult, CondBlock);
     461          29 :   Phi->addIncoming(PrevPhi, PrevIfBlock);
     462          29 :   Value *NewI = Builder.CreateSelect(Mask, Phi, Src0);
     463          29 :   CI->replaceAllUsesWith(NewI);
     464          29 :   CI->eraseFromParent();
     465             : }
     466             : 
     467             : // Translate a masked scatter intrinsic, like
     468             : // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
     469             : //                                  <16 x i1> %Mask)
     470             : // to a chain of basic blocks, that stores element one-by-one if
     471             : // the appropriate mask bit is set.
     472             : //
     473             : // % Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
     474             : // % Mask0 = extractelement <16 x i1> % Mask, i32 0
     475             : // % ToStore0 = icmp eq i1 % Mask0, true
     476             : // br i1 %ToStore0, label %cond.store, label %else
     477             : //
     478             : // cond.store:
     479             : // % Elt0 = extractelement <16 x i32> %Src, i32 0
     480             : // % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
     481             : // store i32 %Elt0, i32* % Ptr0, align 4
     482             : // br label %else
     483             : //
     484             : // else:
     485             : // % Mask1 = extractelement <16 x i1> % Mask, i32 1
     486             : // % ToStore1 = icmp eq i1 % Mask1, true
     487             : // br i1 % ToStore1, label %cond.store1, label %else2
     488             : //
     489             : // cond.store1:
     490             : // % Elt1 = extractelement <16 x i32> %Src, i32 1
     491             : // % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
     492             : // store i32 % Elt1, i32* % Ptr1, align 4
     493             : // br label %else2
     494             : //   . . .
     495          13 : static void scalarizeMaskedScatter(CallInst *CI) {
     496          13 :   Value *Src = CI->getArgOperand(0);
     497          13 :   Value *Ptrs = CI->getArgOperand(1);
     498          13 :   Value *Alignment = CI->getArgOperand(2);
     499          13 :   Value *Mask = CI->getArgOperand(3);
     500             : 
     501             :   assert(isa<VectorType>(Src->getType()) &&
     502             :          "Unexpected data type in masked scatter intrinsic");
     503             :   assert(isa<VectorType>(Ptrs->getType()) &&
     504             :          isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
     505             :          "Vector of pointers is expected in masked scatter intrinsic");
     506             : 
     507          37 :   IRBuilder<> Builder(CI->getContext());
     508          13 :   Instruction *InsertPt = CI;
     509          13 :   BasicBlock *IfBlock = CI->getParent();
     510          13 :   Builder.SetInsertPoint(InsertPt);
     511          52 :   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
     512             : 
     513          26 :   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
     514          26 :   unsigned VectorWidth = Src->getType()->getVectorNumElements();
     515             : 
     516             :   // Shorten the way if the mask is a vector of constants.
     517          26 :   bool IsConstMask = isa<ConstantVector>(Mask);
     518             : 
     519          13 :   if (IsConstMask) {
     520          22 :     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     521          30 :       if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
     522           0 :         continue;
     523          10 :       Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
     524          40 :                                                    "Elt" + Twine(Idx));
     525          10 :       Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
     526          40 :                                                 "Ptr" + Twine(Idx));
     527             :       Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
     528             :     }
     529           2 :     CI->eraseFromParent();
     530           2 :     return;
     531             :   }
     532         235 :   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     533             :     // Fill the "else" block, created in the previous iteration
     534             :     //
     535             :     //  % Mask1 = extractelement <16 x i1> % Mask, i32 Idx
     536             :     //  % ToStore = icmp eq i1 % Mask1, true
     537             :     //  br i1 % ToStore, label %cond.store, label %else
     538             :     //
     539         112 :     Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
     540         448 :                                                     "Mask" + Twine(Idx));
     541             :     Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
     542         112 :                                     ConstantInt::get(Predicate->getType(), 1),
     543         448 :                                     "ToStore" + Twine(Idx));
     544             : 
     545             :     // Create "cond" block
     546             :     //
     547             :     //  % Elt1 = extractelement <16 x i32> %Src, i32 1
     548             :     //  % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
     549             :     //  %store i32 % Elt1, i32* % Ptr1
     550             :     //
     551         224 :     BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
     552         112 :     Builder.SetInsertPoint(InsertPt);
     553             : 
     554         112 :     Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
     555         448 :                                                  "Elt" + Twine(Idx));
     556         112 :     Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
     557         448 :                                               "Ptr" + Twine(Idx));
     558         112 :     Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
     559             : 
     560             :     // Create "else" block, fill it in the next iteration
     561         224 :     BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
     562         112 :     Builder.SetInsertPoint(InsertPt);
     563         112 :     Instruction *OldBr = IfBlock->getTerminator();
     564         112 :     BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
     565         112 :     OldBr->eraseFromParent();
     566         112 :     IfBlock = NewIfBlock;
     567             :   }
     568          11 :   CI->eraseFromParent();
     569             : }
     570             : 
     571      143217 : bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
     572      143217 :   if (skipFunction(F))
     573             :     return false;
     574             : 
     575      142856 :   bool EverMadeChange = false;
     576             : 
     577      142856 :   TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
     578             : 
     579      142856 :   bool MadeChange = true;
     580      428682 :   while (MadeChange) {
     581      142913 :     MadeChange = false;
     582      857042 :     for (Function::iterator I = F.begin(); I != F.end();) {
     583      856995 :       BasicBlock *BB = &*I++;
     584      285665 :       bool ModifiedDTOnIteration = false;
     585      285665 :       MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
     586             : 
     587             :       // Restart BB iteration if the dominator tree of the Function was changed
     588      285665 :       if (ModifiedDTOnIteration)
     589             :         break;
     590             :     }
     591             : 
     592      142913 :     EverMadeChange |= MadeChange;
     593             :   }
     594             : 
     595             :   return EverMadeChange;
     596             : }
     597             : 
     598      285665 : bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
     599      285665 :   bool MadeChange = false;
     600             : 
     601      285665 :   BasicBlock::iterator CurInstIterator = BB.begin();
     602     2212913 :   while (CurInstIterator != BB.end()) {
     603     6083565 :     if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
     604      301650 :       MadeChange |= optimizeCallInst(CI, ModifiedDT);
     605     1927305 :     if (ModifiedDT)
     606             :       return true;
     607             :   }
     608             : 
     609             :   return MadeChange;
     610             : }
     611             : 
     612      301650 : bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
     613             :                                                 bool &ModifiedDT) {
     614             : 
     615      138302 :   IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
     616             :   if (II) {
     617      138302 :     switch (II->getIntrinsicID()) {
     618             :     default:
     619             :       break;
     620         212 :     case Intrinsic::masked_load: {
     621             :       // Scalarize unsupported vector masked load
     622         212 :       if (!TTI->isLegalMaskedLoad(CI->getType())) {
     623           1 :         scalarizeMaskedLoad(CI);
     624           1 :         ModifiedDT = true;
     625             :         return true;
     626             :       }
     627             :       return false;
     628             :     }
     629         101 :     case Intrinsic::masked_store: {
     630         202 :       if (!TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType())) {
     631           1 :         scalarizeMaskedStore(CI);
     632           1 :         ModifiedDT = true;
     633             :         return true;
     634             :       }
     635             :       return false;
     636             :     }
     637         217 :     case Intrinsic::masked_gather: {
     638         217 :       if (!TTI->isLegalMaskedGather(CI->getType())) {
     639          42 :         scalarizeMaskedGather(CI);
     640          42 :         ModifiedDT = true;
     641             :         return true;
     642             :       }
     643             :       return false;
     644             :     }
     645          74 :     case Intrinsic::masked_scatter: {
     646         148 :       if (!TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType())) {
     647          13 :         scalarizeMaskedScatter(CI);
     648          13 :         ModifiedDT = true;
     649             :         return true;
     650             :       }
     651             :       return false;
     652             :     }
     653             :     }
     654             :   }
     655             : 
     656             :   return false;
     657             : }

Generated by: LCOV version 1.13