LCOV - code coverage report
Current view: top level - lib/CodeGen - ScalarizeMaskedMemIntrin.cpp (source / functions) Hit Total Coverage
Test: llvm-toolchain.info Lines: 175 201 87.1 %
Date: 2018-10-20 13:21:21 Functions: 13 14 92.9 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
       2             : //                                    instrinsics
       3             : //
       4             : //                     The LLVM Compiler Infrastructure
       5             : //
       6             : // This file is distributed under the University of Illinois Open Source
       7             : // License. See LICENSE.TXT for details.
       8             : //
       9             : //===----------------------------------------------------------------------===//
      10             : //
      11             : // This pass replaces masked memory intrinsics - when unsupported by the target
      12             : // - with a chain of basic blocks, that deal with the elements one-by-one if the
      13             : // appropriate mask bit is set.
      14             : //
      15             : //===----------------------------------------------------------------------===//
      16             : 
      17             : #include "llvm/ADT/Twine.h"
      18             : #include "llvm/Analysis/TargetTransformInfo.h"
      19             : #include "llvm/CodeGen/TargetSubtargetInfo.h"
      20             : #include "llvm/IR/BasicBlock.h"
      21             : #include "llvm/IR/Constant.h"
      22             : #include "llvm/IR/Constants.h"
      23             : #include "llvm/IR/DerivedTypes.h"
      24             : #include "llvm/IR/Function.h"
      25             : #include "llvm/IR/IRBuilder.h"
      26             : #include "llvm/IR/InstrTypes.h"
      27             : #include "llvm/IR/Instruction.h"
      28             : #include "llvm/IR/Instructions.h"
      29             : #include "llvm/IR/IntrinsicInst.h"
      30             : #include "llvm/IR/Intrinsics.h"
      31             : #include "llvm/IR/Type.h"
      32             : #include "llvm/IR/Value.h"
      33             : #include "llvm/Pass.h"
      34             : #include "llvm/Support/Casting.h"
      35             : #include <algorithm>
      36             : #include <cassert>
      37             : 
      38             : using namespace llvm;
      39             : 
      40             : #define DEBUG_TYPE "scalarize-masked-mem-intrin"
      41             : 
      42             : namespace {
      43             : 
      44             : class ScalarizeMaskedMemIntrin : public FunctionPass {
      45             :   const TargetTransformInfo *TTI = nullptr;
      46             : 
      47             : public:
      48             :   static char ID; // Pass identification, replacement for typeid
      49             : 
      50       27457 :   explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
      51       27457 :     initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
      52       27457 :   }
      53             : 
      54             :   bool runOnFunction(Function &F) override;
      55             : 
      56          21 :   StringRef getPassName() const override {
      57          21 :     return "Scalarize Masked Memory Intrinsics";
      58             :   }
      59             : 
      60       27319 :   void getAnalysisUsage(AnalysisUsage &AU) const override {
      61             :     AU.addRequired<TargetTransformInfoWrapperPass>();
      62       27319 :   }
      63             : 
      64             : private:
      65             :   bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
      66             :   bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
      67             : };
      68             : 
      69             : } // end anonymous namespace
      70             : 
      71             : char ScalarizeMaskedMemIntrin::ID = 0;
      72             : 
      73      151909 : INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
      74             :                 "Scalarize unsupported masked memory intrinsics", false, false)
      75             : 
      76       27453 : FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
      77       27453 :   return new ScalarizeMaskedMemIntrin();
      78             : }
      79             : 
      80         129 : static bool isConstantIntVector(Value *Mask) {
      81             :   Constant *C = dyn_cast<Constant>(Mask);
      82             :   if (!C)
      83             :     return false;
      84             : 
      85          35 :   unsigned NumElts = Mask->getType()->getVectorNumElements();
      86         288 :   for (unsigned i = 0; i != NumElts; ++i) {
      87         253 :     Constant *CElt = C->getAggregateElement(i);
      88         253 :     if (!CElt || !isa<ConstantInt>(CElt))
      89             :       return false;
      90             :   }
      91             : 
      92             :   return true;
      93             : }
      94             : 
      95             : // Translate a masked load intrinsic like
      96             : // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
      97             : //                               <16 x i1> %mask, <16 x i32> %passthru)
      98             : // to a chain of basic blocks, with loading element one-by-one if
      99             : // the appropriate mask bit is set
     100             : //
     101             : //  %1 = bitcast i8* %addr to i32*
     102             : //  %2 = extractelement <16 x i1> %mask, i32 0
     103             : //  br i1 %2, label %cond.load, label %else
     104             : //
     105             : // cond.load:                                        ; preds = %0
     106             : //  %3 = getelementptr i32* %1, i32 0
     107             : //  %4 = load i32* %3
     108             : //  %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
     109             : //  br label %else
     110             : //
     111             : // else:                                             ; preds = %0, %cond.load
     112             : //  %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
     113             : //  %6 = extractelement <16 x i1> %mask, i32 1
     114             : //  br i1 %6, label %cond.load1, label %else2
     115             : //
     116             : // cond.load1:                                       ; preds = %else
     117             : //  %7 = getelementptr i32* %1, i32 1
     118             : //  %8 = load i32* %7
     119             : //  %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
     120             : //  br label %else2
     121             : //
     122             : // else2:                                          ; preds = %else, %cond.load1
     123             : //  %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
     124             : //  %10 = extractelement <16 x i1> %mask, i32 2
     125             : //  br i1 %10, label %cond.load4, label %else5
     126             : //
     127          16 : static void scalarizeMaskedLoad(CallInst *CI) {
     128          16 :   Value *Ptr = CI->getArgOperand(0);
     129             :   Value *Alignment = CI->getArgOperand(1);
     130             :   Value *Mask = CI->getArgOperand(2);
     131             :   Value *Src0 = CI->getArgOperand(3);
     132             : 
     133          16 :   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
     134          16 :   VectorType *VecType = cast<VectorType>(CI->getType());
     135             : 
     136          16 :   Type *EltTy = VecType->getElementType();
     137             : 
     138          16 :   IRBuilder<> Builder(CI->getContext());
     139             :   Instruction *InsertPt = CI;
     140          16 :   BasicBlock *IfBlock = CI->getParent();
     141             :   BasicBlock *CondBlock = nullptr;
     142             :   BasicBlock *PrevIfBlock = CI->getParent();
     143             : 
     144          16 :   Builder.SetInsertPoint(InsertPt);
     145          16 :   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
     146             : 
     147             :   // Short-cut if the mask is all-true.
     148          16 :   if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
     149           1 :     Value *NewI = Builder.CreateAlignedLoad(Ptr, AlignVal);
     150           1 :     CI->replaceAllUsesWith(NewI);
     151           1 :     CI->eraseFromParent();
     152           1 :     return;
     153             :   }
     154             : 
     155             :   // Adjust alignment for the scalar instruction.
     156          15 :   AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
     157             :   // Bitcast %addr fron i8* to EltTy*
     158             :   Type *NewPtrType =
     159          30 :       EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
     160          15 :   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
     161          15 :   unsigned VectorWidth = VecType->getNumElements();
     162             : 
     163             :   // The result vector
     164             :   Value *VResult = Src0;
     165             : 
     166          15 :   if (isConstantIntVector(Mask)) {
     167           6 :     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     168           4 :       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
     169             :         continue;
     170             :       Value *Gep =
     171           1 :           Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
     172           1 :       LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
     173             :       VResult =
     174           1 :           Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
     175             :     }
     176           2 :     CI->replaceAllUsesWith(VResult);
     177           2 :     CI->eraseFromParent();
     178           2 :     return;
     179             :   }
     180             : 
     181          35 :   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     182             :     // Fill the "else" block, created in the previous iteration
     183             :     //
     184             :     //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
     185             :     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
     186             :     //  br i1 %mask_1, label %cond.load, label %else
     187             :     //
     188             : 
     189             :     Value *Predicate =
     190          22 :         Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
     191             : 
     192             :     // Create "cond" block
     193             :     //
     194             :     //  %EltAddr = getelementptr i32* %1, i32 0
     195             :     //  %Elt = load i32* %EltAddr
     196             :     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
     197             :     //
     198          22 :     CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.load");
     199          22 :     Builder.SetInsertPoint(InsertPt);
     200             : 
     201             :     Value *Gep =
     202          22 :         Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
     203          22 :     LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
     204          22 :     Value *NewVResult = Builder.CreateInsertElement(VResult, Load,
     205          22 :                                                     Builder.getInt32(Idx));
     206             : 
     207             :     // Create "else" block, fill it in the next iteration
     208             :     BasicBlock *NewIfBlock =
     209          22 :         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
     210          22 :     Builder.SetInsertPoint(InsertPt);
     211             :     Instruction *OldBr = IfBlock->getTerminator();
     212          22 :     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
     213          22 :     OldBr->eraseFromParent();
     214             :     PrevIfBlock = IfBlock;
     215             :     IfBlock = NewIfBlock;
     216             : 
     217             :     // Create the phi to join the new and previous value.
     218          22 :     PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
     219          22 :     Phi->addIncoming(NewVResult, CondBlock);
     220          22 :     Phi->addIncoming(VResult, PrevIfBlock);
     221             :     VResult = Phi;
     222             :   }
     223             : 
     224          13 :   CI->replaceAllUsesWith(VResult);
     225          13 :   CI->eraseFromParent();
     226             : }
     227             : 
     228             : // Translate a masked store intrinsic, like
     229             : // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
     230             : //                               <16 x i1> %mask)
     231             : // to a chain of basic blocks, that stores element one-by-one if
     232             : // the appropriate mask bit is set
     233             : //
     234             : //   %1 = bitcast i8* %addr to i32*
     235             : //   %2 = extractelement <16 x i1> %mask, i32 0
     236             : //   br i1 %2, label %cond.store, label %else
     237             : //
     238             : // cond.store:                                       ; preds = %0
     239             : //   %3 = extractelement <16 x i32> %val, i32 0
     240             : //   %4 = getelementptr i32* %1, i32 0
     241             : //   store i32 %3, i32* %4
     242             : //   br label %else
     243             : //
     244             : // else:                                             ; preds = %0, %cond.store
     245             : //   %5 = extractelement <16 x i1> %mask, i32 1
     246             : //   br i1 %5, label %cond.store1, label %else2
     247             : //
     248             : // cond.store1:                                      ; preds = %else
     249             : //   %6 = extractelement <16 x i32> %val, i32 1
     250             : //   %7 = getelementptr i32* %1, i32 1
     251             : //   store i32 %6, i32* %7
     252             : //   br label %else2
     253             : //   . . .
     254          10 : static void scalarizeMaskedStore(CallInst *CI) {
     255          10 :   Value *Src = CI->getArgOperand(0);
     256             :   Value *Ptr = CI->getArgOperand(1);
     257             :   Value *Alignment = CI->getArgOperand(2);
     258             :   Value *Mask = CI->getArgOperand(3);
     259             : 
     260          10 :   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
     261          10 :   VectorType *VecType = cast<VectorType>(Src->getType());
     262             : 
     263          10 :   Type *EltTy = VecType->getElementType();
     264             : 
     265          10 :   IRBuilder<> Builder(CI->getContext());
     266             :   Instruction *InsertPt = CI;
     267          10 :   BasicBlock *IfBlock = CI->getParent();
     268          10 :   Builder.SetInsertPoint(InsertPt);
     269          10 :   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
     270             : 
     271             :   // Short-cut if the mask is all-true.
     272          10 :   if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
     273             :     Builder.CreateAlignedStore(Src, Ptr, AlignVal);
     274           1 :     CI->eraseFromParent();
     275           1 :     return;
     276             :   }
     277             : 
     278             :   // Adjust alignment for the scalar instruction.
     279           9 :   AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
     280             :   // Bitcast %addr fron i8* to EltTy*
     281             :   Type *NewPtrType =
     282          18 :       EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
     283           9 :   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
     284           9 :   unsigned VectorWidth = VecType->getNumElements();
     285             : 
     286           9 :   if (isConstantIntVector(Mask)) {
     287           6 :     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     288           4 :       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
     289             :         continue;
     290           1 :       Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
     291             :       Value *Gep =
     292           1 :           Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
     293             :       Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
     294             :     }
     295           2 :     CI->eraseFromParent();
     296           2 :     return;
     297             :   }
     298             : 
     299          21 :   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     300             :     // Fill the "else" block, created in the previous iteration
     301             :     //
     302             :     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
     303             :     //  br i1 %mask_1, label %cond.store, label %else
     304             :     //
     305             :     Value *Predicate =
     306          14 :         Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
     307             : 
     308             :     // Create "cond" block
     309             :     //
     310             :     //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
     311             :     //  %EltAddr = getelementptr i32* %1, i32 0
     312             :     //  %store i32 %OneElt, i32* %EltAddr
     313             :     //
     314             :     BasicBlock *CondBlock =
     315          14 :         IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
     316          14 :     Builder.SetInsertPoint(InsertPt);
     317             : 
     318          14 :     Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
     319             :     Value *Gep =
     320          14 :         Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
     321             :     Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
     322             : 
     323             :     // Create "else" block, fill it in the next iteration
     324             :     BasicBlock *NewIfBlock =
     325          14 :         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
     326          14 :     Builder.SetInsertPoint(InsertPt);
     327             :     Instruction *OldBr = IfBlock->getTerminator();
     328          14 :     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
     329          14 :     OldBr->eraseFromParent();
     330             :     IfBlock = NewIfBlock;
     331             :   }
     332           7 :   CI->eraseFromParent();
     333             : }
     334             : 
     335             : // Translate a masked gather intrinsic like
     336             : // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
     337             : //                               <16 x i1> %Mask, <16 x i32> %Src)
     338             : // to a chain of basic blocks, with loading element one-by-one if
     339             : // the appropriate mask bit is set
     340             : //
     341             : // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
     342             : // %Mask0 = extractelement <16 x i1> %Mask, i32 0
     343             : // br i1 %Mask0, label %cond.load, label %else
     344             : //
     345             : // cond.load:
     346             : // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
     347             : // %Load0 = load i32, i32* %Ptr0, align 4
     348             : // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
     349             : // br label %else
     350             : //
     351             : // else:
     352             : // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
     353             : // %Mask1 = extractelement <16 x i1> %Mask, i32 1
     354             : // br i1 %Mask1, label %cond.load1, label %else2
     355             : //
     356             : // cond.load1:
     357             : // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
     358             : // %Load1 = load i32, i32* %Ptr1, align 4
     359             : // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
     360             : // br label %else2
     361             : // . . .
     362             : // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
     363             : // ret <16 x i32> %Result
     364          76 : static void scalarizeMaskedGather(CallInst *CI) {
     365          76 :   Value *Ptrs = CI->getArgOperand(0);
     366             :   Value *Alignment = CI->getArgOperand(1);
     367             :   Value *Mask = CI->getArgOperand(2);
     368             :   Value *Src0 = CI->getArgOperand(3);
     369             : 
     370          76 :   VectorType *VecType = cast<VectorType>(CI->getType());
     371             : 
     372          76 :   IRBuilder<> Builder(CI->getContext());
     373             :   Instruction *InsertPt = CI;
     374          76 :   BasicBlock *IfBlock = CI->getParent();
     375             :   BasicBlock *CondBlock = nullptr;
     376             :   BasicBlock *PrevIfBlock = CI->getParent();
     377          76 :   Builder.SetInsertPoint(InsertPt);
     378          76 :   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
     379             : 
     380          76 :   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
     381             : 
     382             :   // The result vector
     383             :   Value *VResult = Src0;
     384          76 :   unsigned VectorWidth = VecType->getNumElements();
     385             : 
     386             :   // Shorten the way if the mask is a vector of constants.
     387          76 :   if (isConstantIntVector(Mask)) {
     388         264 :     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     389         235 :       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
     390             :         continue;
     391         219 :       Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
     392         219 :                                                 "Ptr" + Twine(Idx));
     393             :       LoadInst *Load =
     394         219 :           Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
     395         219 :       VResult = Builder.CreateInsertElement(
     396         438 :           VResult, Load, Builder.getInt32(Idx), "Res" + Twine(Idx));
     397             :     }
     398          29 :     CI->replaceAllUsesWith(VResult);
     399          29 :     CI->eraseFromParent();
     400             :     return;
     401             :   }
     402             : 
     403         348 :   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     404             :     // Fill the "else" block, created in the previous iteration
     405             :     //
     406             :     //  %Mask1 = extractelement <16 x i1> %Mask, i32 1
     407             :     //  br i1 %Mask1, label %cond.load, label %else
     408             :     //
     409             : 
     410         301 :     Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
     411         301 :                                                     "Mask" + Twine(Idx));
     412             : 
     413             :     // Create "cond" block
     414             :     //
     415             :     //  %EltAddr = getelementptr i32* %1, i32 0
     416             :     //  %Elt = load i32* %EltAddr
     417             :     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
     418             :     //
     419         301 :     CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
     420         301 :     Builder.SetInsertPoint(InsertPt);
     421             : 
     422         301 :     Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
     423         301 :                                               "Ptr" + Twine(Idx));
     424             :     LoadInst *Load =
     425         301 :         Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
     426         301 :     Value *NewVResult = Builder.CreateInsertElement(VResult, Load,
     427         301 :                                                     Builder.getInt32(Idx),
     428         301 :                                                     "Res" + Twine(Idx));
     429             : 
     430             :     // Create "else" block, fill it in the next iteration
     431         301 :     BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
     432         301 :     Builder.SetInsertPoint(InsertPt);
     433             :     Instruction *OldBr = IfBlock->getTerminator();
     434         301 :     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
     435         301 :     OldBr->eraseFromParent();
     436             :     PrevIfBlock = IfBlock;
     437             :     IfBlock = NewIfBlock;
     438             : 
     439         301 :     PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
     440         301 :     Phi->addIncoming(NewVResult, CondBlock);
     441         301 :     Phi->addIncoming(VResult, PrevIfBlock);
     442             :     VResult = Phi;
     443             :   }
     444             : 
     445          47 :   CI->replaceAllUsesWith(VResult);
     446          47 :   CI->eraseFromParent();
     447             : }
     448             : 
     449             : // Translate a masked scatter intrinsic, like
     450             : // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
     451             : //                                  <16 x i1> %Mask)
     452             : // to a chain of basic blocks, that stores element one-by-one if
     453             : // the appropriate mask bit is set.
     454             : //
     455             : // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
     456             : // %Mask0 = extractelement <16 x i1> %Mask, i32 0
     457             : // br i1 %Mask0, label %cond.store, label %else
     458             : //
     459             : // cond.store:
     460             : // %Elt0 = extractelement <16 x i32> %Src, i32 0
     461             : // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
     462             : // store i32 %Elt0, i32* %Ptr0, align 4
     463             : // br label %else
     464             : //
     465             : // else:
     466             : // %Mask1 = extractelement <16 x i1> %Mask, i32 1
     467             : // br i1 %Mask1, label %cond.store1, label %else2
     468             : //
     469             : // cond.store1:
     470             : // %Elt1 = extractelement <16 x i32> %Src, i32 1
     471             : // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
     472             : // store i32 %Elt1, i32* %Ptr1, align 4
     473             : // br label %else2
     474             : //   . . .
     475          29 : static void scalarizeMaskedScatter(CallInst *CI) {
     476          29 :   Value *Src = CI->getArgOperand(0);
     477             :   Value *Ptrs = CI->getArgOperand(1);
     478             :   Value *Alignment = CI->getArgOperand(2);
     479             :   Value *Mask = CI->getArgOperand(3);
     480             : 
     481             :   assert(isa<VectorType>(Src->getType()) &&
     482             :          "Unexpected data type in masked scatter intrinsic");
     483             :   assert(isa<VectorType>(Ptrs->getType()) &&
     484             :          isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
     485             :          "Vector of pointers is expected in masked scatter intrinsic");
     486             : 
     487          29 :   IRBuilder<> Builder(CI->getContext());
     488             :   Instruction *InsertPt = CI;
     489          29 :   BasicBlock *IfBlock = CI->getParent();
     490          29 :   Builder.SetInsertPoint(InsertPt);
     491          29 :   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
     492             : 
     493          29 :   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
     494          29 :   unsigned VectorWidth = Src->getType()->getVectorNumElements();
     495             : 
     496             :   // Shorten the way if the mask is a vector of constants.
     497          29 :   if (isConstantIntVector(Mask)) {
     498          12 :     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     499          10 :       if (cast<ConstantVector>(Mask)->getAggregateElement(Idx)->isNullValue())
     500             :         continue;
     501          10 :       Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
     502          10 :                                                    "Elt" + Twine(Idx));
     503          10 :       Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
     504          10 :                                                 "Ptr" + Twine(Idx));
     505             :       Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
     506             :     }
     507           2 :     CI->eraseFromParent();
     508             :     return;
     509             :   }
     510             : 
     511         180 :   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     512             :     // Fill the "else" block, created in the previous iteration
     513             :     //
     514             :     //  %Mask1 = extractelement <16 x i1> %Mask, i32 Idx
     515             :     //  br i1 %Mask1, label %cond.store, label %else
     516             :     //
     517         153 :     Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
     518         153 :                                                     "Mask" + Twine(Idx));
     519             : 
     520             :     // Create "cond" block
     521             :     //
     522             :     //  %Elt1 = extractelement <16 x i32> %Src, i32 1
     523             :     //  %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
     524             :     //  %store i32 %Elt1, i32* %Ptr1
     525             :     //
     526         153 :     BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
     527         153 :     Builder.SetInsertPoint(InsertPt);
     528             : 
     529         153 :     Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
     530         153 :                                                  "Elt" + Twine(Idx));
     531         153 :     Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
     532         153 :                                               "Ptr" + Twine(Idx));
     533             :     Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
     534             : 
     535             :     // Create "else" block, fill it in the next iteration
     536         153 :     BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
     537         153 :     Builder.SetInsertPoint(InsertPt);
     538             :     Instruction *OldBr = IfBlock->getTerminator();
     539         153 :     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
     540         153 :     OldBr->eraseFromParent();
     541             :     IfBlock = NewIfBlock;
     542             :   }
     543          27 :   CI->eraseFromParent();
     544             : }
     545             : 
     546      406591 : bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
     547             :   bool EverMadeChange = false;
     548             : 
     549      406591 :   TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
     550             : 
     551             :   bool MadeChange = true;
     552      813313 :   while (MadeChange) {
     553             :     MadeChange = false;
     554     3630094 :     for (Function::iterator I = F.begin(); I != F.end();) {
     555             :       BasicBlock *BB = &*I++;
     556     3223503 :       bool ModifiedDTOnIteration = false;
     557     3223503 :       MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
     558             : 
     559             :       // Restart BB iteration if the dominator tree of the Function was changed
     560     3223503 :       if (ModifiedDTOnIteration)
     561             :         break;
     562             :     }
     563             : 
     564      406722 :     EverMadeChange |= MadeChange;
     565             :   }
     566             : 
     567      406591 :   return EverMadeChange;
     568             : }
     569             : 
     570     3223503 : bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
     571             :   bool MadeChange = false;
     572             : 
     573             :   BasicBlock::iterator CurInstIterator = BB.begin();
     574    40229555 :   while (CurInstIterator != BB.end()) {
     575             :     if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
     576     2391686 :       MadeChange |= optimizeCallInst(CI, ModifiedDT);
     577    37006183 :     if (ModifiedDT)
     578             :       return true;
     579             :   }
     580             : 
     581             :   return MadeChange;
     582             : }
     583             : 
     584           0 : bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
     585             :                                                 bool &ModifiedDT) {
     586             :   IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
     587             :   if (II) {
     588           0 :     switch (II->getIntrinsicID()) {
     589             :     default:
     590             :       break;
     591           0 :     case Intrinsic::masked_load:
     592             :       // Scalarize unsupported vector masked load
     593           0 :       if (!TTI->isLegalMaskedLoad(CI->getType())) {
     594           0 :         scalarizeMaskedLoad(CI);
     595           0 :         ModifiedDT = true;
     596           0 :         return true;
     597             :       }
     598           0 :       return false;
     599           0 :     case Intrinsic::masked_store:
     600           0 :       if (!TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType())) {
     601           0 :         scalarizeMaskedStore(CI);
     602           0 :         ModifiedDT = true;
     603           0 :         return true;
     604             :       }
     605           0 :       return false;
     606           0 :     case Intrinsic::masked_gather:
     607           0 :       if (!TTI->isLegalMaskedGather(CI->getType())) {
     608           0 :         scalarizeMaskedGather(CI);
     609           0 :         ModifiedDT = true;
     610           0 :         return true;
     611             :       }
     612           0 :       return false;
     613           0 :     case Intrinsic::masked_scatter:
     614           0 :       if (!TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType())) {
     615           0 :         scalarizeMaskedScatter(CI);
     616           0 :         ModifiedDT = true;
     617           0 :         return true;
     618             :       }
     619           0 :       return false;
     620             :     }
     621             :   }
     622             : 
     623             :   return false;
     624             : }

Generated by: LCOV version 1.13