26#include "llvm/IR/IntrinsicsSPIRV.h"
36#include <unordered_set>
47using BlockSet = std::unordered_set<BasicBlock *>;
48using Edge = std::pair<BasicBlock *, BasicBlock *>;
55 V.partialOrderVisit(Start,
Op);
62 if (
Node->Entry == BB)
65 for (
auto *Child :
Node->Children) {
66 const auto *CR = getRegionForHeader(Child, BB);
76 std::unordered_set<BasicBlock *> ExitTargets;
84 assert(ExitTargets.size() <= 1);
85 if (ExitTargets.size() == 0)
88 return *ExitTargets.begin();
98 if (
II->getIntrinsicID() != Intrinsic::spv_loop_merge &&
99 II->getIntrinsicID() != Intrinsic::spv_selection_merge)
113 if (
II->getIntrinsicID() != Intrinsic::spv_loop_merge)
123 for (
auto &
I : Header) {
134 if (getDesignatedContinueBlock(&
I))
142 return getDesignatedMergeBlock(
I) !=
nullptr;
151 if (getDesignatedMergeBlock(&
I) !=
nullptr)
175std::vector<Instruction *> getMergeInstructions(
BasicBlock &BB) {
176 std::vector<Instruction *> Output;
178 if (isMergeInstruction(&
I))
179 Output.push_back(&
I);
200 std::stack<BasicBlock *> ToVisit;
203 ToVisit.push(&Start);
204 Seen.
insert(ToVisit.top());
205 while (ToVisit.size() != 0) {
229 for (
size_t i = 0; i < BI->getNumSuccessors(); i++) {
230 if (BI->getSuccessor(i) == OldTarget)
231 BI->setSuccessor(i, NewTarget);
235 if (BI->isUnconditional())
239 if (BI->getSuccessor(0) != BI->getSuccessor(1))
245 Builder.SetInsertPoint(BI);
246 Builder.CreateBr(BI->getSuccessor(0));
247 BI->eraseFromParent();
258 if (!
II ||
II->getIntrinsicID() != Intrinsic::spv_selection_merge)
262 II->eraseFromParent();
263 if (!
C->isConstantUsed())
264 C->destroyConstant();
274 if (isa<ReturnInst>(
T))
277 if (isa<BranchInst>(
T))
278 return replaceIfBranchTargets(BB, OldTarget, NewTarget);
280 if (
auto *SI = dyn_cast<SwitchInst>(
T)) {
281 for (
size_t i = 0; i <
SI->getNumSuccessors(); i++) {
282 if (
SI->getSuccessor(i) == OldTarget)
283 SI->setSuccessor(i, NewTarget);
288 assert(
false &&
"Unhandled terminator type.");
297 struct DivergentConstruct;
301 using ConstructList = std::vector<std::unique_ptr<DivergentConstruct>>;
307 struct DivergentConstruct {
312 DivergentConstruct *Parent =
nullptr;
313 ConstructList Children;
335 std::vector<BasicBlock *> getLoopConstructBlocks(
BasicBlock *Header,
338 std::vector<BasicBlock *> Output;
339 partialOrderVisit(*Header, [&](
BasicBlock *BB) {
344 Output.push_back(BB);
351 std::vector<BasicBlock *>
352 getSelectionConstructBlocks(DivergentConstruct *Node) {
354 BlockSet OutsideBlocks;
355 OutsideBlocks.insert(Node->Merge);
357 for (DivergentConstruct *It = Node->Parent; It !=
nullptr;
359 OutsideBlocks.insert(It->Merge);
361 OutsideBlocks.insert(It->Continue);
364 std::vector<BasicBlock *> Output;
365 partialOrderVisit(*Node->Header, [&](
BasicBlock *BB) {
366 if (OutsideBlocks.count(BB) != 0)
368 if (DT.dominates(Node->Merge, BB) || !DT.dominates(Node->Header, BB))
370 Output.push_back(BB);
377 std::vector<BasicBlock *> getSwitchConstructBlocks(
BasicBlock *Header,
381 std::vector<BasicBlock *> Output;
382 partialOrderVisit(*Header, [&](
BasicBlock *BB) {
390 Output.push_back(BB);
401 std::vector<BasicBlock *> Output;
411 Output.push_back(BB);
440 createAliasBlocksForComplexEdges(std::vector<Edge> Edges) {
441 std::unordered_set<BasicBlock *> Seen;
442 std::vector<Edge> Output;
445 for (
auto &[Src, Dst] : Edges) {
446 auto [Iterator, Inserted] = Seen.insert(Src);
451 F.getContext(), Src->getName() +
".new.src", &
F);
452 replaceBranchTargets(Src, Dst, NewSrc);
458 Output.emplace_back(Src, Dst);
474 std::vector<Edge> &Edges) {
476 std::vector<Edge> FixedEdges = createAliasBlocksForComplexEdges(Edges);
478 std::vector<BasicBlock *> Dsts;
479 std::unordered_map<BasicBlock *, ConstantInt *> DstToIndex;
481 Header->getName() +
".new.exit", &
F);
483 for (
auto &[Src, Dst] : FixedEdges) {
484 if (DstToIndex.count(Dst) != 0)
486 DstToIndex.emplace(Dst, ExitBuilder.
getInt32(DstToIndex.size()));
490 if (Dsts.size() == 1) {
491 for (
auto &[Src, Dst] : FixedEdges) {
492 replaceBranchTargets(Src, Dst, NewExit);
499 F.begin()->getFirstInsertionPt());
500 for (
auto &[Src, Dst] : FixedEdges) {
504 replaceBranchTargets(Src, Dst, NewExit);
513 if (Dsts.size() == 2) {
521 for (
auto It = Dsts.begin() + 1; It != Dsts.end(); ++It) {
522 Sw->
addCase(DstToIndex[*It], *It);
530 Value *createExitVariable(
534 if (isa<ReturnInst>(
T))
540 if (
auto *BI = dyn_cast<BranchInst>(
T)) {
544 BI->isConditional() ? BI->getSuccessor(1) :
nullptr;
547 ? TargetToValue.
at(LHSTarget)
550 ? TargetToValue.
at(RHSTarget)
553 if (
LHS ==
nullptr ||
RHS ==
nullptr)
572 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
573 auto *TopLevelRegion =
574 getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
576 .getTopLevelRegion();
587 auto *CR = getRegionForHeader(TopLevelRegion, &BB);
593 auto *
Merge = getExitFor(CR);
600 if (
Merge ==
nullptr) {
603 "This assumes the branch is not a switch. Maybe that's wrong?");
606 Merge = CreateUnreachable(
F);
629 bool addMergeForNodesWithMultiplePredecessors(
Function &
F) {
638 if (hasLoopMergeInstruction(BB) &&
pred_size(&BB) <= 2)
644 if (isDefinedAsSelectionMergeBy(*Header, BB))
665 std::vector<Instruction *> MergeInstructions;
667 if (isMergeInstruction(&
I))
668 MergeInstructions.push_back(&
I);
670 if (MergeInstructions.size() <= 1)
676 std::sort(MergeInstructions.begin(), MergeInstructions.end(),
680 BasicBlock *RightMerge = getDesignatedMergeBlock(Right);
681 BasicBlock *LeftMerge = getDesignatedMergeBlock(Left);
682 return !Visitor.compare(RightMerge, LeftMerge);
696 bool sortSelectionMergeHeaders(
Function &
F) {
706 bool splitBlocksWithMultipleHeaders(
Function &
F) {
707 std::stack<BasicBlock *> Work;
709 std::vector<Instruction *> MergeInstructions = getMergeInstructions(BB);
710 if (MergeInstructions.size() <= 1)
715 const bool Modified = Work.size() > 0;
716 while (Work.size() > 0) {
720 std::vector<Instruction *> MergeInstructions =
721 getMergeInstructions(*Header);
722 for (
unsigned i = 1; i < MergeInstructions.size(); i++) {
724 Header->splitBasicBlock(MergeInstructions[i],
"new.header");
726 if (getDesignatedContinueBlock(MergeInstructions[0]) ==
nullptr) {
729 BranchInst *BI = cast<BranchInst>(Header->getTerminator());
745 bool addMergeForDivergentBlocks(
Function &
F) {
750 auto MergeBlocks = getMergeBlocks(
F);
751 auto ContinueBlocks = getContinueBlocks(
F);
754 if (getMergeInstructions(BB).
size() != 0)
757 std::vector<BasicBlock *> Candidates;
766 if (Candidates.size() <= 1)
783 std::vector<Edge> getExitsFrom(
const BlockSet &Construct,
785 std::vector<Edge> Output;
787 if (Construct.count(Item) == 0)
802 void constructDivergentConstruct(BlockSet &Visited, Splitter &S,
804 if (Visited.count(BB) != 0)
808 auto MIS = getMergeInstructions(*BB);
809 if (MIS.size() == 0) {
811 constructDivergentConstruct(Visited, S,
Successor, Parent);
821 auto Output = std::make_unique<DivergentConstruct>();
823 Output->Merge =
Merge;
825 Output->Parent = Parent;
827 constructDivergentConstruct(Visited, S,
Merge, Parent);
829 constructDivergentConstruct(Visited, S,
Continue, Output.get());
832 constructDivergentConstruct(Visited, S,
Successor, Output.get());
835 Parent->Children.emplace_back(std::move(Output));
839 BlockSet getConstructBlocks(Splitter &S, DivergentConstruct *Node) {
840 assert(Node->Header && Node->Merge);
842 if (Node->Continue) {
843 auto LoopBlocks = S.getLoopConstructBlocks(Node->Header, Node->Merge);
844 return BlockSet(LoopBlocks.begin(), LoopBlocks.end());
847 auto SelectionBlocks = S.getSelectionConstructBlocks(Node);
848 return BlockSet(SelectionBlocks.begin(), SelectionBlocks.end());
853 bool fixupConstruct(Splitter &S, DivergentConstruct *Node) {
855 for (
auto &Child : Node->Children)
856 Modified |= fixupConstruct(S, Child.get());
860 if (Node->Parent ==
nullptr)
866 if (Node->Parent->Header ==
nullptr)
870 assert(Node->Header && Node->Merge);
871 assert(Node->Parent->Header && Node->Parent->Merge);
873 BlockSet ConstructBlocks = getConstructBlocks(S, Node);
874 auto Edges = getExitsFrom(ConstructBlocks, *Node->Header);
877 if (Edges.size() < 1)
880 bool HasBadEdge = Node->Merge == Node->Parent->Merge ||
881 Node->Merge == Node->Parent->Continue;
883 for (
auto &[Src, Dst] : Edges) {
890 if (Node->Merge == Dst)
895 if (Node->Continue == Dst)
907 BasicBlock *NewExit = S.createSingleExitNode(Node->Header, Edges);
913 auto MergeInstructions = getMergeInstructions(*Node->Header);
914 assert(MergeInstructions.size() == 1);
919 I->setOperand(0, MergeAddress);
927 Node->Merge = NewExit;
934 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
937 DivergentConstruct Root;
939 constructDivergentConstruct(Visited, S, &*
F.begin(), &Root);
940 return fixupConstruct(S, &Root);
955 if (SI->getNumCases() > 1)
962 if (SI->getNumCases() == 0) {
963 Builder.
CreateBr(SI->getDefaultDest());
967 SI->case_begin()->getCaseValue());
968 Builder.
CreateCondBr(Condition, SI->case_begin()->getCaseSuccessor(),
969 SI->getDefaultDest());
971 SI->eraseFromParent();
989 Seen.
insert(SI->getDefaultDest());
991 auto It = SI->case_begin();
992 while (It != SI->case_end()) {
994 if (Seen.count(
Target) == 0) {
1005 SI->addCase(It->getCaseValue(), NewTarget);
1006 It = SI->removeCase(It);
1016 std::vector<BasicBlock *>
ToRemove;
1018 auto MergeBlocks = getMergeBlocks(
F);
1019 auto ContinueBlocks = getContinueBlocks(
F);
1028 if (MergeBlocks.count(&BB) != 0 || ContinueBlocks.count(&BB) != 0)
1035 std::vector<BasicBlock *> Predecessors(
predecessors(&BB).begin(),
1038 replaceBranchTargets(Predecessor, &BB,
Successor);
1048 bool addHeaderToRemainingDivergentDAG(
Function &
F) {
1051 auto MergeBlocks = getMergeBlocks(
F);
1052 auto ContinueBlocks = getContinueBlocks(
F);
1053 auto HeaderBlocks = getHeaderBlocks(
F);
1061 if (HeaderBlocks.count(&BB) != 0)
1066 size_t CandidateEdges = 0;
1068 if (MergeBlocks.count(
Successor) != 0 ||
1073 CandidateEdges += 1;
1076 if (CandidateEdges <= 1)
1082 bool HasBadBlock =
false;
1088 if (Node == Header || Node ==
Merge)
1091 HasBadBlock |= MergeBlocks.count(Node) != 0 ||
1092 ContinueBlocks.count(Node) != 0 ||
1093 HeaderBlocks.count(Node) != 0;
1094 return !HasBadBlock;
1102 if (
Merge ==
nullptr) {
1113 if (isMergeInstruction(SplitInstruction->
getPrevNode()))
1114 SplitInstruction = SplitInstruction->
getPrevNode();
1116 Merge->splitBasicBlockBefore(SplitInstruction,
"new.merge");
1153 Modified |= addMergeForNodesWithMultiplePredecessors(
F);
1158 Modified |= sortSelectionMergeHeaders(
F);
1163 Modified |= splitBlocksWithMultipleHeaders(
F);
1169 Modified |= addMergeForDivergentBlocks(
F);
1193 Modified |= addHeaderToRemainingDivergentDAG(
F);
1219 "invalid metadata hlsl.controlflow.hint");
1222 assert(BranchHint &&
"invalid metadata value for hlsl.controlflow.hint");
1228 {MergeAddress->
getType()}, {Args});
1236 "structurize SPIRV",
false,
false)
for(const MachineOperand &MO :llvm::drop_begin(OldMI.operands(), Desc.getNumOperands()))
ReachingDefAnalysis InstSet & ToRemove
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
This file defines the DenseMap class.
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
This file defines the SmallPtrSet class.
an instruction to allocate memory on the stack
A container for analyses that lazily runs them and caches their results.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
LLVM Basic Block Representation.
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
const BasicBlock * getUniqueSuccessor() const
Return the successor of this block if it has a unique successor.
const Function * getParent() const
Return the enclosing method, or null if none.
SymbolTableList< BasicBlock >::iterator eraseFromParent()
Unlink 'this' from the containing function and delete it.
InstListType::iterator iterator
Instruction iterators...
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
The address of a basic block.
BasicBlock * getBasicBlock() const
static BlockAddress * get(Function *F, BasicBlock *BB)
Return a BlockAddress for the specified function and basic block.
Conditional or Unconditional Branch instruction.
BasicBlock * getSuccessor(unsigned i) const
Represents analyses that only rely on functions' control flow.
This is the shared class of boolean and integer constants.
This is an important base class in LLVM.
bool isConstantUsed() const
Return true if the constant has users other than constant expressions and other dangling things.
void destroyConstant()
Called if some element of this constant is no longer valid.
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
const ValueT & at(const_arg_type_t< KeyT > Val) const
at - Return the entry for the specified key, or abort if no such entry exists.
DomTreeNodeBase * getIDom() const
Core dominator tree base class.
bool dominates(const DomTreeNodeBase< NodeT > *A, const DomTreeNodeBase< NodeT > *B) const
dominates - Returns true iff A dominates B.
void recalculate(ParentType &Func)
recalculate - compute a dominator tree for the given function
DomTreeNodeBase< NodeT > * getNode(const NodeT *BB) const
getNode - return the (Post)DominatorTree node for the specified basic block.
Legacy analysis pass which computes a DominatorTree.
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
FunctionPass class - This class is used to implement most global optimizations.
UnreachableInst * CreateUnreachable()
ConstantInt * getTrue()
Get the constant value for i1 true.
Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
IntegerType * getInt32Ty()
Fetch the type representing a 32-bit integer.
BasicBlock * GetInsertBlock() const
CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
ConstantInt * getInt32(uint32_t C)
Get a constant 32-bit value.
Value * CreateCmp(CmpInst::Predicate Pred, Value *LHS, Value *RHS, const Twine &Name="", MDNode *FPMathTag=nullptr)
SwitchInst * CreateSwitch(Value *V, BasicBlock *Dest, unsigned NumCases=10, MDNode *BranchWeights=nullptr, MDNode *Unpredictable=nullptr)
Create a switch instruction with the specified value, default dest, and with a hint for the number of...
BranchInst * CreateCondBr(Value *Cond, BasicBlock *True, BasicBlock *False, MDNode *BranchWeights=nullptr, MDNode *Unpredictable=nullptr)
Create a conditional 'br Cond, TrueDest, FalseDest' instruction.
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
StoreInst * CreateStore(Value *Val, Value *Ptr, bool isVolatile=false)
ConstantInt * getFalse()
Get the constant value for i1 false.
BranchInst * CreateBr(BasicBlock *Dest)
Create an unconditional 'br label X' instruction.
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
MDNode * getMetadata(unsigned KindID) const
Get the metadata of given kind attached to this Instruction.
A wrapper class for inspecting calls to intrinsic functions.
bool isLoopHeader(const BlockT *BB) const
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
const MDOperand & getOperand(unsigned I) const
unsigned getNumOperands() const
Return number of MDNode operands.
PassRegistry - This class manages the registration and intitialization of the pass subsystem as appli...
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
void preserveSet()
Mark an analysis set as preserved.
PreservedAnalyses run(Function &M, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void createOpSelectMerge(IRBuilder<> *Builder, BlockAddress *MergeAddress)
virtual bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
SmallPtrSet< BasicBlock *, 2 > Exits
SmallPtrSet< BasicBlock *, 8 > Blocks
void reserve(size_type NumEntries)
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
void addCase(ConstantInt *OnVal, BasicBlock *Dest)
Add an entry to the switch instruction.
Target - Wrapper for Target specific information.
The instances of the Type class are immutable: once they are created, they are never changed.
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
FunctionPassManager manages FunctionPasses.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
@ C
The default llvm calling convention, compatible with C.
This is an optimization pass for GlobalISel generic memory operations.
FunctionPass * createSPIRVStructurizerPass()
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
auto successors(const MachineBasicBlock *BB)
bool sortBlocks(Function &F)
auto pred_size(const MachineBasicBlock *BB)
auto succ_size(const MachineBasicBlock *BB)
void initializeSPIRVStructurizerPass(PassRegistry &)
auto predecessors(const MachineBasicBlock *BB)