LCOV - code coverage report
Current view: top level - lib/Analysis - LegacyDivergenceAnalysis.cpp (source / functions) Hit Total Coverage
Test: llvm-toolchain.info Lines: 83 93 89.2 %
Date: 2018-10-20 13:21:21 Functions: 11 13 84.6 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //===- LegacyDivergenceAnalysis.cpp --------- Legacy Divergence Analysis Implementation -==//
       2             : //
       3             : //                     The LLVM Compiler Infrastructure
       4             : //
       5             : // This file is distributed under the University of Illinois Open Source
       6             : // License. See LICENSE.TXT for details.
       7             : //
       8             : //===----------------------------------------------------------------------===//
       9             : //
      10             : // This file implements divergence analysis which determines whether a branch
      11             : // in a GPU program is divergent.It can help branch optimizations such as jump
      12             : // threading and loop unswitching to make better decisions.
      13             : //
      14             : // GPU programs typically use the SIMD execution model, where multiple threads
      15             : // in the same execution group have to execute in lock-step. Therefore, if the
      16             : // code contains divergent branches (i.e., threads in a group do not agree on
      17             : // which path of the branch to take), the group of threads has to execute all
      18             : // the paths from that branch with different subsets of threads enabled until
      19             : // they converge at the immediately post-dominating BB of the paths.
      20             : //
      21             : // Due to this execution model, some optimizations such as jump
      22             : // threading and loop unswitching can be unfortunately harmful when performed on
      23             : // divergent branches. Therefore, an analysis that computes which branches in a
      24             : // GPU program are divergent can help the compiler to selectively run these
      25             : // optimizations.
      26             : //
      27             : // This file defines divergence analysis which computes a conservative but
      28             : // non-trivial approximation of all divergent branches in a GPU program. It
      29             : // partially implements the approach described in
      30             : //
      31             : //   Divergence Analysis
      32             : //   Sampaio, Souza, Collange, Pereira
      33             : //   TOPLAS '13
      34             : //
      35             : // The divergence analysis identifies the sources of divergence (e.g., special
      36             : // variables that hold the thread ID), and recursively marks variables that are
      37             : // data or sync dependent on a source of divergence as divergent.
      38             : //
      39             : // While data dependency is a well-known concept, the notion of sync dependency
      40             : // is worth more explanation. Sync dependence characterizes the control flow
      41             : // aspect of the propagation of branch divergence. For example,
      42             : //
      43             : //   %cond = icmp slt i32 %tid, 10
      44             : //   br i1 %cond, label %then, label %else
      45             : // then:
      46             : //   br label %merge
      47             : // else:
      48             : //   br label %merge
      49             : // merge:
      50             : //   %a = phi i32 [ 0, %then ], [ 1, %else ]
      51             : //
      52             : // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
      53             : // because %tid is not on its use-def chains, %a is sync dependent on %tid
      54             : // because the branch "br i1 %cond" depends on %tid and affects which value %a
      55             : // is assigned to.
      56             : //
      57             : // The current implementation has the following limitations:
      58             : // 1. intra-procedural. It conservatively considers the arguments of a
      59             : //    non-kernel-entry function and the return value of a function call as
      60             : //    divergent.
      61             : // 2. memory as black box. It conservatively considers values loaded from
      62             : //    generic or local address as divergent. This can be improved by leveraging
      63             : //    pointer analysis.
      64             : //
      65             : //===----------------------------------------------------------------------===//
      66             : 
      67             : #include "llvm/Analysis/LegacyDivergenceAnalysis.h"
      68             : #include "llvm/Analysis/Passes.h"
      69             : #include "llvm/Analysis/PostDominators.h"
      70             : #include "llvm/Analysis/TargetTransformInfo.h"
      71             : #include "llvm/IR/Dominators.h"
      72             : #include "llvm/IR/InstIterator.h"
      73             : #include "llvm/IR/Instructions.h"
      74             : #include "llvm/IR/Value.h"
      75             : #include "llvm/Support/Debug.h"
      76             : #include "llvm/Support/raw_ostream.h"
      77             : #include <vector>
      78             : using namespace llvm;
      79             : 
      80             : #define DEBUG_TYPE "divergence"
      81             : 
      82             : namespace {
      83             : 
      84             : class DivergencePropagator {
      85             : public:
      86             :   DivergencePropagator(Function &F, TargetTransformInfo &TTI, DominatorTree &DT,
      87             :                        PostDominatorTree &PDT, DenseSet<const Value *> &DV)
      88       97961 :       : F(F), TTI(TTI), DT(DT), PDT(PDT), DV(DV) {}
      89             :   void populateWithSourcesOfDivergence();
      90             :   void propagate();
      91             : 
      92             : private:
      93             :   // A helper function that explores data dependents of V.
      94             :   void exploreDataDependency(Value *V);
      95             :   // A helper function that explores sync dependents of TI.
      96             :   void exploreSyncDependency(Instruction *TI);
      97             :   // Computes the influence region from Start to End. This region includes all
      98             :   // basic blocks on any simple path from Start to End.
      99             :   void computeInfluenceRegion(BasicBlock *Start, BasicBlock *End,
     100             :                               DenseSet<BasicBlock *> &InfluenceRegion);
     101             :   // Finds all users of I that are outside the influence region, and add these
     102             :   // users to Worklist.
     103             :   void findUsersOutsideInfluenceRegion(
     104             :       Instruction &I, const DenseSet<BasicBlock *> &InfluenceRegion);
     105             : 
     106             :   Function &F;
     107             :   TargetTransformInfo &TTI;
     108             :   DominatorTree &DT;
     109             :   PostDominatorTree &PDT;
     110             :   std::vector<Value *> Worklist; // Stack for DFS.
     111             :   DenseSet<const Value *> &DV;   // Stores all divergent values.
     112             : };
     113             : 
     114       97961 : void DivergencePropagator::populateWithSourcesOfDivergence() {
     115       97961 :   Worklist.clear();
     116       97961 :   DV.clear();
     117     1316993 :   for (auto &I : instructions(F)) {
     118     1219032 :     if (TTI.isSourceOfDivergence(&I)) {
     119       43370 :       Worklist.push_back(&I);
     120       43370 :       DV.insert(&I);
     121             :     }
     122             :   }
     123      323666 :   for (auto &Arg : F.args()) {
     124      225705 :     if (TTI.isSourceOfDivergence(&Arg)) {
     125       40046 :       Worklist.push_back(&Arg);
     126       40046 :       DV.insert(&Arg);
     127             :     }
     128             :   }
     129       97961 : }
     130             : 
     131        2793 : void DivergencePropagator::exploreSyncDependency(Instruction *TI) {
     132             :   // Propagation rule 1: if branch TI is divergent, all PHINodes in TI's
     133             :   // immediate post dominator are divergent. This rule handles if-then-else
     134             :   // patterns. For example,
     135             :   //
     136             :   // if (tid < 5)
     137             :   //   a1 = 1;
     138             :   // else
     139             :   //   a2 = 2;
     140             :   // a = phi(a1, a2); // sync dependent on (tid < 5)
     141        2793 :   BasicBlock *ThisBB = TI->getParent();
     142             : 
     143             :   // Unreachable blocks may not be in the dominator tree.
     144        2793 :   if (!DT.isReachableFromEntry(ThisBB))
     145         140 :     return;
     146             : 
     147             :   // If the function has no exit blocks or doesn't reach any exit blocks, the
     148             :   // post dominator may be null.
     149        2792 :   DomTreeNode *ThisNode = PDT.getNode(ThisBB);
     150             :   if (!ThisNode)
     151           0 :     return;
     152             : 
     153        2792 :   BasicBlock *IPostDom = ThisNode->getIDom()->getBlock();
     154        2792 :   if (IPostDom == nullptr)
     155             :     return;
     156             : 
     157        6762 :   for (auto I = IPostDom->begin(); isa<PHINode>(I); ++I) {
     158             :     // A PHINode is uniform if it returns the same value no matter which path is
     159             :     // taken.
     160        4109 :     if (!cast<PHINode>(I)->hasConstantOrUndefValue() && DV.insert(&*I).second)
     161         986 :       Worklist.push_back(&*I);
     162             :   }
     163             : 
     164             :   // Propagation rule 2: if a value defined in a loop is used outside, the user
     165             :   // is sync dependent on the condition of the loop exits that dominate the
     166             :   // user. For example,
     167             :   //
     168             :   // int i = 0;
     169             :   // do {
     170             :   //   i++;
     171             :   //   if (foo(i)) ... // uniform
     172             :   // } while (i < tid);
     173             :   // if (bar(i)) ...   // divergent
     174             :   //
     175             :   // A program may contain unstructured loops. Therefore, we cannot leverage
     176             :   // LoopInfo, which only recognizes natural loops.
     177             :   //
     178             :   // The algorithm used here handles both natural and unstructured loops.  Given
     179             :   // a branch TI, we first compute its influence region, the union of all simple
     180             :   // paths from TI to its immediate post dominator (IPostDom). Then, we search
     181             :   // for all the values defined in the influence region but used outside. All
     182             :   // these users are sync dependent on TI.
     183             :   DenseSet<BasicBlock *> InfluenceRegion;
     184        2653 :   computeInfluenceRegion(ThisBB, IPostDom, InfluenceRegion);
     185             :   // An insight that can speed up the search process is that all the in-region
     186             :   // values that are used outside must dominate TI. Therefore, instead of
     187             :   // searching every basic blocks in the influence region, we search all the
     188             :   // dominators of TI until it is outside the influence region.
     189             :   BasicBlock *InfluencedBB = ThisBB;
     190         622 :   while (InfluenceRegion.count(InfluencedBB)) {
     191       11601 :     for (auto &I : *InfluencedBB)
     192       10979 :       findUsersOutsideInfluenceRegion(I, InfluenceRegion);
     193         622 :     DomTreeNode *IDomNode = DT.getNode(InfluencedBB)->getIDom();
     194         622 :     if (IDomNode == nullptr)
     195             :       break;
     196         622 :     InfluencedBB = IDomNode->getBlock();
     197             :   }
     198             : }
     199             : 
     200       10979 : void DivergencePropagator::findUsersOutsideInfluenceRegion(
     201             :     Instruction &I, const DenseSet<BasicBlock *> &InfluenceRegion) {
     202       27813 :   for (User *U : I.users()) {
     203             :     Instruction *UserInst = cast<Instruction>(U);
     204       16834 :     if (!InfluenceRegion.count(UserInst->getParent())) {
     205        5535 :       if (DV.insert(UserInst).second)
     206        3768 :         Worklist.push_back(UserInst);
     207             :     }
     208             :   }
     209       10979 : }
     210             : 
     211             : // A helper function for computeInfluenceRegion that adds successors of "ThisBB"
     212             : // to the influence region.
     213             : static void
     214        6893 : addSuccessorsToInfluenceRegion(BasicBlock *ThisBB, BasicBlock *End,
     215             :                                DenseSet<BasicBlock *> &InfluenceRegion,
     216             :                                std::vector<BasicBlock *> &InfluenceStack) {
     217       24627 :   for (BasicBlock *Succ : successors(ThisBB)) {
     218       16184 :     if (Succ != End && InfluenceRegion.insert(Succ).second)
     219        4240 :       InfluenceStack.push_back(Succ);
     220             :   }
     221        6893 : }
     222             : 
     223           0 : void DivergencePropagator::computeInfluenceRegion(
     224             :     BasicBlock *Start, BasicBlock *End,
     225             :     DenseSet<BasicBlock *> &InfluenceRegion) {
     226             :   assert(PDT.properlyDominates(End, Start) &&
     227             :          "End does not properly dominate Start");
     228             : 
     229             :   // The influence region starts from the end of "Start" to the beginning of
     230             :   // "End". Therefore, "Start" should not be in the region unless "Start" is in
     231             :   // a loop that doesn't contain "End".
     232             :   std::vector<BasicBlock *> InfluenceStack;
     233           0 :   addSuccessorsToInfluenceRegion(Start, End, InfluenceRegion, InfluenceStack);
     234           0 :   while (!InfluenceStack.empty()) {
     235           0 :     BasicBlock *BB = InfluenceStack.back();
     236             :     InfluenceStack.pop_back();
     237           0 :     addSuccessorsToInfluenceRegion(BB, End, InfluenceRegion, InfluenceStack);
     238             :   }
     239           0 : }
     240             : 
     241      282397 : void DivergencePropagator::exploreDataDependency(Value *V) {
     242             :   // Follow def-use chains of V.
     243      536479 :   for (User *U : V->users()) {
     244             :     Instruction *UserInst = cast<Instruction>(U);
     245      254082 :     if (!TTI.isAlwaysUniform(U) && DV.insert(UserInst).second)
     246      194227 :       Worklist.push_back(UserInst);
     247             :   }
     248      282397 : }
     249             : 
     250       97961 : void DivergencePropagator::propagate() {
     251             :   // Traverse the dependency graph using DFS.
     252      380358 :   while (!Worklist.empty()) {
     253      282397 :     Value *V = Worklist.back();
     254             :     Worklist.pop_back();
     255             :     if (Instruction *I = dyn_cast<Instruction>(V)) {
     256             :       // Terminators with less than two successors won't introduce sync
     257             :       // dependency. Ignore them.
     258      242351 :       if (I->isTerminator() && I->getNumSuccessors() > 1)
     259        2793 :         exploreSyncDependency(I);
     260             :     }
     261      282397 :     exploreDataDependency(V);
     262             :   }
     263       97961 : }
     264             : 
     265             : } /// end namespace anonymous
     266             : 
     267             : // Register this pass.
     268             : char LegacyDivergenceAnalysis::ID = 0;
     269       85117 : INITIALIZE_PASS_BEGIN(LegacyDivergenceAnalysis, "divergence", "Legacy Divergence Analysis",
     270             :                       false, true)
     271       85117 : INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
     272       85117 : INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
     273      681711 : INITIALIZE_PASS_END(LegacyDivergenceAnalysis, "divergence", "Legacy Divergence Analysis",
     274             :                     false, true)
     275             : 
     276           0 : FunctionPass *llvm::createLegacyDivergenceAnalysisPass() {
     277           0 :   return new LegacyDivergenceAnalysis();
     278             : }
     279             : 
     280       10143 : void LegacyDivergenceAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
     281             :   AU.addRequired<DominatorTreeWrapperPass>();
     282             :   AU.addRequired<PostDominatorTreeWrapperPass>();
     283             :   AU.setPreservesAll();
     284       10143 : }
     285             : 
     286      101149 : bool LegacyDivergenceAnalysis::runOnFunction(Function &F) {
     287      101149 :   auto *TTIWP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>();
     288      101149 :   if (TTIWP == nullptr)
     289             :     return false;
     290             : 
     291      101149 :   TargetTransformInfo &TTI = TTIWP->getTTI(F);
     292             :   // Fast path: if the target does not have branch divergence, we do not mark
     293             :   // any branch as divergent.
     294      101149 :   if (!TTI.hasBranchDivergence())
     295             :     return false;
     296             : 
     297             :   DivergentValues.clear();
     298       97961 :   auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
     299             :   DivergencePropagator DP(F, TTI,
     300       97961 :                           getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
     301       97961 :                           PDT, DivergentValues);
     302       97961 :   DP.populateWithSourcesOfDivergence();
     303       97961 :   DP.propagate();
     304             :   LLVM_DEBUG(
     305             :     dbgs() << "\nAfter divergence analysis on " << F.getName() << ":\n";
     306             :     print(dbgs(), F.getParent())
     307             :   );
     308             :   return false;
     309             : }
     310             : 
     311          53 : void LegacyDivergenceAnalysis::print(raw_ostream &OS, const Module *) const {
     312          53 :   if (DivergentValues.empty())
     313             :     return;
     314          52 :   const Value *FirstDivergentValue = *DivergentValues.begin();
     315             :   const Function *F;
     316             :   if (const Argument *Arg = dyn_cast<Argument>(FirstDivergentValue)) {
     317          12 :     F = Arg->getParent();
     318             :   } else if (const Instruction *I =
     319             :                  dyn_cast<Instruction>(FirstDivergentValue)) {
     320          40 :     F = I->getParent()->getParent();
     321             :   } else {
     322           0 :     llvm_unreachable("Only arguments and instructions can be divergent");
     323             :   }
     324             : 
     325             :   // Dumps all divergent values in F, arguments and then instructions.
     326         167 :   for (auto &Arg : F->args()) {
     327         230 :     OS << (DivergentValues.count(&Arg) ? "DIVERGENT: " : "           ");
     328         115 :     OS << Arg << "\n";
     329             :   }
     330             :   // Iterate instructions using instructions() to ensure a deterministic order.
     331         143 :   for (auto BI = F->begin(), BE = F->end(); BI != BE; ++BI) {
     332             :     auto &BB = *BI;
     333          91 :     OS << "\n           " << BB.getName() << ":\n";
     334         485 :     for (auto &I : BB.instructionsWithoutDebug()) {
     335         424 :       OS << (DivergentValues.count(&I) ? "DIVERGENT:     " : "               ");
     336         212 :       OS << I << "\n";
     337             :     }
     338             :   }
     339          52 :   OS << "\n";
     340             : }

Generated by: LCOV version 1.13