25#include "llvm/IR/IntrinsicsSPIRV.h"
33#include <unordered_set>
38using BlockSet = std::unordered_set<BasicBlock *>;
39using Edge = std::pair<BasicBlock *, BasicBlock *>;
46 V.partialOrderVisit(Start,
Op);
53 if (
Node->Entry == BB)
56 for (
auto *Child :
Node->Children) {
67 std::unordered_set<BasicBlock *> ExitTargets;
75 assert(ExitTargets.size() <= 1);
76 if (ExitTargets.size() == 0)
79 return *ExitTargets.begin();
89 if (
II->getIntrinsicID() != Intrinsic::spv_loop_merge &&
90 II->getIntrinsicID() != Intrinsic::spv_selection_merge)
104 if (
II->getIntrinsicID() != Intrinsic::spv_loop_merge)
114 for (
auto &
I : Header) {
167 std::vector<Instruction *> Output;
170 Output.push_back(&
I);
191 std::stack<BasicBlock *> ToVisit;
194 ToVisit.push(&Start);
195 Seen.
insert(ToVisit.top());
196 while (ToVisit.size() != 0) {
220 for (
size_t i = 0; i < BI->getNumSuccessors(); i++) {
221 if (BI->getSuccessor(i) == OldTarget)
222 BI->setSuccessor(i, NewTarget);
226 if (BI->isUnconditional())
230 if (BI->getSuccessor(0) != BI->getSuccessor(1))
237 Builder.
CreateBr(BI->getSuccessor(0));
249 if (!
II ||
II->getIntrinsicID() != Intrinsic::spv_selection_merge)
253 II->eraseFromParent();
254 if (!
C->isConstantUsed())
255 C->destroyConstant();
265 if (isa<ReturnInst>(
T))
268 if (isa<BranchInst>(
T))
271 if (
auto *SI = dyn_cast<SwitchInst>(
T)) {
272 for (
size_t i = 0; i < SI->getNumSuccessors(); i++) {
273 if (SI->getSuccessor(i) == OldTarget)
274 SI->setSuccessor(i, NewTarget);
279 assert(
false &&
"Unhandled terminator type.");
286 struct DivergentConstruct;
290 using ConstructList = std::vector<std::unique_ptr<DivergentConstruct>>;
296 struct DivergentConstruct {
301 DivergentConstruct *Parent =
nullptr;
324 std::vector<BasicBlock *> getLoopConstructBlocks(
BasicBlock *Header,
327 std::vector<BasicBlock *> Output;
333 Output.push_back(BB);
340 std::vector<BasicBlock *>
341 getSelectionConstructBlocks(DivergentConstruct *
Node) {
344 OutsideBlocks.insert(
Node->Merge);
346 for (DivergentConstruct *It =
Node->Parent; It !=
nullptr;
348 OutsideBlocks.insert(It->Merge);
350 OutsideBlocks.insert(It->Continue);
353 std::vector<BasicBlock *> Output;
355 if (OutsideBlocks.count(BB) != 0)
357 if (DT.dominates(Node->Merge, BB) || !DT.dominates(Node->Header, BB))
359 Output.push_back(BB);
366 std::vector<BasicBlock *> getSwitchConstructBlocks(
BasicBlock *Header,
370 std::vector<BasicBlock *> Output;
379 Output.push_back(BB);
390 std::vector<BasicBlock *> Output;
400 Output.push_back(BB);
429 createAliasBlocksForComplexEdges(std::vector<Edge> Edges) {
430 std::unordered_set<BasicBlock *> Seen;
431 std::vector<Edge> Output;
432 Output.reserve(Edges.size());
434 for (
auto &[Src, Dst] : Edges) {
435 auto [Iterator,
Inserted] = Seen.insert(Src);
440 F.getContext(), Src->getName() +
".new.src", &F);
443 Builder.CreateBr(Dst);
447 Output.emplace_back(Src, Dst);
463 std::vector<Edge> &Edges) {
465 std::vector<Edge> FixedEdges = createAliasBlocksForComplexEdges(Edges);
467 std::vector<BasicBlock *> Dsts;
468 std::unordered_map<BasicBlock *, ConstantInt *> DstToIndex;
470 Header->getName() +
".new.exit", &F);
472 for (
auto &[Src, Dst] : FixedEdges) {
473 if (DstToIndex.count(Dst) != 0)
475 DstToIndex.emplace(Dst, ExitBuilder.getInt32(DstToIndex.size()));
479 if (Dsts.size() == 1) {
480 for (
auto &[Src, Dst] : FixedEdges) {
483 ExitBuilder.CreateBr(Dsts[0]);
487 AllocaInst *Variable = CreateVariable(F, ExitBuilder.getInt32Ty(),
488 F.begin()->getFirstInsertionPt());
489 for (
auto &[Src, Dst] : FixedEdges) {
491 B2.SetInsertPoint(Src->getFirstInsertionPt());
492 B2.CreateStore(DstToIndex[Dst], Variable);
496 Value *
Load = ExitBuilder.CreateLoad(ExitBuilder.getInt32Ty(), Variable);
501 if (Dsts.size() == 2) {
504 ExitBuilder.CreateCondBr(Condition, Dsts[0], Dsts[1]);
508 SwitchInst *Sw = ExitBuilder.CreateSwitch(Load, Dsts[0], Dsts.size() - 1);
510 Sw->
addCase(DstToIndex[BB], BB);
517 Value *createExitVariable(
521 if (isa<ReturnInst>(
T))
525 Builder.SetInsertPoint(
T);
527 if (
auto *BI = dyn_cast<BranchInst>(
T)) {
531 BI->isConditional() ? BI->getSuccessor(1) :
nullptr;
536 if (LHS ==
nullptr || RHS ==
nullptr)
538 return Builder.CreateSelect(BI->getCondition(), LHS, RHS);
549 Builder.CreateUnreachable();
555 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
556 auto *TopLevelRegion =
557 getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
559 .getTopLevelRegion();
583 if (
Merge ==
nullptr) {
587 Merge = CreateUnreachable(
F);
588 Builder.SetInsertPoint(Br);
601 for (
unsigned Imm : LoopControlImms)
602 Args.emplace_back(ConstantInt::get(Builder.getInt32Ty(), Imm));
603 Builder.CreateIntrinsic(Intrinsic::spv_loop_merge, {
Args});
613 bool addMergeForNodesWithMultiplePredecessors(
Function &
F) {
632 Builder.SetInsertPoint(Header->getTerminator());
635 createOpSelectMerge(&Builder, MergeAddress);
649 std::vector<Instruction *> MergeInstructions;
652 MergeInstructions.push_back(&
I);
654 if (MergeInstructions.size() <= 1)
660 std::sort(MergeInstructions.begin(), MergeInstructions.end(),
664 BasicBlock *RightMerge = getDesignatedMergeBlock(Right);
665 BasicBlock *LeftMerge = getDesignatedMergeBlock(Left);
666 return !Visitor.compare(RightMerge, LeftMerge);
680 bool sortSelectionMergeHeaders(
Function &
F) {
690 bool splitBlocksWithMultipleHeaders(
Function &
F) {
691 std::stack<BasicBlock *> Work;
694 if (MergeInstructions.size() <= 1)
699 const bool Modified = Work.size() > 0;
700 while (Work.size() > 0) {
704 std::vector<Instruction *> MergeInstructions =
706 for (
unsigned i = 1; i < MergeInstructions.size(); i++) {
708 Header->splitBasicBlock(MergeInstructions[i],
"new.header");
713 BranchInst *BI = cast<BranchInst>(Header->getTerminator());
715 Builder.SetInsertPoint(BI);
716 Builder.CreateCondBr(Builder.getTrue(), NewBlock, Unreachable);
729 bool addMergeForDivergentBlocks(
Function &
F) {
741 std::vector<BasicBlock *> Candidates;
750 if (Candidates.size() <= 1)
759 createOpSelectMerge(&Builder, MergeAddress);
767 std::vector<Edge> getExitsFrom(
const BlockSet &Construct,
769 std::vector<Edge> Output;
771 if (Construct.count(Item) == 0)
786 void constructDivergentConstruct(
BlockSet &Visited, Splitter &S,
788 if (Visited.count(BB) != 0)
793 if (MIS.size() == 0) {
795 constructDivergentConstruct(Visited, S,
Successor, Parent);
805 auto Output = std::make_unique<DivergentConstruct>();
807 Output->Merge =
Merge;
809 Output->Parent = Parent;
811 constructDivergentConstruct(Visited, S,
Merge, Parent);
813 constructDivergentConstruct(Visited, S,
Continue, Output.get());
816 constructDivergentConstruct(Visited, S,
Successor, Output.get());
819 Parent->Children.emplace_back(std::move(Output));
823 BlockSet getConstructBlocks(Splitter &S, DivergentConstruct *
Node) {
826 if (
Node->Continue) {
827 auto LoopBlocks = S.getLoopConstructBlocks(
Node->Header,
Node->Merge);
828 return BlockSet(LoopBlocks.begin(), LoopBlocks.end());
831 auto SelectionBlocks = S.getSelectionConstructBlocks(
Node);
832 return BlockSet(SelectionBlocks.begin(), SelectionBlocks.end());
837 bool fixupConstruct(Splitter &S, DivergentConstruct *
Node) {
839 for (
auto &Child :
Node->Children)
840 Modified |= fixupConstruct(S, Child.get());
844 if (
Node->Parent ==
nullptr)
850 if (
Node->Parent->Header ==
nullptr)
858 auto Edges = getExitsFrom(ConstructBlocks, *
Node->Header);
861 if (Edges.size() < 1)
864 bool HasBadEdge =
Node->Merge ==
Node->Parent->Merge ||
865 Node->Merge ==
Node->Parent->Continue;
867 for (
auto &[Src, Dst] : Edges) {
874 if (
Node->Merge == Dst)
879 if (
Node->Continue == Dst)
891 BasicBlock *NewExit = S.createSingleExitNode(
Node->Header, Edges);
898 assert(MergeInstructions.size() == 1);
903 I->setOperand(0, MergeAddress);
911 Node->Merge = NewExit;
918 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
921 DivergentConstruct Root;
923 constructDivergentConstruct(Visited, S, &*
F.begin(), &Root);
924 return fixupConstruct(S, &Root);
939 if (
SI->getNumCases() > 1)
944 Builder.SetInsertPoint(SI);
946 if (
SI->getNumCases() == 0) {
947 Builder.CreateBr(
SI->getDefaultDest());
951 SI->case_begin()->getCaseValue());
952 Builder.CreateCondBr(Condition,
SI->case_begin()->getCaseSuccessor(),
953 SI->getDefaultDest());
955 SI->eraseFromParent();
973 Seen.insert(
SI->getDefaultDest());
975 auto It =
SI->case_begin();
976 while (It !=
SI->case_end()) {
978 if (Seen.count(
Target) == 0) {
989 SI->addCase(It->getCaseValue(), NewTarget);
990 It =
SI->removeCase(It);
1000 std::vector<BasicBlock *>
ToRemove;
1012 if (MergeBlocks.count(&BB) != 0 || ContinueBlocks.count(&BB) != 0)
1032 bool addHeaderToRemainingDivergentDAG(
Function &
F) {
1045 if (HeaderBlocks.count(&BB) != 0)
1050 size_t CandidateEdges = 0;
1052 if (MergeBlocks.count(
Successor) != 0 ||
1057 CandidateEdges += 1;
1060 if (CandidateEdges <= 1)
1066 bool HasBadBlock =
false;
1075 HasBadBlock |= MergeBlocks.count(
Node) != 0 ||
1076 ContinueBlocks.count(
Node) != 0 ||
1077 HeaderBlocks.count(
Node) != 0;
1078 return !HasBadBlock;
1086 if (
Merge ==
nullptr) {
1089 Builder.SetInsertPoint(Header->getTerminator());
1092 createOpSelectMerge(&Builder, MergeAddress);
1098 SplitInstruction = SplitInstruction->
getPrevNode();
1100 Merge->splitBasicBlockBefore(SplitInstruction,
"new.merge");
1103 Builder.SetInsertPoint(Header->getTerminator());
1106 createOpSelectMerge(&Builder, MergeAddress);
1135 Modified |= addMergeForNodesWithMultiplePredecessors(
F);
1140 Modified |= sortSelectionMergeHeaders(
F);
1145 Modified |= splitBlocksWithMultipleHeaders(
F);
1151 Modified |= addMergeForDivergentBlocks(
F);
1175 Modified |= addHeaderToRemainingDivergentDAG(
F);
1201 "invalid metadata hlsl.controlflow.hint");
1213char SPIRVStructurizer::ID = 0;
1216 "structurize SPIRV",
false,
false)
1226 return new SPIRVStructurizer();
for(const MachineOperand &MO :llvm::drop_begin(OldMI.operands(), Desc.getNumOperands()))
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
ReachingDefAnalysis InstSet & ToRemove
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
This file defines the DenseMap class.
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
static BasicBlock * getDesignatedMergeBlock(Instruction *I)
static void visit(BasicBlock &Start, std::function< bool(BasicBlock *)> op)
static std::vector< Instruction * > getMergeInstructions(BasicBlock &BB)
static BasicBlock * getDesignatedContinueBlock(Instruction *I)
std::unordered_set< BasicBlock * > BlockSet
static const ConvergenceRegion * getRegionForHeader(const ConvergenceRegion *Node, BasicBlock *BB)
static bool hasLoopMergeInstruction(BasicBlock &BB)
std::pair< BasicBlock *, BasicBlock * > Edge
static SmallPtrSet< BasicBlock *, 2 > getContinueBlocks(Function &F)
static SmallPtrSet< BasicBlock *, 2 > getMergeBlocks(Function &F)
static SmallPtrSet< BasicBlock *, 2 > getHeaderBlocks(Function &F)
static bool isDefinedAsSelectionMergeBy(BasicBlock &Header, BasicBlock &Merge)
static void replaceBranchTargets(BasicBlock *BB, BasicBlock *OldTarget, BasicBlock *NewTarget)
static void partialOrderVisit(BasicBlock &Start, std::function< bool(BasicBlock *)> Op)
static bool isMergeInstruction(Instruction *I)
static BasicBlock * getExitFor(const ConvergenceRegion *CR)
static void replaceIfBranchTargets(BasicBlock *BB, BasicBlock *OldTarget, BasicBlock *NewTarget)
This file defines the SmallPtrSet class.
an instruction to allocate memory on the stack
A container for analyses that lazily runs them and caches their results.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
LLVM Basic Block Representation.
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
LLVM_ABI const BasicBlock * getUniqueSuccessor() const
Return the successor of this block if it has a unique successor.
const Function * getParent() const
Return the enclosing method, or null if none.
LLVM_ABI SymbolTableList< BasicBlock >::iterator eraseFromParent()
Unlink 'this' from the containing function and delete it.
InstListType::iterator iterator
Instruction iterators...
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
The address of a basic block.
BasicBlock * getBasicBlock() const
static LLVM_ABI BlockAddress * get(Function *F, BasicBlock *BB)
Return a BlockAddress for the specified function and basic block.
Conditional or Unconditional Branch instruction.
BasicBlock * getSuccessor(unsigned i) const
bool isUnconditional() const
Represents analyses that only rely on functions' control flow.
This is the shared class of boolean and integer constants.
This is an important base class in LLVM.
LLVM_ABI bool isConstantUsed() const
Return true if the constant has users other than constant expressions and other dangling things.
LLVM_ABI void destroyConstant()
Called if some element of this constant is no longer valid.
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
DomTreeNodeBase * getIDom() const
Core dominator tree base class.
bool dominates(const DomTreeNodeBase< NodeT > *A, const DomTreeNodeBase< NodeT > *B) const
dominates - Returns true iff A dominates B.
void recalculate(ParentType &Func)
recalculate - compute a dominator tree for the given function
DomTreeNodeBase< NodeT > * getNode(const NodeT *BB) const
getNode - return the (Post)DominatorTree node for the specified basic block.
Legacy analysis pass which computes a DominatorTree.
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
IntegerType * getInt32Ty()
Fetch the type representing a 32-bit integer.
BasicBlock * GetInsertBlock() const
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
BranchInst * CreateBr(BasicBlock *Dest)
Create an unconditional 'br label X' instruction.
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
MDNode * getMetadata(unsigned KindID) const
Get the metadata of given kind attached to this Instruction.
A wrapper class for inspecting calls to intrinsic functions.
bool isLoopHeader(const BlockT *BB) const
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
const MDOperand & getOperand(unsigned I) const
unsigned getNumOperands() const
Return number of MDNode operands.
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
PreservedAnalyses & preserveSet()
Mark an analysis set as preserved.
PreservedAnalyses run(Function &M, FunctionAnalysisManager &AM)
SmallPtrSet< BasicBlock *, 2 > Exits
SmallPtrSet< BasicBlock *, 8 > Blocks
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
LLVM_ABI void addCase(ConstantInt *OnVal, BasicBlock *Dest)
Add an entry to the switch instruction.
Target - Wrapper for Target specific information.
The instances of the Type class are immutable: once they are created, they are never changed.
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
FunctionPassManager manages FunctionPasses.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ C
The default llvm calling convention, compatible with C.
LLVM_ABI const_iterator begin(StringRef path LLVM_LIFETIME_BOUND, Style style=Style::native)
Get begin iterator over path.
LLVM_ABI const_iterator end(StringRef path LLVM_LIFETIME_BOUND)
Get end iterator over path.
This is an optimization pass for GlobalISel generic memory operations.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
FunctionPass * createSPIRVStructurizerPass()
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
auto successors(const MachineBasicBlock *BB)
bool sortBlocks(Function &F)
auto pred_size(const MachineBasicBlock *BB)
SmallVector< unsigned, 1 > getSpirvLoopControlOperandsFromLoopMetadata(Loop *L)
auto succ_size(const MachineBasicBlock *BB)
auto predecessors(const MachineBasicBlock *BB)