21using namespace SwitchCG;
26 const APInt &LowCase = Clusters[
First].Low->getValue();
27 const APInt &HighCase = Clusters[
Last].High->getValue();
33 return (HighCase - LowCase).getLimitedValue((
UINT64_MAX - 1) / 100) + 1;
48 std::optional<SDLoc> SL,
57 for (
unsigned i = 1, e = Clusters.size(); i < e; ++i)
58 assert(Clusters[i - 1].
High->getValue().slt(Clusters[i].Low->getValue()));
61 assert(TLI &&
"TLI not set!");
66 const unsigned SmallNumberOfEntries = MinJumpTableEntries / 2;
69 const int64_t
N = Clusters.size();
70 if (
N < 2 ||
N < MinJumpTableEntries)
75 for (
unsigned i = 0; i <
N; ++i) {
76 const APInt &
Hi = Clusters[i].High->getValue();
77 const APInt &
Lo = Clusters[i].Low->getValue();
78 TotalCases[i] = (
Hi -
Lo).getLimitedValue() + 1;
80 TotalCases[i] += TotalCases[i - 1];
91 if (
buildJumpTable(Clusters, 0,
N - 1, SI, SL, DefaultMBB, JTCluster)) {
92 Clusters[0] = JTCluster;
120 enum PartitionScores :
unsigned {
128 MinPartitions[
N - 1] = 1;
129 LastElement[
N - 1] =
N - 1;
130 PartitionsScore[
N - 1] = PartitionScores::SingleCase;
133 for (int64_t i =
N - 2; i >= 0; i--) {
136 MinPartitions[i] = MinPartitions[i + 1] + 1;
138 PartitionsScore[i] = PartitionsScore[i + 1] + PartitionScores::SingleCase;
141 for (int64_t j =
N - 1; j > i; j--) {
149 unsigned NumPartitions = 1 + (j ==
N - 1 ? 0 : MinPartitions[j + 1]);
150 unsigned Score = j ==
N - 1 ? 0 : PartitionsScore[j + 1];
151 int64_t NumEntries = j - i + 1;
154 Score += PartitionScores::SingleCase;
155 else if (NumEntries <= SmallNumberOfEntries)
156 Score += PartitionScores::FewCases;
157 else if (NumEntries >= MinJumpTableEntries)
158 Score += PartitionScores::Table;
162 if (NumPartitions < MinPartitions[i] ||
163 (NumPartitions == MinPartitions[i] && Score > PartitionsScore[i])) {
164 MinPartitions[i] = NumPartitions;
166 PartitionsScore[i] = Score;
173 unsigned DstIndex = 0;
181 if (NumClusters >= MinJumpTableEntries &&
183 Clusters[DstIndex++] = JTCluster;
186 std::memmove(&Clusters[DstIndex++], &Clusters[
I],
sizeof(Clusters[
I]));
189 Clusters.resize(DstIndex);
195 const std::optional<SDLoc> &SL,
201 unsigned NumCmps = 0;
202 std::vector<MachineBasicBlock*> Table;
211 Prob += Clusters[
I].Prob;
212 const APInt &
Low = Clusters[
I].Low->getValue();
213 const APInt &
High = Clusters[
I].High->getValue();
214 NumCmps += (
Low ==
High) ? 1 : 2;
217 const APInt &PreviousHigh = Clusters[
I - 1].High->getValue();
219 uint64_t Gap = (
Low - PreviousHigh).getLimitedValue() - 1;
221 Table.push_back(DefaultMBB);
224 for (
uint64_t J = 0; J < ClusterSize; ++J)
225 Table.push_back(Clusters[
I].MBB);
226 JTProbs[Clusters[
I].MBB] += Clusters[
I].Prob;
229 unsigned NumDests = JTProbs.
size();
230 if (TLI->isSuitableForBitTests(NumDests, NumCmps,
231 Clusters[
First].Low->getValue(),
232 Clusters[
Last].High->getValue(), *
DL)) {
246 if (
Done.count(Succ))
248 addSuccessorWithProb(JumpTableMBB, Succ, JTProbs[Succ]);
259 Clusters[
Last].High->getValue(), SI->getCondition(),
261 JTCases.emplace_back(std::move(JTH), std::move(JT));
264 JTCases.size() - 1, Prob);
275 assert(!Clusters.empty());
279 for (
unsigned i = 1; i < Clusters.size(); ++i)
280 assert(Clusters[i-1].
High->getValue().slt(Clusters[i].Low->getValue()));
288 EVT PTy = TLI->getPointerTy(*
DL);
289 if (!TLI->isOperationLegal(
ISD::SHL, PTy))
293 const int64_t
N = Clusters.size();
303 MinPartitions[
N - 1] = 1;
304 LastElement[
N - 1] =
N - 1;
307 for (int64_t i =
N - 2; i >= 0; --i) {
310 MinPartitions[i] = MinPartitions[i + 1] + 1;
315 for (int64_t j = std::min(
N - 1, i +
BitWidth - 1); j > i; --j) {
319 if (!TLI->rangeFitsInWord(Clusters[i].Low->getValue(),
320 Clusters[j].High->getValue(), *
DL))
325 bool RangesOnly =
true;
326 BitVector Dests(FuncInfo.MF->getNumBlockIDs());
327 for (int64_t k = i; k <= j; k++) {
334 if (!RangesOnly || Dests.
count() > 3)
338 unsigned NumPartitions = 1 + (j ==
N - 1 ? 0 : MinPartitions[j + 1]);
339 if (NumPartitions < MinPartitions[i]) {
341 MinPartitions[i] = NumPartitions;
348 unsigned DstIndex = 0;
355 if (buildBitTests(Clusters,
First,
Last, SI, BitTestCluster)) {
356 Clusters[DstIndex++] = BitTestCluster;
359 std::memmove(&Clusters[DstIndex], &Clusters[
First],
360 sizeof(Clusters[0]) * NumClusters);
361 DstIndex += NumClusters;
364 Clusters.resize(DstIndex);
375 BitVector Dests(FuncInfo.MF->getNumBlockIDs());
376 unsigned NumCmps = 0;
380 NumCmps += (Clusters[
I].Low == Clusters[
I].High) ? 1 : 2;
382 unsigned NumDests = Dests.
count();
388 if (!TLI->isSuitableForBitTests(NumDests, NumCmps,
Low,
High, *
DL))
394 const int BitWidth = TLI->getPointerTy(*DL).getSizeInBits();
396 "Case range must fit in bit mask!");
400 bool ContiguousRange =
true;
402 if (Clusters[
I].
Low->getValue() != Clusters[
I - 1].High->getValue() + 1) {
403 ContiguousRange =
false;
413 ContiguousRange =
false;
421 for (
unsigned i =
First; i <=
Last; ++i) {
424 for (j = 0; j < CBV.size(); ++j)
425 if (CBV[j].BB == Clusters[i].
MBB)
433 uint64_t Lo = (Clusters[i].Low->getValue() - LowBound).getZExtValue();
434 uint64_t Hi = (Clusters[i].High->getValue() - LowBound).getZExtValue();
436 CB->
Mask |= (-1ULL >> (63 - (
Hi -
Lo))) <<
Lo;
439 TotalProb += Clusters[i].Prob;
448 return a.
Bits > b.Bits;
449 return a.
Mask < b.Mask;
452 for (
auto &CB : CBV) {
454 FuncInfo.MF->CreateMachineBasicBlock(SI->getParent());
457 BitTestCases.emplace_back(std::move(LowBound), std::move(CmpRange),
458 SI->getCondition(),
Register(), MVT::Other,
false,
459 ContiguousRange,
nullptr,
nullptr, std::move(BTI),
463 BitTestCases.size() - 1, TotalProb);
470 assert(
CC.Low ==
CC.High &&
"Input clusters must be single-case");
478 const unsigned N = Clusters.size();
479 unsigned DstIndex = 0;
480 for (
unsigned SrcIndex = 0; SrcIndex <
N; ++SrcIndex) {
485 if (DstIndex != 0 && Clusters[DstIndex - 1].
MBB == Succ &&
486 (CaseVal->
getValue() - Clusters[DstIndex - 1].High->getValue()) == 1) {
489 Clusters[DstIndex - 1].High = CaseVal;
490 Clusters[DstIndex - 1].Prob +=
CC.Prob;
492 std::memmove(&Clusters[DstIndex++], &Clusters[SrcIndex],
493 sizeof(Clusters[SrcIndex]));
496 Clusters.resize(DstIndex);
503 if (
X.Prob !=
CC.Prob)
504 return X.Prob >
CC.Prob;
507 return X.Low->getValue().slt(
CC.Low->getValue());
516 auto LeftProb = LastLeft->Prob + W.DefaultProb / 2;
517 auto RightProb = FirstRight->Prob + W.DefaultProb / 2;
524 while (LastLeft + 1 < FirstRight) {
525 if (LeftProb < RightProb || (LeftProb == RightProb && (
I & 1)))
526 LeftProb += (++LastLeft)->Prob;
528 RightProb += (--FirstRight)->Prob;
538 unsigned NumLeft = LastLeft - W.FirstCluster + 1;
539 unsigned NumRight = W.LastCluster - FirstRight + 1;
541 if (std::min(NumLeft, NumRight) < 3 && std::max(NumLeft, NumRight) > 3) {
545 if (NumLeft < NumRight) {
548 unsigned RightSideRank = caseClusterRank(
CC, FirstRight, W.LastCluster);
549 unsigned LeftSideRank = caseClusterRank(
CC, W.FirstCluster, LastLeft);
550 if (LeftSideRank <= RightSideRank) {
557 assert(NumRight < NumLeft);
560 unsigned LeftSideRank = caseClusterRank(
CC, W.FirstCluster, LastLeft);
561 unsigned RightSideRank = caseClusterRank(
CC, FirstRight, W.LastCluster);
562 if (RightSideRank <= LeftSideRank) {
573 assert(LastLeft + 1 == FirstRight);
574 assert(LastLeft >= W.FirstCluster);
575 assert(FirstRight <= W.LastCluster);
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCMetadataPrinterRegistry::Add< ErlangGCPrinter > X("erlang", "erlang-compatible garbage collector")
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file describes how to lower LLVM code to machine code.
Class for arbitrary precision integers.
unsigned getBitWidth() const
Return the number of bits in the APInt.
bool slt(const APInt &RHS) const
Signed less than comparison.
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
size_type count() const
count - Returns the number of bits which are set.
BlockFrequencyInfo pass uses BlockFrequencyInfoImpl implementation to estimate IR basic block frequen...
static BranchProbability getZero()
This is the shared class of boolean and integer constants.
const APInt & getValue() const
Return the constant as an APInt value reference.
void normalizeSuccProbs()
Normalize probabilities of all successors so that the sum of them becomes one.
int getNumber() const
MachineBasicBlocks are uniquely numbered at the function level, unless they're not in a MachineFuncti...
MachineJumpTableInfo * getOrCreateJumpTableInfo(unsigned JTEntryKind)
getOrCreateJumpTableInfo - Get the JumpTableInfo for this function, if it does already exist,...
MachineBasicBlock * CreateMachineBasicBlock(const BasicBlock *BB=nullptr, std::optional< UniqueBBID > BBID=std::nullopt)
CreateMachineBasicBlock - Allocate a new MachineBasicBlock.
unsigned createJumpTableIndex(const std::vector< MachineBasicBlock * > &DestBBs)
createJumpTableIndex - Create a new jump table.
Analysis providing profile information.
Wrapper class representing virtual and physical registers.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
bool buildBitTests(CaseClusterVector &Clusters, unsigned First, unsigned Last, const SwitchInst *SI, CaseCluster &BTCluster)
Build a bit test cluster from Clusters[First..Last].
unsigned caseClusterRank(const CaseCluster &CC, CaseClusterIt First, CaseClusterIt Last)
Determine the rank by weight of CC in [First,Last].
void findJumpTables(CaseClusterVector &Clusters, const SwitchInst *SI, std::optional< SDLoc > SL, MachineBasicBlock *DefaultMBB, ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI)
void findBitTestClusters(CaseClusterVector &Clusters, const SwitchInst *SI)
SplitWorkItemInfo computeSplitWorkItemInfo(const SwitchWorkListItem &W)
Compute information to balance the tree based on branch probabilities to create a near-optimal (in te...
bool buildJumpTable(const CaseClusterVector &Clusters, unsigned First, unsigned Last, const SwitchInst *SI, const std::optional< SDLoc > &SL, MachineBasicBlock *DefaultMBB, CaseCluster &JTCluster)
virtual unsigned getMinimumJumpTableEntries() const
Return lower limit for number of blocks in a jump table.
virtual bool isSuitableForJumpTable(const SwitchInst *SI, uint64_t NumCases, uint64_t Range, ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI) const
Return true if lowering to a jump table is suitable for a set of case clusters which may contain NumC...
virtual bool areJTsAllowed(const Function *Fn) const
Return true if lowering to a jump table is allowed.
CodeGenOptLevel getOptLevel() const
Returns the optimization level: None, Less, Default, or Aggressive.
@ C
The default llvm calling convention, compatible with C.
@ SHL
Shift and rotation operations.
uint64_t getJumpTableNumCases(const SmallVectorImpl< unsigned > &TotalCases, unsigned First, unsigned Last)
Return the number of cases within a range.
std::vector< CaseCluster > CaseClusterVector
void sortAndRangeify(CaseClusterVector &Clusters)
Sort Clusters and merge adjacent cases.
CaseClusterVector::iterator CaseClusterIt
uint64_t getJumpTableRange(const CaseClusterVector &Clusters, unsigned First, unsigned Last)
Return the range of values within a range.
std::vector< CaseBits > CaseBitsVector
@ CC_Range
A cluster of adjacent case labels with the same destination, or just one case.
@ CC_JumpTable
A cluster of cases suitable for jump table lowering.
This is an optimization pass for GlobalISel generic memory operations.
@ Low
Lower the current thread's priority such that it does not affect foreground tasks significantly.
void sort(IteratorTy Start, IteratorTy End)
@ First
Helpers to iterate all locations in the MemoryEffectsBase class.
constexpr unsigned BitWidth
TypeSize getSizeInBits() const
Return the size of the specified value type in bits.
BranchProbability ExtraProb
A cluster of case labels.
static CaseCluster jumpTable(const ConstantInt *Low, const ConstantInt *High, unsigned JTCasesIndex, BranchProbability Prob)
static CaseCluster bitTests(const ConstantInt *Low, const ConstantInt *High, unsigned BTCasesIndex, BranchProbability Prob)