Line data Source code
1 : //===--- ExpandMemCmp.cpp - Expand memcmp() to load/stores ----------------===//
2 : //
3 : // The LLVM Compiler Infrastructure
4 : //
5 : // This file is distributed under the University of Illinois Open Source
6 : // License. See LICENSE.TXT for details.
7 : //
8 : //===----------------------------------------------------------------------===//
9 : //
10 : // This pass tries to expand memcmp() calls into optimally-sized loads and
11 : // compares for the target.
12 : //
13 : //===----------------------------------------------------------------------===//
14 :
15 : #include "llvm/ADT/Statistic.h"
16 : #include "llvm/Analysis/ConstantFolding.h"
17 : #include "llvm/Analysis/TargetLibraryInfo.h"
18 : #include "llvm/Analysis/TargetTransformInfo.h"
19 : #include "llvm/Analysis/ValueTracking.h"
20 : #include "llvm/CodeGen/TargetLowering.h"
21 : #include "llvm/CodeGen/TargetPassConfig.h"
22 : #include "llvm/CodeGen/TargetSubtargetInfo.h"
23 : #include "llvm/IR/IRBuilder.h"
24 :
25 : using namespace llvm;
26 :
27 : #define DEBUG_TYPE "expandmemcmp"
28 :
29 : STATISTIC(NumMemCmpCalls, "Number of memcmp calls");
30 : STATISTIC(NumMemCmpNotConstant, "Number of memcmp calls without constant size");
31 : STATISTIC(NumMemCmpGreaterThanMax,
32 : "Number of memcmp calls with size greater than max size");
33 : STATISTIC(NumMemCmpInlined, "Number of inlined memcmp calls");
34 :
35 : static cl::opt<unsigned> MemCmpEqZeroNumLoadsPerBlock(
36 : "memcmp-num-loads-per-block", cl::Hidden, cl::init(1),
37 : cl::desc("The number of loads per basic block for inline expansion of "
38 : "memcmp that is only being compared against zero."));
39 :
40 : namespace {
41 :
42 :
43 : // This class provides helper functions to expand a memcmp library call into an
44 : // inline expansion.
45 : class MemCmpExpansion {
46 : struct ResultBlock {
47 : BasicBlock *BB = nullptr;
48 : PHINode *PhiSrc1 = nullptr;
49 : PHINode *PhiSrc2 = nullptr;
50 :
51 395 : ResultBlock() = default;
52 : };
53 :
54 : CallInst *const CI;
55 : ResultBlock ResBlock;
56 : const uint64_t Size;
57 : unsigned MaxLoadSize;
58 : uint64_t NumLoadsNonOneByte;
59 : const uint64_t NumLoadsPerBlockForZeroCmp;
60 : std::vector<BasicBlock *> LoadCmpBlocks;
61 : BasicBlock *EndBlock;
62 : PHINode *PhiRes;
63 : const bool IsUsedForZeroCmp;
64 : const DataLayout &DL;
65 : IRBuilder<> Builder;
66 : // Represents the decomposition in blocks of the expansion. For example,
67 : // comparing 33 bytes on X86+sse can be done with 2x16-byte loads and
68 : // 1x1-byte load, which would be represented as [{16, 0}, {16, 16}, {32, 1}.
69 : // TODO(courbet): Involve the target more in this computation. On X86, 7
70 : // bytes can be done more efficiently with two overlaping 4-byte loads than
71 : // covering the interval with [{4, 0},{2, 4},{1, 6}}.
72 : struct LoadEntry {
73 : LoadEntry(unsigned LoadSize, uint64_t Offset)
74 489 : : LoadSize(LoadSize), Offset(Offset) {
75 : assert(Offset % LoadSize == 0 && "invalid load entry");
76 : }
77 :
78 264 : uint64_t getGEPIndex() const { return Offset / LoadSize; }
79 :
80 : // The size of the load for this block, in bytes.
81 : const unsigned LoadSize;
82 : // The offset of this load WRT the base pointer, in bytes.
83 : const uint64_t Offset;
84 : };
85 : SmallVector<LoadEntry, 8> LoadSequence;
86 :
87 : void createLoadCmpBlocks();
88 : void createResultBlock();
89 : void setupResultBlockPHINodes();
90 : void setupEndBlockPHINodes();
91 : Value *getCompareLoadPairs(unsigned BlockIndex, unsigned &LoadIndex);
92 : void emitLoadCompareBlock(unsigned BlockIndex);
93 : void emitLoadCompareBlockMultipleLoads(unsigned BlockIndex,
94 : unsigned &LoadIndex);
95 : void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned GEPIndex);
96 : void emitMemCmpResultBlock();
97 : Value *getMemCmpExpansionZeroCase();
98 : Value *getMemCmpEqZeroOneBlock();
99 : Value *getMemCmpOneBlock();
100 :
101 : public:
102 : MemCmpExpansion(CallInst *CI, uint64_t Size,
103 : const TargetTransformInfo::MemCmpExpansionOptions &Options,
104 : unsigned MaxNumLoads, const bool IsUsedForZeroCmp,
105 : unsigned MaxLoadsPerBlockForZeroCmp, const DataLayout &TheDataLayout);
106 :
107 : unsigned getNumBlocks();
108 858 : uint64_t getNumLoads() const { return LoadSequence.size(); }
109 :
110 : Value *getMemCmpExpansion();
111 : };
112 :
113 : // Initialize the basic block structure required for expansion of memcmp call
114 : // with given maximum load size and memcmp size parameter.
115 : // This structure includes:
116 : // 1. A list of load compare blocks - LoadCmpBlocks.
117 : // 2. An EndBlock, split from original instruction point, which is the block to
118 : // return from.
119 : // 3. ResultBlock, block to branch to for early exit when a
120 : // LoadCmpBlock finds a difference.
121 395 : MemCmpExpansion::MemCmpExpansion(
122 : CallInst *const CI, uint64_t Size,
123 : const TargetTransformInfo::MemCmpExpansionOptions &Options,
124 : const unsigned MaxNumLoads, const bool IsUsedForZeroCmp,
125 395 : const unsigned MaxLoadsPerBlockForZeroCmp, const DataLayout &TheDataLayout)
126 : : CI(CI),
127 : Size(Size),
128 : MaxLoadSize(0),
129 : NumLoadsNonOneByte(0),
130 : NumLoadsPerBlockForZeroCmp(MaxLoadsPerBlockForZeroCmp),
131 : IsUsedForZeroCmp(IsUsedForZeroCmp),
132 : DL(TheDataLayout),
133 1185 : Builder(CI) {
134 : assert(Size > 0 && "zero blocks");
135 : // Scale the max size down if the target can load more bytes than we need.
136 : size_t LoadSizeIndex = 0;
137 1340 : while (LoadSizeIndex < Options.LoadSizes.size() &&
138 670 : Options.LoadSizes[LoadSizeIndex] > Size) {
139 275 : ++LoadSizeIndex;
140 : }
141 395 : this->MaxLoadSize = Options.LoadSizes[LoadSizeIndex];
142 : // Compute the decomposition.
143 : uint64_t CurSize = Size;
144 : uint64_t Offset = 0;
145 868 : while (CurSize && LoadSizeIndex < Options.LoadSizes.size()) {
146 602 : const unsigned LoadSize = Options.LoadSizes[LoadSizeIndex];
147 : assert(LoadSize > 0 && "zero load size");
148 602 : const uint64_t NumLoadsForThisSize = CurSize / LoadSize;
149 1204 : if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) {
150 : // Do not expand if the total number of loads is larger than what the
151 : // target allows. Note that it's important that we exit before completing
152 : // the expansion to avoid using a ton of memory to store the expansion for
153 : // large sizes.
154 : LoadSequence.clear();
155 129 : return;
156 : }
157 473 : if (NumLoadsForThisSize > 0) {
158 910 : for (uint64_t I = 0; I < NumLoadsForThisSize; ++I) {
159 489 : LoadSequence.push_back({LoadSize, Offset});
160 489 : Offset += LoadSize;
161 : }
162 421 : if (LoadSize > 1) {
163 360 : ++NumLoadsNonOneByte;
164 : }
165 421 : CurSize = CurSize % LoadSize;
166 : }
167 473 : ++LoadSizeIndex;
168 : }
169 : assert(LoadSequence.size() <= MaxNumLoads && "broken invariant");
170 : }
171 :
172 : unsigned MemCmpExpansion::getNumBlocks() {
173 424 : if (IsUsedForZeroCmp)
174 796 : return getNumLoads() / NumLoadsPerBlockForZeroCmp +
175 551 : (getNumLoads() % NumLoadsPerBlockForZeroCmp != 0 ? 1 : 0);
176 : return getNumLoads();
177 : }
178 :
179 69 : void MemCmpExpansion::createLoadCmpBlocks() {
180 355 : for (unsigned i = 0; i < getNumBlocks(); i++) {
181 286 : BasicBlock *BB = BasicBlock::Create(CI->getContext(), "loadbb",
182 143 : EndBlock->getParent(), EndBlock);
183 143 : LoadCmpBlocks.push_back(BB);
184 : }
185 69 : }
186 :
187 69 : void MemCmpExpansion::createResultBlock() {
188 138 : ResBlock.BB = BasicBlock::Create(CI->getContext(), "res_block",
189 : EndBlock->getParent(), EndBlock);
190 69 : }
191 :
192 : // This function creates the IR instructions for loading and comparing 1 byte.
193 : // It loads 1 byte from each source of the memcmp parameters with the given
194 : // GEPIndex. It then subtracts the two loaded values and adds this result to the
195 : // final phi node for selecting the memcmp result.
196 30 : void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex,
197 : unsigned GEPIndex) {
198 30 : Value *Source1 = CI->getArgOperand(0);
199 : Value *Source2 = CI->getArgOperand(1);
200 :
201 60 : Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
202 30 : Type *LoadSizeType = Type::getInt8Ty(CI->getContext());
203 : // Cast source to LoadSizeType*.
204 30 : if (Source1->getType() != LoadSizeType)
205 90 : Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
206 30 : if (Source2->getType() != LoadSizeType)
207 90 : Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
208 :
209 : // Get the base address using the GEPIndex.
210 30 : if (GEPIndex != 0) {
211 30 : Source1 = Builder.CreateGEP(LoadSizeType, Source1,
212 30 : ConstantInt::get(LoadSizeType, GEPIndex));
213 30 : Source2 = Builder.CreateGEP(LoadSizeType, Source2,
214 30 : ConstantInt::get(LoadSizeType, GEPIndex));
215 : }
216 :
217 60 : Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
218 30 : Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
219 :
220 60 : LoadSrc1 = Builder.CreateZExt(LoadSrc1, Type::getInt32Ty(CI->getContext()));
221 60 : LoadSrc2 = Builder.CreateZExt(LoadSrc2, Type::getInt32Ty(CI->getContext()));
222 30 : Value *Diff = Builder.CreateSub(LoadSrc1, LoadSrc2);
223 :
224 60 : PhiRes->addIncoming(Diff, LoadCmpBlocks[BlockIndex]);
225 :
226 60 : if (BlockIndex < (LoadCmpBlocks.size() - 1)) {
227 : // Early exit branch if difference found to EndBlock. Otherwise, continue to
228 : // next LoadCmpBlock,
229 0 : Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_NE, Diff,
230 0 : ConstantInt::get(Diff->getType(), 0));
231 : BranchInst *CmpBr =
232 0 : BranchInst::Create(EndBlock, LoadCmpBlocks[BlockIndex + 1], Cmp);
233 0 : Builder.Insert(CmpBr);
234 : } else {
235 : // The last block has an unconditional branch to EndBlock.
236 30 : BranchInst *CmpBr = BranchInst::Create(EndBlock);
237 30 : Builder.Insert(CmpBr);
238 : }
239 30 : }
240 :
241 : /// Generate an equality comparison for one or more pairs of loaded values.
242 : /// This is used in the case where the memcmp() call is compared equal or not
243 : /// equal to zero.
244 179 : Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex,
245 : unsigned &LoadIndex) {
246 : assert(LoadIndex < getNumLoads() &&
247 : "getCompareLoadPairs() called with no remaining loads");
248 : std::vector<Value *> XorList, OrList;
249 : Value *Diff;
250 :
251 : const unsigned NumLoads =
252 369 : std::min(getNumLoads() - LoadIndex, NumLoadsPerBlockForZeroCmp);
253 :
254 : // For a single-block expansion, start inserting before the memcmp call.
255 179 : if (LoadCmpBlocks.empty())
256 158 : Builder.SetInsertPoint(CI);
257 : else
258 42 : Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
259 :
260 : Value *Cmp = nullptr;
261 : // If we have multiple loads per block, we need to generate a composite
262 : // comparison using xor+or. The type for the combinations is the largest load
263 : // type.
264 : IntegerType *const MaxLoadType =
265 179 : NumLoads == 1 ? nullptr
266 73 : : IntegerType::get(CI->getContext(), MaxLoadSize * 8);
267 431 : for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) {
268 252 : const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex];
269 :
270 : IntegerType *LoadSizeType =
271 252 : IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
272 :
273 252 : Value *Source1 = CI->getArgOperand(0);
274 : Value *Source2 = CI->getArgOperand(1);
275 :
276 : // Cast source to LoadSizeType*.
277 252 : if (Source1->getType() != LoadSizeType)
278 756 : Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
279 252 : if (Source2->getType() != LoadSizeType)
280 756 : Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
281 :
282 : // Get the base address using a GEP.
283 252 : if (CurLoadEntry.Offset != 0) {
284 84 : Source1 = Builder.CreateGEP(
285 : LoadSizeType, Source1,
286 84 : ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
287 84 : Source2 = Builder.CreateGEP(
288 : LoadSizeType, Source2,
289 84 : ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
290 : }
291 :
292 : // Get a constant or load a value for each source address.
293 : Value *LoadSrc1 = nullptr;
294 : if (auto *Source1C = dyn_cast<Constant>(Source1))
295 4 : LoadSrc1 = ConstantFoldLoadFromConstPtr(Source1C, LoadSizeType, DL);
296 252 : if (!LoadSrc1)
297 496 : LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
298 :
299 : Value *LoadSrc2 = nullptr;
300 : if (auto *Source2C = dyn_cast<Constant>(Source2))
301 76 : LoadSrc2 = ConstantFoldLoadFromConstPtr(Source2C, LoadSizeType, DL);
302 76 : if (!LoadSrc2)
303 356 : LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
304 :
305 252 : if (NumLoads != 1) {
306 146 : if (LoadSizeType != MaxLoadType) {
307 45 : LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType);
308 45 : LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType);
309 : }
310 : // If we have multiple loads per block, we need to generate a composite
311 : // comparison using xor+or.
312 292 : Diff = Builder.CreateXor(LoadSrc1, LoadSrc2);
313 146 : Diff = Builder.CreateZExt(Diff, MaxLoadType);
314 146 : XorList.push_back(Diff);
315 : } else {
316 : // If there's only one load per block, we just compare the loaded values.
317 106 : Cmp = Builder.CreateICmpNE(LoadSrc1, LoadSrc2);
318 : }
319 : }
320 :
321 : auto pairWiseOr = [&](std::vector<Value *> &InList) -> std::vector<Value *> {
322 : std::vector<Value *> OutList;
323 : for (unsigned i = 0; i < InList.size() - 1; i = i + 2) {
324 : Value *Or = Builder.CreateOr(InList[i], InList[i + 1]);
325 : OutList.push_back(Or);
326 : }
327 : if (InList.size() % 2 != 0)
328 : OutList.push_back(InList.back());
329 : return OutList;
330 179 : };
331 :
332 179 : if (!Cmp) {
333 : // Pairwise OR the XOR results.
334 146 : OrList = pairWiseOr(XorList);
335 :
336 : // Pairwise OR the OR results until one result left.
337 146 : while (OrList.size() != 1) {
338 0 : OrList = pairWiseOr(OrList);
339 : }
340 146 : Cmp = Builder.CreateICmpNE(OrList[0], ConstantInt::get(Diff->getType(), 0));
341 : }
342 :
343 179 : return Cmp;
344 : }
345 :
346 21 : void MemCmpExpansion::emitLoadCompareBlockMultipleLoads(unsigned BlockIndex,
347 : unsigned &LoadIndex) {
348 21 : Value *Cmp = getCompareLoadPairs(BlockIndex, LoadIndex);
349 :
350 21 : BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
351 21 : ? EndBlock
352 11 : : LoadCmpBlocks[BlockIndex + 1];
353 : // Early exit branch if difference found to ResultBlock. Otherwise,
354 : // continue to next LoadCmpBlock or EndBlock.
355 21 : BranchInst *CmpBr = BranchInst::Create(ResBlock.BB, NextBB, Cmp);
356 42 : Builder.Insert(CmpBr);
357 :
358 : // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0
359 : // since early exit to ResultBlock was not taken (no difference was found in
360 : // any of the bytes).
361 42 : if (BlockIndex == LoadCmpBlocks.size() - 1) {
362 10 : Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0);
363 20 : PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]);
364 : }
365 21 : }
366 :
367 : // This function creates the IR intructions for loading and comparing using the
368 : // given LoadSize. It loads the number of bytes specified by LoadSize from each
369 : // source of the memcmp parameters. It then does a subtract to see if there was
370 : // a difference in the loaded values. If a difference is found, it branches
371 : // with an early exit to the ResultBlock for calculating which source was
372 : // larger. Otherwise, it falls through to the either the next LoadCmpBlock or
373 : // the EndBlock if this is the last LoadCmpBlock. Loading 1 byte is handled with
374 : // a special case through emitLoadCompareByteBlock. The special handling can
375 : // simply subtract the loaded values and add it to the result phi node.
376 122 : void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) {
377 : // There is one load per block in this case, BlockIndex == LoadIndex.
378 122 : const LoadEntry &CurLoadEntry = LoadSequence[BlockIndex];
379 :
380 122 : if (CurLoadEntry.LoadSize == 1) {
381 30 : MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex,
382 30 : CurLoadEntry.getGEPIndex());
383 30 : return;
384 : }
385 :
386 : Type *LoadSizeType =
387 92 : IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
388 92 : Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
389 : assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type");
390 :
391 92 : Value *Source1 = CI->getArgOperand(0);
392 : Value *Source2 = CI->getArgOperand(1);
393 :
394 184 : Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
395 : // Cast source to LoadSizeType*.
396 92 : if (Source1->getType() != LoadSizeType)
397 276 : Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
398 92 : if (Source2->getType() != LoadSizeType)
399 276 : Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
400 :
401 : // Get the base address using a GEP.
402 92 : if (CurLoadEntry.Offset != 0) {
403 33 : Source1 = Builder.CreateGEP(
404 : LoadSizeType, Source1,
405 33 : ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
406 33 : Source2 = Builder.CreateGEP(
407 : LoadSizeType, Source2,
408 33 : ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
409 : }
410 :
411 : // Load LoadSizeType from the base address.
412 184 : Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
413 92 : Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
414 :
415 92 : if (DL.isLittleEndian()) {
416 87 : Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
417 : Intrinsic::bswap, LoadSizeType);
418 87 : LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1);
419 87 : LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2);
420 : }
421 :
422 92 : if (LoadSizeType != MaxLoadType) {
423 16 : LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType);
424 16 : LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType);
425 : }
426 :
427 : // Add the loaded values to the phi nodes for calculating memcmp result only
428 : // if result is not used in a zero equality.
429 92 : if (!IsUsedForZeroCmp) {
430 184 : ResBlock.PhiSrc1->addIncoming(LoadSrc1, LoadCmpBlocks[BlockIndex]);
431 184 : ResBlock.PhiSrc2->addIncoming(LoadSrc2, LoadCmpBlocks[BlockIndex]);
432 : }
433 :
434 92 : Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, LoadSrc1, LoadSrc2);
435 92 : BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
436 92 : ? EndBlock
437 63 : : LoadCmpBlocks[BlockIndex + 1];
438 : // Early exit branch if difference found to ResultBlock. Otherwise, continue
439 : // to next LoadCmpBlock or EndBlock.
440 92 : BranchInst *CmpBr = BranchInst::Create(NextBB, ResBlock.BB, Cmp);
441 92 : Builder.Insert(CmpBr);
442 :
443 : // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0
444 : // since early exit to ResultBlock was not taken (no difference was found in
445 : // any of the bytes).
446 184 : if (BlockIndex == LoadCmpBlocks.size() - 1) {
447 29 : Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0);
448 58 : PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]);
449 : }
450 : }
451 :
452 : // This function populates the ResultBlock with a sequence to calculate the
453 : // memcmp result. It compares the two loaded source values and returns -1 if
454 : // src1 < src2 and 1 if src1 > src2.
455 69 : void MemCmpExpansion::emitMemCmpResultBlock() {
456 : // Special case: if memcmp result is used in a zero equality, result does not
457 : // need to be calculated and can simply return 1.
458 69 : if (IsUsedForZeroCmp) {
459 10 : BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt();
460 10 : Builder.SetInsertPoint(ResBlock.BB, InsertPt);
461 10 : Value *Res = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 1);
462 10 : PhiRes->addIncoming(Res, ResBlock.BB);
463 10 : BranchInst *NewBr = BranchInst::Create(EndBlock);
464 20 : Builder.Insert(NewBr);
465 : return;
466 : }
467 59 : BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt();
468 59 : Builder.SetInsertPoint(ResBlock.BB, InsertPt);
469 :
470 59 : Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_ULT, ResBlock.PhiSrc1,
471 59 : ResBlock.PhiSrc2);
472 :
473 : Value *Res =
474 59 : Builder.CreateSelect(Cmp, ConstantInt::get(Builder.getInt32Ty(), -1),
475 59 : ConstantInt::get(Builder.getInt32Ty(), 1));
476 :
477 59 : BranchInst *NewBr = BranchInst::Create(EndBlock);
478 59 : Builder.Insert(NewBr);
479 59 : PhiRes->addIncoming(Res, ResBlock.BB);
480 : }
481 :
482 59 : void MemCmpExpansion::setupResultBlockPHINodes() {
483 59 : Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
484 59 : Builder.SetInsertPoint(ResBlock.BB);
485 : // Note: this assumes one load per block.
486 59 : ResBlock.PhiSrc1 =
487 118 : Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src1");
488 59 : ResBlock.PhiSrc2 =
489 59 : Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src2");
490 59 : }
491 :
492 69 : void MemCmpExpansion::setupEndBlockPHINodes() {
493 138 : Builder.SetInsertPoint(&EndBlock->front());
494 138 : PhiRes = Builder.CreatePHI(Type::getInt32Ty(CI->getContext()), 2, "phi.res");
495 69 : }
496 :
497 10 : Value *MemCmpExpansion::getMemCmpExpansionZeroCase() {
498 10 : unsigned LoadIndex = 0;
499 : // This loop populates each of the LoadCmpBlocks with the IR sequence to
500 : // handle multiple loads per block.
501 52 : for (unsigned I = 0; I < getNumBlocks(); ++I) {
502 21 : emitLoadCompareBlockMultipleLoads(I, LoadIndex);
503 : }
504 :
505 10 : emitMemCmpResultBlock();
506 10 : return PhiRes;
507 : }
508 :
509 : /// A memcmp expansion that compares equality with 0 and only has one block of
510 : /// load and compare can bypass the compare, branch, and phi IR that is required
511 : /// in the general case.
512 158 : Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() {
513 158 : unsigned LoadIndex = 0;
514 158 : Value *Cmp = getCompareLoadPairs(0, LoadIndex);
515 : assert(LoadIndex == getNumLoads() && "some entries were not consumed");
516 316 : return Builder.CreateZExt(Cmp, Type::getInt32Ty(CI->getContext()));
517 : }
518 :
519 : /// A memcmp expansion that only has one block of load and compare can bypass
520 : /// the compare, branch, and phi IR that is required in the general case.
521 39 : Value *MemCmpExpansion::getMemCmpOneBlock() {
522 39 : Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8);
523 39 : Value *Source1 = CI->getArgOperand(0);
524 : Value *Source2 = CI->getArgOperand(1);
525 :
526 : // Cast source to LoadSizeType*.
527 39 : if (Source1->getType() != LoadSizeType)
528 117 : Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
529 39 : if (Source2->getType() != LoadSizeType)
530 117 : Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
531 :
532 : // Load LoadSizeType from the base address.
533 78 : Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
534 39 : Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
535 :
536 39 : if (DL.isLittleEndian() && Size != 1) {
537 37 : Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
538 : Intrinsic::bswap, LoadSizeType);
539 37 : LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1);
540 37 : LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2);
541 : }
542 :
543 39 : if (Size < 4) {
544 : // The i8 and i16 cases don't need compares. We zext the loaded values and
545 : // subtract them to get the suitable negative, zero, or positive i32 result.
546 15 : LoadSrc1 = Builder.CreateZExt(LoadSrc1, Builder.getInt32Ty());
547 15 : LoadSrc2 = Builder.CreateZExt(LoadSrc2, Builder.getInt32Ty());
548 15 : return Builder.CreateSub(LoadSrc1, LoadSrc2);
549 : }
550 :
551 : // The result of memcmp is negative, zero, or positive, so produce that by
552 : // subtracting 2 extended compare bits: sub (ugt, ult).
553 : // If a target prefers to use selects to get -1/0/1, they should be able
554 : // to transform this later. The inverse transform (going from selects to math)
555 : // may not be possible in the DAG because the selects got converted into
556 : // branches before we got there.
557 24 : Value *CmpUGT = Builder.CreateICmpUGT(LoadSrc1, LoadSrc2);
558 24 : Value *CmpULT = Builder.CreateICmpULT(LoadSrc1, LoadSrc2);
559 48 : Value *ZextUGT = Builder.CreateZExt(CmpUGT, Builder.getInt32Ty());
560 48 : Value *ZextULT = Builder.CreateZExt(CmpULT, Builder.getInt32Ty());
561 24 : return Builder.CreateSub(ZextUGT, ZextULT);
562 : }
563 :
564 : // This function expands the memcmp call into an inline expansion and returns
565 : // the memcmp result.
566 266 : Value *MemCmpExpansion::getMemCmpExpansion() {
567 : // Create the basic block framework for a multi-block expansion.
568 266 : if (getNumBlocks() != 1) {
569 69 : BasicBlock *StartBlock = CI->getParent();
570 69 : EndBlock = StartBlock->splitBasicBlock(CI, "endblock");
571 69 : setupEndBlockPHINodes();
572 69 : createResultBlock();
573 :
574 : // If return value of memcmp is not used in a zero equality, we need to
575 : // calculate which source was larger. The calculation requires the
576 : // two loaded source values of each load compare block.
577 : // These will be saved in the phi nodes created by setupResultBlockPHINodes.
578 69 : if (!IsUsedForZeroCmp) setupResultBlockPHINodes();
579 :
580 : // Create the number of required load compare basic blocks.
581 69 : createLoadCmpBlocks();
582 :
583 : // Update the terminator added by splitBasicBlock to branch to the first
584 : // LoadCmpBlock.
585 69 : StartBlock->getTerminator()->setSuccessor(0, LoadCmpBlocks[0]);
586 : }
587 :
588 266 : Builder.SetCurrentDebugLocation(CI->getDebugLoc());
589 :
590 266 : if (IsUsedForZeroCmp)
591 168 : return getNumBlocks() == 1 ? getMemCmpEqZeroOneBlock()
592 10 : : getMemCmpExpansionZeroCase();
593 :
594 98 : if (getNumBlocks() == 1)
595 39 : return getMemCmpOneBlock();
596 :
597 303 : for (unsigned I = 0; I < getNumBlocks(); ++I) {
598 122 : emitLoadCompareBlock(I);
599 : }
600 :
601 59 : emitMemCmpResultBlock();
602 59 : return PhiRes;
603 : }
604 :
605 : // This function checks to see if an expansion of memcmp can be generated.
606 : // It checks for constant compare size that is less than the max inline size.
607 : // If an expansion cannot occur, returns false to leave as a library call.
608 : // Otherwise, the library call is replaced with a new IR instruction sequence.
609 : /// We want to transform:
610 : /// %call = call signext i32 @memcmp(i8* %0, i8* %1, i64 15)
611 : /// To:
612 : /// loadbb:
613 : /// %0 = bitcast i32* %buffer2 to i8*
614 : /// %1 = bitcast i32* %buffer1 to i8*
615 : /// %2 = bitcast i8* %1 to i64*
616 : /// %3 = bitcast i8* %0 to i64*
617 : /// %4 = load i64, i64* %2
618 : /// %5 = load i64, i64* %3
619 : /// %6 = call i64 @llvm.bswap.i64(i64 %4)
620 : /// %7 = call i64 @llvm.bswap.i64(i64 %5)
621 : /// %8 = sub i64 %6, %7
622 : /// %9 = icmp ne i64 %8, 0
623 : /// br i1 %9, label %res_block, label %loadbb1
624 : /// res_block: ; preds = %loadbb2,
625 : /// %loadbb1, %loadbb
626 : /// %phi.src1 = phi i64 [ %6, %loadbb ], [ %22, %loadbb1 ], [ %36, %loadbb2 ]
627 : /// %phi.src2 = phi i64 [ %7, %loadbb ], [ %23, %loadbb1 ], [ %37, %loadbb2 ]
628 : /// %10 = icmp ult i64 %phi.src1, %phi.src2
629 : /// %11 = select i1 %10, i32 -1, i32 1
630 : /// br label %endblock
631 : /// loadbb1: ; preds = %loadbb
632 : /// %12 = bitcast i32* %buffer2 to i8*
633 : /// %13 = bitcast i32* %buffer1 to i8*
634 : /// %14 = bitcast i8* %13 to i32*
635 : /// %15 = bitcast i8* %12 to i32*
636 : /// %16 = getelementptr i32, i32* %14, i32 2
637 : /// %17 = getelementptr i32, i32* %15, i32 2
638 : /// %18 = load i32, i32* %16
639 : /// %19 = load i32, i32* %17
640 : /// %20 = call i32 @llvm.bswap.i32(i32 %18)
641 : /// %21 = call i32 @llvm.bswap.i32(i32 %19)
642 : /// %22 = zext i32 %20 to i64
643 : /// %23 = zext i32 %21 to i64
644 : /// %24 = sub i64 %22, %23
645 : /// %25 = icmp ne i64 %24, 0
646 : /// br i1 %25, label %res_block, label %loadbb2
647 : /// loadbb2: ; preds = %loadbb1
648 : /// %26 = bitcast i32* %buffer2 to i8*
649 : /// %27 = bitcast i32* %buffer1 to i8*
650 : /// %28 = bitcast i8* %27 to i16*
651 : /// %29 = bitcast i8* %26 to i16*
652 : /// %30 = getelementptr i16, i16* %28, i16 6
653 : /// %31 = getelementptr i16, i16* %29, i16 6
654 : /// %32 = load i16, i16* %30
655 : /// %33 = load i16, i16* %31
656 : /// %34 = call i16 @llvm.bswap.i16(i16 %32)
657 : /// %35 = call i16 @llvm.bswap.i16(i16 %33)
658 : /// %36 = zext i16 %34 to i64
659 : /// %37 = zext i16 %35 to i64
660 : /// %38 = sub i64 %36, %37
661 : /// %39 = icmp ne i64 %38, 0
662 : /// br i1 %39, label %res_block, label %loadbb3
663 : /// loadbb3: ; preds = %loadbb2
664 : /// %40 = bitcast i32* %buffer2 to i8*
665 : /// %41 = bitcast i32* %buffer1 to i8*
666 : /// %42 = getelementptr i8, i8* %41, i8 14
667 : /// %43 = getelementptr i8, i8* %40, i8 14
668 : /// %44 = load i8, i8* %42
669 : /// %45 = load i8, i8* %43
670 : /// %46 = zext i8 %44 to i32
671 : /// %47 = zext i8 %45 to i32
672 : /// %48 = sub i32 %46, %47
673 : /// br label %endblock
674 : /// endblock: ; preds = %res_block,
675 : /// %loadbb3
676 : /// %phi.res = phi i32 [ %48, %loadbb3 ], [ %11, %res_block ]
677 : /// ret i32 %phi.res
678 597 : static bool expandMemCmp(CallInst *CI, const TargetTransformInfo *TTI,
679 : const TargetLowering *TLI, const DataLayout *DL) {
680 : NumMemCmpCalls++;
681 :
682 : // Early exit from expansion if -Oz.
683 1194 : if (CI->getFunction()->optForMinSize())
684 : return false;
685 :
686 : // Early exit from expansion if size is not a constant.
687 489 : ConstantInt *SizeCast = dyn_cast<ConstantInt>(CI->getArgOperand(2));
688 : if (!SizeCast) {
689 : NumMemCmpNotConstant++;
690 : return false;
691 : }
692 : const uint64_t SizeVal = SizeCast->getZExtValue();
693 :
694 422 : if (SizeVal == 0) {
695 : return false;
696 : }
697 :
698 : // TTI call to check if target would like to expand memcmp. Also, get the
699 : // available load sizes.
700 409 : const bool IsUsedForZeroCmp = isOnlyUsedInZeroEqualityComparison(CI);
701 409 : const auto *const Options = TTI->enableMemCmpExpansion(IsUsedForZeroCmp);
702 409 : if (!Options) return false;
703 :
704 : const unsigned MaxNumLoads =
705 395 : TLI->getMaxExpandSizeMemcmp(CI->getFunction()->optForSize());
706 :
707 395 : unsigned NumLoadsPerBlock = MemCmpEqZeroNumLoadsPerBlock.getNumOccurrences()
708 395 : ? MemCmpEqZeroNumLoadsPerBlock
709 335 : : TLI->getMemcmpEqZeroLoadsPerBlock();
710 :
711 : MemCmpExpansion Expansion(CI, SizeVal, *Options, MaxNumLoads,
712 790 : IsUsedForZeroCmp, NumLoadsPerBlock, *DL);
713 :
714 : // Don't expand if this will require more loads than desired by the target.
715 395 : if (Expansion.getNumLoads() == 0) {
716 : NumMemCmpGreaterThanMax++;
717 : return false;
718 : }
719 :
720 : NumMemCmpInlined++;
721 :
722 266 : Value *Res = Expansion.getMemCmpExpansion();
723 :
724 : // Replace call with result of expansion and erase call.
725 266 : CI->replaceAllUsesWith(Res);
726 266 : CI->eraseFromParent();
727 :
728 266 : return true;
729 : }
730 :
731 :
732 :
733 : class ExpandMemCmpPass : public FunctionPass {
734 : public:
735 : static char ID;
736 :
737 20210 : ExpandMemCmpPass() : FunctionPass(ID) {
738 20210 : initializeExpandMemCmpPassPass(*PassRegistry::getPassRegistry());
739 : }
740 :
741 198167 : bool runOnFunction(Function &F) override {
742 198167 : if (skipFunction(F)) return false;
743 :
744 197980 : auto *TPC = getAnalysisIfAvailable<TargetPassConfig>();
745 197980 : if (!TPC) {
746 : return false;
747 : }
748 : const TargetLowering* TL =
749 197980 : TPC->getTM<TargetMachine>().getSubtargetImpl(F)->getTargetLowering();
750 :
751 : const TargetLibraryInfo *TLI =
752 197980 : &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
753 : const TargetTransformInfo *TTI =
754 197980 : &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
755 197980 : auto PA = runImpl(F, TLI, TTI, TL);
756 197980 : return !PA.areAllPreserved();
757 : }
758 :
759 : private:
760 20077 : void getAnalysisUsage(AnalysisUsage &AU) const override {
761 : AU.addRequired<TargetLibraryInfoWrapperPass>();
762 : AU.addRequired<TargetTransformInfoWrapperPass>();
763 20077 : FunctionPass::getAnalysisUsage(AU);
764 20077 : }
765 :
766 : PreservedAnalyses runImpl(Function &F, const TargetLibraryInfo *TLI,
767 : const TargetTransformInfo *TTI,
768 : const TargetLowering* TL);
769 : // Returns true if a change was made.
770 : bool runOnBlock(BasicBlock &BB, const TargetLibraryInfo *TLI,
771 : const TargetTransformInfo *TTI, const TargetLowering* TL,
772 : const DataLayout& DL);
773 : };
774 :
775 0 : bool ExpandMemCmpPass::runOnBlock(
776 : BasicBlock &BB, const TargetLibraryInfo *TLI,
777 : const TargetTransformInfo *TTI, const TargetLowering* TL,
778 : const DataLayout& DL) {
779 0 : for (Instruction& I : BB) {
780 : CallInst *CI = dyn_cast<CallInst>(&I);
781 0 : if (!CI) {
782 0 : continue;
783 : }
784 : LibFunc Func;
785 0 : if (TLI->getLibFunc(ImmutableCallSite(CI), Func) &&
786 0 : Func == LibFunc_memcmp && expandMemCmp(CI, TTI, TL, &DL)) {
787 0 : return true;
788 : }
789 : }
790 : return false;
791 : }
792 :
793 :
794 0 : PreservedAnalyses ExpandMemCmpPass::runImpl(
795 : Function &F, const TargetLibraryInfo *TLI, const TargetTransformInfo *TTI,
796 : const TargetLowering* TL) {
797 0 : const DataLayout& DL = F.getParent()->getDataLayout();
798 : bool MadeChanges = false;
799 0 : for (auto BBIt = F.begin(); BBIt != F.end();) {
800 0 : if (runOnBlock(*BBIt, TLI, TTI, TL, DL)) {
801 : MadeChanges = true;
802 : // If changes were made, restart the function from the beginning, since
803 : // the structure of the function was changed.
804 : BBIt = F.begin();
805 : } else {
806 : ++BBIt;
807 : }
808 : }
809 0 : return MadeChanges ? PreservedAnalyses::none() : PreservedAnalyses::all();
810 : }
811 :
812 : } // namespace
813 :
814 : char ExpandMemCmpPass::ID = 0;
815 39044 : INITIALIZE_PASS_BEGIN(ExpandMemCmpPass, "expandmemcmp",
816 : "Expand memcmp() to load/stores", false, false)
817 39044 : INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
818 39044 : INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
819 123360 : INITIALIZE_PASS_END(ExpandMemCmpPass, "expandmemcmp",
820 : "Expand memcmp() to load/stores", false, false)
821 :
822 20207 : FunctionPass *llvm::createExpandMemCmpPass() {
823 20207 : return new ExpandMemCmpPass();
824 : }
|