Line data Source code
1 : //===- BranchProbability.h - Branch Probability Wrapper ---------*- C++ -*-===//
2 : //
3 : // The LLVM Compiler Infrastructure
4 : //
5 : // This file is distributed under the University of Illinois Open Source
6 : // License. See LICENSE.TXT for details.
7 : //
8 : //===----------------------------------------------------------------------===//
9 : //
10 : // Definition of BranchProbability shared by IR and Machine Instructions.
11 : //
12 : //===----------------------------------------------------------------------===//
13 :
14 : #ifndef LLVM_SUPPORT_BRANCHPROBABILITY_H
15 : #define LLVM_SUPPORT_BRANCHPROBABILITY_H
16 :
17 : #include "llvm/Support/DataTypes.h"
18 : #include <algorithm>
19 : #include <cassert>
20 : #include <climits>
21 : #include <numeric>
22 :
23 : namespace llvm {
24 :
25 : class raw_ostream;
26 :
27 : // This class represents Branch Probability as a non-negative fraction that is
28 : // no greater than 1. It uses a fixed-point-like implementation, in which the
29 : // denominator is always a constant value (here we use 1<<31 for maximum
30 : // precision).
31 : class BranchProbability {
32 : // Numerator
33 : uint32_t N;
34 :
35 : // Denominator, which is a constant value.
36 : static const uint32_t D = 1u << 31;
37 : static const uint32_t UnknownN = UINT32_MAX;
38 :
39 : // Construct a BranchProbability with only numerator assuming the denominator
40 : // is 1<<31. For internal use only.
41 : explicit BranchProbability(uint32_t n) : N(n) {}
42 :
43 : public:
44 1298204 : BranchProbability() : N(UnknownN) {}
45 : BranchProbability(uint32_t Numerator, uint32_t Denominator);
46 :
47 0 : bool isZero() const { return N == 0; }
48 0 : bool isUnknown() const { return N == UnknownN; }
49 :
50 : static BranchProbability getZero() { return BranchProbability(0); }
51 : static BranchProbability getOne() { return BranchProbability(D); }
52 : static BranchProbability getUnknown() { return BranchProbability(UnknownN); }
53 : // Create a BranchProbability object with the given numerator and 1<<31
54 : // as denominator.
55 : static BranchProbability getRaw(uint32_t N) { return BranchProbability(N); }
56 : // Create a BranchProbability object from 64-bit integers.
57 : static BranchProbability getBranchProbability(uint64_t Numerator,
58 : uint64_t Denominator);
59 :
60 : // Normalize given probabilties so that the sum of them becomes approximate
61 : // one.
62 : template <class ProbabilityIter>
63 : static void normalizeProbabilities(ProbabilityIter Begin,
64 : ProbabilityIter End);
65 :
66 0 : uint32_t getNumerator() const { return N; }
67 : static uint32_t getDenominator() { return D; }
68 :
69 : // Return (1 - Probability).
70 1363740 : BranchProbability getCompl() const { return BranchProbability(D - N); }
71 :
72 : raw_ostream &print(raw_ostream &OS) const;
73 :
74 : void dump() const;
75 :
76 : /// Scale a large integer.
77 : ///
78 : /// Scales \c Num. Guarantees full precision. Returns the floor of the
79 : /// result.
80 : ///
81 : /// \return \c Num times \c this.
82 : uint64_t scale(uint64_t Num) const;
83 :
84 : /// Scale a large integer by the inverse.
85 : ///
86 : /// Scales \c Num by the inverse of \c this. Guarantees full precision.
87 : /// Returns the floor of the result.
88 : ///
89 : /// \return \c Num divided by \c this.
90 : uint64_t scaleByInverse(uint64_t Num) const;
91 :
92 : BranchProbability &operator+=(BranchProbability RHS) {
93 : assert(N != UnknownN && RHS.N != UnknownN &&
94 : "Unknown probability cannot participate in arithmetics.");
95 : // Saturate the result in case of overflow.
96 295891 : N = (uint64_t(N) + RHS.N > D) ? D : N + RHS.N;
97 : return *this;
98 : }
99 :
100 : BranchProbability &operator-=(BranchProbability RHS) {
101 : assert(N != UnknownN && RHS.N != UnknownN &&
102 : "Unknown probability cannot participate in arithmetics.");
103 : // Saturate the result in case of underflow.
104 161371 : N = N < RHS.N ? 0 : N - RHS.N;
105 : return *this;
106 : }
107 :
108 : BranchProbability &operator*=(BranchProbability RHS) {
109 : assert(N != UnknownN && RHS.N != UnknownN &&
110 : "Unknown probability cannot participate in arithmetics.");
111 176 : N = (static_cast<uint64_t>(N) * RHS.N + D / 2) / D;
112 : return *this;
113 : }
114 :
115 : BranchProbability &operator*=(uint32_t RHS) {
116 : assert(N != UnknownN &&
117 : "Unknown probability cannot participate in arithmetics.");
118 15953 : N = (uint64_t(N) * RHS > D) ? D : N * RHS;
119 : return *this;
120 : }
121 :
122 : BranchProbability &operator/=(uint32_t RHS) {
123 : assert(N != UnknownN &&
124 : "Unknown probability cannot participate in arithmetics.");
125 : assert(RHS > 0 && "The divider cannot be zero.");
126 872407 : N /= RHS;
127 : return *this;
128 : }
129 :
130 : BranchProbability operator+(BranchProbability RHS) const {
131 146 : BranchProbability Prob(*this);
132 : return Prob += RHS;
133 : }
134 :
135 : BranchProbability operator-(BranchProbability RHS) const {
136 6 : BranchProbability Prob(*this);
137 : return Prob -= RHS;
138 : }
139 :
140 : BranchProbability operator*(BranchProbability RHS) const {
141 2 : BranchProbability Prob(*this);
142 : return Prob *= RHS;
143 : }
144 :
145 : BranchProbability operator*(uint32_t RHS) const {
146 15957 : BranchProbability Prob(*this);
147 : return Prob *= RHS;
148 : }
149 :
150 : BranchProbability operator/(uint32_t RHS) const {
151 105463 : BranchProbability Prob(*this);
152 : return Prob /= RHS;
153 : }
154 :
155 7 : bool operator==(BranchProbability RHS) const { return N == RHS.N; }
156 11517 : bool operator!=(BranchProbability RHS) const { return !(*this == RHS); }
157 :
158 0 : bool operator<(BranchProbability RHS) const {
159 : assert(N != UnknownN && RHS.N != UnknownN &&
160 : "Unknown probability cannot participate in comparisons.");
161 963 : return N < RHS.N;
162 : }
163 :
164 : bool operator>(BranchProbability RHS) const {
165 : assert(N != UnknownN && RHS.N != UnknownN &&
166 : "Unknown probability cannot participate in comparisons.");
167 : return RHS < *this;
168 : }
169 :
170 : bool operator<=(BranchProbability RHS) const {
171 : assert(N != UnknownN && RHS.N != UnknownN &&
172 : "Unknown probability cannot participate in comparisons.");
173 7 : return !(RHS < *this);
174 : }
175 :
176 : bool operator>=(BranchProbability RHS) const {
177 : assert(N != UnknownN && RHS.N != UnknownN &&
178 : "Unknown probability cannot participate in comparisons.");
179 725 : return !(*this < RHS);
180 : }
181 : };
182 :
183 : inline raw_ostream &operator<<(raw_ostream &OS, BranchProbability Prob) {
184 921 : return Prob.print(OS);
185 : }
186 :
187 : template <class ProbabilityIter>
188 611144 : void BranchProbability::normalizeProbabilities(ProbabilityIter Begin,
189 : ProbabilityIter End) {
190 611144 : if (Begin == End)
191 494752 : return;
192 :
193 119574 : unsigned UnknownProbCount = 0;
194 : uint64_t Sum = std::accumulate(Begin, End, uint64_t(0),
195 : [&](uint64_t S, const BranchProbability &BP) {
196 239152 : if (!BP.isUnknown())
197 234701 : return S + BP.N;
198 4451 : UnknownProbCount++;
199 : return S;
200 : });
201 :
202 119574 : if (UnknownProbCount > 0) {
203 : BranchProbability ProbForUnknown = BranchProbability::getZero();
204 : // If the sum of all known probabilities is less than one, evenly distribute
205 : // the complement of sum to unknown probabilities. Otherwise, set unknown
206 : // probabilities to zeros and continue to normalize known probabilities.
207 2847 : if (Sum < BranchProbability::getDenominator())
208 2844 : ProbForUnknown = BranchProbability::getRaw(
209 2844 : (BranchProbability::getDenominator() - Sum) / UnknownProbCount);
210 :
211 : std::replace_if(Begin, End,
212 4464 : [](const BranchProbability &BP) { return BP.isUnknown(); },
213 : ProbForUnknown);
214 :
215 2847 : if (Sum <= BranchProbability::getDenominator())
216 : return;
217 : }
218 :
219 116728 : if (Sum == 0) {
220 336 : BranchProbability BP(1, std::distance(Begin, End));
221 : std::fill(Begin, End, BP);
222 : return;
223 : }
224 :
225 350558 : for (auto I = Begin; I != End; ++I)
226 234166 : I->N = (I->N * uint64_t(D) + Sum / 2) / Sum;
227 : }
228 :
229 : }
230 :
231 : #endif
|