47#define DEBUG_TYPE "lower-switch"
59bool IsInRanges(
const IntRange &R,
const std::vector<IntRange> &Ranges) {
66 Ranges, R, [](IntRange
A, IntRange
B) {
return A.High.slt(
B.High); });
67 return I !=
Ranges.end() &&
I->Low.sle(
R.Low);
76 :
Low(low),
High(high), BB(bb) {}
79using CaseVector = std::vector<CaseRange>;
80using CaseItr = std::vector<CaseRange>::iterator;
85 bool operator()(
const CaseRange &C1,
const CaseRange &C2) {
86 const ConstantInt *CI1 = cast<const ConstantInt>(C1.Low);
87 const ConstantInt *CI2 = cast<const ConstantInt>(C2.High);
97 for (CaseVector::const_iterator
B =
C.begin(), E =
C.end();
B != E;) {
98 O <<
"[" <<
B->Low->getValue() <<
", " <<
B->High->getValue() <<
"]";
117 const APInt &NumMergedCases) {
118 for (
auto &
I : SuccBB->
phis()) {
123 APInt LocalNumMergedCases = NumMergedCases;
124 for (;
Idx != E && NewBB; ++
Idx) {
138 for (; LocalNumMergedCases.
ugt(0) &&
Idx < E; ++
Idx)
141 LocalNumMergedCases -= 1;
163 if (Leaf.Low == Leaf.High) {
166 new ICmpInst(NewLeaf, ICmpInst::ICMP_EQ, Val, Leaf.Low,
"SwitchLeaf");
169 if (Leaf.Low == LowerBound) {
171 Comp =
new ICmpInst(NewLeaf, ICmpInst::ICMP_SLE, Val, Leaf.High,
173 }
else if (Leaf.High == UpperBound) {
175 Comp =
new ICmpInst(NewLeaf, ICmpInst::ICMP_SGE, Val, Leaf.Low,
177 }
else if (Leaf.Low->isZero()) {
179 Comp =
new ICmpInst(NewLeaf, ICmpInst::ICMP_ULE, Val, Leaf.High,
185 Val, NegLo, Val->
getName() +
".off", NewLeaf);
187 Comp =
new ICmpInst(NewLeaf, ICmpInst::ICMP_ULE,
Add, UpperBound,
208 APInt Range = Leaf.High->getValue() - Leaf.Low->getValue();
214 assert(BlockIdx != -1 &&
"Switch didn't go to this successor??");
230 const std::vector<IntRange> &UnreachableRanges) {
231 assert(LowerBound && UpperBound &&
"Bounds must be initialized");
239 if (Begin->Low == LowerBound && Begin->High == UpperBound) {
241 FixPhis(Begin->BB, OrigBlock, Predecessor, NumMergedCases);
244 return NewLeafBlock(*Begin, Val, LowerBound, UpperBound, OrigBlock,
248 unsigned Mid =
Size / 2;
249 std::vector<CaseRange>
LHS(Begin, Begin + Mid);
251 std::vector<CaseRange>
RHS(Begin + Mid,
End);
254 CaseRange &Pivot = *(Begin + Mid);
255 LLVM_DEBUG(
dbgs() <<
"Pivot ==> [" << Pivot.Low->getValue() <<
", "
256 << Pivot.High->getValue() <<
"]\n");
269 if (!UnreachableRanges.empty()) {
271 APInt GapLow =
LHS.back().High->getValue() + 1;
273 IntRange Gap = {GapLow, GapHigh};
274 if (GapHigh.
sge(GapLow) && IsInRanges(Gap, UnreachableRanges))
275 NewUpperBound =
LHS.back().High;
279 << NewUpperBound->
getValue() <<
"]\n"
280 <<
"RHS Bounds ==> [" << NewLowerBound->
getValue() <<
", "
281 << UpperBound->
getValue() <<
"]\n");
291 SwitchConvert(
LHS.begin(),
LHS.end(), LowerBound, NewUpperBound, Val,
292 NewNode, OrigBlock,
Default, UnreachableRanges);
294 SwitchConvert(
RHS.begin(),
RHS.end(), NewLowerBound, UpperBound, Val,
295 NewNode, OrigBlock,
Default, UnreachableRanges);
307unsigned Clusterify(CaseVector &Cases,
SwitchInst *SI) {
308 unsigned NumSimpleCases = 0;
311 for (
auto Case :
SI->cases()) {
312 if (Case.getCaseSuccessor() ==
SI->getDefaultDest())
314 Cases.push_back(CaseRange(Case.getCaseValue(), Case.getCaseValue(),
315 Case.getCaseSuccessor()));
322 if (Cases.size() >= 2) {
323 CaseItr
I = Cases.begin();
324 for (CaseItr J = std::next(
I), E = Cases.end(); J != E; ++J) {
325 const APInt &nextValue = J->Low->getValue();
326 const APInt ¤tValue =
I->High->getValue();
333 "Cases should be strictly ascending");
334 if ((nextValue == currentValue + 1) && (currentBB == nextBB)) {
337 }
else if (++
I != J) {
341 Cases.erase(std::next(
I), Cases.end());
344 return NumSimpleCases;
354 Value *Val =
SI->getCondition();
359 if ((OrigBlock != &
F->getEntryBlock() &&
pred_empty(OrigBlock)) ||
361 DeleteList.
insert(OrigBlock);
367 const unsigned NumSimpleCases = Clusterify(Cases, SI);
374 LLVM_DEBUG(
dbgs() <<
"Clusterify finished. Total clusters: " << Cases.size()
375 <<
". Total non-default cases: " << NumSimpleCases
376 <<
"\nCase clusters: " << Cases <<
"\n");
382 FixPhis(
Default, OrigBlock, OrigBlock, UnsignedMax);
383 SI->eraseFromParent();
389 bool DefaultIsUnreachableFromSwitch =
false;
391 if (isa<UnreachableInst>(
Default->getFirstNonPHIOrDbg())) {
395 LowerBound = Cases.front().Low;
396 UpperBound = Cases.back().High;
397 DefaultIsUnreachableFromSwitch =
true;
420 const APInt &
Low = Cases.front().Low->getValue();
421 const APInt &
High = Cases.back().High->getValue();
425 LowerBound = ConstantInt::get(
SI->getContext(), Min);
426 UpperBound = ConstantInt::get(
SI->getContext(), Max);
427 DefaultIsUnreachableFromSwitch = (Min + (NumSimpleCases - 1) == Max);
430 std::vector<IntRange> UnreachableRanges;
432 if (DefaultIsUnreachableFromSwitch) {
434 APInt MaxPop(UnsignedZero);
439 IntRange
R = {SignedMin, SignedMax};
440 UnreachableRanges.push_back(R);
441 for (
const auto &
I : Cases) {
445 IntRange &LastRange = UnreachableRanges.back();
446 if (LastRange.Low.eq(
Low)) {
448 UnreachableRanges.pop_back();
452 LastRange.High =
Low - 1;
454 if (
High.ne(SignedMax)) {
455 IntRange
R = {
High + 1, SignedMax};
456 UnreachableRanges.push_back(R);
460 assert(
High.sge(
Low) &&
"Popularity shouldn't be negative.");
464 if ((Pop +=
N).ugt(MaxPop)) {
471 for (
auto I = UnreachableRanges.begin(), E = UnreachableRanges.end();
476 assert(Next->Low.sgt(
I->High));
483 const unsigned NumDefaultEdges =
SI->getNumCases() + 1 - NumSimpleCases;
484 for (
unsigned I = 0;
I < NumDefaultEdges; ++
I)
485 Default->removePredecessor(OrigBlock);
491 [PopSucc](
const CaseRange &R) {
return R.BB == PopSucc; });
496 SI->eraseFromParent();
499 if (!MaxPop.isZero())
500 for (
APInt I(UnsignedZero);
I.ult(MaxPop - 1); ++
I)
508 Val =
SI->getCondition();
512 SwitchConvert(Cases.begin(), Cases.end(), LowerBound, UpperBound, Val,
513 OrigBlock, OrigBlock,
Default, UnreachableRanges);
521 FixPhis(
Default, OrigBlock,
nullptr, UnsignedMax);
528 SI->eraseFromParent();
532 DeleteList.
insert(OldDefault);
536 bool Changed =
false;
543 if (DeleteList.
count(&Cur))
546 if (
SwitchInst *SI = dyn_cast<SwitchInst>(Cur.getTerminator())) {
548 ProcessSwitchInst(SI, DeleteList, AC, LVI);
579char LowerSwitchLegacyPass::ID = 0;
585 "Lower SwitchInst's to branches",
false,
false)
593 return new LowerSwitchLegacyPass();
596bool LowerSwitchLegacyPass::runOnFunction(
Function &
F) {
597 LazyValueInfo *LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI();
598 auto *ACT = getAnalysisIfAvailable<AssumptionCacheTracker>();
600 return LowerSwitch(
F, LVI, AC);
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static cl::opt< ITMode > IT(cl::desc("IT block support"), cl::Hidden, cl::init(DefaultIT), cl::values(clEnumValN(DefaultIT, "arm-default-it", "Generate any type of IT block"), clEnumValN(RestrictedIT, "arm-restrict-it", "Disallow complex IT blocks")))
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
#define LLVM_ATTRIBUTE_USED
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
This file defines the DenseMap class.
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
This header defines various interfaces for pass management in LLVM.
Lower SwitchInst s to branches
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
Class for arbitrary precision integers.
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
bool sgt(const APInt &RHS) const
Signed greater than comparison.
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
bool slt(const APInt &RHS) const
Signed less than comparison.
bool sge(const APInt &RHS) const
Signed greater or equal comparison.
A container for analyses that lazily runs them and caches their results.
PassT::Result * getCachedResult(IRUnitT &IR) const
Get the cached result of an analysis pass for a given IR unit.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
LLVM Basic Block Representation.
iterator begin()
Instruction iterator methods.
iterator_range< const_phi_iterator > phis() const
Returns a range that iterates over the phis in the basic block.
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
const Function * getParent() const
Return the enclosing method, or null if none.
InstListType::iterator iterator
Instruction iterators...
void removePredecessor(BasicBlock *Pred, bool KeepOneInputPHIs=false)
Update PHI nodes in this BasicBlock before removal of predecessor Pred.
static BranchInst * Create(BasicBlock *IfTrue, InsertPosition InsertBefore=nullptr)
static Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
static Constant * getNeg(Constant *C, bool HasNSW=false)
This is the shared class of boolean and integer constants.
const APInt & getValue() const
Return the constant as an APInt value reference.
This class represents a range of values.
static ConstantRange fromKnownBits(const KnownBits &Known, bool IsSigned)
Initialize a range based on a known bits constraint.
APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
This is an important base class in LLVM.
A parsed version of the target data layout string in and methods for querying it.
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
This instruction compares its operands according to the predicate given to the constructor.
InstListType::iterator insertInto(BasicBlock *ParentBB, InstListType::iterator It)
Inserts an unlinked instruction into ParentBB at position It and returns the iterator of the inserted...
Class to represent integer types.
Analysis to compute lazy value information.
Wrapper around LazyValueInfo.
This pass computes, caches, and vends lazy value constraint information.
void eraseBlock(BasicBlock *BB)
Inform the analysis cache that we have erased a block.
ConstantRange getConstantRange(Value *V, Instruction *CxtI, bool UndefAllowed)
Return the ConstantRange constraint that is known to hold for the specified value at the specified in...
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
void setIncomingBlock(unsigned i, BasicBlock *BB)
Value * removeIncomingValue(unsigned Idx, bool DeletePHIIfEmpty=true)
Remove an incoming value.
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
int getBasicBlockIndex(const BasicBlock *BB) const
Return the first index of the specified basic block in the value list for this PHI.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
LLVM Value Representation.
LLVMContext & getContext() const
All values hold a context through their type.
StringRef getName() const
Return a constant reference to the value's name.
self_iterator getIterator()
This class implements an extremely fast bulk output stream that can only output to a stream.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
@ C
The default llvm calling convention, compatible with C.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
This is an optimization pass for GlobalISel generic memory operations.
@ Low
Lower the current thread's priority such that it does not affect foreground tasks significantly.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
void DeleteDeadBlock(BasicBlock *BB, DomTreeUpdater *DTU=nullptr, bool KeepOneInputPHIs=false)
Delete the specified block, which must have no predecessors.
auto reverse(ContainerTy &&C)
void sort(IteratorTy Start, IteratorTy End)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
auto lower_bound(R &&Range, T &&Value)
Provide wrappers to std::lower_bound which take ranges instead of having to pass begin/end explicitly...
void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, unsigned Depth=0, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
constexpr unsigned BitWidth
void erase_if(Container &C, UnaryPredicate P)
Provide a container algorithm similar to C++ Library Fundamentals v2's erase_if which is equivalent t...
FunctionPass * createLowerSwitchPass()
bool pred_empty(const BasicBlock *BB)
void initializeLowerSwitchLegacyPassPass(PassRegistry &)
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)