Line data Source code
1 : //===- llvm/Analysis/DivergenceAnalysis.h - Divergence Analysis -*- C++ -*-===//
2 : //
3 : // The LLVM Compiler Infrastructure
4 : //
5 : // This file is distributed under the University of Illinois Open Source
6 : // License. See LICENSE.TXT for details.
7 : //
8 : //===----------------------------------------------------------------------===//
9 : //
10 : // \file
11 : // The divergence analysis determines which instructions and branches are
12 : // divergent given a set of divergent source instructions.
13 : //
14 : //===----------------------------------------------------------------------===//
15 :
16 : #ifndef LLVM_ANALYSIS_DIVERGENCE_ANALYSIS_H
17 : #define LLVM_ANALYSIS_DIVERGENCE_ANALYSIS_H
18 :
19 : #include "llvm/ADT/DenseSet.h"
20 : #include "llvm/Analysis/SyncDependenceAnalysis.h"
21 : #include "llvm/IR/Function.h"
22 : #include "llvm/Pass.h"
23 : #include <vector>
24 :
25 : namespace llvm {
26 : class Module;
27 : class Value;
28 : class Instruction;
29 : class Loop;
30 : class raw_ostream;
31 : class TargetTransformInfo;
32 :
33 : /// \brief Generic divergence analysis for reducible CFGs.
34 : ///
35 : /// This analysis propagates divergence in a data-parallel context from sources
36 : /// of divergence to all users. It requires reducible CFGs. All assignments
37 : /// should be in SSA form.
38 : class DivergenceAnalysis {
39 : public:
40 : /// \brief This instance will analyze the whole function \p F or the loop \p
41 : /// RegionLoop.
42 : ///
43 : /// \param RegionLoop if non-null the analysis is restricted to \p RegionLoop.
44 : /// Otherwise the whole function is analyzed.
45 : /// \param IsLCSSAForm whether the analysis may assume that the IR in the
46 : /// region in in LCSSA form.
47 : DivergenceAnalysis(const Function &F, const Loop *RegionLoop,
48 : const DominatorTree &DT, const LoopInfo &LI,
49 : SyncDependenceAnalysis &SDA, bool IsLCSSAForm);
50 :
51 : /// \brief The loop that defines the analyzed region (if any).
52 0 : const Loop *getRegionLoop() const { return RegionLoop; }
53 : const Function &getFunction() const { return F; }
54 :
55 : /// \brief Whether \p BB is part of the region.
56 : bool inRegion(const BasicBlock &BB) const;
57 : /// \brief Whether \p I is part of the region.
58 : bool inRegion(const Instruction &I) const;
59 :
60 : /// \brief Mark \p UniVal as a value that is always uniform.
61 : void addUniformOverride(const Value &UniVal);
62 :
63 : /// \brief Mark \p DivVal as a value that is always divergent.
64 : void markDivergent(const Value &DivVal);
65 :
66 : /// \brief Propagate divergence to all instructions in the region.
67 : /// Divergence is seeded by calls to \p markDivergent.
68 : void compute();
69 :
70 : /// \brief Whether any value was marked or analyzed to be divergent.
71 5 : bool hasDetectedDivergence() const { return !DivergentValues.empty(); }
72 :
73 : /// \brief Whether \p Val will always return a uniform value regardless of its
74 : /// operands
75 : bool isAlwaysUniform(const Value &Val) const;
76 :
77 : /// \brief Whether \p Val is a divergent value
78 : bool isDivergent(const Value &Val) const;
79 :
80 : void print(raw_ostream &OS, const Module *) const;
81 :
82 : private:
83 : bool updateTerminator(const TerminatorInst &Term) const;
84 : bool updatePHINode(const PHINode &Phi) const;
85 :
86 : /// \brief Computes whether \p Inst is divergent based on the
87 : /// divergence of its operands.
88 : ///
89 : /// \returns Whether \p Inst is divergent.
90 : ///
91 : /// This should only be called for non-phi, non-terminator instructions.
92 : bool updateNormalInstruction(const Instruction &Inst) const;
93 :
94 : /// \brief Mark users of live-out users as divergent.
95 : ///
96 : /// \param LoopHeader the header of the divergent loop.
97 : ///
98 : /// Marks all users of live-out values of the loop headed by \p LoopHeader
99 : /// as divergent and puts them on the worklist.
100 : void taintLoopLiveOuts(const BasicBlock &LoopHeader);
101 :
102 : /// \brief Push all users of \p Val (in the region) to the worklist
103 : void pushUsers(const Value &I);
104 :
105 : /// \brief Push all phi nodes in @block to the worklist
106 : void pushPHINodes(const BasicBlock &Block);
107 :
108 : /// \brief Mark \p Block as join divergent
109 : ///
110 : /// A block is join divergent if two threads may reach it from different
111 : /// incoming blocks at the same time.
112 : void markBlockJoinDivergent(const BasicBlock &Block) {
113 11 : DivergentJoinBlocks.insert(&Block);
114 : }
115 :
116 : /// \brief Whether \p Val is divergent when read in \p ObservingBlock.
117 : bool isTemporalDivergent(const BasicBlock &ObservingBlock,
118 : const Value &Val) const;
119 :
120 : /// \brief Whether \p Block is join divergent
121 : ///
122 : /// (see markBlockJoinDivergent).
123 : bool isJoinDivergent(const BasicBlock &Block) const {
124 : return DivergentJoinBlocks.find(&Block) != DivergentJoinBlocks.end();
125 : }
126 :
127 : /// \brief Propagate control-induced divergence to users (phi nodes and
128 : /// instructions).
129 : //
130 : // \param JoinBlock is a divergent loop exit or join point of two disjoint
131 : // paths.
132 : // \returns Whether \p JoinBlock is a divergent loop exit of \p TermLoop.
133 : bool propagateJoinDivergence(const BasicBlock &JoinBlock,
134 : const Loop *TermLoop);
135 :
136 : /// \brief Propagate induced value divergence due to control divergence in \p
137 : /// Term.
138 : void propagateBranchDivergence(const TerminatorInst &Term);
139 :
140 : /// \brief Propagate divergent caused by a divergent loop exit.
141 : ///
142 : /// \param ExitingLoop is a divergent loop.
143 : void propagateLoopDivergence(const Loop &ExitingLoop);
144 :
145 : private:
146 : const Function &F;
147 : // If regionLoop != nullptr, analysis is only performed within \p RegionLoop.
148 : // Otw, analyze the whole function
149 : const Loop *RegionLoop;
150 :
151 : const DominatorTree &DT;
152 : const LoopInfo &LI;
153 :
154 : // Recognized divergent loops
155 : DenseSet<const Loop *> DivergentLoops;
156 :
157 : // The SDA links divergent branches to divergent control-flow joins.
158 : SyncDependenceAnalysis &SDA;
159 :
160 : // Use simplified code path for LCSSA form.
161 : bool IsLCSSAForm;
162 :
163 : // Set of known-uniform values.
164 : DenseSet<const Value *> UniformOverrides;
165 :
166 : // Blocks with joining divergent control from different predecessors.
167 : DenseSet<const BasicBlock *> DivergentJoinBlocks;
168 :
169 : // Detected/marked divergent values.
170 : DenseSet<const Value *> DivergentValues;
171 :
172 : // Internal worklist for divergence propagation.
173 : std::vector<const Instruction *> Worklist;
174 : };
175 :
176 : } // namespace llvm
177 :
178 : #endif // LLVM_ANALYSIS_DIVERGENCE_ANALYSIS_H
|