24#include "llvm/IR/IntrinsicsSPIRV.h"
32#include <unordered_set>
37using BlockSet = std::unordered_set<BasicBlock *>;
38using Edge = std::pair<BasicBlock *, BasicBlock *>;
45 V.partialOrderVisit(Start, std::move(
Op));
52 if (
Node->Entry == BB)
55 for (
auto *Child :
Node->Children) {
66 std::unordered_set<BasicBlock *> ExitTargets;
74 assert(ExitTargets.size() <= 1);
75 if (ExitTargets.size() == 0)
78 return *ExitTargets.begin();
88 if (
II->getIntrinsicID() != Intrinsic::spv_loop_merge &&
89 II->getIntrinsicID() != Intrinsic::spv_selection_merge)
103 if (
II->getIntrinsicID() != Intrinsic::spv_loop_merge)
113 for (
auto &
I : Header) {
166 std::vector<Instruction *> Output;
169 Output.push_back(&
I);
190 std::stack<BasicBlock *> ToVisit;
193 ToVisit.push(&Start);
194 Seen.
insert(ToVisit.top());
195 while (ToVisit.size() != 0) {
219 for (
size_t i = 0; i < BI->getNumSuccessors(); i++) {
220 if (BI->getSuccessor(i) == OldTarget)
221 BI->setSuccessor(i, NewTarget);
225 if (BI->getSuccessor(0) != BI->getSuccessor(1))
231 Builder.SetInsertPoint(BI);
232 Builder.CreateBr(BI->getSuccessor(0));
233 BI->eraseFromParent();
244 if (!
II ||
II->getIntrinsicID() != Intrinsic::spv_selection_merge)
248 II->eraseFromParent();
249 if (!
C->isConstantUsed())
250 C->destroyConstant();
263 if (BI->getSuccessor() == OldTarget)
264 BI->setSuccessor(NewTarget);
272 for (
size_t i = 0; i <
SI->getNumSuccessors(); i++) {
273 if (
SI->getSuccessor(i) == OldTarget)
274 SI->setSuccessor(i, NewTarget);
279 assert(
false &&
"Unhandled terminator type.");
286 struct DivergentConstruct;
290 using ConstructList = std::vector<std::unique_ptr<DivergentConstruct>>;
296 struct DivergentConstruct {
301 DivergentConstruct *Parent =
nullptr;
302 ConstructList Children;
315 Splitter(Function &F, LoopInfo &LI) : F(F), LI(LI) { invalidate(); }
324 std::vector<BasicBlock *> getLoopConstructBlocks(BasicBlock *Header,
327 std::vector<BasicBlock *> Output;
331 if (DT.dominates(
Merge, BB) || !DT.dominates(Header, BB))
333 Output.push_back(BB);
340 std::vector<BasicBlock *>
341 getSelectionConstructBlocks(DivergentConstruct *Node) {
344 OutsideBlocks.insert(
Node->Merge);
346 for (DivergentConstruct *It =
Node->Parent; It !=
nullptr;
348 OutsideBlocks.insert(It->Merge);
350 OutsideBlocks.insert(It->Continue);
353 std::vector<BasicBlock *> Output;
355 if (OutsideBlocks.count(BB) != 0)
357 if (DT.dominates(Node->Merge, BB) || !DT.dominates(Node->Header, BB))
359 Output.push_back(BB);
366 std::vector<BasicBlock *> getSwitchConstructBlocks(BasicBlock *Header,
370 std::vector<BasicBlock *> Output;
373 if (!DT.dominates(Header, BB))
379 Output.push_back(BB);
386 std::vector<BasicBlock *> getCaseConstructBlocks(BasicBlock *Target,
390 std::vector<BasicBlock *> Output;
394 if (!DT.dominates(Target, BB))
400 Output.push_back(BB);
429 createAliasBlocksForComplexEdges(std::vector<Edge> Edges) {
430 std::unordered_set<BasicBlock *> Seen;
431 std::vector<Edge> Output;
432 Output.reserve(Edges.size());
434 for (
auto &[Src, Dst] : Edges) {
435 auto [Iterator,
Inserted] = Seen.insert(Src);
440 F.getContext(), Src->getName() +
".new.src", &F);
443 Builder.CreateBr(Dst);
447 Output.emplace_back(Src, Dst);
453 AllocaInst *CreateVariable(Function &F,
Type *
Type,
455 const DataLayout &
DL = F.getDataLayout();
456 return new AllocaInst(
Type,
DL.getAllocaAddrSpace(),
nullptr,
"reg",
462 BasicBlock *createSingleExitNode(BasicBlock *Header,
463 std::vector<Edge> &Edges) {
465 std::vector<Edge> FixedEdges = createAliasBlocksForComplexEdges(Edges);
467 std::vector<BasicBlock *> Dsts;
468 std::unordered_map<BasicBlock *, ConstantInt *> DstToIndex;
470 Header->getName() +
".new.exit", &F);
472 for (
auto &[Src, Dst] : FixedEdges) {
473 if (DstToIndex.count(Dst) != 0)
475 DstToIndex.emplace(Dst, ExitBuilder.getInt32(DstToIndex.size()));
479 if (Dsts.size() == 1) {
480 for (
auto &[Src, Dst] : FixedEdges) {
483 ExitBuilder.CreateBr(Dsts[0]);
487 AllocaInst *
Variable = CreateVariable(F, ExitBuilder.getInt32Ty(),
488 F.begin()->getFirstInsertionPt());
489 for (
auto &[Src, Dst] : FixedEdges) {
491 B2.SetInsertPoint(Src->getFirstInsertionPt());
492 B2.CreateStore(DstToIndex[Dst], Variable);
496 Value *
Load = ExitBuilder.CreateLoad(ExitBuilder.getInt32Ty(), Variable);
501 if (Dsts.size() == 2) {
504 ExitBuilder.CreateCondBr(Condition, Dsts[0], Dsts[1]);
508 SwitchInst *Sw = ExitBuilder.CreateSwitch(Load, Dsts[0], Dsts.size() - 1);
510 Sw->
addCase(DstToIndex[BB], BB);
517 Value *createExitVariable(
519 const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) {
524 return TargetToValue.
lookup(BI->getSuccessor());
527 Builder.SetInsertPoint(
T);
533 if (
LHS ==
nullptr ||
RHS ==
nullptr)
535 return Builder.CreateSelect(BI->getCondition(),
LHS,
RHS);
546 Builder.CreateUnreachable();
551 bool addMergeForLoops(Function &
F) {
552 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
553 auto *TopLevelRegion =
554 getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
556 .getTopLevelRegion();
580 if (
Merge ==
nullptr) {
582 Merge = CreateUnreachable(
F);
583 Builder.SetInsertPoint(Br);
593 SmallVector<Value *, 2>
Args = {MergeAddress, ContinueAddress};
596 for (
unsigned Imm : LoopControlImms)
597 Args.emplace_back(ConstantInt::get(Builder.getInt32Ty(), Imm));
598 Builder.CreateIntrinsic(Intrinsic::spv_loop_merge, {
Args});
608 bool addMergeForNodesWithMultiplePredecessors(Function &
F) {
627 Builder.SetInsertPoint(Header->getTerminator());
630 createOpSelectMerge(&Builder, MergeAddress);
643 bool sortSelectionMerge(Function &
F, BasicBlock &
Block) {
644 std::vector<Instruction *> MergeInstructions;
645 for (Instruction &
I :
Block)
647 MergeInstructions.push_back(&
I);
649 if (MergeInstructions.size() <= 1)
652 Instruction *InsertionPoint = *MergeInstructions.begin();
654 PartialOrderingVisitor Visitor(
F);
655 std::sort(MergeInstructions.begin(), MergeInstructions.end(),
656 [&Visitor](Instruction *
Left, Instruction *
Right) {
659 BasicBlock *RightMerge = getDesignatedMergeBlock(Right);
660 BasicBlock *LeftMerge = getDesignatedMergeBlock(Left);
661 return !Visitor.compare(RightMerge, LeftMerge);
664 for (Instruction *
I : MergeInstructions) {
675 bool sortSelectionMergeHeaders(Function &
F) {
677 for (BasicBlock &BB :
F) {
685 bool splitBlocksWithMultipleHeaders(Function &
F) {
686 std::stack<BasicBlock *> Work;
689 if (MergeInstructions.size() <= 1)
694 const bool Modified = Work.size() > 0;
695 while (Work.size() > 0) {
699 std::vector<Instruction *> MergeInstructions =
701 for (
unsigned i = 1; i < MergeInstructions.size(); i++) {
703 Header->splitBasicBlock(MergeInstructions[i],
"new.header");
710 Builder.SetInsertPoint(Term);
711 Builder.CreateCondBr(Builder.getTrue(), NewBlock, Unreachable);
712 Term->eraseFromParent();
724 bool addMergeForDivergentBlocks(Function &
F) {
736 std::vector<BasicBlock *> Candidates;
745 if (Candidates.size() <= 1)
754 createOpSelectMerge(&Builder, MergeAddress);
762 std::vector<Edge> getExitsFrom(
const BlockSet &Construct,
763 BasicBlock &Header) {
764 std::vector<Edge> Output;
765 visit(Header, [&](BasicBlock *Item) {
766 if (Construct.count(Item) == 0)
781 void constructDivergentConstruct(
BlockSet &Visited, Splitter &S,
782 BasicBlock *BB, DivergentConstruct *Parent) {
783 if (Visited.count(BB) != 0)
788 if (MIS.size() == 0) {
790 constructDivergentConstruct(Visited, S,
Successor, Parent);
800 auto Output = std::make_unique<DivergentConstruct>();
802 Output->Merge =
Merge;
804 Output->Parent = Parent;
806 constructDivergentConstruct(Visited, S,
Merge, Parent);
808 constructDivergentConstruct(Visited, S,
Continue, Output.get());
811 constructDivergentConstruct(Visited, S,
Successor, Output.get());
814 Parent->Children.emplace_back(std::move(Output));
818 BlockSet getConstructBlocks(Splitter &S, DivergentConstruct *Node) {
821 if (
Node->Continue) {
822 auto LoopBlocks = S.getLoopConstructBlocks(
Node->Header,
Node->Merge);
823 return BlockSet(LoopBlocks.begin(), LoopBlocks.end());
826 auto SelectionBlocks = S.getSelectionConstructBlocks(Node);
827 return BlockSet(SelectionBlocks.begin(), SelectionBlocks.end());
832 bool fixupConstruct(Splitter &S, DivergentConstruct *Node) {
834 for (
auto &Child :
Node->Children)
835 Modified |= fixupConstruct(S, Child.get());
839 if (
Node->Parent ==
nullptr)
845 if (
Node->Parent->Header ==
nullptr)
852 BlockSet ConstructBlocks = getConstructBlocks(S, Node);
853 auto Edges = getExitsFrom(ConstructBlocks, *
Node->Header);
856 if (Edges.size() < 1)
859 bool HasBadEdge =
Node->Merge ==
Node->Parent->Merge ||
860 Node->Merge ==
Node->Parent->Continue;
862 for (
auto &[Src, Dst] : Edges) {
869 if (
Node->Merge == Dst)
874 if (
Node->Continue == Dst)
886 BasicBlock *NewExit = S.createSingleExitNode(
Node->Header, Edges);
893 assert(MergeInstructions.size() == 1);
898 I->setOperand(0, MergeAddress);
906 Node->Merge = NewExit;
912 bool splitCriticalEdges(Function &
F) {
913 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
916 DivergentConstruct Root;
918 constructDivergentConstruct(Visited, S, &*
F.begin(), &Root);
919 return fixupConstruct(S, &Root);
927 bool simplifyBranches(Function &
F) {
930 for (BasicBlock &BB :
F) {
934 if (
SI->getNumCases() > 1)
939 Builder.SetInsertPoint(SI);
941 if (
SI->getNumCases() == 0) {
942 Builder.CreateBr(
SI->getDefaultDest());
946 SI->case_begin()->getCaseValue());
947 Builder.CreateCondBr(Condition,
SI->case_begin()->getCaseSuccessor(),
948 SI->getDefaultDest());
950 SI->eraseFromParent();
959 bool splitSwitchCases(Function &
F) {
962 for (BasicBlock &BB :
F) {
968 Seen.insert(
SI->getDefaultDest());
970 auto It =
SI->case_begin();
971 while (It !=
SI->case_end()) {
973 if (Seen.count(Target) == 0) {
983 Builder.CreateBr(Target);
984 SI->addCase(It->getCaseValue(), NewTarget);
985 It =
SI->removeCase(It);
994 bool removeUselessBlocks(Function &
F) {
1000 for (BasicBlock &BB :
F) {
1007 if (MergeBlocks.count(&BB) != 0 || ContinueBlocks.count(&BB) != 0)
1016 for (BasicBlock *Predecessor : Predecessors)
1027 bool addHeaderToRemainingDivergentDAG(Function &
F) {
1039 for (BasicBlock &BB :
F) {
1040 if (HeaderBlocks.count(&BB) != 0)
1045 size_t CandidateEdges = 0;
1047 if (MergeBlocks.count(
Successor) != 0 ||
1052 CandidateEdges += 1;
1055 if (CandidateEdges <= 1)
1061 bool HasBadBlock =
false;
1062 visit(*Header, [&](
const BasicBlock *Node) {
1067 if (Node == Header || Node ==
Merge)
1070 HasBadBlock |= MergeBlocks.count(Node) != 0 ||
1071 ContinueBlocks.count(Node) != 0 ||
1072 HeaderBlocks.count(Node) != 0;
1073 return !HasBadBlock;
1081 if (
Merge ==
nullptr) {
1084 Builder.SetInsertPoint(Header->getTerminator());
1087 createOpSelectMerge(&Builder, MergeAddress);
1093 SplitInstruction = SplitInstruction->
getPrevNode();
1095 Merge->splitBasicBlockBefore(SplitInstruction,
"new.merge");
1098 Builder.SetInsertPoint(Header->getTerminator());
1101 createOpSelectMerge(&Builder, MergeAddress);
1110 SPIRVStructurizer() : FunctionPass(ID) {}
1130 Modified |= addMergeForNodesWithMultiplePredecessors(
F);
1135 Modified |= sortSelectionMergeHeaders(
F);
1140 Modified |= splitBlocksWithMultipleHeaders(
F);
1146 Modified |= addMergeForDivergentBlocks(
F);
1170 Modified |= addHeaderToRemainingDivergentDAG(
F);
1178 void getAnalysisUsage(AnalysisUsage &AU)
const override {
1181 AU.
addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
1183 AU.
addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>();
1184 FunctionPass::getAnalysisUsage(AU);
1187 void createOpSelectMerge(
IRBuilder<> *Builder, BlockAddress *MergeAddress) {
1190 MDNode *MDNode = BBTerminatorInst->
getMetadata(
"hlsl.controlflow.hint");
1192 ConstantInt *BranchHint = ConstantInt::get(Builder->
getInt32Ty(), 0);
1196 "invalid metadata hlsl.controlflow.hint");
1200 SmallVector<Value *, 2>
Args = {MergeAddress, BranchHint};
1208char SPIRVStructurizer::ID = 0;
1211 "structurize SPIRV",
false,
false)
1221 return new SPIRVStructurizer();
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
ReachingDefInfo InstSet & ToRemove
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
This file defines the DenseMap class.
static bool runOnFunction(Function &F, bool PostInlining)
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
static BasicBlock * getDesignatedMergeBlock(Instruction *I)
static void visit(BasicBlock &Start, std::function< bool(BasicBlock *)> op)
static std::vector< Instruction * > getMergeInstructions(BasicBlock &BB)
static BasicBlock * getDesignatedContinueBlock(Instruction *I)
std::unordered_set< BasicBlock * > BlockSet
static const ConvergenceRegion * getRegionForHeader(const ConvergenceRegion *Node, BasicBlock *BB)
static bool hasLoopMergeInstruction(BasicBlock &BB)
static SmallPtrSet< BasicBlock *, 2 > getContinueBlocks(Function &F)
static SmallPtrSet< BasicBlock *, 2 > getMergeBlocks(Function &F)
static SmallPtrSet< BasicBlock *, 2 > getHeaderBlocks(Function &F)
static bool isDefinedAsSelectionMergeBy(BasicBlock &Header, BasicBlock &Merge)
static void replaceBranchTargets(BasicBlock *BB, BasicBlock *OldTarget, BasicBlock *NewTarget)
static void partialOrderVisit(BasicBlock &Start, std::function< bool(BasicBlock *)> Op)
static bool isMergeInstruction(Instruction *I)
static BasicBlock * getExitFor(const ConvergenceRegion *CR)
static void replaceIfBranchTargets(BasicBlock *BB, BasicBlock *OldTarget, BasicBlock *NewTarget)
This file defines the SmallPtrSet class.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
LLVM Basic Block Representation.
const Function * getParent() const
Return the enclosing method, or null if none.
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
LLVM_ABI const BasicBlock * getUniqueSuccessor() const
Return the successor of this block if it has a unique successor.
LLVM_ABI SymbolTableList< BasicBlock >::iterator eraseFromParent()
Unlink 'this' from the containing function and delete it.
InstListType::iterator iterator
Instruction iterators...
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
The address of a basic block.
BasicBlock * getBasicBlock() const
static LLVM_ABI BlockAddress * get(Function *F, BasicBlock *BB)
Return a BlockAddress for the specified function and basic block.
Represents analyses that only rely on functions' control flow.
This is an important base class in LLVM.
LLVM_ABI bool isConstantUsed() const
Return true if the constant has users other than constant expressions and other dangling things.
LLVM_ABI void destroyConstant()
Called if some element of this constant is no longer valid.
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
bool dominates(const DomTreeNodeBase< NodeT > *A, const DomTreeNodeBase< NodeT > *B) const
dominates - Returns true iff A dominates B.
void recalculate(ParentType &Func)
recalculate - compute a dominator tree for the given function
DomTreeNodeBase< NodeT > * getNode(const NodeT *BB) const
getNode - return the (Post)DominatorTree node for the specified basic block.
Legacy analysis pass which computes a DominatorTree.
FunctionPass class - This class is used to implement most global optimizations.
IntegerType * getInt32Ty()
Fetch the type representing a 32-bit integer.
BasicBlock * GetInsertBlock() const
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
MDNode * getMetadata(unsigned KindID) const
Get the metadata of given kind attached to this Instruction.
A wrapper class for inspecting calls to intrinsic functions.
bool isLoopHeader(const BlockT *BB) const
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
const MDOperand & getOperand(unsigned I) const
unsigned getNumOperands() const
Return number of MDNode operands.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
PreservedAnalyses & preserveSet()
Mark an analysis set as preserved.
PreservedAnalyses run(Function &M, FunctionAnalysisManager &AM)
SmallPtrSet< BasicBlock *, 2 > Exits
SmallPtrSet< BasicBlock *, 8 > Blocks
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
LLVM_ABI void addCase(ConstantInt *OnVal, BasicBlock *Dest)
Add an entry to the switch instruction.
BasicBlock * getSuccessor(unsigned i=0) const
Type * getType() const
All values are typed, get the type of this value.
self_iterator getIterator()
FunctionPassManager manages FunctionPasses.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
@ C
The default llvm calling convention, compatible with C.
PostDomTreeBase< BasicBlock > BBPostDomTree
DomTreeBase< BasicBlock > BBDomTree
@ BasicBlock
Various leaf nodes.
std::enable_if_t< detail::IsValidPointer< X, Y >::value, X * > extract(Y &&MD)
Extract a Value from Metadata.
NodeAddr< NodeBase * > Node
friend class Instruction
Iterator for Instructions in a `BasicBlock.
LLVM_ABI iterator begin() const
This is an optimization pass for GlobalISel generic memory operations.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
FunctionAddr VTableAddr Value
FunctionPass * createSPIRVStructurizerPass()
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
auto successors(const MachineBasicBlock *BB)
bool sortBlocks(Function &F)
auto pred_size(const MachineBasicBlock *BB)
auto dyn_cast_or_null(const Y &Val)
SmallVector< unsigned, 1 > getSpirvLoopControlOperandsFromLoopMetadata(MDNode *LoopMD)
auto succ_size(const MachineBasicBlock *BB)
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
DWARFExpression::Operation Op
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
auto predecessors(const MachineBasicBlock *BB)
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.