LCOV - code coverage report
Current view: top level - lib/CodeGen - ScalarizeMaskedMemIntrin.cpp (source / functions) Hit Total Coverage
Test: llvm-toolchain.info Lines: 189 214 88.3 %
Date: 2018-02-23 15:42:53 Functions: 14 15 93.3 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
       2             : //                                    instrinsics
       3             : //
       4             : //                     The LLVM Compiler Infrastructure
       5             : //
       6             : // This file is distributed under the University of Illinois Open Source
       7             : // License. See LICENSE.TXT for details.
       8             : //
       9             : //===----------------------------------------------------------------------===//
      10             : //
      11             : // This pass replaces masked memory intrinsics - when unsupported by the target
      12             : // - with a chain of basic blocks, that deal with the elements one-by-one if the
      13             : // appropriate mask bit is set.
      14             : //
      15             : //===----------------------------------------------------------------------===//
      16             : 
      17             : #include "llvm/ADT/Twine.h"
      18             : #include "llvm/Analysis/TargetTransformInfo.h"
      19             : #include "llvm/CodeGen/TargetSubtargetInfo.h"
      20             : #include "llvm/IR/BasicBlock.h"
      21             : #include "llvm/IR/Constant.h"
      22             : #include "llvm/IR/Constants.h"
      23             : #include "llvm/IR/DerivedTypes.h"
      24             : #include "llvm/IR/Function.h"
      25             : #include "llvm/IR/IRBuilder.h"
      26             : #include "llvm/IR/InstrTypes.h"
      27             : #include "llvm/IR/Instruction.h"
      28             : #include "llvm/IR/Instructions.h"
      29             : #include "llvm/IR/IntrinsicInst.h"
      30             : #include "llvm/IR/Intrinsics.h"
      31             : #include "llvm/IR/Type.h"
      32             : #include "llvm/IR/Value.h"
      33             : #include "llvm/Pass.h"
      34             : #include "llvm/Support/Casting.h"
      35             : #include <algorithm>
      36             : #include <cassert>
      37             : 
      38             : using namespace llvm;
      39             : 
      40             : #define DEBUG_TYPE "scalarize-masked-mem-intrin"
      41             : 
      42             : namespace {
      43             : 
      44       37396 : class ScalarizeMaskedMemIntrin : public FunctionPass {
      45             :   const TargetTransformInfo *TTI = nullptr;
      46             : 
      47             : public:
      48             :   static char ID; // Pass identification, replacement for typeid
      49             : 
      50       37610 :   explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
      51       18805 :     initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
      52       18805 :   }
      53             : 
      54             :   bool runOnFunction(Function &F) override;
      55             : 
      56          25 :   StringRef getPassName() const override {
      57          25 :     return "Scalarize Masked Memory Intrinsics";
      58             :   }
      59             : 
      60       18720 :   void getAnalysisUsage(AnalysisUsage &AU) const override {
      61             :     AU.addRequired<TargetTransformInfoWrapperPass>();
      62       18720 :   }
      63             : 
      64             : private:
      65             :   bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
      66             :   bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
      67             : };
      68             : 
      69             : } // end anonymous namespace
      70             : 
      71             : char ScalarizeMaskedMemIntrin::ID = 0;
      72             : 
      73      233635 : INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
      74             :                 "Scalarize unsupported masked memory intrinsics", false, false)
      75             : 
      76       18804 : FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
      77       18804 :   return new ScalarizeMaskedMemIntrin();
      78             : }
      79             : 
      80             : // Translate a masked load intrinsic like
      81             : // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
      82             : //                               <16 x i1> %mask, <16 x i32> %passthru)
      83             : // to a chain of basic blocks, with loading element one-by-one if
      84             : // the appropriate mask bit is set
      85             : //
      86             : //  %1 = bitcast i8* %addr to i32*
      87             : //  %2 = extractelement <16 x i1> %mask, i32 0
      88             : //  %3 = icmp eq i1 %2, true
      89             : //  br i1 %3, label %cond.load, label %else
      90             : //
      91             : // cond.load:                                        ; preds = %0
      92             : //  %4 = getelementptr i32* %1, i32 0
      93             : //  %5 = load i32* %4
      94             : //  %6 = insertelement <16 x i32> undef, i32 %5, i32 0
      95             : //  br label %else
      96             : //
      97             : // else:                                             ; preds = %0, %cond.load
      98             : //  %res.phi.else = phi <16 x i32> [ %6, %cond.load ], [ undef, %0 ]
      99             : //  %7 = extractelement <16 x i1> %mask, i32 1
     100             : //  %8 = icmp eq i1 %7, true
     101             : //  br i1 %8, label %cond.load1, label %else2
     102             : //
     103             : // cond.load1:                                       ; preds = %else
     104             : //  %9 = getelementptr i32* %1, i32 1
     105             : //  %10 = load i32* %9
     106             : //  %11 = insertelement <16 x i32> %res.phi.else, i32 %10, i32 1
     107             : //  br label %else2
     108             : //
     109             : // else2:                                          ; preds = %else, %cond.load1
     110             : //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
     111             : //  %12 = extractelement <16 x i1> %mask, i32 2
     112             : //  %13 = icmp eq i1 %12, true
     113             : //  br i1 %13, label %cond.load4, label %else5
     114             : //
     115           5 : static void scalarizeMaskedLoad(CallInst *CI) {
     116           5 :   Value *Ptr = CI->getArgOperand(0);
     117             :   Value *Alignment = CI->getArgOperand(1);
     118             :   Value *Mask = CI->getArgOperand(2);
     119             :   Value *Src0 = CI->getArgOperand(3);
     120             : 
     121           5 :   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
     122           5 :   VectorType *VecType = dyn_cast<VectorType>(CI->getType());
     123             :   assert(VecType && "Unexpected return type of masked load intrinsic");
     124             : 
     125           5 :   Type *EltTy = CI->getType()->getVectorElementType();
     126             : 
     127           5 :   IRBuilder<> Builder(CI->getContext());
     128             :   Instruction *InsertPt = CI;
     129           5 :   BasicBlock *IfBlock = CI->getParent();
     130             :   BasicBlock *CondBlock = nullptr;
     131             :   BasicBlock *PrevIfBlock = CI->getParent();
     132             : 
     133           5 :   Builder.SetInsertPoint(InsertPt);
     134           5 :   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
     135             : 
     136             :   // Short-cut if the mask is all-true.
     137             :   bool IsAllOnesMask =
     138           5 :       isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue();
     139             : 
     140             :   if (IsAllOnesMask) {
     141           0 :     Value *NewI = Builder.CreateAlignedLoad(Ptr, AlignVal);
     142           0 :     CI->replaceAllUsesWith(NewI);
     143           0 :     CI->eraseFromParent();
     144           0 :     return;
     145             :   }
     146             : 
     147             :   // Adjust alignment for the scalar instruction.
     148          10 :   AlignVal = std::min(AlignVal, VecType->getScalarSizeInBits() / 8);
     149             :   // Bitcast %addr fron i8* to EltTy*
     150             :   Type *NewPtrType =
     151          10 :       EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
     152           5 :   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
     153           5 :   unsigned VectorWidth = VecType->getNumElements();
     154             : 
     155           5 :   Value *UndefVal = UndefValue::get(VecType);
     156             : 
     157             :   // The result vector
     158             :   Value *VResult = UndefVal;
     159             : 
     160           5 :   if (isa<ConstantVector>(Mask)) {
     161           0 :     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     162           0 :       if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
     163           0 :         continue;
     164             :       Value *Gep =
     165           0 :           Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
     166           0 :       LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
     167           0 :       VResult =
     168           0 :           Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
     169             :     }
     170           0 :     Value *NewI = Builder.CreateSelect(Mask, VResult, Src0);
     171           0 :     CI->replaceAllUsesWith(NewI);
     172           0 :     CI->eraseFromParent();
     173           0 :     return;
     174             :   }
     175             : 
     176             :   PHINode *Phi = nullptr;
     177             :   Value *PrevPhi = UndefVal;
     178             : 
     179          21 :   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     180             :     // Fill the "else" block, created in the previous iteration
     181             :     //
     182             :     //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
     183             :     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
     184             :     //  %to_load = icmp eq i1 %mask_1, true
     185             :     //  br i1 %to_load, label %cond.load, label %else
     186             :     //
     187           8 :     if (Idx > 0) {
     188           3 :       Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
     189           3 :       Phi->addIncoming(VResult, CondBlock);
     190           3 :       Phi->addIncoming(PrevPhi, PrevIfBlock);
     191             :       PrevPhi = Phi;
     192             :       VResult = Phi;
     193             :     }
     194             : 
     195             :     Value *Predicate =
     196           8 :         Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
     197           8 :     Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
     198           8 :                                     ConstantInt::get(Predicate->getType(), 1));
     199             : 
     200             :     // Create "cond" block
     201             :     //
     202             :     //  %EltAddr = getelementptr i32* %1, i32 0
     203             :     //  %Elt = load i32* %EltAddr
     204             :     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
     205             :     //
     206           8 :     CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.load");
     207           8 :     Builder.SetInsertPoint(InsertPt);
     208             : 
     209             :     Value *Gep =
     210           8 :         Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
     211           8 :     LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
     212           8 :     VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
     213             : 
     214             :     // Create "else" block, fill it in the next iteration
     215             :     BasicBlock *NewIfBlock =
     216           8 :         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
     217           8 :     Builder.SetInsertPoint(InsertPt);
     218             :     Instruction *OldBr = IfBlock->getTerminator();
     219           8 :     BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
     220           8 :     OldBr->eraseFromParent();
     221             :     PrevIfBlock = IfBlock;
     222             :     IfBlock = NewIfBlock;
     223             :   }
     224             : 
     225           5 :   Phi = Builder.CreatePHI(VecType, 2, "res.phi.select");
     226           5 :   Phi->addIncoming(VResult, CondBlock);
     227           5 :   Phi->addIncoming(PrevPhi, PrevIfBlock);
     228           5 :   Value *NewI = Builder.CreateSelect(Mask, Phi, Src0);
     229           5 :   CI->replaceAllUsesWith(NewI);
     230           5 :   CI->eraseFromParent();
     231             : }
     232             : 
     233             : // Translate a masked store intrinsic, like
     234             : // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
     235             : //                               <16 x i1> %mask)
     236             : // to a chain of basic blocks, that stores element one-by-one if
     237             : // the appropriate mask bit is set
     238             : //
     239             : //   %1 = bitcast i8* %addr to i32*
     240             : //   %2 = extractelement <16 x i1> %mask, i32 0
     241             : //   %3 = icmp eq i1 %2, true
     242             : //   br i1 %3, label %cond.store, label %else
     243             : //
     244             : // cond.store:                                       ; preds = %0
     245             : //   %4 = extractelement <16 x i32> %val, i32 0
     246             : //   %5 = getelementptr i32* %1, i32 0
     247             : //   store i32 %4, i32* %5
     248             : //   br label %else
     249             : //
     250             : // else:                                             ; preds = %0, %cond.store
     251             : //   %6 = extractelement <16 x i1> %mask, i32 1
     252             : //   %7 = icmp eq i1 %6, true
     253             : //   br i1 %7, label %cond.store1, label %else2
     254             : //
     255             : // cond.store1:                                      ; preds = %else
     256             : //   %8 = extractelement <16 x i32> %val, i32 1
     257             : //   %9 = getelementptr i32* %1, i32 1
     258             : //   store i32 %8, i32* %9
     259             : //   br label %else2
     260             : //   . . .
     261           5 : static void scalarizeMaskedStore(CallInst *CI) {
     262           5 :   Value *Src = CI->getArgOperand(0);
     263             :   Value *Ptr = CI->getArgOperand(1);
     264             :   Value *Alignment = CI->getArgOperand(2);
     265             :   Value *Mask = CI->getArgOperand(3);
     266             : 
     267           5 :   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
     268           5 :   VectorType *VecType = dyn_cast<VectorType>(Src->getType());
     269             :   assert(VecType && "Unexpected data type in masked store intrinsic");
     270             : 
     271           5 :   Type *EltTy = VecType->getElementType();
     272             : 
     273           5 :   IRBuilder<> Builder(CI->getContext());
     274             :   Instruction *InsertPt = CI;
     275           5 :   BasicBlock *IfBlock = CI->getParent();
     276           5 :   Builder.SetInsertPoint(InsertPt);
     277           5 :   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
     278             : 
     279             :   // Short-cut if the mask is all-true.
     280             :   bool IsAllOnesMask =
     281           5 :       isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue();
     282             : 
     283             :   if (IsAllOnesMask) {
     284             :     Builder.CreateAlignedStore(Src, Ptr, AlignVal);
     285           0 :     CI->eraseFromParent();
     286           0 :     return;
     287             :   }
     288             : 
     289             :   // Adjust alignment for the scalar instruction.
     290          10 :   AlignVal = std::max(AlignVal, VecType->getScalarSizeInBits() / 8);
     291             :   // Bitcast %addr fron i8* to EltTy*
     292             :   Type *NewPtrType =
     293          10 :       EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
     294           5 :   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
     295           5 :   unsigned VectorWidth = VecType->getNumElements();
     296             : 
     297           5 :   if (isa<ConstantVector>(Mask)) {
     298           0 :     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     299           0 :       if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
     300           0 :         continue;
     301           0 :       Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
     302             :       Value *Gep =
     303           0 :           Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
     304             :       Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
     305             :     }
     306           0 :     CI->eraseFromParent();
     307           0 :     return;
     308             :   }
     309             : 
     310          21 :   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     311             :     // Fill the "else" block, created in the previous iteration
     312             :     //
     313             :     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
     314             :     //  %to_store = icmp eq i1 %mask_1, true
     315             :     //  br i1 %to_store, label %cond.store, label %else
     316             :     //
     317             :     Value *Predicate =
     318           8 :         Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
     319           8 :     Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
     320           8 :                                     ConstantInt::get(Predicate->getType(), 1));
     321             : 
     322             :     // Create "cond" block
     323             :     //
     324             :     //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
     325             :     //  %EltAddr = getelementptr i32* %1, i32 0
     326             :     //  %store i32 %OneElt, i32* %EltAddr
     327             :     //
     328             :     BasicBlock *CondBlock =
     329           8 :         IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
     330           8 :     Builder.SetInsertPoint(InsertPt);
     331             : 
     332           8 :     Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
     333             :     Value *Gep =
     334           8 :         Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
     335             :     Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
     336             : 
     337             :     // Create "else" block, fill it in the next iteration
     338             :     BasicBlock *NewIfBlock =
     339           8 :         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
     340           8 :     Builder.SetInsertPoint(InsertPt);
     341             :     Instruction *OldBr = IfBlock->getTerminator();
     342           8 :     BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
     343           8 :     OldBr->eraseFromParent();
     344             :     IfBlock = NewIfBlock;
     345             :   }
     346           5 :   CI->eraseFromParent();
     347             : }
     348             : 
     349             : // Translate a masked gather intrinsic like
     350             : // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
     351             : //                               <16 x i1> %Mask, <16 x i32> %Src)
     352             : // to a chain of basic blocks, with loading element one-by-one if
     353             : // the appropriate mask bit is set
     354             : //
     355             : // % Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
     356             : // % Mask0 = extractelement <16 x i1> %Mask, i32 0
     357             : // % ToLoad0 = icmp eq i1 % Mask0, true
     358             : // br i1 % ToLoad0, label %cond.load, label %else
     359             : //
     360             : // cond.load:
     361             : // % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
     362             : // % Load0 = load i32, i32* % Ptr0, align 4
     363             : // % Res0 = insertelement <16 x i32> undef, i32 % Load0, i32 0
     364             : // br label %else
     365             : //
     366             : // else:
     367             : // %res.phi.else = phi <16 x i32>[% Res0, %cond.load], [undef, % 0]
     368             : // % Mask1 = extractelement <16 x i1> %Mask, i32 1
     369             : // % ToLoad1 = icmp eq i1 % Mask1, true
     370             : // br i1 % ToLoad1, label %cond.load1, label %else2
     371             : //
     372             : // cond.load1:
     373             : // % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
     374             : // % Load1 = load i32, i32* % Ptr1, align 4
     375             : // % Res1 = insertelement <16 x i32> %res.phi.else, i32 % Load1, i32 1
     376             : // br label %else2
     377             : // . . .
     378             : // % Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
     379             : // ret <16 x i32> %Result
     380          69 : static void scalarizeMaskedGather(CallInst *CI) {
     381          69 :   Value *Ptrs = CI->getArgOperand(0);
     382             :   Value *Alignment = CI->getArgOperand(1);
     383             :   Value *Mask = CI->getArgOperand(2);
     384             :   Value *Src0 = CI->getArgOperand(3);
     385             : 
     386          69 :   VectorType *VecType = dyn_cast<VectorType>(CI->getType());
     387             : 
     388             :   assert(VecType && "Unexpected return type of masked load intrinsic");
     389             : 
     390          69 :   IRBuilder<> Builder(CI->getContext());
     391             :   Instruction *InsertPt = CI;
     392          69 :   BasicBlock *IfBlock = CI->getParent();
     393             :   BasicBlock *CondBlock = nullptr;
     394             :   BasicBlock *PrevIfBlock = CI->getParent();
     395          69 :   Builder.SetInsertPoint(InsertPt);
     396          69 :   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
     397             : 
     398          69 :   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
     399             : 
     400          69 :   Value *UndefVal = UndefValue::get(VecType);
     401             : 
     402             :   // The result vector
     403             :   Value *VResult = UndefVal;
     404          69 :   unsigned VectorWidth = VecType->getNumElements();
     405             : 
     406             :   // Shorten the way if the mask is a vector of constants.
     407             :   bool IsConstMask = isa<ConstantVector>(Mask);
     408             : 
     409          69 :   if (IsConstMask) {
     410         418 :     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     411         394 :       if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
     412          13 :         continue;
     413         184 :       Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
     414         184 :                                                 "Ptr" + Twine(Idx));
     415             :       LoadInst *Load =
     416         184 :           Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
     417         184 :       VResult = Builder.CreateInsertElement(
     418         368 :           VResult, Load, Builder.getInt32(Idx), "Res" + Twine(Idx));
     419             :     }
     420          24 :     Value *NewI = Builder.CreateSelect(Mask, VResult, Src0);
     421          24 :     CI->replaceAllUsesWith(NewI);
     422          24 :     CI->eraseFromParent();
     423             :     return;
     424             :   }
     425             : 
     426             :   PHINode *Phi = nullptr;
     427             :   Value *PrevPhi = UndefVal;
     428             : 
     429         635 :   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     430             :     // Fill the "else" block, created in the previous iteration
     431             :     //
     432             :     //  %Mask1 = extractelement <16 x i1> %Mask, i32 1
     433             :     //  %ToLoad1 = icmp eq i1 %Mask1, true
     434             :     //  br i1 %ToLoad1, label %cond.load, label %else
     435             :     //
     436         295 :     if (Idx > 0) {
     437         250 :       Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
     438         250 :       Phi->addIncoming(VResult, CondBlock);
     439         250 :       Phi->addIncoming(PrevPhi, PrevIfBlock);
     440             :       PrevPhi = Phi;
     441             :       VResult = Phi;
     442             :     }
     443             : 
     444         295 :     Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
     445         295 :                                                     "Mask" + Twine(Idx));
     446             :     Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
     447         295 :                                     ConstantInt::get(Predicate->getType(), 1),
     448         295 :                                     "ToLoad" + Twine(Idx));
     449             : 
     450             :     // Create "cond" block
     451             :     //
     452             :     //  %EltAddr = getelementptr i32* %1, i32 0
     453             :     //  %Elt = load i32* %EltAddr
     454             :     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
     455             :     //
     456         295 :     CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
     457         295 :     Builder.SetInsertPoint(InsertPt);
     458             : 
     459         295 :     Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
     460         295 :                                               "Ptr" + Twine(Idx));
     461             :     LoadInst *Load =
     462         295 :         Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
     463         295 :     VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx),
     464         295 :                                           "Res" + Twine(Idx));
     465             : 
     466             :     // Create "else" block, fill it in the next iteration
     467         295 :     BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
     468         295 :     Builder.SetInsertPoint(InsertPt);
     469             :     Instruction *OldBr = IfBlock->getTerminator();
     470         295 :     BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
     471         295 :     OldBr->eraseFromParent();
     472             :     PrevIfBlock = IfBlock;
     473             :     IfBlock = NewIfBlock;
     474             :   }
     475             : 
     476          45 :   Phi = Builder.CreatePHI(VecType, 2, "res.phi.select");
     477          45 :   Phi->addIncoming(VResult, CondBlock);
     478          45 :   Phi->addIncoming(PrevPhi, PrevIfBlock);
     479          45 :   Value *NewI = Builder.CreateSelect(Mask, Phi, Src0);
     480          45 :   CI->replaceAllUsesWith(NewI);
     481          45 :   CI->eraseFromParent();
     482             : }
     483             : 
     484             : // Translate a masked scatter intrinsic, like
     485             : // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
     486             : //                                  <16 x i1> %Mask)
     487             : // to a chain of basic blocks, that stores element one-by-one if
     488             : // the appropriate mask bit is set.
     489             : //
     490             : // % Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
     491             : // % Mask0 = extractelement <16 x i1> % Mask, i32 0
     492             : // % ToStore0 = icmp eq i1 % Mask0, true
     493             : // br i1 %ToStore0, label %cond.store, label %else
     494             : //
     495             : // cond.store:
     496             : // % Elt0 = extractelement <16 x i32> %Src, i32 0
     497             : // % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
     498             : // store i32 %Elt0, i32* % Ptr0, align 4
     499             : // br label %else
     500             : //
     501             : // else:
     502             : // % Mask1 = extractelement <16 x i1> % Mask, i32 1
     503             : // % ToStore1 = icmp eq i1 % Mask1, true
     504             : // br i1 % ToStore1, label %cond.store1, label %else2
     505             : //
     506             : // cond.store1:
     507             : // % Elt1 = extractelement <16 x i32> %Src, i32 1
     508             : // % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
     509             : // store i32 % Elt1, i32* % Ptr1, align 4
     510             : // br label %else2
     511             : //   . . .
     512          22 : static void scalarizeMaskedScatter(CallInst *CI) {
     513          22 :   Value *Src = CI->getArgOperand(0);
     514             :   Value *Ptrs = CI->getArgOperand(1);
     515             :   Value *Alignment = CI->getArgOperand(2);
     516             :   Value *Mask = CI->getArgOperand(3);
     517             : 
     518             :   assert(isa<VectorType>(Src->getType()) &&
     519             :          "Unexpected data type in masked scatter intrinsic");
     520             :   assert(isa<VectorType>(Ptrs->getType()) &&
     521             :          isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
     522             :          "Vector of pointers is expected in masked scatter intrinsic");
     523             : 
     524          22 :   IRBuilder<> Builder(CI->getContext());
     525             :   Instruction *InsertPt = CI;
     526          22 :   BasicBlock *IfBlock = CI->getParent();
     527          22 :   Builder.SetInsertPoint(InsertPt);
     528          22 :   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
     529             : 
     530          22 :   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
     531          22 :   unsigned VectorWidth = Src->getType()->getVectorNumElements();
     532             : 
     533             :   // Shorten the way if the mask is a vector of constants.
     534             :   bool IsConstMask = isa<ConstantVector>(Mask);
     535             : 
     536          22 :   if (IsConstMask) {
     537          22 :     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     538          20 :       if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
     539           0 :         continue;
     540          10 :       Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
     541          10 :                                                    "Elt" + Twine(Idx));
     542          10 :       Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
     543          10 :                                                 "Ptr" + Twine(Idx));
     544             :       Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
     545             :     }
     546           2 :     CI->eraseFromParent();
     547             :     return;
     548             :   }
     549         294 :   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     550             :     // Fill the "else" block, created in the previous iteration
     551             :     //
     552             :     //  % Mask1 = extractelement <16 x i1> % Mask, i32 Idx
     553             :     //  % ToStore = icmp eq i1 % Mask1, true
     554             :     //  br i1 % ToStore, label %cond.store, label %else
     555             :     //
     556         137 :     Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
     557         137 :                                                     "Mask" + Twine(Idx));
     558             :     Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
     559         137 :                                     ConstantInt::get(Predicate->getType(), 1),
     560         137 :                                     "ToStore" + Twine(Idx));
     561             : 
     562             :     // Create "cond" block
     563             :     //
     564             :     //  % Elt1 = extractelement <16 x i32> %Src, i32 1
     565             :     //  % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
     566             :     //  %store i32 % Elt1, i32* % Ptr1
     567             :     //
     568         137 :     BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
     569         137 :     Builder.SetInsertPoint(InsertPt);
     570             : 
     571         137 :     Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
     572         137 :                                                  "Elt" + Twine(Idx));
     573         137 :     Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
     574         137 :                                               "Ptr" + Twine(Idx));
     575             :     Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
     576             : 
     577             :     // Create "else" block, fill it in the next iteration
     578         137 :     BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
     579         137 :     Builder.SetInsertPoint(InsertPt);
     580             :     Instruction *OldBr = IfBlock->getTerminator();
     581         137 :     BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
     582         137 :     OldBr->eraseFromParent();
     583             :     IfBlock = NewIfBlock;
     584             :   }
     585          20 :   CI->eraseFromParent();
     586             : }
     587             : 
     588      167599 : bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
     589      167599 :   if (skipFunction(F))
     590             :     return false;
     591             : 
     592             :   bool EverMadeChange = false;
     593             : 
     594      167134 :   TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
     595             : 
     596             :   bool MadeChange = true;
     597      501604 :   while (MadeChange) {
     598             :     MadeChange = false;
     599      799061 :     for (Function::iterator I = F.begin(); I != F.end();) {
     600             :       BasicBlock *BB = &*I++;
     601      316014 :       bool ModifiedDTOnIteration = false;
     602      316014 :       MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
     603             : 
     604             :       // Restart BB iteration if the dominator tree of the Function was changed
     605      316014 :       if (ModifiedDTOnIteration)
     606             :         break;
     607             :     }
     608             : 
     609      167235 :     EverMadeChange |= MadeChange;
     610             :   }
     611             : 
     612             :   return EverMadeChange;
     613             : }
     614             : 
     615      316014 : bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
     616             :   bool MadeChange = false;
     617             : 
     618             :   BasicBlock::iterator CurInstIterator = BB.begin();
     619     2415192 :   while (CurInstIterator != BB.end()) {
     620             :     if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
     621      329330 :       MadeChange |= optimizeCallInst(CI, ModifiedDT);
     622     2099279 :     if (ModifiedDT)
     623             :       return true;
     624             :   }
     625             : 
     626             :   return MadeChange;
     627             : }
     628             : 
     629      329330 : bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
     630             :                                                 bool &ModifiedDT) {
     631             :   IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
     632             :   if (II) {
     633      157505 :     switch (II->getIntrinsicID()) {
     634             :     default:
     635             :       break;
     636         227 :     case Intrinsic::masked_load:
     637             :       // Scalarize unsupported vector masked load
     638         227 :       if (!TTI->isLegalMaskedLoad(CI->getType())) {
     639           5 :         scalarizeMaskedLoad(CI);
     640           5 :         ModifiedDT = true;
     641             :         return true;
     642             :       }
     643             :       return false;
     644         107 :     case Intrinsic::masked_store:
     645         214 :       if (!TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType())) {
     646           5 :         scalarizeMaskedStore(CI);
     647           5 :         ModifiedDT = true;
     648             :         return true;
     649             :       }
     650             :       return false;
     651         357 :     case Intrinsic::masked_gather:
     652         357 :       if (!TTI->isLegalMaskedGather(CI->getType())) {
     653          69 :         scalarizeMaskedGather(CI);
     654          69 :         ModifiedDT = true;
     655             :         return true;
     656             :       }
     657             :       return false;
     658         107 :     case Intrinsic::masked_scatter:
     659         214 :       if (!TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType())) {
     660          22 :         scalarizeMaskedScatter(CI);
     661          22 :         ModifiedDT = true;
     662             :         return true;
     663             :       }
     664             :       return false;
     665             :     }
     666             :   }
     667             : 
     668             :   return false;
     669             : }

Generated by: LCOV version 1.13