24#include "llvm/IR/IntrinsicsSPIRV.h"
32#include <unordered_set>
37using BlockSet = std::unordered_set<BasicBlock *>;
38using Edge = std::pair<BasicBlock *, BasicBlock *>;
45 V.partialOrderVisit(Start, std::move(
Op));
52 if (
Node->Entry == BB)
55 for (
auto *Child :
Node->Children) {
66 std::unordered_set<BasicBlock *> ExitTargets;
74 assert(ExitTargets.size() <= 1);
75 if (ExitTargets.size() == 0)
78 return *ExitTargets.begin();
88 if (
II->getIntrinsicID() != Intrinsic::spv_loop_merge &&
89 II->getIntrinsicID() != Intrinsic::spv_selection_merge)
103 if (
II->getIntrinsicID() != Intrinsic::spv_loop_merge)
113 for (
auto &
I : Header) {
166 std::vector<Instruction *> Output;
169 Output.push_back(&
I);
190 std::stack<BasicBlock *> ToVisit;
193 ToVisit.push(&Start);
194 Seen.
insert(ToVisit.top());
195 while (ToVisit.size() != 0) {
219 for (
size_t i = 0; i < BI->getNumSuccessors(); i++) {
220 if (BI->getSuccessor(i) == OldTarget)
221 BI->setSuccessor(i, NewTarget);
225 if (BI->getSuccessor(0) != BI->getSuccessor(1))
231 Builder.SetInsertPoint(BI);
232 Builder.CreateBr(BI->getSuccessor(0));
233 BI->eraseFromParent();
244 if (!
II ||
II->getIntrinsicID() != Intrinsic::spv_selection_merge)
248 II->eraseFromParent();
249 if (!
C->isConstantUsed())
250 C->destroyConstant();
263 if (BI->getSuccessor() == OldTarget)
264 BI->setSuccessor(NewTarget);
272 for (
size_t i = 0; i <
SI->getNumSuccessors(); i++) {
273 if (
SI->getSuccessor(i) == OldTarget)
274 SI->setSuccessor(i, NewTarget);
279 assert(
false &&
"Unhandled terminator type.");
286 struct DivergentConstruct;
290 using ConstructList = std::vector<std::unique_ptr<DivergentConstruct>>;
296 struct DivergentConstruct {
301 DivergentConstruct *Parent =
nullptr;
302 ConstructList Children;
315 Splitter(Function &F, LoopInfo &LI) : F(F), LI(LI) { invalidate(); }
324 std::vector<BasicBlock *> getLoopConstructBlocks(BasicBlock *Header,
327 std::vector<BasicBlock *> Output;
331 if (DT.dominates(
Merge, BB) || !DT.dominates(Header, BB))
333 Output.push_back(BB);
340 std::vector<BasicBlock *>
341 getSelectionConstructBlocks(DivergentConstruct *Node) {
344 OutsideBlocks.insert(
Node->Merge);
346 for (DivergentConstruct *It =
Node->Parent; It !=
nullptr;
348 OutsideBlocks.insert(It->Merge);
350 OutsideBlocks.insert(It->Continue);
353 std::vector<BasicBlock *> Output;
355 if (OutsideBlocks.count(BB) != 0)
357 if (DT.dominates(Node->Merge, BB) || !DT.dominates(Node->Header, BB))
359 Output.push_back(BB);
366 std::vector<BasicBlock *> getSwitchConstructBlocks(BasicBlock *Header,
370 std::vector<BasicBlock *> Output;
373 if (!DT.dominates(Header, BB))
379 Output.push_back(BB);
386 std::vector<BasicBlock *> getCaseConstructBlocks(BasicBlock *Target,
390 std::vector<BasicBlock *> Output;
394 if (!DT.dominates(Target, BB))
400 Output.push_back(BB);
429 createAliasBlocksForComplexEdges(std::vector<Edge> Edges) {
430 std::unordered_set<BasicBlock *> Seen;
431 std::vector<Edge> Output;
432 Output.reserve(Edges.size());
434 for (
auto &[Src, Dst] : Edges) {
435 auto [Iterator,
Inserted] = Seen.insert(Src);
440 F.getContext(), Src->getName() +
".new.src", &F);
443 Builder.CreateBr(Dst);
447 Output.emplace_back(Src, Dst);
455 BasicBlock *createSingleExitNode(BasicBlock *Header,
456 std::vector<Edge> &Edges) {
458 std::vector<Edge> FixedEdges = createAliasBlocksForComplexEdges(Edges);
460 std::vector<BasicBlock *> Dsts;
461 std::unordered_map<BasicBlock *, ConstantInt *> DstToIndex;
463 Header->getName() +
".new.exit", &F);
465 for (
auto &[Src, Dst] : FixedEdges) {
466 if (DstToIndex.count(Dst) != 0)
468 DstToIndex.emplace(Dst, ExitBuilder.getInt32(DstToIndex.size()));
472 if (Dsts.size() == 1) {
473 for (
auto &[Src, Dst] : FixedEdges) {
476 ExitBuilder.CreateBr(Dsts[0]);
481 for (
auto &[Src, Dst] : FixedEdges) {
483 B2.SetInsertPoint(Src->getFirstInsertionPt());
484 B2.CreateStore(DstToIndex[Dst], Variable);
488 Value *
Load = ExitBuilder.CreateLoad(ExitBuilder.getInt32Ty(), Variable);
493 if (Dsts.size() == 2) {
496 ExitBuilder.CreateCondBr(Condition, Dsts[0], Dsts[1]);
500 SwitchInst *Sw = ExitBuilder.CreateSwitch(Load, Dsts[0], Dsts.size() - 1);
502 Sw->
addCase(DstToIndex[BB], BB);
511 Builder.CreateUnreachable();
516 bool addMergeForLoops(Function &
F) {
517 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
518 auto *TopLevelRegion =
519 getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
521 .getTopLevelRegion();
545 if (
Merge ==
nullptr) {
547 Merge = CreateUnreachable(
F);
548 Builder.SetInsertPoint(Br);
550 Br->eraseFromParent();
558 SmallVector<Value *, 2>
Args = {MergeAddress, ContinueAddress};
559 SmallVector<unsigned, 1> LoopControlImms =
561 for (
unsigned Imm : LoopControlImms)
562 Args.emplace_back(ConstantInt::get(Builder.getInt32Ty(), Imm));
563 Builder.CreateIntrinsic(Intrinsic::spv_loop_merge, {
Args});
573 bool addMergeForNodesWithMultiplePredecessors(Function &
F) {
592 Builder.SetInsertPoint(Header->getTerminator());
595 createOpSelectMerge(&Builder, MergeAddress);
608 bool sortSelectionMerge(Function &
F, BasicBlock &
Block) {
609 std::vector<Instruction *> MergeInstructions;
610 for (Instruction &
I :
Block)
612 MergeInstructions.push_back(&
I);
614 if (MergeInstructions.size() <= 1)
617 Instruction *InsertionPoint = *MergeInstructions.begin();
619 PartialOrderingVisitor Visitor(
F);
620 std::sort(MergeInstructions.begin(), MergeInstructions.end(),
621 [&Visitor](Instruction *
Left, Instruction *
Right) {
624 BasicBlock *RightMerge = getDesignatedMergeBlock(Right);
625 BasicBlock *LeftMerge = getDesignatedMergeBlock(Left);
626 return !Visitor.compare(RightMerge, LeftMerge);
629 for (Instruction *
I : MergeInstructions) {
640 bool sortSelectionMergeHeaders(Function &
F) {
642 for (BasicBlock &BB :
F) {
650 bool splitBlocksWithMultipleHeaders(Function &
F) {
651 std::stack<BasicBlock *> Work;
654 if (MergeInstructions.size() <= 1)
659 const bool Modified = Work.size() > 0;
660 while (Work.size() > 0) {
664 std::vector<Instruction *> MergeInstructions =
666 for (
unsigned i = 1; i < MergeInstructions.size(); i++) {
668 Header->splitBasicBlock(MergeInstructions[i],
"new.header");
675 Builder.SetInsertPoint(Term);
676 Builder.CreateCondBr(Builder.getTrue(), NewBlock, Unreachable);
677 Term->eraseFromParent();
689 bool addMergeForDivergentBlocks(Function &
F) {
701 std::vector<BasicBlock *> Candidates;
710 if (Candidates.size() <= 1)
719 createOpSelectMerge(&Builder, MergeAddress);
727 std::vector<Edge> getExitsFrom(
const BlockSet &Construct,
728 BasicBlock &Header) {
729 std::vector<Edge> Output;
730 visit(Header, [&](BasicBlock *Item) {
731 if (Construct.count(Item) == 0)
746 void constructDivergentConstruct(
BlockSet &Visited, Splitter &S,
747 BasicBlock *BB, DivergentConstruct *Parent) {
748 if (Visited.count(BB) != 0)
753 if (MIS.size() == 0) {
755 constructDivergentConstruct(Visited, S,
Successor, Parent);
765 auto Output = std::make_unique<DivergentConstruct>();
767 Output->Merge =
Merge;
769 Output->Parent = Parent;
771 constructDivergentConstruct(Visited, S,
Merge, Parent);
773 constructDivergentConstruct(Visited, S,
Continue, Output.get());
776 constructDivergentConstruct(Visited, S,
Successor, Output.get());
779 Parent->Children.emplace_back(std::move(Output));
783 BlockSet getConstructBlocks(Splitter &S, DivergentConstruct *Node) {
786 if (
Node->Continue) {
787 auto LoopBlocks = S.getLoopConstructBlocks(
Node->Header,
Node->Merge);
788 return BlockSet(LoopBlocks.begin(), LoopBlocks.end());
791 auto SelectionBlocks = S.getSelectionConstructBlocks(Node);
792 return BlockSet(SelectionBlocks.begin(), SelectionBlocks.end());
797 bool fixupConstruct(Splitter &S, DivergentConstruct *Node) {
799 for (
auto &Child :
Node->Children)
800 Modified |= fixupConstruct(S, Child.get());
804 if (
Node->Parent ==
nullptr)
810 if (
Node->Parent->Header ==
nullptr)
817 BlockSet ConstructBlocks = getConstructBlocks(S, Node);
818 auto Edges = getExitsFrom(ConstructBlocks, *
Node->Header);
821 if (Edges.size() < 1)
824 bool HasBadEdge =
Node->Merge ==
Node->Parent->Merge ||
825 Node->Merge ==
Node->Parent->Continue;
827 for (
auto &[Src, Dst] : Edges) {
834 if (
Node->Merge == Dst)
839 if (
Node->Continue == Dst)
851 BasicBlock *NewExit = S.createSingleExitNode(
Node->Header, Edges);
858 assert(MergeInstructions.size() == 1);
863 I->setOperand(0, MergeAddress);
871 Node->Merge = NewExit;
878 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
881 DivergentConstruct Root;
883 constructDivergentConstruct(Visited, S, &*
F.begin(), &Root);
884 return fixupConstruct(S, &Root);
892 bool simplifyBranches(Function &
F) {
895 for (BasicBlock &BB :
F) {
899 if (
SI->getNumCases() > 1)
904 Builder.SetInsertPoint(SI);
906 if (
SI->getNumCases() == 0) {
907 Builder.CreateBr(
SI->getDefaultDest());
911 SI->case_begin()->getCaseValue());
912 Builder.CreateCondBr(Condition,
SI->case_begin()->getCaseSuccessor(),
913 SI->getDefaultDest());
915 SI->eraseFromParent();
924 bool splitSwitchCases(Function &
F) {
927 for (BasicBlock &BB :
F) {
933 Seen.insert(
SI->getDefaultDest());
935 auto It =
SI->case_begin();
936 while (It !=
SI->case_end()) {
938 if (Seen.count(Target) == 0) {
948 Builder.CreateBr(Target);
949 SI->addCase(It->getCaseValue(), NewTarget);
950 It =
SI->removeCase(It);
959 bool removeUselessBlocks(Function &
F) {
965 for (BasicBlock &BB :
F) {
972 if (MergeBlocks.count(&BB) != 0 || ContinueBlocks.count(&BB) != 0)
981 for (BasicBlock *Predecessor : Predecessors)
992 bool addHeaderToRemainingDivergentDAG(Function &
F) {
1004 for (BasicBlock &BB :
F) {
1005 if (HeaderBlocks.count(&BB) != 0)
1010 size_t CandidateEdges = 0;
1012 if (MergeBlocks.count(
Successor) != 0 ||
1017 CandidateEdges += 1;
1020 if (CandidateEdges <= 1)
1026 bool HasBadBlock =
false;
1027 visit(*Header, [&](
const BasicBlock *Node) {
1032 if (Node == Header || Node ==
Merge)
1035 HasBadBlock |= MergeBlocks.count(Node) != 0 ||
1036 ContinueBlocks.count(Node) != 0 ||
1037 HeaderBlocks.count(Node) != 0;
1038 return !HasBadBlock;
1046 if (
Merge ==
nullptr) {
1049 Builder.SetInsertPoint(Header->getTerminator());
1052 createOpSelectMerge(&Builder, MergeAddress);
1058 SplitInstruction = SplitInstruction->
getPrevNode();
1060 Merge->splitBasicBlockBefore(SplitInstruction,
"new.merge");
1063 Builder.SetInsertPoint(Header->getTerminator());
1066 createOpSelectMerge(&Builder, MergeAddress);
1075 SPIRVStructurizer() : FunctionPass(ID) {}
1095 Modified |= addMergeForNodesWithMultiplePredecessors(
F);
1100 Modified |= sortSelectionMergeHeaders(
F);
1105 Modified |= splitBlocksWithMultipleHeaders(
F);
1111 Modified |= addMergeForDivergentBlocks(
F);
1135 Modified |= addHeaderToRemainingDivergentDAG(
F);
1143 void getAnalysisUsage(AnalysisUsage &AU)
const override {
1146 AU.
addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
1148 AU.
addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>();
1149 FunctionPass::getAnalysisUsage(AU);
1152 void createOpSelectMerge(
IRBuilder<> *Builder, BlockAddress *MergeAddress) {
1155 MDNode *MDNode = BBTerminatorInst->
getMetadata(
"hlsl.controlflow.hint");
1157 ConstantInt *BranchHint = ConstantInt::get(Builder->
getInt32Ty(), 0);
1161 "invalid metadata hlsl.controlflow.hint");
1165 SmallVector<Value *, 2>
Args = {MergeAddress, BranchHint};
1173char SPIRVStructurizer::ID = 0;
1176 "structurize SPIRV",
false,
false)
1186 return new SPIRVStructurizer();
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
ReachingDefInfo InstSet & ToRemove
This file defines the DenseMap class.
static bool runOnFunction(Function &F, bool PostInlining)
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
static bool splitCriticalEdges(CallBrInst *CBR, DominatorTree *DT)
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
static BasicBlock * getDesignatedMergeBlock(Instruction *I)
static void visit(BasicBlock &Start, std::function< bool(BasicBlock *)> op)
static std::vector< Instruction * > getMergeInstructions(BasicBlock &BB)
static BasicBlock * getDesignatedContinueBlock(Instruction *I)
std::unordered_set< BasicBlock * > BlockSet
static const ConvergenceRegion * getRegionForHeader(const ConvergenceRegion *Node, BasicBlock *BB)
static bool hasLoopMergeInstruction(BasicBlock &BB)
static SmallPtrSet< BasicBlock *, 2 > getContinueBlocks(Function &F)
static SmallPtrSet< BasicBlock *, 2 > getMergeBlocks(Function &F)
static SmallPtrSet< BasicBlock *, 2 > getHeaderBlocks(Function &F)
static bool isDefinedAsSelectionMergeBy(BasicBlock &Header, BasicBlock &Merge)
static void replaceBranchTargets(BasicBlock *BB, BasicBlock *OldTarget, BasicBlock *NewTarget)
static void partialOrderVisit(BasicBlock &Start, std::function< bool(BasicBlock *)> Op)
static bool isMergeInstruction(Instruction *I)
static BasicBlock * getExitFor(const ConvergenceRegion *CR)
static void replaceIfBranchTargets(BasicBlock *BB, BasicBlock *OldTarget, BasicBlock *NewTarget)
This file defines the SmallPtrSet class.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
LLVM Basic Block Representation.
const Function * getParent() const
Return the enclosing method, or null if none.
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
LLVM_ABI const BasicBlock * getUniqueSuccessor() const
Return the successor of this block if it has a unique successor.
LLVM_ABI SymbolTableList< BasicBlock >::iterator eraseFromParent()
Unlink 'this' from the containing function and delete it.
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction; assumes that the block is well-formed.
The address of a basic block.
BasicBlock * getBasicBlock() const
static LLVM_ABI BlockAddress * get(Function *F, BasicBlock *BB)
Return a BlockAddress for the specified function and basic block.
Represents analyses that only rely on functions' control flow.
This is an important base class in LLVM.
LLVM_ABI bool isConstantUsed() const
Return true if the constant has users other than constant expressions and other dangling things.
LLVM_ABI void destroyConstant()
Called if some element of this constant is no longer valid.
bool dominates(const DomTreeNodeBase< NodeT > *A, const DomTreeNodeBase< NodeT > *B) const
dominates - Returns true iff A dominates B.
void recalculate(ParentType &Func)
recalculate - compute a dominator tree for the given function
DomTreeNodeBase< NodeT > * getNode(const NodeT *BB) const
getNode - return the (Post)DominatorTree node for the specified basic block.
Legacy analysis pass which computes a DominatorTree.
FunctionPass class - This class is used to implement most global optimizations.
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > OverloadTypes, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="", ArrayRef< OperandBundleDef > OpBundles={})
Create a call to intrinsic ID with Args, mangled using OverloadTypes.
IntegerType * getInt32Ty()
Fetch the type representing a 32-bit integer.
BasicBlock * GetInsertBlock() const
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
MDNode * getMetadata(unsigned KindID) const
Get the metadata of given kind attached to this Instruction.
A wrapper class for inspecting calls to intrinsic functions.
bool isLoopHeader(const BlockT *BB) const
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
const MDOperand & getOperand(unsigned I) const
unsigned getNumOperands() const
Return number of MDNode operands.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
PreservedAnalyses & preserveSet()
Mark an analysis set as preserved.
PreservedAnalyses run(Function &M, FunctionAnalysisManager &AM)
SmallPtrSet< BasicBlock *, 2 > Exits
SmallPtrSet< BasicBlock *, 8 > Blocks
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
LLVM_ABI void addCase(ConstantInt *OnVal, BasicBlock *Dest)
Add an entry to the switch instruction.
BasicBlock * getSuccessor(unsigned i=0) const
Type * getType() const
All values are typed, get the type of this value.
self_iterator getIterator()
FunctionPassManager manages FunctionPasses.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
@ C
The default llvm calling convention, compatible with C.
PostDomTreeBase< BasicBlock > BBPostDomTree
DomTreeBase< BasicBlock > BBDomTree
@ BasicBlock
Various leaf nodes.
std::enable_if_t< detail::IsValidPointer< X, Y >::value, X * > extract(Y &&MD)
Extract a Value from Metadata.
NodeAddr< NodeBase * > Node
friend class Instruction
Iterator for Instructions in a `BasicBlock.
LLVM_ABI iterator begin() const
This is an optimization pass for GlobalISel generic memory operations.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
FunctionAddr VTableAddr Value
FunctionPass * createSPIRVStructurizerPass()
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
auto successors(const MachineBasicBlock *BB)
bool sortBlocks(Function &F)
auto pred_size(const MachineBasicBlock *BB)
AllocaInst * createVariable(Function &F, Type *Type)
auto dyn_cast_or_null(const Y &Val)
SmallVector< unsigned, 1 > getSpirvLoopControlOperandsFromLoopMetadata(MDNode *LoopMD)
auto succ_size(const MachineBasicBlock *BB)
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
DWARFExpression::Operation Op
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
auto predecessors(const MachineBasicBlock *BB)
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.