LCOV - code coverage report
Current view: top level - lib/CodeGen - ExpandMemCmp.cpp (source / functions) Hit Total Coverage
Test: llvm-toolchain.info Lines: 267 284 94.0 %
Date: 2018-10-20 13:21:21 Functions: 20 23 87.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //===--- ExpandMemCmp.cpp - Expand memcmp() to load/stores ----------------===//
       2             : //
       3             : //                     The LLVM Compiler Infrastructure
       4             : //
       5             : // This file is distributed under the University of Illinois Open Source
       6             : // License. See LICENSE.TXT for details.
       7             : //
       8             : //===----------------------------------------------------------------------===//
       9             : //
      10             : // This pass tries to expand memcmp() calls into optimally-sized loads and
      11             : // compares for the target.
      12             : //
      13             : //===----------------------------------------------------------------------===//
      14             : 
      15             : #include "llvm/ADT/Statistic.h"
      16             : #include "llvm/Analysis/ConstantFolding.h"
      17             : #include "llvm/Analysis/TargetLibraryInfo.h"
      18             : #include "llvm/Analysis/TargetTransformInfo.h"
      19             : #include "llvm/Analysis/ValueTracking.h"
      20             : #include "llvm/CodeGen/TargetLowering.h"
      21             : #include "llvm/CodeGen/TargetPassConfig.h"
      22             : #include "llvm/CodeGen/TargetSubtargetInfo.h"
      23             : #include "llvm/IR/IRBuilder.h"
      24             : 
      25             : using namespace llvm;
      26             : 
      27             : #define DEBUG_TYPE "expandmemcmp"
      28             : 
      29             : STATISTIC(NumMemCmpCalls, "Number of memcmp calls");
      30             : STATISTIC(NumMemCmpNotConstant, "Number of memcmp calls without constant size");
      31             : STATISTIC(NumMemCmpGreaterThanMax,
      32             :           "Number of memcmp calls with size greater than max size");
      33             : STATISTIC(NumMemCmpInlined, "Number of inlined memcmp calls");
      34             : 
      35             : static cl::opt<unsigned> MemCmpEqZeroNumLoadsPerBlock(
      36             :     "memcmp-num-loads-per-block", cl::Hidden, cl::init(1),
      37             :     cl::desc("The number of loads per basic block for inline expansion of "
      38             :              "memcmp that is only being compared against zero."));
      39             : 
      40             : namespace {
      41             : 
      42             : 
      43             : // This class provides helper functions to expand a memcmp library call into an
      44             : // inline expansion.
      45             : class MemCmpExpansion {
      46             :   struct ResultBlock {
      47             :     BasicBlock *BB = nullptr;
      48             :     PHINode *PhiSrc1 = nullptr;
      49             :     PHINode *PhiSrc2 = nullptr;
      50             : 
      51         395 :     ResultBlock() = default;
      52             :   };
      53             : 
      54             :   CallInst *const CI;
      55             :   ResultBlock ResBlock;
      56             :   const uint64_t Size;
      57             :   unsigned MaxLoadSize;
      58             :   uint64_t NumLoadsNonOneByte;
      59             :   const uint64_t NumLoadsPerBlockForZeroCmp;
      60             :   std::vector<BasicBlock *> LoadCmpBlocks;
      61             :   BasicBlock *EndBlock;
      62             :   PHINode *PhiRes;
      63             :   const bool IsUsedForZeroCmp;
      64             :   const DataLayout &DL;
      65             :   IRBuilder<> Builder;
      66             :   // Represents the decomposition in blocks of the expansion. For example,
      67             :   // comparing 33 bytes on X86+sse can be done with 2x16-byte loads and
      68             :   // 1x1-byte load, which would be represented as [{16, 0}, {16, 16}, {32, 1}.
      69             :   // TODO(courbet): Involve the target more in this computation. On X86, 7
      70             :   // bytes can be done more efficiently with two overlaping 4-byte loads than
      71             :   // covering the interval with [{4, 0},{2, 4},{1, 6}}.
      72             :   struct LoadEntry {
      73             :     LoadEntry(unsigned LoadSize, uint64_t Offset)
      74         489 :         : LoadSize(LoadSize), Offset(Offset) {
      75             :       assert(Offset % LoadSize == 0 && "invalid load entry");
      76             :     }
      77             : 
      78         264 :     uint64_t getGEPIndex() const { return Offset / LoadSize; }
      79             : 
      80             :     // The size of the load for this block, in bytes.
      81             :     const unsigned LoadSize;
      82             :     // The offset of this load WRT the base pointer, in bytes.
      83             :     const uint64_t Offset;
      84             :   };
      85             :   SmallVector<LoadEntry, 8> LoadSequence;
      86             : 
      87             :   void createLoadCmpBlocks();
      88             :   void createResultBlock();
      89             :   void setupResultBlockPHINodes();
      90             :   void setupEndBlockPHINodes();
      91             :   Value *getCompareLoadPairs(unsigned BlockIndex, unsigned &LoadIndex);
      92             :   void emitLoadCompareBlock(unsigned BlockIndex);
      93             :   void emitLoadCompareBlockMultipleLoads(unsigned BlockIndex,
      94             :                                          unsigned &LoadIndex);
      95             :   void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned GEPIndex);
      96             :   void emitMemCmpResultBlock();
      97             :   Value *getMemCmpExpansionZeroCase();
      98             :   Value *getMemCmpEqZeroOneBlock();
      99             :   Value *getMemCmpOneBlock();
     100             : 
     101             :  public:
     102             :   MemCmpExpansion(CallInst *CI, uint64_t Size,
     103             :                   const TargetTransformInfo::MemCmpExpansionOptions &Options,
     104             :                   unsigned MaxNumLoads, const bool IsUsedForZeroCmp,
     105             :                   unsigned MaxLoadsPerBlockForZeroCmp, const DataLayout &TheDataLayout);
     106             : 
     107             :   unsigned getNumBlocks();
     108         858 :   uint64_t getNumLoads() const { return LoadSequence.size(); }
     109             : 
     110             :   Value *getMemCmpExpansion();
     111             : };
     112             : 
     113             : // Initialize the basic block structure required for expansion of memcmp call
     114             : // with given maximum load size and memcmp size parameter.
     115             : // This structure includes:
     116             : // 1. A list of load compare blocks - LoadCmpBlocks.
     117             : // 2. An EndBlock, split from original instruction point, which is the block to
     118             : // return from.
     119             : // 3. ResultBlock, block to branch to for early exit when a
     120             : // LoadCmpBlock finds a difference.
     121         395 : MemCmpExpansion::MemCmpExpansion(
     122             :     CallInst *const CI, uint64_t Size,
     123             :     const TargetTransformInfo::MemCmpExpansionOptions &Options,
     124             :     const unsigned MaxNumLoads, const bool IsUsedForZeroCmp,
     125         395 :     const unsigned MaxLoadsPerBlockForZeroCmp, const DataLayout &TheDataLayout)
     126             :     : CI(CI),
     127             :       Size(Size),
     128             :       MaxLoadSize(0),
     129             :       NumLoadsNonOneByte(0),
     130             :       NumLoadsPerBlockForZeroCmp(MaxLoadsPerBlockForZeroCmp),
     131             :       IsUsedForZeroCmp(IsUsedForZeroCmp),
     132             :       DL(TheDataLayout),
     133        1185 :       Builder(CI) {
     134             :   assert(Size > 0 && "zero blocks");
     135             :   // Scale the max size down if the target can load more bytes than we need.
     136             :   size_t LoadSizeIndex = 0;
     137        1340 :   while (LoadSizeIndex < Options.LoadSizes.size() &&
     138         670 :          Options.LoadSizes[LoadSizeIndex] > Size) {
     139         275 :     ++LoadSizeIndex;
     140             :   }
     141         395 :   this->MaxLoadSize = Options.LoadSizes[LoadSizeIndex];
     142             :   // Compute the decomposition.
     143             :   uint64_t CurSize = Size;
     144             :   uint64_t Offset = 0;
     145         868 :   while (CurSize && LoadSizeIndex < Options.LoadSizes.size()) {
     146         602 :     const unsigned LoadSize = Options.LoadSizes[LoadSizeIndex];
     147             :     assert(LoadSize > 0 && "zero load size");
     148         602 :     const uint64_t NumLoadsForThisSize = CurSize / LoadSize;
     149        1204 :     if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) {
     150             :       // Do not expand if the total number of loads is larger than what the
     151             :       // target allows. Note that it's important that we exit before completing
     152             :       // the expansion to avoid using a ton of memory to store the expansion for
     153             :       // large sizes.
     154             :       LoadSequence.clear();
     155         129 :       return;
     156             :     }
     157         473 :     if (NumLoadsForThisSize > 0) {
     158         910 :       for (uint64_t I = 0; I < NumLoadsForThisSize; ++I) {
     159         489 :         LoadSequence.push_back({LoadSize, Offset});
     160         489 :         Offset += LoadSize;
     161             :       }
     162         421 :       if (LoadSize > 1) {
     163         360 :         ++NumLoadsNonOneByte;
     164             :       }
     165         421 :       CurSize = CurSize % LoadSize;
     166             :     }
     167         473 :     ++LoadSizeIndex;
     168             :   }
     169             :   assert(LoadSequence.size() <= MaxNumLoads && "broken invariant");
     170             : }
     171             : 
     172             : unsigned MemCmpExpansion::getNumBlocks() {
     173         424 :   if (IsUsedForZeroCmp)
     174         796 :     return getNumLoads() / NumLoadsPerBlockForZeroCmp +
     175         551 :            (getNumLoads() % NumLoadsPerBlockForZeroCmp != 0 ? 1 : 0);
     176             :   return getNumLoads();
     177             : }
     178             : 
     179          69 : void MemCmpExpansion::createLoadCmpBlocks() {
     180         355 :   for (unsigned i = 0; i < getNumBlocks(); i++) {
     181         286 :     BasicBlock *BB = BasicBlock::Create(CI->getContext(), "loadbb",
     182         143 :                                         EndBlock->getParent(), EndBlock);
     183         143 :     LoadCmpBlocks.push_back(BB);
     184             :   }
     185          69 : }
     186             : 
     187          69 : void MemCmpExpansion::createResultBlock() {
     188         138 :   ResBlock.BB = BasicBlock::Create(CI->getContext(), "res_block",
     189             :                                    EndBlock->getParent(), EndBlock);
     190          69 : }
     191             : 
     192             : // This function creates the IR instructions for loading and comparing 1 byte.
     193             : // It loads 1 byte from each source of the memcmp parameters with the given
     194             : // GEPIndex. It then subtracts the two loaded values and adds this result to the
     195             : // final phi node for selecting the memcmp result.
     196          30 : void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex,
     197             :                                                unsigned GEPIndex) {
     198          30 :   Value *Source1 = CI->getArgOperand(0);
     199             :   Value *Source2 = CI->getArgOperand(1);
     200             : 
     201          60 :   Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
     202          30 :   Type *LoadSizeType = Type::getInt8Ty(CI->getContext());
     203             :   // Cast source to LoadSizeType*.
     204          30 :   if (Source1->getType() != LoadSizeType)
     205          90 :     Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
     206          30 :   if (Source2->getType() != LoadSizeType)
     207          90 :     Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
     208             : 
     209             :   // Get the base address using the GEPIndex.
     210          30 :   if (GEPIndex != 0) {
     211          30 :     Source1 = Builder.CreateGEP(LoadSizeType, Source1,
     212          30 :                                 ConstantInt::get(LoadSizeType, GEPIndex));
     213          30 :     Source2 = Builder.CreateGEP(LoadSizeType, Source2,
     214          30 :                                 ConstantInt::get(LoadSizeType, GEPIndex));
     215             :   }
     216             : 
     217          60 :   Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
     218          30 :   Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
     219             : 
     220          60 :   LoadSrc1 = Builder.CreateZExt(LoadSrc1, Type::getInt32Ty(CI->getContext()));
     221          60 :   LoadSrc2 = Builder.CreateZExt(LoadSrc2, Type::getInt32Ty(CI->getContext()));
     222          30 :   Value *Diff = Builder.CreateSub(LoadSrc1, LoadSrc2);
     223             : 
     224          60 :   PhiRes->addIncoming(Diff, LoadCmpBlocks[BlockIndex]);
     225             : 
     226          60 :   if (BlockIndex < (LoadCmpBlocks.size() - 1)) {
     227             :     // Early exit branch if difference found to EndBlock. Otherwise, continue to
     228             :     // next LoadCmpBlock,
     229           0 :     Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_NE, Diff,
     230           0 :                                     ConstantInt::get(Diff->getType(), 0));
     231             :     BranchInst *CmpBr =
     232           0 :         BranchInst::Create(EndBlock, LoadCmpBlocks[BlockIndex + 1], Cmp);
     233           0 :     Builder.Insert(CmpBr);
     234             :   } else {
     235             :     // The last block has an unconditional branch to EndBlock.
     236          30 :     BranchInst *CmpBr = BranchInst::Create(EndBlock);
     237          30 :     Builder.Insert(CmpBr);
     238             :   }
     239          30 : }
     240             : 
     241             : /// Generate an equality comparison for one or more pairs of loaded values.
     242             : /// This is used in the case where the memcmp() call is compared equal or not
     243             : /// equal to zero.
     244         179 : Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex,
     245             :                                             unsigned &LoadIndex) {
     246             :   assert(LoadIndex < getNumLoads() &&
     247             :          "getCompareLoadPairs() called with no remaining loads");
     248             :   std::vector<Value *> XorList, OrList;
     249             :   Value *Diff;
     250             : 
     251             :   const unsigned NumLoads =
     252         369 :       std::min(getNumLoads() - LoadIndex, NumLoadsPerBlockForZeroCmp);
     253             : 
     254             :   // For a single-block expansion, start inserting before the memcmp call.
     255         179 :   if (LoadCmpBlocks.empty())
     256         158 :     Builder.SetInsertPoint(CI);
     257             :   else
     258          42 :     Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
     259             : 
     260             :   Value *Cmp = nullptr;
     261             :   // If we have multiple loads per block, we need to generate a composite
     262             :   // comparison using xor+or. The type for the combinations is the largest load
     263             :   // type.
     264             :   IntegerType *const MaxLoadType =
     265         179 :       NumLoads == 1 ? nullptr
     266          73 :                     : IntegerType::get(CI->getContext(), MaxLoadSize * 8);
     267         431 :   for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) {
     268         252 :     const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex];
     269             : 
     270             :     IntegerType *LoadSizeType =
     271         252 :         IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
     272             : 
     273         252 :     Value *Source1 = CI->getArgOperand(0);
     274             :     Value *Source2 = CI->getArgOperand(1);
     275             : 
     276             :     // Cast source to LoadSizeType*.
     277         252 :     if (Source1->getType() != LoadSizeType)
     278         756 :       Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
     279         252 :     if (Source2->getType() != LoadSizeType)
     280         756 :       Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
     281             : 
     282             :     // Get the base address using a GEP.
     283         252 :     if (CurLoadEntry.Offset != 0) {
     284          84 :       Source1 = Builder.CreateGEP(
     285             :           LoadSizeType, Source1,
     286          84 :           ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
     287          84 :       Source2 = Builder.CreateGEP(
     288             :           LoadSizeType, Source2,
     289          84 :           ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
     290             :     }
     291             : 
     292             :     // Get a constant or load a value for each source address.
     293             :     Value *LoadSrc1 = nullptr;
     294             :     if (auto *Source1C = dyn_cast<Constant>(Source1))
     295           4 :       LoadSrc1 = ConstantFoldLoadFromConstPtr(Source1C, LoadSizeType, DL);
     296         252 :     if (!LoadSrc1)
     297         496 :       LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
     298             : 
     299             :     Value *LoadSrc2 = nullptr;
     300             :     if (auto *Source2C = dyn_cast<Constant>(Source2))
     301          76 :       LoadSrc2 = ConstantFoldLoadFromConstPtr(Source2C, LoadSizeType, DL);
     302          76 :     if (!LoadSrc2)
     303         356 :       LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
     304             : 
     305         252 :     if (NumLoads != 1) {
     306         146 :       if (LoadSizeType != MaxLoadType) {
     307          45 :         LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType);
     308          45 :         LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType);
     309             :       }
     310             :       // If we have multiple loads per block, we need to generate a composite
     311             :       // comparison using xor+or.
     312         292 :       Diff = Builder.CreateXor(LoadSrc1, LoadSrc2);
     313         146 :       Diff = Builder.CreateZExt(Diff, MaxLoadType);
     314         146 :       XorList.push_back(Diff);
     315             :     } else {
     316             :       // If there's only one load per block, we just compare the loaded values.
     317         106 :       Cmp = Builder.CreateICmpNE(LoadSrc1, LoadSrc2);
     318             :     }
     319             :   }
     320             : 
     321             :   auto pairWiseOr = [&](std::vector<Value *> &InList) -> std::vector<Value *> {
     322             :     std::vector<Value *> OutList;
     323             :     for (unsigned i = 0; i < InList.size() - 1; i = i + 2) {
     324             :       Value *Or = Builder.CreateOr(InList[i], InList[i + 1]);
     325             :       OutList.push_back(Or);
     326             :     }
     327             :     if (InList.size() % 2 != 0)
     328             :       OutList.push_back(InList.back());
     329             :     return OutList;
     330         179 :   };
     331             : 
     332         179 :   if (!Cmp) {
     333             :     // Pairwise OR the XOR results.
     334         146 :     OrList = pairWiseOr(XorList);
     335             : 
     336             :     // Pairwise OR the OR results until one result left.
     337         146 :     while (OrList.size() != 1) {
     338           0 :       OrList = pairWiseOr(OrList);
     339             :     }
     340         146 :     Cmp = Builder.CreateICmpNE(OrList[0], ConstantInt::get(Diff->getType(), 0));
     341             :   }
     342             : 
     343         179 :   return Cmp;
     344             : }
     345             : 
     346          21 : void MemCmpExpansion::emitLoadCompareBlockMultipleLoads(unsigned BlockIndex,
     347             :                                                         unsigned &LoadIndex) {
     348          21 :   Value *Cmp = getCompareLoadPairs(BlockIndex, LoadIndex);
     349             : 
     350          21 :   BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
     351          21 :                            ? EndBlock
     352          11 :                            : LoadCmpBlocks[BlockIndex + 1];
     353             :   // Early exit branch if difference found to ResultBlock. Otherwise,
     354             :   // continue to next LoadCmpBlock or EndBlock.
     355          21 :   BranchInst *CmpBr = BranchInst::Create(ResBlock.BB, NextBB, Cmp);
     356          42 :   Builder.Insert(CmpBr);
     357             : 
     358             :   // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0
     359             :   // since early exit to ResultBlock was not taken (no difference was found in
     360             :   // any of the bytes).
     361          42 :   if (BlockIndex == LoadCmpBlocks.size() - 1) {
     362          10 :     Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0);
     363          20 :     PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]);
     364             :   }
     365          21 : }
     366             : 
     367             : // This function creates the IR intructions for loading and comparing using the
     368             : // given LoadSize. It loads the number of bytes specified by LoadSize from each
     369             : // source of the memcmp parameters. It then does a subtract to see if there was
     370             : // a difference in the loaded values. If a difference is found, it branches
     371             : // with an early exit to the ResultBlock for calculating which source was
     372             : // larger. Otherwise, it falls through to the either the next LoadCmpBlock or
     373             : // the EndBlock if this is the last LoadCmpBlock. Loading 1 byte is handled with
     374             : // a special case through emitLoadCompareByteBlock. The special handling can
     375             : // simply subtract the loaded values and add it to the result phi node.
     376         122 : void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) {
     377             :   // There is one load per block in this case, BlockIndex == LoadIndex.
     378         122 :   const LoadEntry &CurLoadEntry = LoadSequence[BlockIndex];
     379             : 
     380         122 :   if (CurLoadEntry.LoadSize == 1) {
     381          30 :     MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex,
     382          30 :                                               CurLoadEntry.getGEPIndex());
     383          30 :     return;
     384             :   }
     385             : 
     386             :   Type *LoadSizeType =
     387          92 :       IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
     388          92 :   Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
     389             :   assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type");
     390             : 
     391          92 :   Value *Source1 = CI->getArgOperand(0);
     392             :   Value *Source2 = CI->getArgOperand(1);
     393             : 
     394         184 :   Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
     395             :   // Cast source to LoadSizeType*.
     396          92 :   if (Source1->getType() != LoadSizeType)
     397         276 :     Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
     398          92 :   if (Source2->getType() != LoadSizeType)
     399         276 :     Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
     400             : 
     401             :   // Get the base address using a GEP.
     402          92 :   if (CurLoadEntry.Offset != 0) {
     403          33 :     Source1 = Builder.CreateGEP(
     404             :         LoadSizeType, Source1,
     405          33 :         ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
     406          33 :     Source2 = Builder.CreateGEP(
     407             :         LoadSizeType, Source2,
     408          33 :         ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
     409             :   }
     410             : 
     411             :   // Load LoadSizeType from the base address.
     412         184 :   Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
     413          92 :   Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
     414             : 
     415          92 :   if (DL.isLittleEndian()) {
     416          87 :     Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
     417             :                                                 Intrinsic::bswap, LoadSizeType);
     418          87 :     LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1);
     419          87 :     LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2);
     420             :   }
     421             : 
     422          92 :   if (LoadSizeType != MaxLoadType) {
     423          16 :     LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType);
     424          16 :     LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType);
     425             :   }
     426             : 
     427             :   // Add the loaded values to the phi nodes for calculating memcmp result only
     428             :   // if result is not used in a zero equality.
     429          92 :   if (!IsUsedForZeroCmp) {
     430         184 :     ResBlock.PhiSrc1->addIncoming(LoadSrc1, LoadCmpBlocks[BlockIndex]);
     431         184 :     ResBlock.PhiSrc2->addIncoming(LoadSrc2, LoadCmpBlocks[BlockIndex]);
     432             :   }
     433             : 
     434          92 :   Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, LoadSrc1, LoadSrc2);
     435          92 :   BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
     436          92 :                            ? EndBlock
     437          63 :                            : LoadCmpBlocks[BlockIndex + 1];
     438             :   // Early exit branch if difference found to ResultBlock. Otherwise, continue
     439             :   // to next LoadCmpBlock or EndBlock.
     440          92 :   BranchInst *CmpBr = BranchInst::Create(NextBB, ResBlock.BB, Cmp);
     441          92 :   Builder.Insert(CmpBr);
     442             : 
     443             :   // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0
     444             :   // since early exit to ResultBlock was not taken (no difference was found in
     445             :   // any of the bytes).
     446         184 :   if (BlockIndex == LoadCmpBlocks.size() - 1) {
     447          29 :     Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0);
     448          58 :     PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]);
     449             :   }
     450             : }
     451             : 
     452             : // This function populates the ResultBlock with a sequence to calculate the
     453             : // memcmp result. It compares the two loaded source values and returns -1 if
     454             : // src1 < src2 and 1 if src1 > src2.
     455          69 : void MemCmpExpansion::emitMemCmpResultBlock() {
     456             :   // Special case: if memcmp result is used in a zero equality, result does not
     457             :   // need to be calculated and can simply return 1.
     458          69 :   if (IsUsedForZeroCmp) {
     459          10 :     BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt();
     460          10 :     Builder.SetInsertPoint(ResBlock.BB, InsertPt);
     461          10 :     Value *Res = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 1);
     462          10 :     PhiRes->addIncoming(Res, ResBlock.BB);
     463          10 :     BranchInst *NewBr = BranchInst::Create(EndBlock);
     464          20 :     Builder.Insert(NewBr);
     465             :     return;
     466             :   }
     467          59 :   BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt();
     468          59 :   Builder.SetInsertPoint(ResBlock.BB, InsertPt);
     469             : 
     470          59 :   Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_ULT, ResBlock.PhiSrc1,
     471          59 :                                   ResBlock.PhiSrc2);
     472             : 
     473             :   Value *Res =
     474          59 :       Builder.CreateSelect(Cmp, ConstantInt::get(Builder.getInt32Ty(), -1),
     475          59 :                            ConstantInt::get(Builder.getInt32Ty(), 1));
     476             : 
     477          59 :   BranchInst *NewBr = BranchInst::Create(EndBlock);
     478          59 :   Builder.Insert(NewBr);
     479          59 :   PhiRes->addIncoming(Res, ResBlock.BB);
     480             : }
     481             : 
     482          59 : void MemCmpExpansion::setupResultBlockPHINodes() {
     483          59 :   Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
     484          59 :   Builder.SetInsertPoint(ResBlock.BB);
     485             :   // Note: this assumes one load per block.
     486          59 :   ResBlock.PhiSrc1 =
     487         118 :       Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src1");
     488          59 :   ResBlock.PhiSrc2 =
     489          59 :       Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src2");
     490          59 : }
     491             : 
     492          69 : void MemCmpExpansion::setupEndBlockPHINodes() {
     493         138 :   Builder.SetInsertPoint(&EndBlock->front());
     494         138 :   PhiRes = Builder.CreatePHI(Type::getInt32Ty(CI->getContext()), 2, "phi.res");
     495          69 : }
     496             : 
     497          10 : Value *MemCmpExpansion::getMemCmpExpansionZeroCase() {
     498          10 :   unsigned LoadIndex = 0;
     499             :   // This loop populates each of the LoadCmpBlocks with the IR sequence to
     500             :   // handle multiple loads per block.
     501          52 :   for (unsigned I = 0; I < getNumBlocks(); ++I) {
     502          21 :     emitLoadCompareBlockMultipleLoads(I, LoadIndex);
     503             :   }
     504             : 
     505          10 :   emitMemCmpResultBlock();
     506          10 :   return PhiRes;
     507             : }
     508             : 
     509             : /// A memcmp expansion that compares equality with 0 and only has one block of
     510             : /// load and compare can bypass the compare, branch, and phi IR that is required
     511             : /// in the general case.
     512         158 : Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() {
     513         158 :   unsigned LoadIndex = 0;
     514         158 :   Value *Cmp = getCompareLoadPairs(0, LoadIndex);
     515             :   assert(LoadIndex == getNumLoads() && "some entries were not consumed");
     516         316 :   return Builder.CreateZExt(Cmp, Type::getInt32Ty(CI->getContext()));
     517             : }
     518             : 
     519             : /// A memcmp expansion that only has one block of load and compare can bypass
     520             : /// the compare, branch, and phi IR that is required in the general case.
     521          39 : Value *MemCmpExpansion::getMemCmpOneBlock() {
     522          39 :   Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8);
     523          39 :   Value *Source1 = CI->getArgOperand(0);
     524             :   Value *Source2 = CI->getArgOperand(1);
     525             : 
     526             :   // Cast source to LoadSizeType*.
     527          39 :   if (Source1->getType() != LoadSizeType)
     528         117 :     Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
     529          39 :   if (Source2->getType() != LoadSizeType)
     530         117 :     Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
     531             : 
     532             :   // Load LoadSizeType from the base address.
     533          78 :   Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
     534          39 :   Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
     535             : 
     536          39 :   if (DL.isLittleEndian() && Size != 1) {
     537          37 :     Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
     538             :                                                 Intrinsic::bswap, LoadSizeType);
     539          37 :     LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1);
     540          37 :     LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2);
     541             :   }
     542             : 
     543          39 :   if (Size < 4) {
     544             :     // The i8 and i16 cases don't need compares. We zext the loaded values and
     545             :     // subtract them to get the suitable negative, zero, or positive i32 result.
     546          15 :     LoadSrc1 = Builder.CreateZExt(LoadSrc1, Builder.getInt32Ty());
     547          15 :     LoadSrc2 = Builder.CreateZExt(LoadSrc2, Builder.getInt32Ty());
     548          15 :     return Builder.CreateSub(LoadSrc1, LoadSrc2);
     549             :   }
     550             : 
     551             :   // The result of memcmp is negative, zero, or positive, so produce that by
     552             :   // subtracting 2 extended compare bits: sub (ugt, ult).
     553             :   // If a target prefers to use selects to get -1/0/1, they should be able
     554             :   // to transform this later. The inverse transform (going from selects to math)
     555             :   // may not be possible in the DAG because the selects got converted into
     556             :   // branches before we got there.
     557          24 :   Value *CmpUGT = Builder.CreateICmpUGT(LoadSrc1, LoadSrc2);
     558          24 :   Value *CmpULT = Builder.CreateICmpULT(LoadSrc1, LoadSrc2);
     559          48 :   Value *ZextUGT = Builder.CreateZExt(CmpUGT, Builder.getInt32Ty());
     560          48 :   Value *ZextULT = Builder.CreateZExt(CmpULT, Builder.getInt32Ty());
     561          24 :   return Builder.CreateSub(ZextUGT, ZextULT);
     562             : }
     563             : 
     564             : // This function expands the memcmp call into an inline expansion and returns
     565             : // the memcmp result.
     566         266 : Value *MemCmpExpansion::getMemCmpExpansion() {
     567             :   // Create the basic block framework for a multi-block expansion.
     568         266 :   if (getNumBlocks() != 1) {
     569          69 :     BasicBlock *StartBlock = CI->getParent();
     570          69 :     EndBlock = StartBlock->splitBasicBlock(CI, "endblock");
     571          69 :     setupEndBlockPHINodes();
     572          69 :     createResultBlock();
     573             : 
     574             :     // If return value of memcmp is not used in a zero equality, we need to
     575             :     // calculate which source was larger. The calculation requires the
     576             :     // two loaded source values of each load compare block.
     577             :     // These will be saved in the phi nodes created by setupResultBlockPHINodes.
     578          69 :     if (!IsUsedForZeroCmp) setupResultBlockPHINodes();
     579             : 
     580             :     // Create the number of required load compare basic blocks.
     581          69 :     createLoadCmpBlocks();
     582             : 
     583             :     // Update the terminator added by splitBasicBlock to branch to the first
     584             :     // LoadCmpBlock.
     585          69 :     StartBlock->getTerminator()->setSuccessor(0, LoadCmpBlocks[0]);
     586             :   }
     587             : 
     588         266 :   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
     589             : 
     590         266 :   if (IsUsedForZeroCmp)
     591         168 :     return getNumBlocks() == 1 ? getMemCmpEqZeroOneBlock()
     592          10 :                                : getMemCmpExpansionZeroCase();
     593             : 
     594          98 :   if (getNumBlocks() == 1)
     595          39 :     return getMemCmpOneBlock();
     596             : 
     597         303 :   for (unsigned I = 0; I < getNumBlocks(); ++I) {
     598         122 :     emitLoadCompareBlock(I);
     599             :   }
     600             : 
     601          59 :   emitMemCmpResultBlock();
     602          59 :   return PhiRes;
     603             : }
     604             : 
     605             : // This function checks to see if an expansion of memcmp can be generated.
     606             : // It checks for constant compare size that is less than the max inline size.
     607             : // If an expansion cannot occur, returns false to leave as a library call.
     608             : // Otherwise, the library call is replaced with a new IR instruction sequence.
     609             : /// We want to transform:
     610             : /// %call = call signext i32 @memcmp(i8* %0, i8* %1, i64 15)
     611             : /// To:
     612             : /// loadbb:
     613             : ///  %0 = bitcast i32* %buffer2 to i8*
     614             : ///  %1 = bitcast i32* %buffer1 to i8*
     615             : ///  %2 = bitcast i8* %1 to i64*
     616             : ///  %3 = bitcast i8* %0 to i64*
     617             : ///  %4 = load i64, i64* %2
     618             : ///  %5 = load i64, i64* %3
     619             : ///  %6 = call i64 @llvm.bswap.i64(i64 %4)
     620             : ///  %7 = call i64 @llvm.bswap.i64(i64 %5)
     621             : ///  %8 = sub i64 %6, %7
     622             : ///  %9 = icmp ne i64 %8, 0
     623             : ///  br i1 %9, label %res_block, label %loadbb1
     624             : /// res_block:                                        ; preds = %loadbb2,
     625             : /// %loadbb1, %loadbb
     626             : ///  %phi.src1 = phi i64 [ %6, %loadbb ], [ %22, %loadbb1 ], [ %36, %loadbb2 ]
     627             : ///  %phi.src2 = phi i64 [ %7, %loadbb ], [ %23, %loadbb1 ], [ %37, %loadbb2 ]
     628             : ///  %10 = icmp ult i64 %phi.src1, %phi.src2
     629             : ///  %11 = select i1 %10, i32 -1, i32 1
     630             : ///  br label %endblock
     631             : /// loadbb1:                                          ; preds = %loadbb
     632             : ///  %12 = bitcast i32* %buffer2 to i8*
     633             : ///  %13 = bitcast i32* %buffer1 to i8*
     634             : ///  %14 = bitcast i8* %13 to i32*
     635             : ///  %15 = bitcast i8* %12 to i32*
     636             : ///  %16 = getelementptr i32, i32* %14, i32 2
     637             : ///  %17 = getelementptr i32, i32* %15, i32 2
     638             : ///  %18 = load i32, i32* %16
     639             : ///  %19 = load i32, i32* %17
     640             : ///  %20 = call i32 @llvm.bswap.i32(i32 %18)
     641             : ///  %21 = call i32 @llvm.bswap.i32(i32 %19)
     642             : ///  %22 = zext i32 %20 to i64
     643             : ///  %23 = zext i32 %21 to i64
     644             : ///  %24 = sub i64 %22, %23
     645             : ///  %25 = icmp ne i64 %24, 0
     646             : ///  br i1 %25, label %res_block, label %loadbb2
     647             : /// loadbb2:                                          ; preds = %loadbb1
     648             : ///  %26 = bitcast i32* %buffer2 to i8*
     649             : ///  %27 = bitcast i32* %buffer1 to i8*
     650             : ///  %28 = bitcast i8* %27 to i16*
     651             : ///  %29 = bitcast i8* %26 to i16*
     652             : ///  %30 = getelementptr i16, i16* %28, i16 6
     653             : ///  %31 = getelementptr i16, i16* %29, i16 6
     654             : ///  %32 = load i16, i16* %30
     655             : ///  %33 = load i16, i16* %31
     656             : ///  %34 = call i16 @llvm.bswap.i16(i16 %32)
     657             : ///  %35 = call i16 @llvm.bswap.i16(i16 %33)
     658             : ///  %36 = zext i16 %34 to i64
     659             : ///  %37 = zext i16 %35 to i64
     660             : ///  %38 = sub i64 %36, %37
     661             : ///  %39 = icmp ne i64 %38, 0
     662             : ///  br i1 %39, label %res_block, label %loadbb3
     663             : /// loadbb3:                                          ; preds = %loadbb2
     664             : ///  %40 = bitcast i32* %buffer2 to i8*
     665             : ///  %41 = bitcast i32* %buffer1 to i8*
     666             : ///  %42 = getelementptr i8, i8* %41, i8 14
     667             : ///  %43 = getelementptr i8, i8* %40, i8 14
     668             : ///  %44 = load i8, i8* %42
     669             : ///  %45 = load i8, i8* %43
     670             : ///  %46 = zext i8 %44 to i32
     671             : ///  %47 = zext i8 %45 to i32
     672             : ///  %48 = sub i32 %46, %47
     673             : ///  br label %endblock
     674             : /// endblock:                                         ; preds = %res_block,
     675             : /// %loadbb3
     676             : ///  %phi.res = phi i32 [ %48, %loadbb3 ], [ %11, %res_block ]
     677             : ///  ret i32 %phi.res
     678         597 : static bool expandMemCmp(CallInst *CI, const TargetTransformInfo *TTI,
     679             :                          const TargetLowering *TLI, const DataLayout *DL) {
     680             :   NumMemCmpCalls++;
     681             : 
     682             :   // Early exit from expansion if -Oz.
     683        1194 :   if (CI->getFunction()->optForMinSize())
     684             :     return false;
     685             : 
     686             :   // Early exit from expansion if size is not a constant.
     687         489 :   ConstantInt *SizeCast = dyn_cast<ConstantInt>(CI->getArgOperand(2));
     688             :   if (!SizeCast) {
     689             :     NumMemCmpNotConstant++;
     690             :     return false;
     691             :   }
     692             :   const uint64_t SizeVal = SizeCast->getZExtValue();
     693             : 
     694         422 :   if (SizeVal == 0) {
     695             :     return false;
     696             :   }
     697             : 
     698             :   // TTI call to check if target would like to expand memcmp. Also, get the
     699             :   // available load sizes.
     700         409 :   const bool IsUsedForZeroCmp = isOnlyUsedInZeroEqualityComparison(CI);
     701         409 :   const auto *const Options = TTI->enableMemCmpExpansion(IsUsedForZeroCmp);
     702         409 :   if (!Options) return false;
     703             : 
     704             :   const unsigned MaxNumLoads =
     705         395 :       TLI->getMaxExpandSizeMemcmp(CI->getFunction()->optForSize());
     706             : 
     707         395 :   unsigned NumLoadsPerBlock = MemCmpEqZeroNumLoadsPerBlock.getNumOccurrences()
     708         395 :                                   ? MemCmpEqZeroNumLoadsPerBlock
     709         335 :                                   : TLI->getMemcmpEqZeroLoadsPerBlock();
     710             : 
     711             :   MemCmpExpansion Expansion(CI, SizeVal, *Options, MaxNumLoads,
     712         790 :                             IsUsedForZeroCmp, NumLoadsPerBlock, *DL);
     713             : 
     714             :   // Don't expand if this will require more loads than desired by the target.
     715         395 :   if (Expansion.getNumLoads() == 0) {
     716             :     NumMemCmpGreaterThanMax++;
     717             :     return false;
     718             :   }
     719             : 
     720             :   NumMemCmpInlined++;
     721             : 
     722         266 :   Value *Res = Expansion.getMemCmpExpansion();
     723             : 
     724             :   // Replace call with result of expansion and erase call.
     725         266 :   CI->replaceAllUsesWith(Res);
     726         266 :   CI->eraseFromParent();
     727             : 
     728         266 :   return true;
     729             : }
     730             : 
     731             : 
     732             : 
     733             : class ExpandMemCmpPass : public FunctionPass {
     734             : public:
     735             :   static char ID;
     736             : 
     737       20210 :   ExpandMemCmpPass() : FunctionPass(ID) {
     738       20210 :     initializeExpandMemCmpPassPass(*PassRegistry::getPassRegistry());
     739             :   }
     740             : 
     741      198167 :   bool runOnFunction(Function &F) override {
     742      198167 :     if (skipFunction(F)) return false;
     743             : 
     744      197980 :     auto *TPC = getAnalysisIfAvailable<TargetPassConfig>();
     745      197980 :     if (!TPC) {
     746             :       return false;
     747             :     }
     748             :     const TargetLowering* TL =
     749      197980 :         TPC->getTM<TargetMachine>().getSubtargetImpl(F)->getTargetLowering();
     750             : 
     751             :     const TargetLibraryInfo *TLI =
     752      197980 :         &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
     753             :     const TargetTransformInfo *TTI =
     754      197980 :         &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
     755      197980 :     auto PA = runImpl(F, TLI, TTI, TL);
     756      197980 :     return !PA.areAllPreserved();
     757             :   }
     758             : 
     759             : private:
     760       20077 :   void getAnalysisUsage(AnalysisUsage &AU) const override {
     761             :     AU.addRequired<TargetLibraryInfoWrapperPass>();
     762             :     AU.addRequired<TargetTransformInfoWrapperPass>();
     763       20077 :     FunctionPass::getAnalysisUsage(AU);
     764       20077 :   }
     765             : 
     766             :   PreservedAnalyses runImpl(Function &F, const TargetLibraryInfo *TLI,
     767             :                             const TargetTransformInfo *TTI,
     768             :                             const TargetLowering* TL);
     769             :   // Returns true if a change was made.
     770             :   bool runOnBlock(BasicBlock &BB, const TargetLibraryInfo *TLI,
     771             :                   const TargetTransformInfo *TTI, const TargetLowering* TL,
     772             :                   const DataLayout& DL);
     773             : };
     774             : 
     775           0 : bool ExpandMemCmpPass::runOnBlock(
     776             :     BasicBlock &BB, const TargetLibraryInfo *TLI,
     777             :     const TargetTransformInfo *TTI, const TargetLowering* TL,
     778             :     const DataLayout& DL) {
     779           0 :   for (Instruction& I : BB) {
     780             :     CallInst *CI = dyn_cast<CallInst>(&I);
     781           0 :     if (!CI) {
     782           0 :       continue;
     783             :     }
     784             :     LibFunc Func;
     785           0 :     if (TLI->getLibFunc(ImmutableCallSite(CI), Func) &&
     786           0 :         Func == LibFunc_memcmp && expandMemCmp(CI, TTI, TL, &DL)) {
     787           0 :       return true;
     788             :     }
     789             :   }
     790             :   return false;
     791             : }
     792             : 
     793             : 
     794           0 : PreservedAnalyses ExpandMemCmpPass::runImpl(
     795             :     Function &F, const TargetLibraryInfo *TLI, const TargetTransformInfo *TTI,
     796             :     const TargetLowering* TL) {
     797           0 :   const DataLayout& DL = F.getParent()->getDataLayout();
     798             :   bool MadeChanges = false;
     799           0 :   for (auto BBIt = F.begin(); BBIt != F.end();) {
     800           0 :     if (runOnBlock(*BBIt, TLI, TTI, TL, DL)) {
     801             :       MadeChanges = true;
     802             :       // If changes were made, restart the function from the beginning, since
     803             :       // the structure of the function was changed.
     804             :       BBIt = F.begin();
     805             :     } else {
     806             :       ++BBIt;
     807             :     }
     808             :   }
     809           0 :   return MadeChanges ? PreservedAnalyses::none() : PreservedAnalyses::all();
     810             : }
     811             : 
     812             : } // namespace
     813             : 
     814             : char ExpandMemCmpPass::ID = 0;
     815       39044 : INITIALIZE_PASS_BEGIN(ExpandMemCmpPass, "expandmemcmp",
     816             :                       "Expand memcmp() to load/stores", false, false)
     817       39044 : INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
     818       39044 : INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
     819      123360 : INITIALIZE_PASS_END(ExpandMemCmpPass, "expandmemcmp",
     820             :                     "Expand memcmp() to load/stores", false, false)
     821             : 
     822       20207 : FunctionPass *llvm::createExpandMemCmpPass() {
     823       20207 :   return new ExpandMemCmpPass();
     824             : }

Generated by: LCOV version 1.13