Line data Source code
1 : //===- SyncDependenceAnalysis.cpp - Divergent Branch Dependence Calculation
2 : //--===//
3 : //
4 : // The LLVM Compiler Infrastructure
5 : //
6 : // This file is distributed under the University of Illinois Open Source
7 : // License. See LICENSE.TXT for details.
8 : //
9 : //===----------------------------------------------------------------------===//
10 : //
11 : // This file implements an algorithm that returns for a divergent branch
12 : // the set of basic blocks whose phi nodes become divergent due to divergent
13 : // control. These are the blocks that are reachable by two disjoint paths from
14 : // the branch or loop exits that have a reaching path that is disjoint from a
15 : // path to the loop latch.
16 : //
17 : // The SyncDependenceAnalysis is used in the DivergenceAnalysis to model
18 : // control-induced divergence in phi nodes.
19 : //
20 : // -- Summary --
21 : // The SyncDependenceAnalysis lazily computes sync dependences [3].
22 : // The analysis evaluates the disjoint path criterion [2] by a reduction
23 : // to SSA construction. The SSA construction algorithm is implemented as
24 : // a simple data-flow analysis [1].
25 : //
26 : // [1] "A Simple, Fast Dominance Algorithm", SPI '01, Cooper, Harvey and Kennedy
27 : // [2] "Efficiently Computing Static Single Assignment Form
28 : // and the Control Dependence Graph", TOPLAS '91,
29 : // Cytron, Ferrante, Rosen, Wegman and Zadeck
30 : // [3] "Improving Performance of OpenCL on CPUs", CC '12, Karrenberg and Hack
31 : // [4] "Divergence Analysis", TOPLAS '13, Sampaio, Souza, Collange and Pereira
32 : //
33 : // -- Sync dependence --
34 : // Sync dependence [4] characterizes the control flow aspect of the
35 : // propagation of branch divergence. For example,
36 : //
37 : // %cond = icmp slt i32 %tid, 10
38 : // br i1 %cond, label %then, label %else
39 : // then:
40 : // br label %merge
41 : // else:
42 : // br label %merge
43 : // merge:
44 : // %a = phi i32 [ 0, %then ], [ 1, %else ]
45 : //
46 : // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
47 : // because %tid is not on its use-def chains, %a is sync dependent on %tid
48 : // because the branch "br i1 %cond" depends on %tid and affects which value %a
49 : // is assigned to.
50 : //
51 : // -- Reduction to SSA construction --
52 : // There are two disjoint paths from A to X, if a certain variant of SSA
53 : // construction places a phi node in X under the following set-up scheme [2].
54 : //
55 : // This variant of SSA construction ignores incoming undef values.
56 : // That is paths from the entry without a definition do not result in
57 : // phi nodes.
58 : //
59 : // entry
60 : // / \
61 : // A \
62 : // / \ Y
63 : // B C /
64 : // \ / \ /
65 : // D E
66 : // \ /
67 : // F
68 : // Assume that A contains a divergent branch. We are interested
69 : // in the set of all blocks where each block is reachable from A
70 : // via two disjoint paths. This would be the set {D, F} in this
71 : // case.
72 : // To generally reduce this query to SSA construction we introduce
73 : // a virtual variable x and assign to x different values in each
74 : // successor block of A.
75 : // entry
76 : // / \
77 : // A \
78 : // / \ Y
79 : // x = 0 x = 1 /
80 : // \ / \ /
81 : // D E
82 : // \ /
83 : // F
84 : // Our flavor of SSA construction for x will construct the following
85 : // entry
86 : // / \
87 : // A \
88 : // / \ Y
89 : // x0 = 0 x1 = 1 /
90 : // \ / \ /
91 : // x2=phi E
92 : // \ /
93 : // x3=phi
94 : // The blocks D and F contain phi nodes and are thus each reachable
95 : // by two disjoins paths from A.
96 : //
97 : // -- Remarks --
98 : // In case of loop exits we need to check the disjoint path criterion for loops
99 : // [2]. To this end, we check whether the definition of x differs between the
100 : // loop exit and the loop header (_after_ SSA construction).
101 : //
102 : //===----------------------------------------------------------------------===//
103 : #include "llvm/ADT/PostOrderIterator.h"
104 : #include "llvm/ADT/SmallPtrSet.h"
105 : #include "llvm/Analysis/PostDominators.h"
106 : #include "llvm/Analysis/SyncDependenceAnalysis.h"
107 : #include "llvm/IR/BasicBlock.h"
108 : #include "llvm/IR/CFG.h"
109 : #include "llvm/IR/Dominators.h"
110 : #include "llvm/IR/Function.h"
111 :
112 : #include <stack>
113 : #include <unordered_set>
114 :
115 : #define DEBUG_TYPE "sync-dependence"
116 :
117 : namespace llvm {
118 :
119 : ConstBlockSet SyncDependenceAnalysis::EmptyBlockSet;
120 :
121 13 : SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT,
122 : const PostDominatorTree &PDT,
123 13 : const LoopInfo &LI)
124 26 : : FuncRPOT(DT.getRoot()->getParent()), DT(DT), PDT(PDT), LI(LI) {}
125 :
126 13 : SyncDependenceAnalysis::~SyncDependenceAnalysis() {}
127 :
128 : using FunctionRPOT = ReversePostOrderTraversal<const Function *>;
129 :
130 : // divergence propagator for reducible CFGs
131 : struct DivergencePropagator {
132 : const FunctionRPOT &FuncRPOT;
133 : const DominatorTree &DT;
134 : const PostDominatorTree &PDT;
135 : const LoopInfo &LI;
136 :
137 : // identified join points
138 : std::unique_ptr<ConstBlockSet> JoinBlocks;
139 :
140 : // reached loop exits (by a path disjoint to a path to the loop header)
141 : SmallPtrSet<const BasicBlock *, 4> ReachedLoopExits;
142 :
143 : // if DefMap[B] == C then C is the dominating definition at block B
144 : // if DefMap[B] ~ undef then we haven't seen B yet
145 : // if DefMap[B] == B then B is a join point of disjoint paths from X or B is
146 : // an immediate successor of X (initial value).
147 : using DefiningBlockMap = std::map<const BasicBlock *, const BasicBlock *>;
148 : DefiningBlockMap DefMap;
149 :
150 : // all blocks with pending visits
151 : std::unordered_set<const BasicBlock *> PendingUpdates;
152 :
153 12 : DivergencePropagator(const FunctionRPOT &FuncRPOT, const DominatorTree &DT,
154 : const PostDominatorTree &PDT, const LoopInfo &LI)
155 12 : : FuncRPOT(FuncRPOT), DT(DT), PDT(PDT), LI(LI),
156 12 : JoinBlocks(new ConstBlockSet) {}
157 :
158 : // set the definition at @block and mark @block as pending for a visit
159 17 : void addPending(const BasicBlock &Block, const BasicBlock &DefBlock) {
160 17 : bool WasAdded = DefMap.emplace(&Block, &DefBlock).second;
161 17 : if (WasAdded)
162 6 : PendingUpdates.insert(&Block);
163 17 : }
164 :
165 : void printDefs(raw_ostream &Out) {
166 : Out << "Propagator::DefMap {\n";
167 : for (const auto *Block : FuncRPOT) {
168 : auto It = DefMap.find(Block);
169 : Out << Block->getName() << " : ";
170 : if (It == DefMap.end()) {
171 : Out << "\n";
172 : } else {
173 : const auto *DefBlock = It->second;
174 : Out << (DefBlock ? DefBlock->getName() : "<null>") << "\n";
175 : }
176 : }
177 : Out << "}\n";
178 : }
179 :
180 : // process @succBlock with reaching definition @defBlock
181 : // the original divergent branch was in @parentLoop (if any)
182 18 : void visitSuccessor(const BasicBlock &SuccBlock, const Loop *ParentLoop,
183 : const BasicBlock &DefBlock) {
184 :
185 : // @succBlock is a loop exit
186 18 : if (ParentLoop && !ParentLoop->contains(&SuccBlock)) {
187 0 : DefMap.emplace(&SuccBlock, &DefBlock);
188 0 : ReachedLoopExits.insert(&SuccBlock);
189 0 : return;
190 : }
191 :
192 : // first reaching def?
193 18 : auto ItLastDef = DefMap.find(&SuccBlock);
194 18 : if (ItLastDef == DefMap.end()) {
195 6 : addPending(SuccBlock, DefBlock);
196 6 : return;
197 : }
198 :
199 : // a join of at least two definitions
200 12 : if (ItLastDef->second != &DefBlock) {
201 : // do we know this join already?
202 11 : if (!JoinBlocks->insert(&SuccBlock).second)
203 : return;
204 :
205 : // update the definition
206 11 : addPending(SuccBlock, SuccBlock);
207 : }
208 : }
209 :
210 : // find all blocks reachable by two disjoint paths from @rootTerm.
211 : // This method works for both divergent TerminatorInsts and loops with
212 : // divergent exits.
213 : // @rootBlock is either the block containing the branch or the header of the
214 : // divergent loop.
215 : // @nodeSuccessors is the set of successors of the node (Loop or Terminator)
216 : // headed by @rootBlock.
217 : // @parentLoop is the parent loop of the Loop or the loop that contains the
218 : // Terminator.
219 : template <typename SuccessorIterable>
220 : std::unique_ptr<ConstBlockSet>
221 12 : computeJoinPoints(const BasicBlock &RootBlock,
222 : SuccessorIterable NodeSuccessors, const Loop *ParentLoop) {
223 : assert(JoinBlocks);
224 :
225 : // immediate post dominator (no join block beyond that block)
226 12 : const auto *PdNode = PDT.getNode(const_cast<BasicBlock *>(&RootBlock));
227 12 : const auto *IpdNode = PdNode->getIDom();
228 12 : const auto *PdBoundBlock = IpdNode ? IpdNode->getBlock() : nullptr;
229 :
230 : // bootstrap with branch targets
231 56 : for (const auto *SuccBlock : NodeSuccessors) {
232 : DefMap.emplace(SuccBlock, SuccBlock);
233 :
234 23 : if (ParentLoop && !ParentLoop->contains(SuccBlock)) {
235 : // immediate loop exit from node.
236 2 : ReachedLoopExits.insert(SuccBlock);
237 2 : continue;
238 : } else {
239 : // regular successor
240 : PendingUpdates.insert(SuccBlock);
241 : }
242 : }
243 :
244 12 : auto ItBeginRPO = FuncRPOT.begin();
245 :
246 : // skip until term (TODO RPOT won't let us start at @term directly)
247 26 : for (; *ItBeginRPO != &RootBlock; ++ItBeginRPO) {}
248 :
249 : auto ItEndRPO = FuncRPOT.end();
250 : assert(ItBeginRPO != ItEndRPO);
251 :
252 : // propagate definitions at the immediate successors of the node in RPO
253 : auto ItBlockRPO = ItBeginRPO;
254 28 : while (++ItBlockRPO != ItEndRPO && *ItBlockRPO != PdBoundBlock) {
255 16 : const auto *Block = *ItBlockRPO;
256 :
257 : // skip @block if not pending update
258 : auto ItPending = PendingUpdates.find(Block);
259 16 : if (ItPending == PendingUpdates.end())
260 0 : continue;
261 : PendingUpdates.erase(ItPending);
262 :
263 : // propagate definition at @block to its successors
264 : auto ItDef = DefMap.find(Block);
265 16 : const auto *DefBlock = ItDef->second;
266 : assert(DefBlock);
267 :
268 16 : auto *BlockLoop = LI.getLoopFor(Block);
269 0 : if (ParentLoop &&
270 16 : (ParentLoop != BlockLoop && ParentLoop->contains(BlockLoop))) {
271 : // if the successor is the header of a nested loop pretend its a
272 : // single node with the loop's exits as successors
273 : SmallVector<BasicBlock *, 4> BlockLoopExits;
274 0 : BlockLoop->getExitBlocks(BlockLoopExits);
275 0 : for (const auto *BlockLoopExit : BlockLoopExits) {
276 0 : visitSuccessor(*BlockLoopExit, ParentLoop, *DefBlock);
277 : }
278 :
279 : } else {
280 : // the successors are either on the same loop level or loop exits
281 50 : for (const auto *SuccBlock : successors(Block)) {
282 18 : visitSuccessor(*SuccBlock, ParentLoop, *DefBlock);
283 : }
284 : }
285 : }
286 :
287 : // We need to know the definition at the parent loop header to decide
288 : // whether the definition at the header is different from the definition at
289 : // the loop exits, which would indicate a divergent loop exits.
290 : //
291 : // A // loop header
292 : // |
293 : // B // nested loop header
294 : // |
295 : // C -> X (exit from B loop) -..-> (A latch)
296 : // |
297 : // D -> back to B (B latch)
298 : // |
299 : // proper exit from both loops
300 : //
301 : // D post-dominates B as it is the only proper exit from the "A loop".
302 : // If C has a divergent branch, propagation will therefore stop at D.
303 : // That implies that B will never receive a definition.
304 : // But that definition can only be the same as at D (D itself in thise case)
305 : // because all paths to anywhere have to pass through D.
306 : //
307 12 : const BasicBlock *ParentLoopHeader =
308 : ParentLoop ? ParentLoop->getHeader() : nullptr;
309 12 : if (ParentLoop && ParentLoop->contains(PdBoundBlock)) {
310 0 : DefMap[ParentLoopHeader] = DefMap[PdBoundBlock];
311 : }
312 :
313 : // analyze reached loop exits
314 12 : if (!ReachedLoopExits.empty()) {
315 : assert(ParentLoop);
316 2 : const auto *HeaderDefBlock = DefMap[ParentLoopHeader];
317 : LLVM_DEBUG(printDefs(dbgs()));
318 : assert(HeaderDefBlock && "no definition in header of carrying loop");
319 :
320 4 : for (const auto *ExitBlock : ReachedLoopExits) {
321 : auto ItExitDef = DefMap.find(ExitBlock);
322 : assert((ItExitDef != DefMap.end()) &&
323 : "no reaching def at reachable loop exit");
324 2 : if (ItExitDef->second != HeaderDefBlock) {
325 2 : JoinBlocks->insert(ExitBlock);
326 : }
327 : }
328 : }
329 :
330 12 : return std::move(JoinBlocks);
331 : }
332 10 : };
333 :
334 : const ConstBlockSet &SyncDependenceAnalysis::join_blocks(const Loop &Loop) {
335 : using LoopExitVec = SmallVector<BasicBlock *, 4>;
336 : LoopExitVec LoopExits;
337 10 : Loop.getExitBlocks(LoopExits);
338 10 : if (LoopExits.size() < 1) {
339 10 : return EmptyBlockSet;
340 : }
341 :
342 52 : // already available in cache?
343 : auto ItCached = CachedLoopExitJoins.find(&Loop);
344 : if (ItCached != CachedLoopExitJoins.end())
345 21 : return *ItCached->second;
346 :
347 2 : // compute all join points
348 2 : DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI};
349 : auto JoinBlocks = Propagator.computeJoinPoints<const LoopExitVec &>(
350 : *Loop.getHeader(), LoopExits, Loop.getParentLoop());
351 :
352 : auto ItInserted = CachedLoopExitJoins.emplace(&Loop, std::move(JoinBlocks));
353 : assert(ItInserted.second);
354 : return *ItInserted.first->second;
355 10 : }
356 :
357 : const ConstBlockSet &
358 20 : SyncDependenceAnalysis::join_blocks(const TerminatorInst &Term) {
359 : // trivial case
360 : if (Term.getNumSuccessors() < 1) {
361 : return EmptyBlockSet;
362 : }
363 :
364 : // already available in cache?
365 26 : auto ItCached = CachedBranchJoins.find(&Term);
366 16 : if (ItCached != CachedBranchJoins.end())
367 : return *ItCached->second;
368 :
369 : // compute all join points
370 16 : DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI};
371 0 : const auto &TermBlock = *Term.getParent();
372 : auto JoinBlocks = Propagator.computeJoinPoints<succ_const_range>(
373 : TermBlock, successors(Term.getParent()), LI.getLoopFor(&TermBlock));
374 :
375 : auto ItInserted = CachedBranchJoins.emplace(&Term, std::move(JoinBlocks));
376 16 : assert(ItInserted.second);
377 : return *ItInserted.first->second;
378 : }
379 16 :
380 0 : } // namespace llvm
|