LLVM  7.0.0svn
AMDGPUUnifyDivergentExitNodes.cpp
Go to the documentation of this file.
1 //===- AMDGPUUnifyDivergentExitNodes.cpp ----------------------------------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This is a variant of the UnifyDivergentExitNodes pass. Rather than ensuring
11 // there is at most one ret and one unreachable instruction, it ensures there is
12 // at most one divergent exiting block.
13 //
14 // StructurizeCFG can't deal with multi-exit regions formed by branches to
15 // multiple return nodes. It is not desirable to structurize regions with
16 // uniform branches, so unifying those to the same return block as divergent
17 // branches inhibits use of scalar branching. It still can't deal with the case
18 // where one branch goes to return, and one unreachable. Replace unreachable in
19 // this case with a return.
20 //
21 //===----------------------------------------------------------------------===//
22 
23 #include "AMDGPU.h"
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringRef.h"
32 #include "llvm/IR/BasicBlock.h"
33 #include "llvm/IR/CFG.h"
34 #include "llvm/IR/Constants.h"
35 #include "llvm/IR/Function.h"
36 #include "llvm/IR/InstrTypes.h"
37 #include "llvm/IR/Instructions.h"
38 #include "llvm/IR/Intrinsics.h"
39 #include "llvm/IR/Type.h"
40 #include "llvm/Pass.h"
41 #include "llvm/Support/Casting.h"
42 #include "llvm/Transforms/Scalar.h"
43 #include "llvm/Transforms/Utils.h"
44 
45 using namespace llvm;
46 
47 #define DEBUG_TYPE "amdgpu-unify-divergent-exit-nodes"
48 
49 namespace {
50 
51 class AMDGPUUnifyDivergentExitNodes : public FunctionPass {
52 public:
53  static char ID; // Pass identification, replacement for typeid
54 
55  AMDGPUUnifyDivergentExitNodes() : FunctionPass(ID) {
57  }
58 
59  // We can preserve non-critical-edgeness when we unify function exit nodes
60  void getAnalysisUsage(AnalysisUsage &AU) const override;
61  bool runOnFunction(Function &F) override;
62 };
63 
64 } // end anonymous namespace
65 
67 
69 
70 INITIALIZE_PASS_BEGIN(AMDGPUUnifyDivergentExitNodes, DEBUG_TYPE,
71  "Unify divergent function exit nodes", false, false)
74 INITIALIZE_PASS_END(AMDGPUUnifyDivergentExitNodes, DEBUG_TYPE,
75  "Unify divergent function exit nodes", false, false)
76 
77 void AMDGPUUnifyDivergentExitNodes::getAnalysisUsage(AnalysisUsage &AU) const{
78  // TODO: Preserve dominator tree.
79  AU.addRequired<PostDominatorTreeWrapperPass>();
80 
81  AU.addRequired<DivergenceAnalysis>();
82 
83  // No divergent values are changed, only blocks and branch edges.
84  AU.addPreserved<DivergenceAnalysis>();
85 
86  // We preserve the non-critical-edgeness property
87  AU.addPreservedID(BreakCriticalEdgesID);
88 
89  // This is a cluster of orthogonal Transforms
90  AU.addPreservedID(LowerSwitchID);
92 
93  AU.addRequired<TargetTransformInfoWrapperPass>();
94 }
95 
96 /// \returns true if \p BB is reachable through only uniform branches.
97 /// XXX - Is there a more efficient way to find this?
98 static bool isUniformlyReached(const DivergenceAnalysis &DA,
99  BasicBlock &BB) {
102 
103  for (BasicBlock *Pred : predecessors(&BB))
104  Stack.push_back(Pred);
105 
106  while (!Stack.empty()) {
107  BasicBlock *Top = Stack.pop_back_val();
108  if (!DA.isUniform(Top->getTerminator()))
109  return false;
110 
111  for (BasicBlock *Pred : predecessors(Top)) {
112  if (Visited.insert(Pred).second)
113  Stack.push_back(Pred);
114  }
115  }
116 
117  return true;
118 }
119 
121  ArrayRef<BasicBlock *> ReturningBlocks,
122  const TargetTransformInfo &TTI,
123  StringRef Name) {
124  // Otherwise, we need to insert a new basic block into the function, add a PHI
125  // nodes (if the function returns values), and convert all of the return
126  // instructions into unconditional branches.
127  BasicBlock *NewRetBlock = BasicBlock::Create(F.getContext(), Name, &F);
128 
129  PHINode *PN = nullptr;
130  if (F.getReturnType()->isVoidTy()) {
131  ReturnInst::Create(F.getContext(), nullptr, NewRetBlock);
132  } else {
133  // If the function doesn't return void... add a PHI node to the block...
134  PN = PHINode::Create(F.getReturnType(), ReturningBlocks.size(),
135  "UnifiedRetVal");
136  NewRetBlock->getInstList().push_back(PN);
137  ReturnInst::Create(F.getContext(), PN, NewRetBlock);
138  }
139 
140  // Loop over all of the blocks, replacing the return instruction with an
141  // unconditional branch.
142  for (BasicBlock *BB : ReturningBlocks) {
143  // Add an incoming element to the PHI node for every return instruction that
144  // is merging into this new block...
145  if (PN)
146  PN->addIncoming(BB->getTerminator()->getOperand(0), BB);
147 
148  // Remove and delete the return inst.
149  BB->getTerminator()->eraseFromParent();
150  BranchInst::Create(NewRetBlock, BB);
151  }
152 
153  for (BasicBlock *BB : ReturningBlocks) {
154  // Cleanup possible branch to unconditional branch to the return.
155  simplifyCFG(BB, TTI, {2});
156  }
157 
158  return NewRetBlock;
159 }
160 
162  auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
163  if (PDT.getRoots().size() <= 1)
164  return false;
165 
166  DivergenceAnalysis &DA = getAnalysis<DivergenceAnalysis>();
167 
168  // Loop over all of the blocks in a function, tracking all of the blocks that
169  // return.
170  SmallVector<BasicBlock *, 4> ReturningBlocks;
171  SmallVector<BasicBlock *, 4> UnreachableBlocks;
172 
173  // Dummy return block for infinite loop.
174  BasicBlock *DummyReturnBB = nullptr;
175 
176  for (BasicBlock *BB : PDT.getRoots()) {
177  if (isa<ReturnInst>(BB->getTerminator())) {
178  if (!isUniformlyReached(DA, *BB))
179  ReturningBlocks.push_back(BB);
180  } else if (isa<UnreachableInst>(BB->getTerminator())) {
181  if (!isUniformlyReached(DA, *BB))
182  UnreachableBlocks.push_back(BB);
183  } else if (BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator())) {
184 
186  if (DummyReturnBB == nullptr) {
187  DummyReturnBB = BasicBlock::Create(F.getContext(),
188  "DummyReturnBlock", &F);
189  Type *RetTy = F.getReturnType();
190  Value *RetVal = RetTy->isVoidTy() ? nullptr : UndefValue::get(RetTy);
191  ReturnInst::Create(F.getContext(), RetVal, DummyReturnBB);
192  ReturningBlocks.push_back(DummyReturnBB);
193  }
194 
195  if (BI->isUnconditional()) {
196  BasicBlock *LoopHeaderBB = BI->getSuccessor(0);
197  BI->eraseFromParent(); // Delete the unconditional branch.
198  // Add a new conditional branch with a dummy edge to the return block.
199  BranchInst::Create(LoopHeaderBB, DummyReturnBB, BoolTrue, BB);
200  } else { // Conditional branch.
201  // Create a new transition block to hold the conditional branch.
202  BasicBlock *TransitionBB = BasicBlock::Create(F.getContext(),
203  "TransitionBlock", &F);
204 
205  // Move BI from BB to the new transition block.
206  BI->removeFromParent();
207  TransitionBB->getInstList().push_back(BI);
208 
209  // Create a branch that will always branch to the transition block.
210  BranchInst::Create(TransitionBB, DummyReturnBB, BoolTrue, BB);
211  }
212  }
213  }
214 
215  if (!UnreachableBlocks.empty()) {
216  BasicBlock *UnreachableBlock = nullptr;
217 
218  if (UnreachableBlocks.size() == 1) {
219  UnreachableBlock = UnreachableBlocks.front();
220  } else {
221  UnreachableBlock = BasicBlock::Create(F.getContext(),
222  "UnifiedUnreachableBlock", &F);
223  new UnreachableInst(F.getContext(), UnreachableBlock);
224 
225  for (BasicBlock *BB : UnreachableBlocks) {
226  // Remove and delete the unreachable inst.
227  BB->getTerminator()->eraseFromParent();
228  BranchInst::Create(UnreachableBlock, BB);
229  }
230  }
231 
232  if (!ReturningBlocks.empty()) {
233  // Don't create a new unreachable inst if we have a return. The
234  // structurizer/annotator can't handle the multiple exits
235 
236  Type *RetTy = F.getReturnType();
237  Value *RetVal = RetTy->isVoidTy() ? nullptr : UndefValue::get(RetTy);
238  // Remove and delete the unreachable inst.
239  UnreachableBlock->getTerminator()->eraseFromParent();
240 
241  Function *UnreachableIntrin =
242  Intrinsic::getDeclaration(F.getParent(), Intrinsic::amdgcn_unreachable);
243 
244  // Insert a call to an intrinsic tracking that this is an unreachable
245  // point, in case we want to kill the active lanes or something later.
246  CallInst::Create(UnreachableIntrin, {}, "", UnreachableBlock);
247 
248  // Don't create a scalar trap. We would only want to trap if this code was
249  // really reached, but a scalar trap would happen even if no lanes
250  // actually reached here.
251  ReturnInst::Create(F.getContext(), RetVal, UnreachableBlock);
252  ReturningBlocks.push_back(UnreachableBlock);
253  }
254  }
255 
256  // Now handle return blocks.
257  if (ReturningBlocks.empty())
258  return false; // No blocks return
259 
260  if (ReturningBlocks.size() == 1)
261  return false; // Already has a single return block
262 
263  const TargetTransformInfo &TTI
264  = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
265 
266  unifyReturnBlockSet(F, ReturningBlocks, TTI, "UnifiedReturnBlock");
267  return true;
268 }
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks &#39;this&#39; from the containing basic block and deletes it.
Definition: Instruction.cpp:68
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Compute iterated dominance frontiers using a linear time algorithm.
Definition: AllocatorList.h:24
LLVM_ATTRIBUTE_ALWAYS_INLINE size_type size() const
Definition: SmallVector.h:137
F(f)
static CallInst * Create(Value *Func, ArrayRef< Value *> Args, ArrayRef< OperandBundleDef > Bundles=None, const Twine &NameStr="", Instruction *InsertBefore=nullptr)
static ReturnInst * Create(LLVMContext &C, Value *retVal=nullptr, Instruction *InsertBefore=nullptr)
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:51
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:92
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: APInt.h:33
Function * getDeclaration(Module *M, ID id, ArrayRef< Type *> Tys=None)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
Definition: Function.cpp:1007
bool isUniform(const Value *V) const
bool isVoidTy() const
Return true if this is &#39;void&#39;.
Definition: Type.h:141
static bool runOnFunction(Function &F, bool PostInlining)
Type * getReturnType() const
Returns the type of the ret val.
Definition: Function.h:155
Wrapper pass for TargetTransformInfo.
LLVM Basic Block Representation.
Definition: BasicBlock.h:59
The instances of the Type class are immutable: once they are created, they are never changed...
Definition: Type.h:46
Conditional or Unconditional Branch instruction.
char & BreakCriticalEdgesID
size_t size() const
size - Get the array size.
Definition: ArrayRef.h:149
This function has undefined behavior.
This file contains the declarations for the subclasses of Constant, which represent the different fla...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:371
Represent the analysis usage information of a pass.
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:285
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition: BasicBlock.h:101
LLVMContext & getContext() const
getContext - Return a reference to the LLVMContext associated with this function. ...
Definition: Function.cpp:194
char & LowerSwitchID
static UndefValue * get(Type *T)
Static factory methods - Return an &#39;undef&#39; object of the specified type.
Definition: Constants.cpp:1392
INITIALIZE_PASS_END(RegBankSelect, DEBUG_TYPE, "Assign register bank of generic virtual registers", false, false) RegBankSelect
static bool isUniformlyReached(const DivergenceAnalysis &DA, BasicBlock &BB)
const InstListType & getInstList() const
Return the underlying instruction list container.
Definition: BasicBlock.h:329
void removeFromParent()
Unlink &#39;this&#39; from the containing function, but do not delete it.
Definition: BasicBlock.cpp:111
INITIALIZE_PASS_BEGIN(AMDGPUUnifyDivergentExitNodes, DEBUG_TYPE, "Unify divergent function exit nodes", false, false) INITIALIZE_PASS_END(AMDGPUUnifyDivergentExitNodes
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements...
Definition: SmallPtrSet.h:418
This is the shared class of boolean and integer constants.
Definition: Constants.h:84
char & AMDGPUUnifyDivergentExitNodesID
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
This is a &#39;vector&#39; (really, a variable-sized array), optimized for the case when the array is small...
Definition: SmallVector.h:861
LLVM_NODISCARD T pop_back_val()
Definition: SmallVector.h:382
static BranchInst * Create(BasicBlock *IfTrue, Instruction *InsertBefore=nullptr)
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", Instruction *InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
pred_range predecessors(BasicBlock *BB)
Definition: CFG.h:113
static ConstantInt * getTrue(LLVMContext &Context)
Definition: Constants.cpp:577
void push_back(pointer val)
Definition: ilist.h:313
iterator_range< typename GraphTraits< GraphType >::nodes_iterator > nodes(const GraphType &G)
Definition: GraphTraits.h:102
LLVM_NODISCARD bool empty() const
Definition: SmallVector.h:62
SymbolTableList< BasicBlock >::iterator eraseFromParent()
Unlink &#39;this&#39; from the containing function and delete it.
Definition: BasicBlock.cpp:115
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:565
LLVM Value Representation.
Definition: Value.h:73
print Print MemDeps of function
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:49
bool simplifyCFG(BasicBlock *BB, const TargetTransformInfo &TTI, const SimplifyCFGOptions &Options={}, SmallPtrSetImpl< BasicBlock *> *LoopHeaders=nullptr)
This function is used to do simplification of a CFG.
This pass exposes codegen information to IR-level passes.
const TerminatorInst * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.cpp:138
void initializeAMDGPUUnifyDivergentExitNodesPass(PassRegistry &)
static BasicBlock * unifyReturnBlockSet(Function &F, ArrayRef< BasicBlock *> ReturningBlocks, const TargetTransformInfo &TTI, StringRef Name)