84 #ifdef EXPENSIVE_CHECKS
90 #define DEBUG_TYPE "dfa-jump-threading"
92 STATISTIC(NumTransforms,
"Number of transformations done");
93 STATISTIC(NumCloned,
"Number of blocks cloned");
94 STATISTIC(NumPaths,
"Number of individual paths threaded");
98 cl::desc(
"View the CFG before DFA Jump Threading"),
102 "dfa-max-path-length",
103 cl::desc(
"Max number of blocks searched to find a threading path"),
108 cl::desc(
"Max number of paths enumerated around a switch"),
113 cl::desc(
"Maximum cost accepted for the transformation"),
118 class SelectInstToUnfold {
126 PHINode *getUse() {
return SIUse; }
128 explicit operator bool()
const {
return SI && SIUse; }
132 std::vector<SelectInstToUnfold> *NewSIsToUnfold,
133 std::vector<BasicBlock *> *NewBBs);
135 class DFAJumpThreading {
139 : AC(AC), DT(DT),
TTI(
TTI), ORE(ORE) {}
149 for (SelectInstToUnfold SIToUnfold : SelectInsts)
150 Stack.push_back(SIToUnfold);
152 while (!
Stack.empty()) {
153 SelectInstToUnfold SIToUnfold =
Stack.pop_back_val();
155 std::vector<SelectInstToUnfold> NewSIsToUnfold;
156 std::vector<BasicBlock *> NewBBs;
157 unfold(&DTU, SIToUnfold, &NewSIsToUnfold, &NewBBs);
160 for (
const SelectInstToUnfold &NewSIToUnfold : NewSIsToUnfold)
161 Stack.push_back(NewSIToUnfold);
171 class DFAJumpThreadingLegacyPass :
public FunctionPass {
189 &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(
F);
190 DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
192 &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F);
194 &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
196 return DFAJumpThreading(AC, DT,
TTI, ORE).run(
F);
203 "DFA Jump Threading",
false,
false)
213 return new DFAJumpThreadingLegacyPass();
219 void createBasicBlockAndSinkSelectInst(
222 BranchInst **NewBranch, std::vector<SelectInstToUnfold> *NewSIsToUnfold,
223 std::vector<BasicBlock *> *NewBBs) {
229 NewBBs->push_back(*NewBlock);
232 NewSIsToUnfold->push_back(SelectInstToUnfold(SIToSink, SIUse));
244 std::vector<SelectInstToUnfold> *NewSIsToUnfold,
245 std::vector<BasicBlock *> *NewBBs) {
247 PHINode *SIUse = SIToUnfold.getUse();
264 if (
SelectInst *SIOp = dyn_cast<SelectInst>(
SI->getTrueValue())) {
265 createBasicBlockAndSinkSelectInst(DTU,
SI, SIUse, SIOp, EndBlock,
266 "si.unfold.true", &TrueBlock, &TrueBranch,
267 NewSIsToUnfold, NewBBs);
269 if (
SelectInst *SIOp = dyn_cast<SelectInst>(
SI->getFalseValue())) {
270 createBasicBlockAndSinkSelectInst(DTU,
SI, SIUse, SIOp, EndBlock,
271 "si.unfold.false", &FalseBlock,
272 &FalseBranch, NewSIsToUnfold, NewBBs);
277 if (!TrueBlock && !FalseBlock) {
280 NewBBs->push_back(FalseBlock);
292 if (TrueBlock && FalseBlock) {
305 Phi.addIncoming(Phi.getIncomingValueForBlock(StartBlock), TrueBlock);
306 Phi.addIncoming(Phi.getIncomingValueForBlock(StartBlock), FalseBlock);
311 Value *SIOp1 =
SI->getTrueValue();
312 Value *SIOp2 =
SI->getFalseValue();
316 NewBlock = FalseBlock;
321 NewBlock = TrueBlock;
334 for (
auto II = EndBlock->
begin();
PHINode *Phi = dyn_cast<PHINode>(II);
337 Phi->addIncoming(Phi->getIncomingValueForBlock(StartBlock), NewBlock);
346 SI->eraseFromParent();
354 typedef std::deque<BasicBlock *> PathType;
355 typedef std::vector<PathType> PathsType;
357 typedef std::vector<ClonedBlock> CloneList;
386 struct ThreadingPath {
388 uint64_t getExitValue()
const {
return ExitVal; }
393 bool isExitValueSet()
const {
return IsExitValSet; }
396 const BasicBlock *getDeterminatorBB()
const {
return DBB; }
400 const PathType &getPath()
const {
return Path; }
401 void setPath(
const PathType &NewPath) {
Path = NewPath; }
404 OS <<
Path <<
" [ " << ExitVal <<
", " << DBB->getName() <<
" ]";
411 bool IsExitValSet =
false;
423 if (isCandidate(
SI)) {
428 <<
"Switch instruction is not predictable.";
433 virtual ~MainSwitch() =
default;
435 SwitchInst *getInstr()
const {
return Instr; }
446 std::deque<Value *> Q;
450 Value *SICond =
SI->getCondition();
452 if (!isa<PHINode>(SICond))
455 addToQueue(SICond, Q, SeenValues);
458 Value *Current = Q.front();
461 if (
auto *Phi = dyn_cast<PHINode>(Current)) {
462 for (
Value *Incoming : Phi->incoming_values()) {
463 addToQueue(Incoming, Q, SeenValues);
466 }
else if (
SelectInst *SelI = dyn_cast<SelectInst>(Current)) {
467 if (!isValidSelectInst(SelI))
469 addToQueue(SelI->getTrueValue(), Q, SeenValues);
470 addToQueue(SelI->getFalseValue(), Q, SeenValues);
472 if (
auto *SelIUse = dyn_cast<PHINode>(SelI->user_back()))
473 SelectInsts.push_back(SelectInstToUnfold(SelI, SelIUse));
474 }
else if (isa<Constant>(Current)) {
490 void addToQueue(
Value *Val, std::deque<Value *> &Q,
499 if (!
SI->hasOneUse())
504 if (!SIUse && !(isa<PHINode>(SIUse) || isa<SelectInst>(SIUse)))
515 if (isa<PHINode>(SIUse) &&
521 for (SelectInstToUnfold SIToUnfold : SelectInsts) {
535 struct AllSwitchPaths {
540 std::vector<ThreadingPath> &getThreadingPaths() {
return TPaths; }
541 unsigned getNumThreadingPaths() {
return TPaths.size(); }
543 BasicBlock *getSwitchBlock() {
return SwitchBlock; }
546 VisitedBlocks Visited;
547 PathsType LoopPaths =
paths(SwitchBlock, Visited, 1);
548 StateDefMap StateDef = getStateDefMap(LoopPaths);
550 if (StateDef.empty()) {
554 <<
"Switch instruction is not predictable.";
559 for (PathType Path : LoopPaths) {
564 if (StateDef.count(
BB) != 0) {
565 const PHINode *Phi = dyn_cast<PHINode>(StateDef[
BB]);
566 assert(Phi &&
"Expected a state-defining instr to be a phi node.");
569 if (
const ConstantInt *
C = dyn_cast<const ConstantInt>(V)) {
570 TPath.setExitValue(
C);
571 TPath.setDeterminator(
BB);
577 if (TPath.isExitValueSet() &&
BB ==
Path.front())
583 if (TPath.isExitValueSet() && isSupported(TPath))
584 TPaths.push_back(TPath);
594 unsigned PathDepth)
const {
602 <<
"Exploration stopped after visiting MaxPathLength="
614 if (!Successors.
insert(Succ).second)
618 if (Succ == SwitchBlock) {
624 if (Visited.contains(Succ))
627 PathsType SuccPaths =
paths(Succ, Visited, PathDepth + 1);
628 for (PathType Path : SuccPaths) {
629 PathType NewPath(Path);
630 NewPath.push_front(
BB);
631 Res.push_back(NewPath);
648 StateDefMap getStateDefMap(
const PathsType &LoopPaths)
const {
653 for (
const PathType &Path : LoopPaths) {
660 assert(isa<PHINode>(FirstDef) &&
"The first definition must be a phi.");
663 Stack.push_back(dyn_cast<PHINode>(FirstDef));
666 while (!
Stack.empty()) {
670 SeenValues.
insert(CurPhi);
674 bool IsOutsideLoops = LoopBBs.
count(IncomingBB) == 0;
675 if (Incoming == FirstDef || isa<ConstantInt>(Incoming) ||
676 SeenValues.
contains(Incoming) || IsOutsideLoops) {
681 if (!isa<PHINode>(Incoming))
682 return StateDefMap();
684 Stack.push_back(cast<PHINode>(Incoming));
709 bool isSupported(
const ThreadingPath &TPath) {
717 const BasicBlock *DeterminatorBB = TPath.getDeterminatorBB();
720 SwitchCondUseBB == TPath.getPath().front() &&
721 "The first BB in a threading path should have the switch instruction");
722 if (SwitchCondUseBB != TPath.getPath().front())
726 PathType
Path = TPath.getPath();
730 bool IsDetBBSeen =
false;
731 bool IsDefBBSeen =
false;
732 bool IsUseBBSeen =
false;
734 if (
BB == DeterminatorBB)
736 if (
BB == SwitchCondDefBB)
738 if (
BB == SwitchCondUseBB)
740 if (IsDetBBSeen && IsUseBBSeen && !IsDefBBSeen)
750 std::vector<ThreadingPath> TPaths;
753 struct TransformDFA {
758 : SwitchPaths(SwitchPaths), DT(DT), AC(AC),
TTI(
TTI), ORE(ORE),
759 EphValues(EphValues) {}
762 if (isLegalAndProfitableToTransform()) {
763 createAllExitPaths();
773 bool isLegalAndProfitableToTransform() {
779 DuplicateBlockMap DuplicateMap;
781 for (ThreadingPath &TPath : SwitchPaths->getThreadingPaths()) {
782 PathType PathBBs = TPath.getPath();
783 uint64_t NextState = TPath.getExitValue();
784 const BasicBlock *Determinator = TPath.getDeterminatorBB();
788 BasicBlock *VisitedBB = getClonedBB(
BB, NextState, DuplicateMap);
791 DuplicateMap[
BB].push_back({
BB, NextState});
796 if (PathBBs.front() == Determinator)
801 auto DetIt =
std::find(PathBBs.begin(), PathBBs.end(), Determinator);
802 for (
auto BBIt = DetIt; BBIt != PathBBs.end(); BBIt++) {
804 VisitedBB = getClonedBB(
BB, NextState, DuplicateMap);
808 DuplicateMap[
BB].push_back({
BB, NextState});
812 LLVM_DEBUG(
dbgs() <<
"DFA Jump Threading: Not jump threading, contains "
813 <<
"non-duplicatable instructions.\n");
817 <<
"Contains non-duplicatable instructions.";
823 LLVM_DEBUG(
dbgs() <<
"DFA Jump Threading: Not jump threading, contains "
824 <<
"convergent instructions.\n");
827 <<
"Contains convergent instructions.";
832 if (!
Metrics.NumInsts.isValid()) {
833 LLVM_DEBUG(
dbgs() <<
"DFA Jump Threading: Not jump threading, contains "
834 <<
"instructions with invalid cost.\n");
837 <<
"Contains instructions with invalid cost.";
843 unsigned DuplicationCost = 0;
845 unsigned JumpTableSize = 0;
848 if (JumpTableSize == 0) {
852 unsigned CondBranches =
854 DuplicationCost = *
Metrics.NumInsts.getValue() / CondBranches;
862 DuplicationCost = *
Metrics.NumInsts.getValue() / JumpTableSize;
865 LLVM_DEBUG(
dbgs() <<
"\nDFA Jump Threading: Cost to jump thread block "
866 << SwitchPaths->getSwitchBlock()->getName()
867 <<
" is: " << DuplicationCost <<
"\n\n");
870 LLVM_DEBUG(
dbgs() <<
"Not jump threading, duplication cost exceeds the "
871 <<
"cost threshold.\n");
874 <<
"Duplication cost exceeds the cost threshold (cost="
875 <<
ore::NV(
"Cost", DuplicationCost)
883 <<
"Switch statement jump-threaded.";
890 void createAllExitPaths() {
894 BasicBlock *SwitchBlock = SwitchPaths->getSwitchBlock();
895 for (ThreadingPath &TPath : SwitchPaths->getThreadingPaths()) {
897 PathType NewPath(TPath.getPath());
898 NewPath.push_back(SwitchBlock);
899 TPath.setPath(NewPath);
903 DuplicateBlockMap DuplicateMap;
910 for (ThreadingPath &TPath : SwitchPaths->getThreadingPaths()) {
911 createExitPath(NewDefs, TPath, DuplicateMap, BlocksToClean, &DTU);
917 for (ThreadingPath &TPath : SwitchPaths->getThreadingPaths())
918 updateLastSuccessor(TPath, DuplicateMap, &DTU);
934 void createExitPath(DefMap &NewDefs, ThreadingPath &Path,
935 DuplicateBlockMap &DuplicateMap,
940 PathType PathBBs =
Path.getPath();
943 if (PathBBs.front() == Determinator)
946 auto DetIt =
std::find(PathBBs.begin(), PathBBs.end(), Determinator);
947 auto Prev = std::prev(DetIt);
949 for (
auto BBIt = DetIt; BBIt != PathBBs.end(); BBIt++) {
955 BasicBlock *NextBB = getClonedBB(
BB, NextState, DuplicateMap);
957 updatePredecessor(PrevBB,
BB, NextBB, DTU);
963 BasicBlock *NewBB = cloneBlockAndUpdatePredecessor(
964 BB, PrevBB, NextState, DuplicateMap, NewDefs, DTU);
965 DuplicateMap[
BB].push_back({NewBB, NextState});
966 BlocksToClean.
insert(NewBB);
977 void updateSSA(DefMap &NewDefs) {
981 for (
auto KV : NewDefs) {
984 std::vector<Instruction *> Cloned = KV.second;
988 for (
Use &U :
I->uses()) {
991 if (UserPN->getIncomingBlock(U) ==
BB)
993 }
else if (
User->getParent() ==
BB) {
997 UsesToRename.push_back(&U);
1002 if (UsesToRename.empty())
1010 unsigned VarNum = SSAUpdate.
AddVariable(
I->getName(),
I->getType());
1015 while (!UsesToRename.empty())
1031 DuplicateBlockMap &DuplicateMap,
1044 if (isa<PHINode>(&
I))
1052 updateSuccessorPhis(
BB, NewBB, NextState, VMap, DuplicateMap);
1053 updatePredecessor(PrevBB,
BB, NewBB, DTU);
1054 updateDefMap(NewDefs, VMap);
1059 if (SuccSet.
insert(SuccBB).second)
1072 DuplicateBlockMap &DuplicateMap) {
1073 std::vector<BasicBlock *> BlocksToUpdate;
1077 if (
BB == SwitchPaths->getSwitchBlock()) {
1079 BasicBlock *NextCase = getNextCaseSuccessor(Switch, NextState);
1080 BlocksToUpdate.push_back(NextCase);
1081 BasicBlock *ClonedSucc = getClonedBB(NextCase, NextState, DuplicateMap);
1083 BlocksToUpdate.push_back(ClonedSucc);
1088 BlocksToUpdate.push_back(Succ);
1093 BasicBlock *ClonedSucc = getClonedBB(Succ, NextState, DuplicateMap);
1095 BlocksToUpdate.push_back(ClonedSucc);
1103 for (
auto II = Succ->begin();
PHINode *Phi = dyn_cast<PHINode>(II);
1107 if (isa<Constant>(Incoming)) {
1111 Value *ClonedVal = VMap[Incoming];
1127 if (!isPredecessor(OldBB, PrevBB))
1147 for (
auto Entry : VMap) {
1149 dyn_cast<Instruction>(
const_cast<Value *
>(Entry.first));
1150 if (!Inst || !Entry.second || isa<BranchInst>(Inst) ||
1151 isa<SwitchInst>(Inst)) {
1155 Instruction *Cloned = dyn_cast<Instruction>(Entry.second);
1159 NewDefsVector.push_back({Inst, Cloned});
1163 sort(NewDefsVector, [](
const auto &
LHS,
const auto &
RHS) {
1164 if (
LHS.first ==
RHS.first)
1165 return LHS.second->comesBefore(
RHS.second);
1166 return LHS.first->comesBefore(
RHS.first);
1169 for (
const auto &KV : NewDefsVector)
1170 NewDefs[KV.first].push_back(KV.second);
1178 void updateLastSuccessor(ThreadingPath &TPath,
1179 DuplicateBlockMap &DuplicateMap,
1181 uint64_t NextState = TPath.getExitValue();
1183 BasicBlock *LastBlock = getClonedBB(
BB, NextState, DuplicateMap);
1190 BasicBlock *NextCase = getNextCaseSuccessor(Switch, NextState);
1192 std::vector<DominatorTree::UpdateType> DTUpdates;
1195 if (Succ != NextCase && SuccSet.
insert(Succ).second)
1199 Switch->eraseFromParent();
1210 std::vector<PHINode *> PhiToRemove;
1211 for (
auto II =
BB->begin();
PHINode *Phi = dyn_cast<PHINode>(II); ++II) {
1212 PhiToRemove.push_back(Phi);
1214 for (
PHINode *PN : PhiToRemove) {
1216 PN->eraseFromParent();
1222 for (
auto II =
BB->begin();
PHINode *Phi = dyn_cast<PHINode>(II); ++II) {
1223 std::vector<BasicBlock *> BlocksToRemove;
1225 if (!isPredecessor(
BB, IncomingBB))
1226 BlocksToRemove.push_back(IncomingBB);
1236 DuplicateBlockMap &DuplicateMap) {
1237 CloneList ClonedBBs = DuplicateMap[
BB];
1241 auto It =
llvm::find_if(ClonedBBs, [NextState](
const ClonedBlock &
C) {
1242 return C.State == NextState;
1244 return It != ClonedBBs.end() ? (*It).BB :
nullptr;
1251 for (
auto Case :
Switch->cases()) {
1252 if (Case.getCaseValue()->getZExtValue() == NextState) {
1253 NextCase = Case.getCaseSuccessor();
1258 NextCase =
Switch->getDefaultDest();
1267 AllSwitchPaths *SwitchPaths;
1273 std::vector<ThreadingPath> TPaths;
1277 LLVM_DEBUG(
dbgs() <<
"\nDFA Jump threading: " <<
F.getName() <<
"\n");
1279 if (
F.hasOptSize()) {
1280 LLVM_DEBUG(
dbgs() <<
"Skipping due to the 'minsize' attribute\n");
1288 bool MadeChanges =
false;
1291 auto *
SI = dyn_cast<SwitchInst>(
BB.getTerminator());
1296 <<
" is a candidate\n");
1303 <<
"candidate for jump threading\n");
1306 unfoldSelectInstrs(DT,
Switch.getSelectInsts());
1307 if (!
Switch.getSelectInsts().empty())
1310 AllSwitchPaths SwitchPaths(&Switch, ORE);
1313 if (SwitchPaths.getNumThreadingPaths() > 0) {
1314 ThreadableLoops.push_back(SwitchPaths);
1326 if (ThreadableLoops.size() > 0)
1329 for (AllSwitchPaths SwitchPaths : ThreadableLoops) {
1330 TransformDFA Transform(&SwitchPaths, DT, AC,
TTI, ORE, EphValues);
1335 #ifdef EXPENSIVE_CHECKS
1353 if (!DFAJumpThreading(&AC, &DT, &
TTI, &ORE).
run(
F))