27#include "llvm/IR/IntrinsicsSPIRV.h"
34#include <unordered_set>
45using BlockSet = std::unordered_set<BasicBlock *>;
46using Edge = std::pair<BasicBlock *, BasicBlock *>;
53 V.partialOrderVisit(Start,
Op);
60 if (
Node->Entry == BB)
63 for (
auto *Child :
Node->Children) {
64 const auto *CR = getRegionForHeader(Child, BB);
74 std::unordered_set<BasicBlock *> ExitTargets;
82 assert(ExitTargets.size() <= 1);
83 if (ExitTargets.size() == 0)
86 return *ExitTargets.begin();
96 if (
II->getIntrinsicID() != Intrinsic::spv_loop_merge &&
97 II->getIntrinsicID() != Intrinsic::spv_selection_merge)
111 if (
II->getIntrinsicID() != Intrinsic::spv_loop_merge)
121 for (
auto &
I : Header) {
132 if (getDesignatedContinueBlock(&
I))
140 return getDesignatedMergeBlock(
I) !=
nullptr;
149 if (getDesignatedMergeBlock(&
I) !=
nullptr)
173std::vector<Instruction *> getMergeInstructions(
BasicBlock &BB) {
174 std::vector<Instruction *> Output;
176 if (isMergeInstruction(&
I))
177 Output.push_back(&
I);
198 std::stack<BasicBlock *> ToVisit;
201 ToVisit.push(&Start);
202 Seen.
insert(ToVisit.top());
203 while (ToVisit.size() != 0) {
227 for (
size_t i = 0; i < BI->getNumSuccessors(); i++) {
228 if (BI->getSuccessor(i) == OldTarget)
229 BI->setSuccessor(i, NewTarget);
233 if (BI->isUnconditional())
237 if (BI->getSuccessor(0) != BI->getSuccessor(1))
243 Builder.SetInsertPoint(BI);
244 Builder.CreateBr(BI->getSuccessor(0));
245 BI->eraseFromParent();
256 if (!
II ||
II->getIntrinsicID() != Intrinsic::spv_selection_merge)
260 II->eraseFromParent();
261 if (!
C->isConstantUsed())
262 C->destroyConstant();
272 if (isa<ReturnInst>(
T))
275 if (isa<BranchInst>(
T))
276 return replaceIfBranchTargets(BB, OldTarget, NewTarget);
278 if (
auto *SI = dyn_cast<SwitchInst>(
T)) {
279 for (
size_t i = 0; i <
SI->getNumSuccessors(); i++) {
280 if (
SI->getSuccessor(i) == OldTarget)
281 SI->setSuccessor(i, NewTarget);
286 assert(
false &&
"Unhandled terminator type.");
295 struct DivergentConstruct;
299 using ConstructList = std::vector<std::unique_ptr<DivergentConstruct>>;
305 struct DivergentConstruct {
310 DivergentConstruct *Parent =
nullptr;
311 ConstructList Children;
333 std::vector<BasicBlock *> getLoopConstructBlocks(
BasicBlock *Header,
336 std::vector<BasicBlock *> Output;
337 partialOrderVisit(*Header, [&](
BasicBlock *BB) {
342 Output.push_back(BB);
349 std::vector<BasicBlock *>
350 getSelectionConstructBlocks(DivergentConstruct *Node) {
352 BlockSet OutsideBlocks;
353 OutsideBlocks.insert(Node->Merge);
355 for (DivergentConstruct *It = Node->Parent; It !=
nullptr;
357 OutsideBlocks.insert(It->Merge);
359 OutsideBlocks.insert(It->Continue);
362 std::vector<BasicBlock *> Output;
363 partialOrderVisit(*Node->Header, [&](
BasicBlock *BB) {
364 if (OutsideBlocks.count(BB) != 0)
366 if (DT.dominates(Node->Merge, BB) || !DT.dominates(Node->Header, BB))
368 Output.push_back(BB);
375 std::vector<BasicBlock *> getSwitchConstructBlocks(
BasicBlock *Header,
379 std::vector<BasicBlock *> Output;
380 partialOrderVisit(*Header, [&](
BasicBlock *BB) {
388 Output.push_back(BB);
399 std::vector<BasicBlock *> Output;
409 Output.push_back(BB);
438 createAliasBlocksForComplexEdges(std::vector<Edge> Edges) {
439 std::unordered_set<BasicBlock *> Seen;
440 std::vector<Edge> Output;
443 for (
auto &[Src, Dst] : Edges) {
444 auto [Iterator, Inserted] = Seen.insert(Src);
449 F.getContext(), Src->getName() +
".new.src", &
F);
450 replaceBranchTargets(Src, Dst, NewSrc);
456 Output.emplace_back(Src, Dst);
472 std::vector<Edge> &Edges) {
474 std::vector<Edge> FixedEdges = createAliasBlocksForComplexEdges(Edges);
476 std::vector<BasicBlock *> Dsts;
477 std::unordered_map<BasicBlock *, ConstantInt *> DstToIndex;
479 Header->getName() +
".new.exit", &
F);
481 for (
auto &[Src, Dst] : FixedEdges) {
482 if (DstToIndex.count(Dst) != 0)
484 DstToIndex.emplace(Dst, ExitBuilder.
getInt32(DstToIndex.size()));
488 if (Dsts.size() == 1) {
489 for (
auto &[Src, Dst] : FixedEdges) {
490 replaceBranchTargets(Src, Dst, NewExit);
497 F.begin()->getFirstInsertionPt());
498 for (
auto &[Src, Dst] : FixedEdges) {
502 replaceBranchTargets(Src, Dst, NewExit);
511 if (Dsts.size() == 2) {
519 for (
auto It = Dsts.begin() + 1; It != Dsts.end(); ++It) {
520 Sw->
addCase(DstToIndex[*It], *It);
528 Value *createExitVariable(
532 if (isa<ReturnInst>(
T))
538 if (
auto *BI = dyn_cast<BranchInst>(
T)) {
542 BI->isConditional() ? BI->getSuccessor(1) :
nullptr;
545 ? TargetToValue.
at(LHSTarget)
548 ? TargetToValue.
at(RHSTarget)
551 if (
LHS ==
nullptr ||
RHS ==
nullptr)
570 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
571 auto *TopLevelRegion =
572 getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
574 .getTopLevelRegion();
585 auto *CR = getRegionForHeader(TopLevelRegion, &BB);
591 auto *
Merge = getExitFor(CR);
598 if (
Merge ==
nullptr) {
601 "This assumes the branch is not a switch. Maybe that's wrong?");
604 Merge = CreateUnreachable(
F);
627 bool addMergeForNodesWithMultiplePredecessors(
Function &
F) {
636 if (hasLoopMergeInstruction(BB) &&
pred_size(&BB) <= 2)
642 if (isDefinedAsSelectionMergeBy(*Header, BB))
664 std::vector<Instruction *> MergeInstructions;
666 if (isMergeInstruction(&
I))
667 MergeInstructions.push_back(&
I);
669 if (MergeInstructions.size() <= 1)
675 std::sort(MergeInstructions.begin(), MergeInstructions.end(),
679 BasicBlock *RightMerge = getDesignatedMergeBlock(Right);
680 BasicBlock *LeftMerge = getDesignatedMergeBlock(Left);
681 return !Visitor.compare(RightMerge, LeftMerge);
695 bool sortSelectionMergeHeaders(
Function &
F) {
705 bool splitBlocksWithMultipleHeaders(
Function &
F) {
706 std::stack<BasicBlock *> Work;
708 std::vector<Instruction *> MergeInstructions = getMergeInstructions(BB);
709 if (MergeInstructions.size() <= 1)
714 const bool Modified = Work.size() > 0;
715 while (Work.size() > 0) {
719 std::vector<Instruction *> MergeInstructions =
720 getMergeInstructions(*Header);
721 for (
unsigned i = 1; i < MergeInstructions.size(); i++) {
723 Header->splitBasicBlock(MergeInstructions[i],
"new.header");
725 if (getDesignatedContinueBlock(MergeInstructions[0]) ==
nullptr) {
728 BranchInst *BI = cast<BranchInst>(Header->getTerminator());
744 bool addMergeForDivergentBlocks(
Function &
F) {
749 auto MergeBlocks = getMergeBlocks(
F);
750 auto ContinueBlocks = getContinueBlocks(
F);
753 if (getMergeInstructions(BB).
size() != 0)
756 std::vector<BasicBlock *> Candidates;
765 if (Candidates.size() <= 1)
783 std::vector<Edge> getExitsFrom(
const BlockSet &Construct,
785 std::vector<Edge> Output;
787 if (Construct.count(Item) == 0)
802 void constructDivergentConstruct(BlockSet &Visited, Splitter &S,
804 if (Visited.count(BB) != 0)
808 auto MIS = getMergeInstructions(*BB);
809 if (MIS.size() == 0) {
811 constructDivergentConstruct(Visited, S,
Successor, Parent);
821 auto Output = std::make_unique<DivergentConstruct>();
823 Output->Merge =
Merge;
825 Output->Parent = Parent;
827 constructDivergentConstruct(Visited, S,
Merge, Parent);
829 constructDivergentConstruct(Visited, S,
Continue, Output.get());
832 constructDivergentConstruct(Visited, S,
Successor, Output.get());
835 Parent->Children.emplace_back(std::move(Output));
839 BlockSet getConstructBlocks(Splitter &S, DivergentConstruct *Node) {
840 assert(Node->Header && Node->Merge);
842 if (Node->Continue) {
843 auto LoopBlocks = S.getLoopConstructBlocks(Node->Header, Node->Merge);
844 return BlockSet(LoopBlocks.begin(), LoopBlocks.end());
847 auto SelectionBlocks = S.getSelectionConstructBlocks(Node);
848 return BlockSet(SelectionBlocks.begin(), SelectionBlocks.end());
853 bool fixupConstruct(Splitter &S, DivergentConstruct *Node) {
855 for (
auto &Child : Node->Children)
856 Modified |= fixupConstruct(S, Child.get());
860 if (Node->Parent ==
nullptr)
866 if (Node->Parent->Header ==
nullptr)
870 assert(Node->Header && Node->Merge);
871 assert(Node->Parent->Header && Node->Parent->Merge);
873 BlockSet ConstructBlocks = getConstructBlocks(S, Node);
874 auto Edges = getExitsFrom(ConstructBlocks, *Node->Header);
877 if (Edges.size() < 1)
880 bool HasBadEdge = Node->Merge == Node->Parent->Merge ||
881 Node->Merge == Node->Parent->Continue;
883 for (
auto &[Src, Dst] : Edges) {
890 if (Node->Merge == Dst)
895 if (Node->Continue == Dst)
907 BasicBlock *NewExit = S.createSingleExitNode(Node->Header, Edges);
913 auto MergeInstructions = getMergeInstructions(*Node->Header);
914 assert(MergeInstructions.size() == 1);
919 I->setOperand(0, MergeAddress);
927 Node->Merge = NewExit;
934 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
937 DivergentConstruct Root;
939 constructDivergentConstruct(Visited, S, &*
F.begin(), &Root);
940 return fixupConstruct(S, &Root);
955 if (SI->getNumCases() > 1)
962 if (SI->getNumCases() == 0) {
963 Builder.
CreateBr(SI->getDefaultDest());
967 SI->case_begin()->getCaseValue());
968 Builder.
CreateCondBr(Condition, SI->case_begin()->getCaseSuccessor(),
969 SI->getDefaultDest());
971 SI->eraseFromParent();
989 Seen.
insert(SI->getDefaultDest());
991 auto It = SI->case_begin();
992 while (It != SI->case_end()) {
994 if (Seen.count(
Target) == 0) {
1005 SI->addCase(It->getCaseValue(), NewTarget);
1006 It = SI->removeCase(It);
1016 std::vector<BasicBlock *>
ToRemove;
1018 auto MergeBlocks = getMergeBlocks(
F);
1019 auto ContinueBlocks = getContinueBlocks(
F);
1028 if (MergeBlocks.count(&BB) != 0 || ContinueBlocks.count(&BB) != 0)
1035 std::vector<BasicBlock *> Predecessors(
predecessors(&BB).begin(),
1038 replaceBranchTargets(Predecessor, &BB,
Successor);
1048 bool addHeaderToRemainingDivergentDAG(
Function &
F) {
1051 auto MergeBlocks = getMergeBlocks(
F);
1052 auto ContinueBlocks = getContinueBlocks(
F);
1053 auto HeaderBlocks = getHeaderBlocks(
F);
1061 if (HeaderBlocks.count(&BB) != 0)
1066 size_t CandidateEdges = 0;
1068 if (MergeBlocks.count(
Successor) != 0 ||
1073 CandidateEdges += 1;
1076 if (CandidateEdges <= 1)
1082 bool HasBadBlock =
false;
1088 if (Node == Header || Node ==
Merge)
1091 HasBadBlock |= MergeBlocks.count(Node) != 0 ||
1092 ContinueBlocks.count(Node) != 0 ||
1093 HeaderBlocks.count(Node) != 0;
1094 return !HasBadBlock;
1102 if (
Merge ==
nullptr) {
1114 if (isMergeInstruction(SplitInstruction->
getPrevNode()))
1115 SplitInstruction = SplitInstruction->
getPrevNode();
1117 Merge->splitBasicBlockBefore(SplitInstruction,
"new.merge");
1155 Modified |= addMergeForNodesWithMultiplePredecessors(
F);
1160 Modified |= sortSelectionMergeHeaders(
F);
1165 Modified |= splitBlocksWithMultipleHeaders(
F);
1171 Modified |= addMergeForDivergentBlocks(
F);
1195 Modified |= addHeaderToRemainingDivergentDAG(
F);
1217 "structurize SPIRV",
false,
false)
for(const MachineOperand &MO :llvm::drop_begin(OldMI.operands(), Desc.getNumOperands()))
ReachingDefAnalysis InstSet & ToRemove
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
This file defines the DenseMap class.
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
This file defines the SmallPtrSet class.
an instruction to allocate memory on the stack
A container for analyses that lazily runs them and caches their results.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
LLVM Basic Block Representation.
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
const BasicBlock * getUniqueSuccessor() const
Return the successor of this block if it has a unique successor.
const Function * getParent() const
Return the enclosing method, or null if none.
SymbolTableList< BasicBlock >::iterator eraseFromParent()
Unlink 'this' from the containing function and delete it.
InstListType::iterator iterator
Instruction iterators...
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
The address of a basic block.
BasicBlock * getBasicBlock() const
static BlockAddress * get(Function *F, BasicBlock *BB)
Return a BlockAddress for the specified function and basic block.
Conditional or Unconditional Branch instruction.
BasicBlock * getSuccessor(unsigned i) const
Represents analyses that only rely on functions' control flow.
This is an important base class in LLVM.
bool isConstantUsed() const
Return true if the constant has users other than constant expressions and other dangling things.
void destroyConstant()
Called if some element of this constant is no longer valid.
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
const ValueT & at(const_arg_type_t< KeyT > Val) const
at - Return the entry for the specified key, or abort if no such entry exists.
DomTreeNodeBase * getIDom() const
Core dominator tree base class.
bool dominates(const DomTreeNodeBase< NodeT > *A, const DomTreeNodeBase< NodeT > *B) const
dominates - Returns true iff A dominates B.
void recalculate(ParentType &Func)
recalculate - compute a dominator tree for the given function
DomTreeNodeBase< NodeT > * getNode(const NodeT *BB) const
getNode - return the (Post)DominatorTree node for the specified basic block.
Legacy analysis pass which computes a DominatorTree.
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
UnreachableInst * CreateUnreachable()
ConstantInt * getTrue()
Get the constant value for i1 true.
CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, Instruction *FMFSource=nullptr, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
IntegerType * getInt32Ty()
Fetch the type representing a 32-bit integer.
ConstantInt * getInt32(uint32_t C)
Get a constant 32-bit value.
Value * CreateCmp(CmpInst::Predicate Pred, Value *LHS, Value *RHS, const Twine &Name="", MDNode *FPMathTag=nullptr)
SwitchInst * CreateSwitch(Value *V, BasicBlock *Dest, unsigned NumCases=10, MDNode *BranchWeights=nullptr, MDNode *Unpredictable=nullptr)
Create a switch instruction with the specified value, default dest, and with a hint for the number of...
BranchInst * CreateCondBr(Value *Cond, BasicBlock *True, BasicBlock *False, MDNode *BranchWeights=nullptr, MDNode *Unpredictable=nullptr)
Create a conditional 'br Cond, TrueDest, FalseDest' instruction.
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
StoreInst * CreateStore(Value *Val, Value *Ptr, bool isVolatile=false)
ConstantInt * getFalse()
Get the constant value for i1 false.
BranchInst * CreateBr(BasicBlock *Dest)
Create an unconditional 'br label X' instruction.
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
A wrapper class for inspecting calls to intrinsic functions.
bool isLoopHeader(const BlockT *BB) const
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
PassRegistry - This class manages the registration and intitialization of the pass subsystem as appli...
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
void preserveSet()
Mark an analysis set as preserved.
PreservedAnalyses run(Function &M, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
virtual bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
SmallPtrSet< BasicBlock *, 2 > Exits
SmallPtrSet< BasicBlock *, 8 > Blocks
void reserve(size_type NumEntries)
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
void addCase(ConstantInt *OnVal, BasicBlock *Dest)
Add an entry to the switch instruction.
Target - Wrapper for Target specific information.
The instances of the Type class are immutable: once they are created, they are never changed.
LLVM Value Representation.
Pass manager infrastructure for declaring and invalidating analyses.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
@ C
The default llvm calling convention, compatible with C.
This is an optimization pass for GlobalISel generic memory operations.
FunctionPass * createSPIRVStructurizerPass()
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
auto successors(const MachineBasicBlock *BB)
bool sortBlocks(Function &F)
auto pred_size(const MachineBasicBlock *BB)
auto succ_size(const MachineBasicBlock *BB)
void initializeSPIRVStructurizerPass(PassRegistry &)
auto predecessors(const MachineBasicBlock *BB)