LLVM 17.0.0git
DivergenceAnalysis.cpp
Go to the documentation of this file.
1//===---- DivergenceAnalysis.cpp --- Divergence Analysis Implementation ----==//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a general divergence analysis for loop vectorization
10// and GPU programs. It determines which branches and values in a loop or GPU
11// program are divergent. It can help branch optimizations such as jump
12// threading and loop unswitching to make better decisions.
13//
14// GPU programs typically use the SIMD execution model, where multiple threads
15// in the same execution group have to execute in lock-step. Therefore, if the
16// code contains divergent branches (i.e., threads in a group do not agree on
17// which path of the branch to take), the group of threads has to execute all
18// the paths from that branch with different subsets of threads enabled until
19// they re-converge.
20//
21// Due to this execution model, some optimizations such as jump
22// threading and loop unswitching can interfere with thread re-convergence.
23// Therefore, an analysis that computes which branches in a GPU program are
24// divergent can help the compiler to selectively run these optimizations.
25//
26// This implementation is derived from the Vectorization Analysis of the
27// Region Vectorizer (RV). The analysis is based on the approach described in
28//
29// An abstract interpretation for SPMD divergence
30// on reducible control flow graphs.
31// Julian Rosemann, Simon Moll and Sebastian Hack
32// POPL '21
33//
34// This implementation is generic in the sense that it does
35// not itself identify original sources of divergence.
36// Instead specialized adapter classes, (LoopDivergenceAnalysis) for loops and
37// (DivergenceAnalysis) for functions, identify the sources of divergence
38// (e.g., special variables that hold the thread ID or the iteration variable).
39//
40// The generic implementation propagates divergence to variables that are data
41// or sync dependent on a source of divergence.
42//
43// While data dependency is a well-known concept, the notion of sync dependency
44// is worth more explanation. Sync dependence characterizes the control flow
45// aspect of the propagation of branch divergence. For example,
46//
47// %cond = icmp slt i32 %tid, 10
48// br i1 %cond, label %then, label %else
49// then:
50// br label %merge
51// else:
52// br label %merge
53// merge:
54// %a = phi i32 [ 0, %then ], [ 1, %else ]
55//
56// Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
57// because %tid is not on its use-def chains, %a is sync dependent on %tid
58// because the branch "br i1 %cond" depends on %tid and affects which value %a
59// is assigned to.
60//
61// The sync dependence detection (which branch induces divergence in which join
62// points) is implemented in the SyncDependenceAnalysis.
63//
64// The current implementation has the following limitations:
65// 1. intra-procedural. It conservatively considers the arguments of a
66// non-kernel-entry function and the return value of a function call as
67// divergent.
68// 2. memory as black box. It conservatively considers values loaded from
69// generic or local address as divergent. This can be improved by leveraging
70// pointer analysis and/or by modelling non-escaping memory objects in SSA
71// as done in RV.
72//
73//===----------------------------------------------------------------------===//
74
77#include "llvm/Analysis/CFG.h"
81#include "llvm/IR/Dominators.h"
84#include "llvm/IR/Value.h"
85#include "llvm/Support/Debug.h"
87
88using namespace llvm;
89
90#define DEBUG_TYPE "divergence"
91
93 const Function &F, const Loop *RegionLoop, const DominatorTree &DT,
94 const LoopInfo &LI, SyncDependenceAnalysis &SDA, bool IsLCSSAForm)
95 : F(F), RegionLoop(RegionLoop), DT(DT), LI(LI), SDA(SDA),
96 IsLCSSAForm(IsLCSSAForm) {}
97
99 if (isAlwaysUniform(DivVal))
100 return false;
101 assert(isa<Instruction>(DivVal) || isa<Argument>(DivVal));
102 assert(!isAlwaysUniform(DivVal) && "cannot be a divergent");
103 return DivergentValues.insert(&DivVal).second;
104}
105
107 UniformOverrides.insert(&UniVal);
108}
109
110bool DivergenceAnalysisImpl::isTemporalDivergent(
111 const BasicBlock &ObservingBlock, const Value &Val) const {
112 const auto *Inst = dyn_cast<const Instruction>(&Val);
113 if (!Inst)
114 return false;
115 // check whether any divergent loop carrying Val terminates before control
116 // proceeds to ObservingBlock
117 for (const auto *Loop = LI.getLoopFor(Inst->getParent());
118 Loop != RegionLoop && !Loop->contains(&ObservingBlock);
119 Loop = Loop->getParentLoop()) {
120 if (DivergentLoops.contains(Loop))
121 return true;
122 }
123
124 return false;
125}
126
128 return I.getParent() && inRegion(*I.getParent());
129}
130
132 return RegionLoop ? RegionLoop->contains(&BB) : (BB.getParent() == &F);
133}
134
135void DivergenceAnalysisImpl::pushUsers(const Value &V) {
136 const auto *I = dyn_cast<const Instruction>(&V);
137
138 if (I && I->isTerminator()) {
139 analyzeControlDivergence(*I);
140 return;
141 }
142
143 for (const auto *User : V.users()) {
144 const auto *UserInst = dyn_cast<const Instruction>(User);
145 if (!UserInst)
146 continue;
147
148 // only compute divergent inside loop
149 if (!inRegion(*UserInst))
150 continue;
151
152 // All users of divergent values are immediate divergent
153 if (markDivergent(*UserInst))
154 Worklist.push_back(UserInst);
155 }
156}
157
159 const Loop &DivLoop) {
160 const auto *I = dyn_cast<const Instruction>(&U);
161 if (!I)
162 return nullptr;
163 if (!DivLoop.contains(I))
164 return nullptr;
165 return I;
166}
167
168void DivergenceAnalysisImpl::analyzeTemporalDivergence(
169 const Instruction &I, const Loop &OuterDivLoop) {
170 if (isAlwaysUniform(I))
171 return;
172 if (isDivergent(I))
173 return;
174
175 LLVM_DEBUG(dbgs() << "Analyze temporal divergence: " << I.getName() << "\n");
176 assert((isa<PHINode>(I) || !IsLCSSAForm) &&
177 "In LCSSA form all users of loop-exiting defs are Phi nodes.");
178 for (const Use &Op : I.operands()) {
179 const auto *OpInst = getIfCarriedInstruction(Op, OuterDivLoop);
180 if (!OpInst)
181 continue;
182 if (markDivergent(I))
183 pushUsers(I);
184 return;
185 }
186}
187
188// marks all users of loop-carried values of the loop headed by LoopHeader as
189// divergent
190void DivergenceAnalysisImpl::analyzeLoopExitDivergence(
191 const BasicBlock &DivExit, const Loop &OuterDivLoop) {
192 // All users are in immediate exit blocks
193 if (IsLCSSAForm) {
194 for (const auto &Phi : DivExit.phis()) {
195 analyzeTemporalDivergence(Phi, OuterDivLoop);
196 }
197 return;
198 }
199
200 // For non-LCSSA we have to follow all live out edges wherever they may lead.
201 const BasicBlock &LoopHeader = *OuterDivLoop.getHeader();
203 TaintStack.push_back(&DivExit);
204
205 // Otherwise potential users of loop-carried values could be anywhere in the
206 // dominance region of DivLoop (including its fringes for phi nodes)
208 Visited.insert(&DivExit);
209
210 do {
211 auto *UserBlock = TaintStack.pop_back_val();
212
213 // don't spread divergence beyond the region
214 if (!inRegion(*UserBlock))
215 continue;
216
217 assert(!OuterDivLoop.contains(UserBlock) &&
218 "irreducible control flow detected");
219
220 // phi nodes at the fringes of the dominance region
221 if (!DT.dominates(&LoopHeader, UserBlock)) {
222 // all PHI nodes of UserBlock become divergent
223 for (const auto &Phi : UserBlock->phis()) {
224 analyzeTemporalDivergence(Phi, OuterDivLoop);
225 }
226 continue;
227 }
228
229 // Taint outside users of values carried by OuterDivLoop.
230 for (const auto &I : *UserBlock) {
231 analyzeTemporalDivergence(I, OuterDivLoop);
232 }
233
234 // visit all blocks in the dominance region
235 for (const auto *SuccBlock : successors(UserBlock)) {
236 if (!Visited.insert(SuccBlock).second) {
237 continue;
238 }
239 TaintStack.push_back(SuccBlock);
240 }
241 } while (!TaintStack.empty());
242}
243
244void DivergenceAnalysisImpl::propagateLoopExitDivergence(
245 const BasicBlock &DivExit, const Loop &InnerDivLoop) {
246 LLVM_DEBUG(dbgs() << "\tpropLoopExitDiv " << DivExit.getName() << "\n");
247
248 // Find outer-most loop that does not contain \p DivExit
249 const Loop *DivLoop = &InnerDivLoop;
250 const Loop *OuterDivLoop = DivLoop;
251 const Loop *ExitLevelLoop = LI.getLoopFor(&DivExit);
252 const unsigned LoopExitDepth =
253 ExitLevelLoop ? ExitLevelLoop->getLoopDepth() : 0;
254 while (DivLoop && DivLoop->getLoopDepth() > LoopExitDepth) {
255 DivergentLoops.insert(DivLoop); // all crossed loops are divergent
256 OuterDivLoop = DivLoop;
257 DivLoop = DivLoop->getParentLoop();
258 }
259 LLVM_DEBUG(dbgs() << "\tOuter-most left loop: " << OuterDivLoop->getName()
260 << "\n");
261
262 analyzeLoopExitDivergence(DivExit, *OuterDivLoop);
263}
264
265// this is a divergent join point - mark all phi nodes as divergent and push
266// them onto the stack.
267void DivergenceAnalysisImpl::taintAndPushPhiNodes(const BasicBlock &JoinBlock) {
268 LLVM_DEBUG(dbgs() << "taintAndPushPhiNodes in " << JoinBlock.getName()
269 << "\n");
270
271 // ignore divergence outside the region
272 if (!inRegion(JoinBlock)) {
273 return;
274 }
275
276 // push non-divergent phi nodes in JoinBlock to the worklist
277 for (const auto &Phi : JoinBlock.phis()) {
278 if (isDivergent(Phi))
279 continue;
280 // FIXME Theoretically ,the 'undef' value could be replaced by any other
281 // value causing spurious divergence.
282 if (Phi.hasConstantOrUndefValue())
283 continue;
284 if (markDivergent(Phi))
285 Worklist.push_back(&Phi);
286 }
287}
288
289void DivergenceAnalysisImpl::analyzeControlDivergence(const Instruction &Term) {
290 LLVM_DEBUG(dbgs() << "analyzeControlDiv " << Term.getParent()->getName()
291 << "\n");
292
293 // Don't propagate divergence from unreachable blocks.
294 if (!DT.isReachableFromEntry(Term.getParent()))
295 return;
296
297 const auto *BranchLoop = LI.getLoopFor(Term.getParent());
298
299 const auto &DivDesc = SDA.getJoinBlocks(Term);
300
301 // Iterate over all blocks now reachable by a disjoint path join
302 for (const auto *JoinBlock : DivDesc.JoinDivBlocks) {
303 taintAndPushPhiNodes(*JoinBlock);
304 }
305
306 assert(DivDesc.LoopDivBlocks.empty() || BranchLoop);
307 for (const auto *DivExitBlock : DivDesc.LoopDivBlocks) {
308 propagateLoopExitDivergence(*DivExitBlock, *BranchLoop);
309 }
310}
311
313 // Initialize worklist.
314 auto DivValuesCopy = DivergentValues;
315 for (const auto *DivVal : DivValuesCopy) {
316 assert(isDivergent(*DivVal) && "Worklist invariant violated!");
317 pushUsers(*DivVal);
318 }
319
320 // All values on the Worklist are divergent.
321 // Their users may not have been updated yed.
322 while (!Worklist.empty()) {
323 const Instruction &I = *Worklist.back();
324 Worklist.pop_back();
325
326 // propagate value divergence to users
327 assert(isDivergent(I) && "Worklist invariant violated!");
328 pushUsers(I);
329 }
330}
331
333 return UniformOverrides.contains(&V);
334}
335
337 return DivergentValues.contains(&V);
338}
339
341 Value &V = *U.get();
342 Instruction &I = *cast<Instruction>(U.getUser());
343 return isDivergent(V) || isTemporalDivergent(*I.getParent(), V);
344}
345
347 const PostDominatorTree &PDT, const LoopInfo &LI,
349 bool KnownReducible)
350 : F(F) {
351 if (!KnownReducible) {
353 RPOTraversal FuncRPOT(&F);
354 if (containsIrreducibleCFG<const BasicBlock *, const RPOTraversal,
355 const LoopInfo>(FuncRPOT, LI)) {
356 ContainsIrreducible = true;
357 return;
358 }
359 }
360 SDA = std::make_unique<SyncDependenceAnalysis>(DT, PDT, LI);
361 DA = std::make_unique<DivergenceAnalysisImpl>(F, nullptr, DT, LI, *SDA,
362 /* LCSSA */ false);
363 for (auto &I : instructions(F)) {
364 if (TTI.isSourceOfDivergence(&I)) {
365 DA->markDivergent(I);
366 } else if (TTI.isAlwaysUniform(&I)) {
367 DA->addUniformOverride(I);
368 }
369 }
370 for (auto &Arg : F.args()) {
372 DA->markDivergent(Arg);
373 }
374 }
375
376 DA->compute();
377}
378
379AnalysisKey DivergenceAnalysis::Key;
380
383 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
384 auto &PDT = AM.getResult<PostDominatorTreeAnalysis>(F);
385 auto &LI = AM.getResult<LoopAnalysis>(F);
386 auto &TTI = AM.getResult<TargetIRAnalysis>(F);
387
388 return DivergenceInfo(F, DT, PDT, LI, TTI, /* KnownReducible = */ false);
389}
390
393 auto &DI = FAM.getResult<DivergenceAnalysis>(F);
394 OS << "'Divergence Analysis' for function '" << F.getName() << "':\n";
395 if (DI.hasDivergence()) {
396 for (auto &Arg : F.args()) {
397 OS << (DI.isDivergent(Arg) ? "DIVERGENT: " : " ");
398 OS << Arg << "\n";
399 }
400 for (const BasicBlock &BB : F) {
401 OS << "\n " << BB.getName() << ":\n";
402 for (const auto &I : BB.instructionsWithoutDebug()) {
403 OS << (DI.isDivergent(I) ? "DIVERGENT: " : " ");
404 OS << I << "\n";
405 }
406 }
407 }
408 return PreservedAnalyses::all();
409}
amdgpu Simplify well known AMD library false FunctionCallee Value * Arg
#define LLVM_DEBUG(X)
Definition: Debug.h:101
static const Instruction * getIfCarriedInstruction(const Use &U, const Loop &DivLoop)
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
print must be executed print the must be executed context for all instructions
FunctionAnalysisManager FAM
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This pass exposes codegen information to IR-level passes.
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:620
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:774
LLVM Basic Block Representation.
Definition: BasicBlock.h:56
iterator_range< const_phi_iterator > phis() const
Returns a range that iterates over the phis in the basic block.
Definition: BasicBlock.h:372
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:112
Implements a dense probed hash-table based set.
Definition: DenseSet.h:271
bool inRegion(const BasicBlock &BB) const
Whether BB is part of the region.
bool isAlwaysUniform(const Value &Val) const
Whether Val will always return a uniform value regardless of its operands.
void addUniformOverride(const Value &UniVal)
Mark UniVal as a value that is always uniform.
void compute()
Propagate divergence to all instructions in the region.
bool isDivergentUse(const Use &U) const
Whether U is divergent.
bool markDivergent(const Value &DivVal)
Mark DivVal as a value that is always divergent.
bool isDivergent(const Value &Val) const
Whether Val is divergent at its definition.
DivergenceAnalysisImpl(const Function &F, const Loop *RegionLoop, const DominatorTree &DT, const LoopInfo &LI, SyncDependenceAnalysis &SDA, bool IsLCSSAForm)
This instance will analyze the whole function F or the loop RegionLoop.
Divergence analysis frontend for GPU kernels.
Result run(Function &F, FunctionAnalysisManager &AM)
Runs the divergence analysis on @F, a GPU kernel.
DivergenceInfo(Function &F, const DominatorTree &DT, const PostDominatorTree &PDT, const LoopInfo &LI, const TargetTransformInfo &TTI, bool KnownReducible)
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:279
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:166
bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
Definition: Dominators.cpp:321
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:122
Analysis pass that exposes the LoopInfo for a function.
Definition: LoopInfo.h:1268
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
Definition: LoopInfo.h:139
BlockT * getHeader() const
Definition: LoopInfo.h:105
unsigned getLoopDepth() const
Return the nesting level of this loop.
Definition: LoopInfo.h:97
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
Definition: LoopInfo.h:114
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
Definition: LoopInfo.h:992
Represents a single loop in the control flow graph.
Definition: LoopInfo.h:547
StringRef getName() const
Definition: LoopInfo.h:891
Analysis pass which computes a PostDominatorTree.
PostDominatorTree Class - Concrete subclass of DominatorTree that is used to compute the post-dominat...
A set of analyses that are preserved following a run of a transformation pass.
Definition: PassManager.h:152
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: PassManager.h:158
bool empty() const
Definition: SmallVector.h:94
void push_back(const T &Elt)
Definition: SmallVector.h:416
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1200
Relates points of divergent control to join points in reducible CFGs.
const ControlDivergenceDesc & getJoinBlocks(const Instruction &Term)
Computes divergent join points and loop exits caused by branch divergence in Term.
Analysis pass providing the TargetTransformInfo.
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
bool isAlwaysUniform(const Value *V) const
bool isSourceOfDivergence(const Value *V) const
Returns whether V is a source of divergence.
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
LLVM Value Representation.
Definition: Value.h:74
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:308
std::pair< iterator, bool > insert(const ValueT &V)
Definition: DenseSet.h:206
bool contains(const_arg_type_t< ValueT > V) const
Check if the set contains the given element.
Definition: DenseSet.h:185
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
auto successors(const MachineBasicBlock *BB)
bool containsIrreducibleCFG(RPOTraversalT &RPOTraversal, const LoopInfoT &LI)
Return true if the control flow in RPOTraversal is irreducible.
Definition: CFG.h:136
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition: PassManager.h:69
PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM)