LLVM 20.0.0git
SPIRVMergeRegionExitTargets.cpp
Go to the documentation of this file.
1//===-- SPIRVMergeRegionExitTargets.cpp ----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Merge the multiple exit targets of a convergence region into a single block.
10// Each exit target will be assigned a constant value, and a phi node + switch
11// will allow the new exit target to re-route to the correct basic block.
12//
13//===----------------------------------------------------------------------===//
14
16#include "SPIRV.h"
17#include "SPIRVSubtarget.h"
18#include "SPIRVTargetMachine.h"
19#include "SPIRVUtils.h"
20#include "llvm/ADT/DenseMap.h"
24#include "llvm/IR/CFG.h"
25#include "llvm/IR/Dominators.h"
26#include "llvm/IR/IRBuilder.h"
28#include "llvm/IR/Intrinsics.h"
29#include "llvm/IR/IntrinsicsSPIRV.h"
34
35using namespace llvm;
36
37namespace llvm {
39
41public:
42 static char ID;
43
46 };
47
48 // Gather all the successors of |BB|.
49 // This function asserts if the terminator neither a branch, switch or return.
50 std::unordered_set<BasicBlock *> gatherSuccessors(BasicBlock *BB) {
51 std::unordered_set<BasicBlock *> output;
52 auto *T = BB->getTerminator();
53
54 if (auto *BI = dyn_cast<BranchInst>(T)) {
55 output.insert(BI->getSuccessor(0));
56 if (BI->isConditional())
57 output.insert(BI->getSuccessor(1));
58 return output;
59 }
60
61 if (auto *SI = dyn_cast<SwitchInst>(T)) {
62 output.insert(SI->getDefaultDest());
63 for (auto &Case : SI->cases())
64 output.insert(Case.getCaseSuccessor());
65 return output;
66 }
67
68 assert(isa<ReturnInst>(T) && "Unhandled terminator type.");
69 return output;
70 }
71
72 /// Create a value in BB set to the value associated with the branch the block
73 /// terminator will take.
75 BasicBlock *BB,
76 const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) {
77 auto *T = BB->getTerminator();
78 if (isa<ReturnInst>(T))
79 return nullptr;
80
81 IRBuilder<> Builder(BB);
82 Builder.SetInsertPoint(T);
83
84 if (auto *BI = dyn_cast<BranchInst>(T)) {
85
86 BasicBlock *LHSTarget = BI->getSuccessor(0);
87 BasicBlock *RHSTarget =
88 BI->isConditional() ? BI->getSuccessor(1) : nullptr;
89
90 Value *LHS = TargetToValue.count(LHSTarget) != 0
91 ? TargetToValue.at(LHSTarget)
92 : nullptr;
93 Value *RHS = TargetToValue.count(RHSTarget) != 0
94 ? TargetToValue.at(RHSTarget)
95 : nullptr;
96
97 if (LHS == nullptr || RHS == nullptr)
98 return LHS == nullptr ? RHS : LHS;
99 return Builder.CreateSelect(BI->getCondition(), LHS, RHS);
100 }
101
102 // TODO: add support for switch cases.
103 llvm_unreachable("Unhandled terminator type.");
104 }
105
106 /// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|.
108 const SmallPtrSet<BasicBlock *, 4> &ToReplace,
109 BasicBlock *NewTarget) {
110 auto *T = BB->getTerminator();
111 if (isa<ReturnInst>(T))
112 return;
113
114 if (auto *BI = dyn_cast<BranchInst>(T)) {
115 for (size_t i = 0; i < BI->getNumSuccessors(); i++) {
116 if (ToReplace.count(BI->getSuccessor(i)) != 0)
117 BI->setSuccessor(i, NewTarget);
118 }
119 return;
120 }
121
122 if (auto *SI = dyn_cast<SwitchInst>(T)) {
123 for (size_t i = 0; i < SI->getNumSuccessors(); i++) {
124 if (ToReplace.count(SI->getSuccessor(i)) != 0)
125 SI->setSuccessor(i, NewTarget);
126 }
127 return;
128 }
129
130 assert(false && "Unhandled terminator type.");
131 }
132
134 BasicBlock::iterator Position) {
135 const DataLayout &DL = F.getDataLayout();
136 return new AllocaInst(Type, DL.getAllocaAddrSpace(), nullptr, "reg",
137 Position);
138 }
139
140 // Run the pass on the given convergence region, ignoring the sub-regions.
141 // Returns true if the CFG changed, false otherwise.
144 // Gather all the exit targets for this region.
146 for (BasicBlock *Exit : CR->Exits) {
147 for (BasicBlock *Target : gatherSuccessors(Exit)) {
148 if (CR->Blocks.count(Target) == 0)
149 ExitTargets.insert(Target);
150 }
151 }
152
153 // If we have zero or one exit target, nothing do to.
154 if (ExitTargets.size() <= 1)
155 return false;
156
157 // Create the new single exit target.
158 auto F = CR->Entry->getParent();
159 auto NewExitTarget = BasicBlock::Create(F->getContext(), "new.exit", F);
160 IRBuilder<> Builder(NewExitTarget);
161
162 AllocaInst *Variable = CreateVariable(*F, Builder.getInt32Ty(),
163 F->begin()->getFirstInsertionPt());
164
165 // CodeGen output needs to be stable. Using the set as-is would order
166 // the targets differently depending on the allocation pattern.
167 // Sorting per basic-block ordering in the function.
168 std::vector<BasicBlock *> SortedExitTargets;
169 std::vector<BasicBlock *> SortedExits;
170 for (BasicBlock &BB : *F) {
171 if (ExitTargets.count(&BB) != 0)
172 SortedExitTargets.push_back(&BB);
173 if (CR->Exits.count(&BB) != 0)
174 SortedExits.push_back(&BB);
175 }
176
177 // Creating one constant per distinct exit target. This will be route to the
178 // correct target.
180 for (BasicBlock *Target : SortedExitTargets)
181 TargetToValue.insert(
182 std::make_pair(Target, Builder.getInt32(TargetToValue.size())));
183
184 // Creating one variable per exit node, set to the constant matching the
185 // targeted external block.
186 std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable;
187 for (auto Exit : SortedExits) {
188 llvm::Value *Value = createExitVariable(Exit, TargetToValue);
189 IRBuilder<> B2(Exit);
190 B2.SetInsertPoint(Exit->getFirstInsertionPt());
191 B2.CreateStore(Value, Variable);
192 ExitToVariable.emplace_back(std::make_pair(Exit, Value));
193 }
194
195 llvm::Value *Load = Builder.CreateLoad(Builder.getInt32Ty(), Variable);
196
197 // Creating the switch to jump to the correct exit target.
198 llvm::SwitchInst *Sw = Builder.CreateSwitch(Load, SortedExitTargets[0],
199 SortedExitTargets.size() - 1);
200 for (size_t i = 1; i < SortedExitTargets.size(); i++) {
201 BasicBlock *BB = SortedExitTargets[i];
202 Sw->addCase(TargetToValue[BB], BB);
203 }
204
205 // Fix exit branches to redirect to the new exit.
206 for (auto Exit : CR->Exits)
207 replaceBranchTargets(Exit, ExitTargets, NewExitTarget);
208
209 CR = CR->Parent;
210 while (CR) {
211 CR->Blocks.insert(NewExitTarget);
212 CR = CR->Parent;
213 }
214
215 return true;
216 }
217
218 /// Run the pass on the given convergence region and sub-regions (DFS).
219 /// Returns true if a region/sub-region was modified, false otherwise.
220 /// This returns as soon as one region/sub-region has been modified.
222 for (auto *Child : CR->Children)
223 if (runOnConvergenceRegion(LI, Child))
224 return true;
225
226 return runOnConvergenceRegionNoRecurse(LI, CR);
227 }
228
229#if !NDEBUG
230 /// Validates each edge exiting the region has the same destination basic
231 /// block.
233 for (auto *Child : CR->Children)
234 validateRegionExits(Child);
235
236 std::unordered_set<BasicBlock *> ExitTargets;
237 for (auto *Exit : CR->Exits) {
238 auto Set = gatherSuccessors(Exit);
239 for (auto *BB : Set) {
240 if (CR->Blocks.count(BB) == 0)
241 ExitTargets.insert(BB);
242 }
243 }
244
245 assert(ExitTargets.size() <= 1);
246 }
247#endif
248
249 virtual bool runOnFunction(Function &F) override {
250 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
251 auto *TopLevelRegion =
252 getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
253 .getRegionInfo()
254 .getWritableTopLevelRegion();
255
256 // FIXME: very inefficient method: each time a region is modified, we bubble
257 // back up, and recompute the whole convergence region tree. Once the
258 // algorithm is completed and test coverage good enough, rewrite this pass
259 // to be efficient instead of simple.
260 bool modified = false;
261 while (runOnConvergenceRegion(LI, TopLevelRegion)) {
262 modified = true;
263 }
264
265#if !defined(NDEBUG) || defined(EXPENSIVE_CHECKS)
266 validateRegionExits(TopLevelRegion);
267#endif
268 return modified;
269 }
270
271 void getAnalysisUsage(AnalysisUsage &AU) const override {
275
278 }
279};
280} // namespace llvm
281
283
285 "SPIRV split region exit blocks", false, false)
286INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
290
292 "SPIRV split region exit blocks", false, false)
293
295 return new SPIRVMergeRegionExitTargets();
296}
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
This file defines the DenseMap class.
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
#define F(x, y, z)
Definition: MD5.cpp:55
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:55
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:57
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
convergence region
split region exit blocks
This file defines the SmallPtrSet class.
Value * RHS
Value * LHS
an instruction to allocate memory on the stack
Definition: Instructions.h:63
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
LLVM Basic Block Representation.
Definition: BasicBlock.h:61
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition: BasicBlock.h:212
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:219
InstListType::iterator iterator
Instruction iterators...
Definition: BasicBlock.h:177
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:239
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:63
unsigned size() const
Definition: DenseMap.h:99
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition: DenseMap.h:152
const ValueT & at(const_arg_type_t< KeyT > Val) const
at - Return the entry for the specified key, or abort if no such entry exists.
Definition: DenseMap.h:202
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:211
Legacy analysis pass which computes a DominatorTree.
Definition: Dominators.h:317
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:310
Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
Definition: IRBuilder.cpp:1048
IntegerType * getInt32Ty()
Fetch the type representing a 32-bit integer.
Definition: IRBuilder.h:523
ConstantInt * getInt32(uint32_t C)
Get a constant 32-bit value.
Definition: IRBuilder.h:483
SwitchInst * CreateSwitch(Value *V, BasicBlock *Dest, unsigned NumCases=10, MDNode *BranchWeights=nullptr, MDNode *Unpredictable=nullptr)
Create a switch instruction with the specified value, default dest, and with a hint for the number of...
Definition: IRBuilder.h:1167
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
Definition: IRBuilder.h:1813
StoreInst * CreateStore(Value *Val, Value *Ptr, bool isVolatile=false)
Definition: IRBuilder.h:1826
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition: IRBuilder.h:177
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2697
void setSuccessor(unsigned Idx, BasicBlock *BB)
Update the specified successor to point at the provided block.
The legacy pass manager's analysis pass to compute loop information.
Definition: LoopInfo.h:593
PassRegistry - This class manages the registration and intitialization of the pass subsystem as appli...
Definition: PassRegistry.h:37
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:98
std::unordered_set< BasicBlock * > gatherSuccessors(BasicBlock *BB)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void replaceBranchTargets(BasicBlock *BB, const SmallPtrSet< BasicBlock *, 4 > &ToReplace, BasicBlock *NewTarget)
Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|.
bool runOnConvergenceRegion(LoopInfo &LI, SPIRV::ConvergenceRegion *CR)
Run the pass on the given convergence region and sub-regions (DFS).
llvm::Value * createExitVariable(BasicBlock *BB, const DenseMap< BasicBlock *, ConstantInt * > &TargetToValue)
Create a value in BB set to the value associated with the branch the block terminator will take.
bool runOnConvergenceRegionNoRecurse(LoopInfo &LI, SPIRV::ConvergenceRegion *CR)
void validateRegionExits(const SPIRV::ConvergenceRegion *CR)
Validates each edge exiting the region has the same destination basic block.
virtual bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
AllocaInst * CreateVariable(Function &F, Type *Type, BasicBlock::iterator Position)
SmallVector< ConvergenceRegion * > Children
SmallPtrSet< BasicBlock *, 2 > Exits
SmallPtrSet< BasicBlock *, 8 > Blocks
size_type size() const
Definition: SmallPtrSet.h:94
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
Definition: SmallPtrSet.h:452
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:384
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:519
Multiway switch.
void addCase(ConstantInt *OnVal, BasicBlock *Dest)
Add an entry to the switch instruction.
Target - Wrapper for Target specific information.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
LLVM Value Representation.
Definition: Value.h:74
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void initializeSPIRVMergeRegionExitTargetsPass(PassRegistry &)
FunctionPass * createSPIRVMergeRegionExitTargetsPass()