32 #define DEBUG_TYPE "lower-switch"
39 static bool IsInRanges(
const IntRange &R,
40 const std::vector<IntRange> &
Ranges) {
46 auto I = std::lower_bound(
47 Ranges.begin(), Ranges.end(), R,
48 [](
const IntRange &
A,
const IntRange &B) {
return A.High < B.High; });
49 return I != Ranges.end() &&
I->Low <= R.Low;
75 : Low(low), High(high), BB(bb) {}
78 typedef std::vector<CaseRange> CaseVector;
79 typedef std::vector<CaseRange>::iterator CaseItr;
83 BasicBlock *switchConvert(CaseItr Begin, CaseItr End,
87 const std::vector<IntRange> &UnreachableRanges);
90 unsigned Clusterify(CaseVector &Cases,
SwitchInst *SI);
96 bool operator () (
const LowerSwitch::CaseRange& C1,
97 const LowerSwitch::CaseRange& C2) {
99 const ConstantInt* CI1 = cast<const ConstantInt>(C1.Low);
100 const ConstantInt* CI2 = cast<const ConstantInt>(C2.High);
108 "Lower SwitchInst's to branches",
false,
false)
114 return new LowerSwitch();
117 bool LowerSwitch::runOnFunction(
Function &
F) {
118 bool Changed =
false;
125 processSwitchInst(SI);
135 const LowerSwitch::CaseVector &C)
138 const LowerSwitch::CaseVector &C) {
141 for (LowerSwitch::CaseVector::const_iterator B = C.begin(),
142 E = C.end(); B != E; ) {
143 O << *B->Low <<
" -" << *B->High;
144 if (++B != E) O <<
", ";
161 unsigned NumMergedCases) {
168 unsigned LocalNumMergedCases = NumMergedCases;
169 for (; Idx != E; ++Idx) {
179 for (++Idx; LocalNumMergedCases > 0 && Idx < E; ++Idx)
182 LocalNumMergedCases--;
186 for (
auto III = Indices.
rbegin(), IIE = Indices.
rend(); III != IIE; ++III)
197 LowerSwitch::switchConvert(CaseItr Begin, CaseItr End,
ConstantInt *LowerBound,
201 const std::vector<IntRange> &UnreachableRanges) {
202 unsigned Size = End - Begin;
209 if (Begin->Low == LowerBound && Begin->High == UpperBound) {
210 unsigned NumMergedCases = 0;
211 if (LowerBound && UpperBound)
214 fixPhis(Begin->BB, OrigBlock, Predecessor, NumMergedCases);
217 return newLeafBlock(*Begin, Val, OrigBlock, Default);
220 unsigned Mid = Size / 2;
221 std::vector<CaseRange> LHS(Begin, Begin + Mid);
223 std::vector<CaseRange> RHS(Begin + Mid, End);
226 CaseRange &Pivot = *(Begin + Mid);
228 << Pivot.Low->getValue()
229 <<
" -" << Pivot.High->getValue() <<
"\n");
240 NewLowerBound->getValue() - 1);
242 if (!UnreachableRanges.empty()) {
244 int64_t GapLow = LHS.back().High->getSExtValue() + 1;
245 int64_t GapHigh = NewLowerBound->getSExtValue() - 1;
246 IntRange Gap = { GapLow, GapHigh };
247 if (GapHigh >= GapLow && IsInRanges(Gap, UnreachableRanges))
248 NewUpperBound = LHS.back().High;
258 dbgs() <<
"RHS Bounds ==> ";
259 dbgs() << NewLowerBound->getSExtValue() <<
" - ";
272 Val, Pivot.Low,
"Pivot");
274 BasicBlock *LBranch = switchConvert(LHS.begin(), LHS.end(), LowerBound,
275 NewUpperBound, Val, NewNode, OrigBlock,
277 BasicBlock *RBranch = switchConvert(RHS.begin(), RHS.end(), NewLowerBound,
278 UpperBound, Val, NewNode, OrigBlock,
282 F->getBasicBlockList().insert(++FI, NewNode);
283 NewNode->getInstList().push_back(Comp);
306 if (Leaf.Low == Leaf.High) {
309 Leaf.Low,
"SwitchLeaf");
312 if (Leaf.Low->isMinValue(
true )) {
316 }
else if (Leaf.Low->isZero()) {
341 uint64_t Range = Leaf.High->getSExtValue() -
342 Leaf.Low->getSExtValue();
343 for (uint64_t j = 0; j < Range; ++j) {
348 assert(BlockIdx != -1 &&
"Switch didn't go to this successor??");
356 unsigned LowerSwitch::Clusterify(CaseVector& Cases,
SwitchInst *SI) {
357 unsigned numCmps = 0;
361 Cases.push_back(CaseRange(i.getCaseValue(), i.getCaseValue(),
362 i.getCaseSuccessor()));
364 std::sort(Cases.begin(), Cases.end(), CaseCmp());
367 if (Cases.size() >= 2) {
368 CaseItr
I = Cases.begin();
369 for (CaseItr J = std::next(I), E = Cases.end(); J != E; ++J) {
370 int64_t nextValue = J->Low->getSExtValue();
371 int64_t currentValue = I->High->getSExtValue();
377 assert(nextValue > currentValue &&
"Cases should be strictly ascending");
378 if ((nextValue == currentValue + 1) && (currentBB == nextBB)) {
381 }
else if (++I != J) {
385 Cases.erase(std::next(I), Cases.end());
388 for (CaseItr I=Cases.begin(), E=Cases.end(); I!=E; ++
I, ++numCmps) {
389 if (I->Low != I->High)
400 void LowerSwitch::processSwitchInst(
SwitchInst *SI) {
416 unsigned numCmps = Clusterify(Cases, SI);
417 DEBUG(
dbgs() <<
"Clusterify finished. Total clusters: " << Cases.size()
418 <<
". Total compares: " << numCmps <<
"\n");
419 DEBUG(
dbgs() <<
"Cases: " << Cases <<
"\n");
424 std::vector<IntRange> UnreachableRanges;
430 assert(!Cases.empty());
431 LowerBound = Cases.front().Low;
432 UpperBound = Cases.back().High;
438 IntRange R = { INT64_MIN, INT64_MAX };
439 UnreachableRanges.push_back(R);
440 for (
const auto &I : Cases) {
441 int64_t Low = I.Low->getSExtValue();
442 int64_t High = I.High->getSExtValue();
444 IntRange &LastRange = UnreachableRanges.
back();
445 if (LastRange.Low == Low) {
447 UnreachableRanges.pop_back();
450 assert(Low > LastRange.Low);
451 LastRange.High = Low - 1;
453 if (High != INT64_MAX) {
454 IntRange R = { High + 1, INT64_MAX };
455 UnreachableRanges.push_back(R);
459 int64_t
N = High - Low + 1;
460 unsigned &Pop = Popularity[I.BB];
461 if ((Pop += N) > MaxPop) {
468 for (
auto I = UnreachableRanges.begin(), E = UnreachableRanges.end();
470 assert(I->Low <= I->High);
473 assert(Next->Low > I->High);
480 assert(MaxPop > 0 && PopSucc);
482 Cases.erase(std::remove_if(
483 Cases.begin(), Cases.end(),
484 [PopSucc](
const CaseRange &R) {
return R.BB == PopSucc; }),
506 assert(BlockIdx != -1 &&
"Switch didn't go to this successor??");
511 switchConvert(Cases.begin(), Cases.end(), LowerBound, UpperBound, Val,
512 OrigBlock, OrigBlock, NewDefault, UnreachableRanges);
iplist< Instruction >::iterator eraseFromParent()
eraseFromParent - This method unlinks 'this' from the containing basic block and deletes it...
void push_back(const T &Elt)
CaseIt case_end()
Returns a read/write iterator that points one past the last in the SwitchInst.
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
const Instruction & back() const
FunctionPass * createLowerSwitchPass()
void DeleteDeadBlock(BasicBlock *BB)
DeleteDeadBlock - Delete the specified block, which must have no predecessors.
CaseIt case_begin()
Returns a read/write iterator that points to the first case in SwitchInst.
const Function * getParent() const
Return the enclosing method, or null if none.
StringRef getName() const
Return a constant reference to the value's name.
iterator begin()
Instruction iterator methods.
static Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
Instruction * getFirstNonPHIOrDbg()
Returns a pointer to the first instruction in this block that is not a PHINode or a debug intrinsic...
Value * removeIncomingValue(unsigned Idx, bool DeletePHIIfEmpty=true)
removeIncomingValue - Remove an incoming value.
const APInt & getValue() const
Return the constant as an APInt value reference.
Instruction * getFirstNonPHI()
Returns a pointer to the first instruction in this block that is not a PHINode instruction.
AnalysisUsage & addPreservedID(const void *ID)
static BinaryOperator * CreateAdd(Value *S1, Value *S2, const Twine &Name, Instruction *InsertBefore, Value *FlagsOp)
unsigned getNumIncomingValues() const
getNumIncomingValues - Return the number of incoming edges
LLVM Basic Block Representation.
This is an important base class in LLVM.
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Interval::pred_iterator pred_begin(Interval *I)
pred_begin/pred_end - define methods so that Intervals may be used just like BasicBlocks can with the...
static void fixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB, unsigned NumMergedCases)
Represent the analysis usage information of a pass.
BasicBlock * getIncomingBlock(unsigned i) const
getIncomingBlock - Return incoming basic block number i.
const InstListType & getInstList() const
Return the underlying instruction list container.
This instruction compares its operands according to the predicate given to the constructor.
iterator insert(iterator where, NodeTy *New)
FunctionPass class - This class is used to implement most global optimizations.
Interval::pred_iterator pred_end(Interval *I)
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
LLVMContext & getContext() const
All values hold a context through their type.
iterator erase(iterator where)
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
const BasicBlockListType & getBasicBlockList() const
This is the shared class of boolean and integer constants.
void setIncomingBlock(unsigned i, BasicBlock *BB)
bool slt(const APInt &RHS) const
Signed less than comparison.
static Constant * get(Type *Ty, uint64_t V, bool isSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
static BranchInst * Create(BasicBlock *IfTrue, Instruction *InsertBefore=nullptr)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Value * getCondition() const
static Constant * getNeg(Constant *C, bool HasNUW=false, bool HasNSW=false)
BasicBlock * getDefaultDest() const
TerminatorInst * getTerminator()
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
raw_ostream & operator<<(raw_ostream &OS, const APInt &I)
unsigned getNumCases() const
getNumCases - return the number of 'cases' in this switch instruction, except the default case ...
SwitchInst - Multiway switch.
reverse_iterator rbegin()
LLVM Value Representation.
void initializeLowerSwitchPass(PassRegistry &)
This class implements an extremely fast bulk output stream that can only output to a stream...
int64_t getSExtValue() const
Return the constant as a 64-bit integer value after it has been sign extended as appropriate for the ...
int getBasicBlockIndex(const BasicBlock *BB) const
getBasicBlockIndex - Return the first index of the specified basic block in the value list for this P...
const BasicBlock * getParent() const
#define LLVM_ATTRIBUTE_USED