38 #define DEBUG_TYPE "expandmemcmp"
40 STATISTIC(NumMemCmpCalls,
"Number of memcmp calls");
41 STATISTIC(NumMemCmpNotConstant,
"Number of memcmp calls without constant size");
43 "Number of memcmp calls with size greater than max size");
44 STATISTIC(NumMemCmpInlined,
"Number of inlined memcmp calls");
48 cl::desc(
"The number of loads per basic block for inline expansion of "
49 "memcmp that is only being compared against zero."));
53 cl::desc(
"Set maximum number of loads used in expanded memcmp"));
57 cl::desc(
"Set maximum number of loads used in expanded memcmp for -Os/Oz"));
64 class MemCmpExpansion {
70 ResultBlock() =
default;
76 unsigned MaxLoadSize = 0;
78 const uint64_t NumLoadsPerBlockForZeroCmp;
79 std::vector<BasicBlock *> LoadCmpBlocks;
82 const bool IsUsedForZeroCmp;
90 LoadEntry(
unsigned LoadSize,
uint64_t Offset)
100 LoadEntryVector LoadSequence;
102 void createLoadCmpBlocks();
103 void createResultBlock();
104 void setupResultBlockPHINodes();
105 void setupEndBlockPHINodes();
106 Value *getCompareLoadPairs(
unsigned BlockIndex,
unsigned &LoadIndex);
107 void emitLoadCompareBlock(
unsigned BlockIndex);
108 void emitLoadCompareBlockMultipleLoads(
unsigned BlockIndex,
109 unsigned &LoadIndex);
110 void emitLoadCompareByteBlock(
unsigned BlockIndex,
unsigned OffsetBytes);
111 void emitMemCmpResultBlock();
112 Value *getMemCmpExpansionZeroCase();
113 Value *getMemCmpEqZeroOneBlock();
114 Value *getMemCmpOneBlock();
116 Value *Lhs =
nullptr;
117 Value *Rhs =
nullptr;
119 LoadPair getLoadPair(
Type *LoadSizeType,
bool NeedsBSwap,
Type *CmpSizeType,
120 unsigned OffsetBytes);
122 static LoadEntryVector
124 unsigned MaxNumLoads,
unsigned &NumLoadsNonOneByte);
125 static LoadEntryVector
126 computeOverlappingLoadSequence(
uint64_t Size,
unsigned MaxLoadSize,
127 unsigned MaxNumLoads,
128 unsigned &NumLoadsNonOneByte);
133 const bool IsUsedForZeroCmp,
const DataLayout &TheDataLayout,
136 unsigned getNumBlocks();
137 uint64_t getNumLoads()
const {
return LoadSequence.size(); }
139 Value *getMemCmpExpansion();
144 const unsigned MaxNumLoads,
unsigned &NumLoadsNonOneByte) {
145 NumLoadsNonOneByte = 0;
146 LoadEntryVector LoadSequence;
148 while (Size && !LoadSizes.
empty()) {
149 const unsigned LoadSize = LoadSizes.
front();
151 if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) {
158 if (NumLoadsForThisSize > 0) {
159 for (
uint64_t I = 0;
I < NumLoadsForThisSize; ++
I) {
160 LoadSequence.push_back({LoadSize,
Offset});
164 ++NumLoadsNonOneByte;
173 MemCmpExpansion::computeOverlappingLoadSequence(
uint64_t Size,
174 const unsigned MaxLoadSize,
175 const unsigned MaxNumLoads,
176 unsigned &NumLoadsNonOneByte) {
178 if (Size < 2 || MaxLoadSize < 2)
183 const uint64_t NumNonOverlappingLoads =
Size / MaxLoadSize;
184 assert(NumNonOverlappingLoads &&
"there must be at least one load");
187 Size =
Size - NumNonOverlappingLoads * MaxLoadSize;
194 if ((NumNonOverlappingLoads + 1) > MaxNumLoads)
198 LoadEntryVector LoadSequence;
200 for (
uint64_t I = 0;
I < NumNonOverlappingLoads; ++
I) {
201 LoadSequence.push_back({MaxLoadSize,
Offset});
206 assert(Size > 0 && Size < MaxLoadSize &&
"broken invariant");
207 LoadSequence.push_back({MaxLoadSize,
Offset - (MaxLoadSize -
Size)});
208 NumLoadsNonOneByte = 1;
220 MemCmpExpansion::MemCmpExpansion(
223 const bool IsUsedForZeroCmp,
const DataLayout &TheDataLayout,
225 : CI(CI),
Size(
Size), NumLoadsPerBlockForZeroCmp(
Options.NumLoadsPerBlock),
226 IsUsedForZeroCmp(IsUsedForZeroCmp),
DL(TheDataLayout), DTU(DTU),
234 assert(!LoadSizes.
empty() &&
"cannot load Size bytes");
235 MaxLoadSize = LoadSizes.
front();
237 unsigned GreedyNumLoadsNonOneByte = 0;
238 LoadSequence = computeGreedyLoadSequence(
Size, LoadSizes,
Options.MaxNumLoads,
239 GreedyNumLoadsNonOneByte);
240 NumLoadsNonOneByte = GreedyNumLoadsNonOneByte;
241 assert(LoadSequence.size() <=
Options.MaxNumLoads &&
"broken invariant");
244 if (
Options.AllowOverlappingLoads &&
245 (LoadSequence.empty() || LoadSequence.size() > 2)) {
246 unsigned OverlappingNumLoadsNonOneByte = 0;
247 auto OverlappingLoads = computeOverlappingLoadSequence(
248 Size, MaxLoadSize,
Options.MaxNumLoads, OverlappingNumLoadsNonOneByte);
249 if (!OverlappingLoads.empty() &&
250 (LoadSequence.empty() ||
251 OverlappingLoads.size() < LoadSequence.size())) {
252 LoadSequence = OverlappingLoads;
253 NumLoadsNonOneByte = OverlappingNumLoadsNonOneByte;
256 assert(LoadSequence.size() <=
Options.MaxNumLoads &&
"broken invariant");
259 unsigned MemCmpExpansion::getNumBlocks() {
260 if (IsUsedForZeroCmp)
261 return getNumLoads() / NumLoadsPerBlockForZeroCmp +
262 (getNumLoads() % NumLoadsPerBlockForZeroCmp != 0 ? 1 : 0);
263 return getNumLoads();
266 void MemCmpExpansion::createLoadCmpBlocks() {
267 for (
unsigned i = 0;
i < getNumBlocks();
i++) {
270 LoadCmpBlocks.push_back(
BB);
274 void MemCmpExpansion::createResultBlock() {
275 ResBlock.BB = BasicBlock::Create(CI->
getContext(),
"res_block",
279 MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair(
Type *LoadSizeType,
282 unsigned OffsetBytes) {
288 if (OffsetBytes > 0) {
289 auto *ByteType = Type::getInt8Ty(CI->
getContext());
290 LhsSource =
Builder.CreateConstGEP1_64(
291 ByteType,
Builder.CreateBitCast(LhsSource, ByteType->getPointerTo()),
293 RhsSource =
Builder.CreateConstGEP1_64(
294 ByteType,
Builder.CreateBitCast(RhsSource, ByteType->getPointerTo()),
303 Value *Lhs =
nullptr;
304 if (
auto *
C = dyn_cast<Constant>(LhsSource))
307 Lhs =
Builder.CreateAlignedLoad(LoadSizeType, LhsSource, LhsAlign);
309 Value *Rhs =
nullptr;
310 if (
auto *
C = dyn_cast<Constant>(RhsSource))
313 Rhs =
Builder.CreateAlignedLoad(LoadSizeType, RhsSource, RhsAlign);
318 Intrinsic::bswap, LoadSizeType);
319 Lhs =
Builder.CreateCall(Bswap, Lhs);
320 Rhs =
Builder.CreateCall(Bswap, Rhs);
324 if (CmpSizeType !=
nullptr && CmpSizeType != LoadSizeType) {
325 Lhs =
Builder.CreateZExt(Lhs, CmpSizeType);
326 Rhs =
Builder.CreateZExt(Rhs, CmpSizeType);
335 void MemCmpExpansion::emitLoadCompareByteBlock(
unsigned BlockIndex,
336 unsigned OffsetBytes) {
339 const LoadPair Loads =
340 getLoadPair(Type::getInt8Ty(CI->
getContext()),
false,
341 Type::getInt32Ty(CI->
getContext()), OffsetBytes);
346 if (BlockIndex < (LoadCmpBlocks.size() - 1)) {
352 BranchInst::Create(EndBlock, LoadCmpBlocks[BlockIndex + 1], Cmp);
360 BranchInst *CmpBr = BranchInst::Create(EndBlock);
370 Value *MemCmpExpansion::getCompareLoadPairs(
unsigned BlockIndex,
371 unsigned &LoadIndex) {
372 assert(LoadIndex < getNumLoads() &&
373 "getCompareLoadPairs() called with no remaining loads");
374 std::vector<Value *> XorList, OrList;
375 Value *Diff =
nullptr;
377 const unsigned NumLoads =
378 std::min(getNumLoads() - LoadIndex, NumLoadsPerBlockForZeroCmp);
381 if (LoadCmpBlocks.empty())
384 Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
391 NumLoads == 1 ? nullptr
393 for (
unsigned i = 0;
i < NumLoads; ++
i, ++LoadIndex) {
394 const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex];
395 const LoadPair Loads = getLoadPair(
397 false, MaxLoadType, CurLoadEntry.Offset);
402 Diff =
Builder.CreateXor(Loads.Lhs, Loads.Rhs);
403 Diff =
Builder.CreateZExt(Diff, MaxLoadType);
404 XorList.push_back(Diff);
407 Cmp =
Builder.CreateICmpNE(Loads.Lhs, Loads.Rhs);
411 auto pairWiseOr = [&](std::vector<Value *> &InList) -> std::vector<Value *> {
412 std::vector<Value *> OutList;
413 for (
unsigned i = 0;
i < InList.size() - 1;
i =
i + 2) {
415 OutList.push_back(Or);
417 if (InList.size() % 2 != 0)
418 OutList.push_back(InList.back());
424 OrList = pairWiseOr(XorList);
427 while (OrList.size() != 1) {
428 OrList = pairWiseOr(OrList);
431 assert(Diff &&
"Failed to find comparison diff");
438 void MemCmpExpansion::emitLoadCompareBlockMultipleLoads(
unsigned BlockIndex,
439 unsigned &LoadIndex) {
440 Value *
Cmp = getCompareLoadPairs(BlockIndex, LoadIndex);
442 BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
444 : LoadCmpBlocks[BlockIndex + 1];
448 BranchInst *CmpBr = BranchInst::Create(ResBlock.BB, NextBB, Cmp);
457 if (BlockIndex == LoadCmpBlocks.size() - 1) {
459 PhiRes->
addIncoming(Zero, LoadCmpBlocks[BlockIndex]);
472 void MemCmpExpansion::emitLoadCompareBlock(
unsigned BlockIndex) {
474 const LoadEntry &CurLoadEntry = LoadSequence[BlockIndex];
476 if (CurLoadEntry.LoadSize == 1) {
477 MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex, CurLoadEntry.Offset);
484 assert(CurLoadEntry.LoadSize <= MaxLoadSize &&
"Unexpected load type");
486 Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
488 const LoadPair Loads =
489 getLoadPair(LoadSizeType,
DL.isLittleEndian(), MaxLoadType,
490 CurLoadEntry.Offset);
494 if (!IsUsedForZeroCmp) {
495 ResBlock.PhiSrc1->addIncoming(Loads.Lhs, LoadCmpBlocks[BlockIndex]);
496 ResBlock.PhiSrc2->addIncoming(Loads.Rhs, LoadCmpBlocks[BlockIndex]);
499 Value *
Cmp =
Builder.CreateICmp(ICmpInst::ICMP_EQ, Loads.Lhs, Loads.Rhs);
500 BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
502 : LoadCmpBlocks[BlockIndex + 1];
506 BranchInst *CmpBr = BranchInst::Create(NextBB, ResBlock.BB, Cmp);
515 if (BlockIndex == LoadCmpBlocks.size() - 1) {
517 PhiRes->
addIncoming(Zero, LoadCmpBlocks[BlockIndex]);
524 void MemCmpExpansion::emitMemCmpResultBlock() {
527 if (IsUsedForZeroCmp) {
529 Builder.SetInsertPoint(ResBlock.BB, InsertPt);
532 BranchInst *NewBr = BranchInst::Create(EndBlock);
539 Builder.SetInsertPoint(ResBlock.BB, InsertPt);
549 BranchInst *NewBr = BranchInst::Create(EndBlock);
555 void MemCmpExpansion::setupResultBlockPHINodes() {
557 Builder.SetInsertPoint(ResBlock.BB);
560 Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte,
"phi.src1");
562 Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte,
"phi.src2");
565 void MemCmpExpansion::setupEndBlockPHINodes() {
570 Value *MemCmpExpansion::getMemCmpExpansionZeroCase() {
571 unsigned LoadIndex = 0;
574 for (
unsigned I = 0;
I < getNumBlocks(); ++
I) {
575 emitLoadCompareBlockMultipleLoads(
I, LoadIndex);
578 emitMemCmpResultBlock();
585 Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() {
586 unsigned LoadIndex = 0;
587 Value *
Cmp = getCompareLoadPairs(0, LoadIndex);
588 assert(LoadIndex == getNumLoads() &&
"some entries were not consumed");
594 Value *MemCmpExpansion::getMemCmpOneBlock() {
596 bool NeedsBSwap =
DL.isLittleEndian() &&
Size != 1;
601 const LoadPair Loads =
602 getLoadPair(LoadSizeType, NeedsBSwap,
Builder.getInt32Ty(),
604 return Builder.CreateSub(Loads.Lhs, Loads.Rhs);
607 const LoadPair Loads = getLoadPair(LoadSizeType, NeedsBSwap, LoadSizeType,
615 Value *CmpUGT =
Builder.CreateICmpUGT(Loads.Lhs, Loads.Rhs);
616 Value *CmpULT =
Builder.CreateICmpULT(Loads.Lhs, Loads.Rhs);
619 return Builder.CreateSub(ZextUGT, ZextULT);
624 Value *MemCmpExpansion::getMemCmpExpansion() {
626 if (getNumBlocks() != 1) {
628 EndBlock =
SplitBlock(StartBlock, CI, DTU,
nullptr,
629 nullptr,
"endblock");
630 setupEndBlockPHINodes();
637 if (!IsUsedForZeroCmp) setupResultBlockPHINodes();
640 createLoadCmpBlocks();
647 {DominatorTree::Delete, StartBlock, EndBlock}});
652 if (IsUsedForZeroCmp)
653 return getNumBlocks() == 1 ? getMemCmpEqZeroOneBlock()
654 : getMemCmpExpansionZeroCase();
656 if (getNumBlocks() == 1)
657 return getMemCmpOneBlock();
659 for (
unsigned I = 0;
I < getNumBlocks(); ++
I) {
660 emitLoadCompareBlock(
I);
663 emitMemCmpResultBlock();
753 NumMemCmpNotConstant++;
763 const bool IsUsedForZeroCmp =
781 MemCmpExpansion Expansion(CI, SizeVal,
Options, IsUsedForZeroCmp, *
DL, DTU);
784 if (Expansion.getNumLoads() == 0) {
785 NumMemCmpGreaterThanMax++;
791 Value *Res = Expansion.getMemCmpExpansion();
811 auto *TPC = getAnalysisIfAvailable<TargetPassConfig>();
816 TPC->getTM<
TargetMachine>().getSubtargetImpl(
F)->getTargetLowering();
819 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
F);
821 &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F);
822 auto *PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI();
824 &getAnalysis<LazyBlockFrequencyInfoPass>().getBFI() :
827 if (
auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>())
828 DT = &DTWP->getDomTree();
830 return !PA.areAllPreserved();
867 (Func == LibFunc_memcmp || Func == LibFunc_bcmp) &&
868 expandMemCmp(CI,
TTI, TL, &
DL, PSI,
BFI, DTU, Func == LibFunc_bcmp)) {
882 DTU.
emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy);
885 bool MadeChanges =
false;
886 for (
auto BBIt =
F.begin(); BBIt !=
F.end();) {
887 if (runOnBlock(*BBIt, TLI,
TTI, TL,
DL, PSI,
BFI,
911 "Expand memcmp() to load/stores",
false,
false)
921 return new ExpandMemCmpPass();