Line data Source code
1 : //===- AMDGPUUnifyDivergentExitNodes.cpp ----------------------------------===//
2 : //
3 : // The LLVM Compiler Infrastructure
4 : //
5 : // This file is distributed under the University of Illinois Open Source
6 : // License. See LICENSE.TXT for details.
7 : //
8 : //===----------------------------------------------------------------------===//
9 : //
10 : // This is a variant of the UnifyDivergentExitNodes pass. Rather than ensuring
11 : // there is at most one ret and one unreachable instruction, it ensures there is
12 : // at most one divergent exiting block.
13 : //
14 : // StructurizeCFG can't deal with multi-exit regions formed by branches to
15 : // multiple return nodes. It is not desirable to structurize regions with
16 : // uniform branches, so unifying those to the same return block as divergent
17 : // branches inhibits use of scalar branching. It still can't deal with the case
18 : // where one branch goes to return, and one unreachable. Replace unreachable in
19 : // this case with a return.
20 : //
21 : //===----------------------------------------------------------------------===//
22 :
23 : #include "AMDGPU.h"
24 : #include "llvm/ADT/ArrayRef.h"
25 : #include "llvm/ADT/SmallPtrSet.h"
26 : #include "llvm/ADT/SmallVector.h"
27 : #include "llvm/ADT/StringRef.h"
28 : #include "llvm/Analysis/LegacyDivergenceAnalysis.h"
29 : #include "llvm/Analysis/PostDominators.h"
30 : #include "llvm/Analysis/TargetTransformInfo.h"
31 : #include "llvm/Transforms/Utils/Local.h"
32 : #include "llvm/IR/BasicBlock.h"
33 : #include "llvm/IR/CFG.h"
34 : #include "llvm/IR/Constants.h"
35 : #include "llvm/IR/Function.h"
36 : #include "llvm/IR/InstrTypes.h"
37 : #include "llvm/IR/Instructions.h"
38 : #include "llvm/IR/Intrinsics.h"
39 : #include "llvm/IR/Type.h"
40 : #include "llvm/Pass.h"
41 : #include "llvm/Support/Casting.h"
42 : #include "llvm/Transforms/Scalar.h"
43 : #include "llvm/Transforms/Utils.h"
44 :
45 : using namespace llvm;
46 :
47 : #define DEBUG_TYPE "amdgpu-unify-divergent-exit-nodes"
48 :
49 : namespace {
50 :
51 : class AMDGPUUnifyDivergentExitNodes : public FunctionPass {
52 : public:
53 : static char ID; // Pass identification, replacement for typeid
54 :
55 1966 : AMDGPUUnifyDivergentExitNodes() : FunctionPass(ID) {
56 1966 : initializeAMDGPUUnifyDivergentExitNodesPass(*PassRegistry::getPassRegistry());
57 : }
58 :
59 : // We can preserve non-critical-edgeness when we unify function exit nodes
60 : void getAnalysisUsage(AnalysisUsage &AU) const override;
61 : bool runOnFunction(Function &F) override;
62 : };
63 :
64 : } // end anonymous namespace
65 :
66 : char AMDGPUUnifyDivergentExitNodes::ID = 0;
67 :
68 : char &llvm::AMDGPUUnifyDivergentExitNodesID = AMDGPUUnifyDivergentExitNodes::ID;
69 :
70 85105 : INITIALIZE_PASS_BEGIN(AMDGPUUnifyDivergentExitNodes, DEBUG_TYPE,
71 : "Unify divergent function exit nodes", false, false)
72 85105 : INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
73 85105 : INITIALIZE_PASS_DEPENDENCY(LegacyDivergenceAnalysis)
74 200990 : INITIALIZE_PASS_END(AMDGPUUnifyDivergentExitNodes, DEBUG_TYPE,
75 : "Unify divergent function exit nodes", false, false)
76 :
77 1950 : void AMDGPUUnifyDivergentExitNodes::getAnalysisUsage(AnalysisUsage &AU) const{
78 : // TODO: Preserve dominator tree.
79 : AU.addRequired<PostDominatorTreeWrapperPass>();
80 :
81 : AU.addRequired<LegacyDivergenceAnalysis>();
82 :
83 : // No divergent values are changed, only blocks and branch edges.
84 : AU.addPreserved<LegacyDivergenceAnalysis>();
85 :
86 : // We preserve the non-critical-edgeness property
87 1950 : AU.addPreservedID(BreakCriticalEdgesID);
88 :
89 : // This is a cluster of orthogonal Transforms
90 1950 : AU.addPreservedID(LowerSwitchID);
91 1950 : FunctionPass::getAnalysisUsage(AU);
92 :
93 : AU.addRequired<TargetTransformInfoWrapperPass>();
94 1950 : }
95 :
96 : /// \returns true if \p BB is reachable through only uniform branches.
97 : /// XXX - Is there a more efficient way to find this?
98 178 : static bool isUniformlyReached(const LegacyDivergenceAnalysis &DA,
99 : BasicBlock &BB) {
100 : SmallVector<BasicBlock *, 8> Stack;
101 : SmallPtrSet<BasicBlock *, 8> Visited;
102 :
103 409 : for (BasicBlock *Pred : predecessors(&BB))
104 231 : Stack.push_back(Pred);
105 :
106 334 : while (!Stack.empty()) {
107 : BasicBlock *Top = Stack.pop_back_val();
108 : if (!DA.isUniform(Top->getTerminator()))
109 92 : return false;
110 :
111 236 : for (BasicBlock *Pred : predecessors(Top)) {
112 80 : if (Visited.insert(Pred).second)
113 63 : Stack.push_back(Pred);
114 : }
115 : }
116 :
117 : return true;
118 : }
119 :
120 43 : static BasicBlock *unifyReturnBlockSet(Function &F,
121 : ArrayRef<BasicBlock *> ReturningBlocks,
122 : const TargetTransformInfo &TTI,
123 : StringRef Name) {
124 : // Otherwise, we need to insert a new basic block into the function, add a PHI
125 : // nodes (if the function returns values), and convert all of the return
126 : // instructions into unconditional branches.
127 86 : BasicBlock *NewRetBlock = BasicBlock::Create(F.getContext(), Name, &F);
128 :
129 : PHINode *PN = nullptr;
130 43 : if (F.getReturnType()->isVoidTy()) {
131 34 : ReturnInst::Create(F.getContext(), nullptr, NewRetBlock);
132 : } else {
133 : // If the function doesn't return void... add a PHI node to the block...
134 9 : PN = PHINode::Create(F.getReturnType(), ReturningBlocks.size(),
135 : "UnifiedRetVal");
136 9 : NewRetBlock->getInstList().push_back(PN);
137 9 : ReturnInst::Create(F.getContext(), PN, NewRetBlock);
138 : }
139 :
140 : // Loop over all of the blocks, replacing the return instruction with an
141 : // unconditional branch.
142 131 : for (BasicBlock *BB : ReturningBlocks) {
143 : // Add an incoming element to the PHI node for every return instruction that
144 : // is merging into this new block...
145 88 : if (PN)
146 36 : PN->addIncoming(BB->getTerminator()->getOperand(0), BB);
147 :
148 : // Remove and delete the return inst.
149 88 : BB->getTerminator()->eraseFromParent();
150 88 : BranchInst::Create(NewRetBlock, BB);
151 : }
152 :
153 131 : for (BasicBlock *BB : ReturningBlocks) {
154 : // Cleanup possible branch to unconditional branch to the return.
155 88 : simplifyCFG(BB, TTI, {2});
156 : }
157 :
158 43 : return NewRetBlock;
159 : }
160 :
161 19390 : bool AMDGPUUnifyDivergentExitNodes::runOnFunction(Function &F) {
162 19390 : auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
163 19390 : if (PDT.getRoots().size() <= 1)
164 : return false;
165 :
166 93 : LegacyDivergenceAnalysis &DA = getAnalysis<LegacyDivergenceAnalysis>();
167 :
168 : // Loop over all of the blocks in a function, tracking all of the blocks that
169 : // return.
170 : SmallVector<BasicBlock *, 4> ReturningBlocks;
171 : SmallVector<BasicBlock *, 4> UnreachableBlocks;
172 :
173 : // Dummy return block for infinite loop.
174 93 : BasicBlock *DummyReturnBB = nullptr;
175 :
176 288 : for (BasicBlock *BB : PDT.getRoots()) {
177 195 : if (isa<ReturnInst>(BB->getTerminator())) {
178 125 : if (!isUniformlyReached(DA, *BB))
179 65 : ReturningBlocks.push_back(BB);
180 70 : } else if (isa<UnreachableInst>(BB->getTerminator())) {
181 53 : if (!isUniformlyReached(DA, *BB))
182 27 : UnreachableBlocks.push_back(BB);
183 : } else if (BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator())) {
184 :
185 17 : ConstantInt *BoolTrue = ConstantInt::getTrue(F.getContext());
186 17 : if (DummyReturnBB == nullptr) {
187 30 : DummyReturnBB = BasicBlock::Create(F.getContext(),
188 : "DummyReturnBlock", &F);
189 : Type *RetTy = F.getReturnType();
190 15 : Value *RetVal = RetTy->isVoidTy() ? nullptr : UndefValue::get(RetTy);
191 15 : ReturnInst::Create(F.getContext(), RetVal, DummyReturnBB);
192 15 : ReturningBlocks.push_back(DummyReturnBB);
193 : }
194 :
195 17 : if (BI->isUnconditional()) {
196 : BasicBlock *LoopHeaderBB = BI->getSuccessor(0);
197 15 : BI->eraseFromParent(); // Delete the unconditional branch.
198 : // Add a new conditional branch with a dummy edge to the return block.
199 15 : BranchInst::Create(LoopHeaderBB, DummyReturnBB, BoolTrue, BB);
200 : } else { // Conditional branch.
201 : // Create a new transition block to hold the conditional branch.
202 2 : BasicBlock *TransitionBB = BasicBlock::Create(F.getContext(),
203 : "TransitionBlock", &F);
204 :
205 : // Move BI from BB to the new transition block.
206 2 : BI->removeFromParent();
207 2 : TransitionBB->getInstList().push_back(BI);
208 :
209 : // Create a branch that will always branch to the transition block.
210 2 : BranchInst::Create(TransitionBB, DummyReturnBB, BoolTrue, BB);
211 : }
212 : }
213 : }
214 :
215 93 : if (!UnreachableBlocks.empty()) {
216 20 : BasicBlock *UnreachableBlock = nullptr;
217 :
218 20 : if (UnreachableBlocks.size() == 1) {
219 14 : UnreachableBlock = UnreachableBlocks.front();
220 : } else {
221 6 : UnreachableBlock = BasicBlock::Create(F.getContext(),
222 : "UnifiedUnreachableBlock", &F);
223 12 : new UnreachableInst(F.getContext(), UnreachableBlock);
224 :
225 19 : for (BasicBlock *BB : UnreachableBlocks) {
226 : // Remove and delete the unreachable inst.
227 13 : BB->getTerminator()->eraseFromParent();
228 13 : BranchInst::Create(UnreachableBlock, BB);
229 : }
230 : }
231 :
232 20 : if (!ReturningBlocks.empty()) {
233 : // Don't create a new unreachable inst if we have a return. The
234 : // structurizer/annotator can't handle the multiple exits
235 :
236 : Type *RetTy = F.getReturnType();
237 16 : Value *RetVal = RetTy->isVoidTy() ? nullptr : UndefValue::get(RetTy);
238 : // Remove and delete the unreachable inst.
239 32 : UnreachableBlock->getTerminator()->eraseFromParent();
240 :
241 : Function *UnreachableIntrin =
242 16 : Intrinsic::getDeclaration(F.getParent(), Intrinsic::amdgcn_unreachable);
243 :
244 : // Insert a call to an intrinsic tracking that this is an unreachable
245 : // point, in case we want to kill the active lanes or something later.
246 32 : CallInst::Create(UnreachableIntrin, {}, "", UnreachableBlock);
247 :
248 : // Don't create a scalar trap. We would only want to trap if this code was
249 : // really reached, but a scalar trap would happen even if no lanes
250 : // actually reached here.
251 16 : ReturnInst::Create(F.getContext(), RetVal, UnreachableBlock);
252 16 : ReturningBlocks.push_back(UnreachableBlock);
253 : }
254 : }
255 :
256 : // Now handle return blocks.
257 93 : if (ReturningBlocks.empty())
258 : return false; // No blocks return
259 :
260 51 : if (ReturningBlocks.size() == 1)
261 : return false; // Already has a single return block
262 :
263 : const TargetTransformInfo &TTI
264 43 : = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
265 :
266 43 : unifyReturnBlockSet(F, ReturningBlocks, TTI, "UnifiedReturnBlock");
267 43 : return true;
268 : }
|