32 #define DEBUG_TYPE "lower-switch"
39 static bool IsInRanges(
const IntRange &R,
40 const std::vector<IntRange> &Ranges) {
46 auto I = std::lower_bound(
47 Ranges.begin(), Ranges.end(), R,
48 [](
const IntRange &
A,
const IntRange &
B) {
return A.High <
B.High; });
49 return I != Ranges.end() &&
I->Low <= R.Low;
68 : Low(low),
High(high), BB(bb) {}
71 typedef std::vector<CaseRange> CaseVector;
72 typedef std::vector<CaseRange>::iterator CaseItr;
80 const std::vector<IntRange> &UnreachableRanges);
89 bool operator () (
const LowerSwitch::CaseRange& C1,
90 const LowerSwitch::CaseRange& C2) {
92 const ConstantInt* CI1 = cast<const ConstantInt>(C1.Low);
93 const ConstantInt* CI2 = cast<const ConstantInt>(C2.High);
101 "Lower SwitchInst's to branches",
false,
false)
107 return new LowerSwitch();
110 bool LowerSwitch::runOnFunction(
Function &
F) {
111 bool Changed =
false;
119 if (DeleteList.
count(Cur))
124 processSwitchInst(SI, DeleteList);
137 const LowerSwitch::CaseVector &
C)
140 const LowerSwitch::CaseVector &
C) {
143 for (LowerSwitch::CaseVector::const_iterator
B = C.begin(),
144 E = C.end();
B !=
E; ) {
145 O << *
B->Low <<
" -" << *
B->High;
146 if (++
B !=
E) O <<
", ";
163 unsigned NumMergedCases) {
171 unsigned LocalNumMergedCases = NumMergedCases;
172 for (; Idx !=
E; ++Idx) {
182 for (++Idx; LocalNumMergedCases > 0 && Idx <
E; ++Idx)
185 LocalNumMergedCases--;
189 for (
unsigned III :
reverse(Indices))
200 LowerSwitch::switchConvert(CaseItr Begin, CaseItr
End,
ConstantInt *LowerBound,
204 const std::vector<IntRange> &UnreachableRanges) {
205 unsigned Size = End - Begin;
212 if (Begin->Low == LowerBound && Begin->High == UpperBound) {
213 unsigned NumMergedCases = 0;
214 if (LowerBound && UpperBound)
217 fixPhis(Begin->BB, OrigBlock, Predecessor, NumMergedCases);
220 return newLeafBlock(*Begin, Val, OrigBlock, Default);
223 unsigned Mid = Size / 2;
224 std::vector<CaseRange> LHS(Begin, Begin + Mid);
226 std::vector<CaseRange> RHS(Begin + Mid, End);
229 CaseRange &Pivot = *(Begin + Mid);
231 << Pivot.Low->getValue()
232 <<
" -" << Pivot.High->getValue() <<
"\n");
243 NewLowerBound->getValue() - 1);
245 if (!UnreachableRanges.empty()) {
247 int64_t GapLow = LHS.back().High->getSExtValue() + 1;
248 int64_t GapHigh = NewLowerBound->getSExtValue() - 1;
249 IntRange Gap = { GapLow, GapHigh };
250 if (GapHigh >= GapLow && IsInRanges(Gap, UnreachableRanges))
251 NewUpperBound = LHS.back().High;
261 dbgs() <<
"RHS Bounds ==> ";
262 dbgs() << NewLowerBound->getSExtValue() <<
" - ";
275 Val, Pivot.Low,
"Pivot");
277 BasicBlock *LBranch = switchConvert(LHS.begin(), LHS.end(), LowerBound,
278 NewUpperBound, Val, NewNode, OrigBlock,
279 Default, UnreachableRanges);
280 BasicBlock *RBranch = switchConvert(RHS.begin(), RHS.end(), NewLowerBound,
281 UpperBound, Val, NewNode, OrigBlock,
282 Default, UnreachableRanges);
284 F->getBasicBlockList().insert(++OrigBlock->getIterator(), NewNode);
305 if (Leaf.Low == Leaf.High) {
308 Leaf.Low,
"SwitchLeaf");
311 if (Leaf.Low->isMinValue(
true )) {
315 }
else if (Leaf.Low->isZero()) {
340 uint64_t Range = Leaf.High->getSExtValue() -
341 Leaf.Low->getSExtValue();
342 for (uint64_t j = 0; j < Range; ++j) {
347 assert(BlockIdx != -1 &&
"Switch didn't go to this successor??");
356 unsigned numCmps = 0;
360 Cases.push_back(CaseRange(
i.getCaseValue(),
i.getCaseValue(),
361 i.getCaseSuccessor()));
363 std::sort(Cases.begin(), Cases.end(), CaseCmp());
366 if (Cases.size() >= 2) {
367 CaseItr
I = Cases.begin();
368 for (CaseItr J = std::next(I),
E = Cases.end(); J !=
E; ++J) {
369 int64_t nextValue = J->Low->getSExtValue();
370 int64_t currentValue = I->High->getSExtValue();
376 assert(nextValue > currentValue &&
"Cases should be strictly ascending");
377 if ((nextValue == currentValue + 1) && (currentBB == nextBB)) {
380 }
else if (++I != J) {
384 Cases.erase(std::next(I), Cases.end());
387 for (CaseItr I=Cases.begin(),
E=Cases.end(); I!=
E; ++
I, ++numCmps) {
388 if (I->Low != I->High)
398 void LowerSwitch::processSwitchInst(
SwitchInst *SI,
415 unsigned numCmps = Clusterify(Cases, SI);
416 DEBUG(
dbgs() <<
"Clusterify finished. Total clusters: " << Cases.size()
417 <<
". Total compares: " << numCmps <<
"\n");
418 DEBUG(
dbgs() <<
"Cases: " << Cases <<
"\n");
423 std::vector<IntRange> UnreachableRanges;
430 LowerBound = Cases.front().Low;
431 UpperBound = Cases.back().High;
437 IntRange R = { INT64_MIN, INT64_MAX };
438 UnreachableRanges.push_back(R);
439 for (
const auto &I : Cases) {
440 int64_t Low = I.Low->getSExtValue();
441 int64_t
High = I.High->getSExtValue();
443 IntRange &LastRange = UnreachableRanges.back();
444 if (LastRange.Low == Low) {
446 UnreachableRanges.pop_back();
449 assert(Low > LastRange.Low);
450 LastRange.High = Low - 1;
452 if (High != INT64_MAX) {
453 IntRange R = { High + 1, INT64_MAX };
454 UnreachableRanges.push_back(R);
458 int64_t
N = High - Low + 1;
459 unsigned &Pop = Popularity[I.BB];
460 if ((Pop += N) > MaxPop) {
467 for (
auto I = UnreachableRanges.begin(),
E = UnreachableRanges.end();
469 assert(I->Low <= I->High);
472 assert(Next->Low > I->High);
479 assert(MaxPop > 0 && PopSucc);
483 [PopSucc](
const CaseRange &R) {
return R.BB == PopSucc; }),
505 assert(BlockIdx != -1 &&
"Switch didn't go to this successor??");
510 switchConvert(Cases.begin(), Cases.end(), LowerBound, UpperBound, Val,
511 OrigBlock, OrigBlock, NewDefault, UnreachableRanges);
522 DeleteList.
insert(OldDefault);
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
void push_back(const T &Elt)
CaseIt case_end()
Returns a read/write iterator that points one past the last in the SwitchInst.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
iterator erase(iterator where)
FunctionPass * createLowerSwitchPass()
auto remove_if(R &&Range, UnaryPredicate P) -> decltype(std::begin(Range))
Provide wrappers to std::remove_if which take ranges instead of having to pass begin/end explicitly...
void DeleteDeadBlock(BasicBlock *BB)
Delete the specified block, which must have no predecessors.
size_type count(PtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
CaseIt case_begin()
Returns a read/write iterator that points to the first case in the SwitchInst.
const Function * getParent() const
Return the enclosing method, or null if none.
StringRef getName() const
Return a constant reference to the value's name.
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
iterator begin()
Instruction iterator methods.
static Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
Instruction * getFirstNonPHIOrDbg()
Returns a pointer to the first instruction in this block that is not a PHINode or a debug intrinsic...
Value * removeIncomingValue(unsigned Idx, bool DeletePHIIfEmpty=true)
Remove an incoming value.
const APInt & getValue() const
Return the constant as an APInt value reference.
Instruction * getFirstNonPHI()
Returns a pointer to the first instruction in this block that is not a PHINode instruction.
auto reverse(ContainerTy &&C, typename std::enable_if< has_rbegin< ContainerTy >::value >::type *=nullptr) -> decltype(make_range(C.rbegin(), C.rend()))
static GCRegistry::Add< OcamlGC > B("ocaml","ocaml 3.10-compatible GC")
static BinaryOperator * CreateAdd(Value *S1, Value *S2, const Twine &Name, Instruction *InsertBefore, Value *FlagsOp)
static GCRegistry::Add< CoreCLRGC > E("coreclr","CoreCLR-compatible GC")
unsigned getNumIncomingValues() const
Return the number of incoming edges.
LLVM Basic Block Representation.
This is an important base class in LLVM.
This file contains the declarations for the subclasses of Constant, which represent the different fla...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Interval::pred_iterator pred_begin(Interval *I)
pred_begin/pred_end - define methods so that Intervals may be used just like BasicBlocks can with the...
static void fixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB, unsigned NumMergedCases)
Update the first occurrence of the "switch statement" BB in the PHI node with the "new" BB...
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
const InstListType & getInstList() const
Return the underlying instruction list container.
This instruction compares its operands according to the predicate given to the constructor.
static const unsigned End
FunctionPass class - This class is used to implement most global optimizations.
Interval::pred_iterator pred_end(Interval *I)
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
self_iterator getIterator()
LLVMContext & getContext() const
All values hold a context through their type.
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Iterator for intrusive lists based on ilist_node.
const BasicBlockListType & getBasicBlockList() const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements...
This is the shared class of boolean and integer constants.
void setIncomingBlock(unsigned i, BasicBlock *BB)
bool slt(const APInt &RHS) const
Signed less than comparison.
static Constant * get(Type *Ty, uint64_t V, bool isSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
static BranchInst * Create(BasicBlock *IfTrue, Instruction *InsertBefore=nullptr)
static GCRegistry::Add< ShadowStackGC > C("shadow-stack","Very portable GC for uncooperative code generators")
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
void push_back(pointer val)
Value * getCondition() const
static Constant * getNeg(Constant *C, bool HasNUW=false, bool HasNSW=false)
iterator insert(iterator where, pointer New)
BasicBlock * getDefaultDest() const
TerminatorInst * getTerminator()
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
raw_ostream & operator<<(raw_ostream &OS, const APInt &I)
unsigned getNumCases() const
Return the number of 'cases' in this switch instruction, excluding the default case.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
LLVM Value Representation.
void initializeLowerSwitchPass(PassRegistry &)
This class implements an extremely fast bulk output stream that can only output to a stream...
int64_t getSExtValue() const
Return the constant as a 64-bit integer value after it has been sign extended as appropriate for the ...
static GCRegistry::Add< ErlangGC > A("erlang","erlang-compatible garbage collector")
int getBasicBlockIndex(const BasicBlock *BB) const
Return the first index of the specified basic block in the value list for this PHI.
const BasicBlock * getParent() const
#define LLVM_ATTRIBUTE_USED