LLVM  6.0.0svn
AMDGPUUnifyDivergentExitNodes.cpp
Go to the documentation of this file.
1 //===- AMDGPUUnifyDivergentExitNodes.cpp ----------------------------------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This is a variant of the UnifyDivergentExitNodes pass. Rather than ensuring
11 // there is at most one ret and one unreachable instruction, it ensures there is
12 // at most one divergent exiting block.
13 //
14 // StructurizeCFG can't deal with multi-exit regions formed by branches to
15 // multiple return nodes. It is not desirable to structurize regions with
16 // uniform branches, so unifying those to the same return block as divergent
17 // branches inhibits use of scalar branching. It still can't deal with the case
18 // where one branch goes to return, and one unreachable. Replace unreachable in
19 // this case with a return.
20 //
21 //===----------------------------------------------------------------------===//
22 
23 #include "AMDGPU.h"
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringRef.h"
31 #include "llvm/IR/BasicBlock.h"
32 #include "llvm/IR/CFG.h"
33 #include "llvm/IR/Constants.h"
34 #include "llvm/IR/Function.h"
35 #include "llvm/IR/InstrTypes.h"
36 #include "llvm/IR/Instructions.h"
37 #include "llvm/IR/Intrinsics.h"
38 #include "llvm/IR/Type.h"
39 #include "llvm/Pass.h"
40 #include "llvm/Support/Casting.h"
41 #include "llvm/Transforms/Scalar.h"
43 
44 using namespace llvm;
45 
46 #define DEBUG_TYPE "amdgpu-unify-divergent-exit-nodes"
47 
48 namespace {
49 
50 class AMDGPUUnifyDivergentExitNodes : public FunctionPass {
51 public:
52  static char ID; // Pass identification, replacement for typeid
53 
54  AMDGPUUnifyDivergentExitNodes() : FunctionPass(ID) {
56  }
57 
58  // We can preserve non-critical-edgeness when we unify function exit nodes
59  void getAnalysisUsage(AnalysisUsage &AU) const override;
60  bool runOnFunction(Function &F) override;
61 };
62 
63 } // end anonymous namespace
64 
66 
68 
69 INITIALIZE_PASS_BEGIN(AMDGPUUnifyDivergentExitNodes, DEBUG_TYPE,
70  "Unify divergent function exit nodes", false, false)
73 INITIALIZE_PASS_END(AMDGPUUnifyDivergentExitNodes, DEBUG_TYPE,
74  "Unify divergent function exit nodes", false, false)
75 
76 void AMDGPUUnifyDivergentExitNodes::getAnalysisUsage(AnalysisUsage &AU) const{
77  // TODO: Preserve dominator tree.
78  AU.addRequired<PostDominatorTreeWrapperPass>();
79 
80  AU.addRequired<DivergenceAnalysis>();
81 
82  // No divergent values are changed, only blocks and branch edges.
83  AU.addPreserved<DivergenceAnalysis>();
84 
85  // We preserve the non-critical-edgeness property
86  AU.addPreservedID(BreakCriticalEdgesID);
87 
88  // This is a cluster of orthogonal Transforms
89  AU.addPreservedID(LowerSwitchID);
91 
92  AU.addRequired<TargetTransformInfoWrapperPass>();
93 }
94 
95 /// \returns true if \p BB is reachable through only uniform branches.
96 /// XXX - Is there a more efficient way to find this?
97 static bool isUniformlyReached(const DivergenceAnalysis &DA,
98  BasicBlock &BB) {
101 
102  for (BasicBlock *Pred : predecessors(&BB))
103  Stack.push_back(Pred);
104 
105  while (!Stack.empty()) {
106  BasicBlock *Top = Stack.pop_back_val();
107  if (!DA.isUniform(Top->getTerminator()))
108  return false;
109 
110  for (BasicBlock *Pred : predecessors(Top)) {
111  if (Visited.insert(Pred).second)
112  Stack.push_back(Pred);
113  }
114  }
115 
116  return true;
117 }
118 
120  ArrayRef<BasicBlock *> ReturningBlocks,
121  const TargetTransformInfo &TTI,
122  StringRef Name) {
123  // Otherwise, we need to insert a new basic block into the function, add a PHI
124  // nodes (if the function returns values), and convert all of the return
125  // instructions into unconditional branches.
126  BasicBlock *NewRetBlock = BasicBlock::Create(F.getContext(), Name, &F);
127 
128  PHINode *PN = nullptr;
129  if (F.getReturnType()->isVoidTy()) {
130  ReturnInst::Create(F.getContext(), nullptr, NewRetBlock);
131  } else {
132  // If the function doesn't return void... add a PHI node to the block...
133  PN = PHINode::Create(F.getReturnType(), ReturningBlocks.size(),
134  "UnifiedRetVal");
135  NewRetBlock->getInstList().push_back(PN);
136  ReturnInst::Create(F.getContext(), PN, NewRetBlock);
137  }
138 
139  // Loop over all of the blocks, replacing the return instruction with an
140  // unconditional branch.
141  for (BasicBlock *BB : ReturningBlocks) {
142  // Add an incoming element to the PHI node for every return instruction that
143  // is merging into this new block...
144  if (PN)
145  PN->addIncoming(BB->getTerminator()->getOperand(0), BB);
146 
147  BB->getInstList().pop_back(); // Remove the return insn
148  BranchInst::Create(NewRetBlock, BB);
149  }
150 
151  for (BasicBlock *BB : ReturningBlocks) {
152  // Cleanup possible branch to unconditional branch to the return.
153  simplifyCFG(BB, TTI, {2});
154  }
155 
156  return NewRetBlock;
157 }
158 
159 bool AMDGPUUnifyDivergentExitNodes::runOnFunction(Function &F) {
160  auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
161  if (PDT.getRoots().size() <= 1)
162  return false;
163 
164  DivergenceAnalysis &DA = getAnalysis<DivergenceAnalysis>();
165 
166  // Loop over all of the blocks in a function, tracking all of the blocks that
167  // return.
168  SmallVector<BasicBlock *, 4> ReturningBlocks;
169  SmallVector<BasicBlock *, 4> UnreachableBlocks;
170 
171  for (BasicBlock *BB : PDT.getRoots()) {
172  if (isa<ReturnInst>(BB->getTerminator())) {
173  if (!isUniformlyReached(DA, *BB))
174  ReturningBlocks.push_back(BB);
175  } else if (isa<UnreachableInst>(BB->getTerminator())) {
176  if (!isUniformlyReached(DA, *BB))
177  UnreachableBlocks.push_back(BB);
178  }
179  }
180 
181  if (!UnreachableBlocks.empty()) {
182  BasicBlock *UnreachableBlock = nullptr;
183 
184  if (UnreachableBlocks.size() == 1) {
185  UnreachableBlock = UnreachableBlocks.front();
186  } else {
187  UnreachableBlock = BasicBlock::Create(F.getContext(),
188  "UnifiedUnreachableBlock", &F);
189  new UnreachableInst(F.getContext(), UnreachableBlock);
190 
191  for (BasicBlock *BB : UnreachableBlocks) {
192  BB->getInstList().pop_back(); // Remove the unreachable inst.
193  BranchInst::Create(UnreachableBlock, BB);
194  }
195  }
196 
197  if (!ReturningBlocks.empty()) {
198  // Don't create a new unreachable inst if we have a return. The
199  // structurizer/annotator can't handle the multiple exits
200 
201  Type *RetTy = F.getReturnType();
202  Value *RetVal = RetTy->isVoidTy() ? nullptr : UndefValue::get(RetTy);
203  UnreachableBlock->getInstList().pop_back(); // Remove the unreachable inst.
204 
205  Function *UnreachableIntrin =
206  Intrinsic::getDeclaration(F.getParent(), Intrinsic::amdgcn_unreachable);
207 
208  // Insert a call to an intrinsic tracking that this is an unreachable
209  // point, in case we want to kill the active lanes or something later.
210  CallInst::Create(UnreachableIntrin, {}, "", UnreachableBlock);
211 
212  // Don't create a scalar trap. We would only want to trap if this code was
213  // really reached, but a scalar trap would happen even if no lanes
214  // actually reached here.
215  ReturnInst::Create(F.getContext(), RetVal, UnreachableBlock);
216  ReturningBlocks.push_back(UnreachableBlock);
217  }
218  }
219 
220  // Now handle return blocks.
221  if (ReturningBlocks.empty())
222  return false; // No blocks return
223 
224  if (ReturningBlocks.size() == 1)
225  return false; // Already has a single return block
226 
227  const TargetTransformInfo &TTI
228  = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
229 
230  unifyReturnBlockSet(F, ReturningBlocks, TTI, "UnifiedReturnBlock");
231  return true;
232 }
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Compute iterated dominance frontiers using a linear time algorithm.
Definition: AllocatorList.h:24
LLVM_ATTRIBUTE_ALWAYS_INLINE size_type size() const
Definition: SmallVector.h:136
F(f)
static CallInst * Create(Value *Func, ArrayRef< Value *> Args, ArrayRef< OperandBundleDef > Bundles=None, const Twine &NameStr="", Instruction *InsertBefore=nullptr)
static ReturnInst * Create(LLVMContext &C, Value *retVal=nullptr, Instruction *InsertBefore=nullptr)
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:51
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:91
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: APInt.h:33
Function * getDeclaration(Module *M, ID id, ArrayRef< Type *> Tys=None)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
Definition: Function.cpp:980
bool isUniform(const Value *V) const
bool isVoidTy() const
Return true if this is &#39;void&#39;.
Definition: Type.h:141
Type * getReturnType() const
Returns the type of the ret val.
Definition: Function.h:150
Wrapper pass for TargetTransformInfo.
LLVM Basic Block Representation.
Definition: BasicBlock.h:59
The instances of the Type class are immutable: once they are created, they are never changed...
Definition: Type.h:46
char & BreakCriticalEdgesID
size_t size() const
size - Get the array size.
Definition: ArrayRef.h:149
This function has undefined behavior.
This file contains the declarations for the subclasses of Constant, which represent the different fla...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:371
Represent the analysis usage information of a pass.
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:285
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition: BasicBlock.h:101
LLVMContext & getContext() const
getContext - Return a reference to the LLVMContext associated with this function. ...
Definition: Function.cpp:194
char & LowerSwitchID
static UndefValue * get(Type *T)
Static factory methods - Return an &#39;undef&#39; object of the specified type.
Definition: Constants.cpp:1320
INITIALIZE_PASS_END(RegBankSelect, DEBUG_TYPE, "Assign register bank of generic virtual registers", false, false) RegBankSelect
static bool isUniformlyReached(const DivergenceAnalysis &DA, BasicBlock &BB)
const InstListType & getInstList() const
Return the underlying instruction list container.
Definition: BasicBlock.h:317
INITIALIZE_PASS_BEGIN(AMDGPUUnifyDivergentExitNodes, DEBUG_TYPE, "Unify divergent function exit nodes", false, false) INITIALIZE_PASS_END(AMDGPUUnifyDivergentExitNodes
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements...
Definition: SmallPtrSet.h:418
char & AMDGPUUnifyDivergentExitNodesID
This pass provides access to the codegen interfaces that are needed for IR-level transformations.
This is a &#39;vector&#39; (really, a variable-sized array), optimized for the case when the array is small...
Definition: SmallVector.h:864
LLVM_NODISCARD T pop_back_val()
Definition: SmallVector.h:385
static BranchInst * Create(BasicBlock *IfTrue, Instruction *InsertBefore=nullptr)
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", Instruction *InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
pred_range predecessors(BasicBlock *BB)
Definition: CFG.h:110
void push_back(pointer val)
Definition: ilist.h:326
iterator_range< typename GraphTraits< GraphType >::nodes_iterator > nodes(const GraphType &G)
Definition: GraphTraits.h:89
LLVM_NODISCARD bool empty() const
Definition: SmallVector.h:61
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:545
LLVM Value Representation.
Definition: Value.h:73
print Print MemDeps of function
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:49
bool simplifyCFG(BasicBlock *BB, const TargetTransformInfo &TTI, const SimplifyCFGOptions &Options={}, SmallPtrSetImpl< BasicBlock *> *LoopHeaders=nullptr)
This function is used to do simplification of a CFG.
This pass exposes codegen information to IR-level passes.
const TerminatorInst * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.cpp:120
void initializeAMDGPUUnifyDivergentExitNodesPass(PassRegistry &)
void pop_back()
Definition: ilist.h:331
static BasicBlock * unifyReturnBlockSet(Function &F, ArrayRef< BasicBlock *> ReturningBlocks, const TargetTransformInfo &TTI, StringRef Name)