LCOV - code coverage report
Current view: top level - lib/Analysis - DivergenceAnalysis.cpp (source / functions) Hit Total Coverage
Test: llvm-toolchain.info Lines: 103 106 97.2 %
Date: 2017-09-14 15:23:50 Functions: 12 13 92.3 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //===- DivergenceAnalysis.cpp --------- Divergence Analysis Implementation -==//
       2             : //
       3             : //                     The LLVM Compiler Infrastructure
       4             : //
       5             : // This file is distributed under the University of Illinois Open Source
       6             : // License. See LICENSE.TXT for details.
       7             : //
       8             : //===----------------------------------------------------------------------===//
       9             : //
      10             : // This file implements divergence analysis which determines whether a branch
      11             : // in a GPU program is divergent.It can help branch optimizations such as jump
      12             : // threading and loop unswitching to make better decisions.
      13             : //
      14             : // GPU programs typically use the SIMD execution model, where multiple threads
      15             : // in the same execution group have to execute in lock-step. Therefore, if the
      16             : // code contains divergent branches (i.e., threads in a group do not agree on
      17             : // which path of the branch to take), the group of threads has to execute all
      18             : // the paths from that branch with different subsets of threads enabled until
      19             : // they converge at the immediately post-dominating BB of the paths.
      20             : //
      21             : // Due to this execution model, some optimizations such as jump
      22             : // threading and loop unswitching can be unfortunately harmful when performed on
      23             : // divergent branches. Therefore, an analysis that computes which branches in a
      24             : // GPU program are divergent can help the compiler to selectively run these
      25             : // optimizations.
      26             : //
      27             : // This file defines divergence analysis which computes a conservative but
      28             : // non-trivial approximation of all divergent branches in a GPU program. It
      29             : // partially implements the approach described in
      30             : //
      31             : //   Divergence Analysis
      32             : //   Sampaio, Souza, Collange, Pereira
      33             : //   TOPLAS '13
      34             : //
      35             : // The divergence analysis identifies the sources of divergence (e.g., special
      36             : // variables that hold the thread ID), and recursively marks variables that are
      37             : // data or sync dependent on a source of divergence as divergent.
      38             : //
      39             : // While data dependency is a well-known concept, the notion of sync dependency
      40             : // is worth more explanation. Sync dependence characterizes the control flow
      41             : // aspect of the propagation of branch divergence. For example,
      42             : //
      43             : //   %cond = icmp slt i32 %tid, 10
      44             : //   br i1 %cond, label %then, label %else
      45             : // then:
      46             : //   br label %merge
      47             : // else:
      48             : //   br label %merge
      49             : // merge:
      50             : //   %a = phi i32 [ 0, %then ], [ 1, %else ]
      51             : //
      52             : // Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
      53             : // because %tid is not on its use-def chains, %a is sync dependent on %tid
      54             : // because the branch "br i1 %cond" depends on %tid and affects which value %a
      55             : // is assigned to.
      56             : //
      57             : // The current implementation has the following limitations:
      58             : // 1. intra-procedural. It conservatively considers the arguments of a
      59             : //    non-kernel-entry function and the return value of a function call as
      60             : //    divergent.
      61             : // 2. memory as black box. It conservatively considers values loaded from
      62             : //    generic or local address as divergent. This can be improved by leveraging
      63             : //    pointer analysis.
      64             : //
      65             : //===----------------------------------------------------------------------===//
      66             : 
      67             : #include "llvm/Analysis/DivergenceAnalysis.h"
      68             : #include "llvm/Analysis/Passes.h"
      69             : #include "llvm/Analysis/PostDominators.h"
      70             : #include "llvm/Analysis/TargetTransformInfo.h"
      71             : #include "llvm/IR/Dominators.h"
      72             : #include "llvm/IR/InstIterator.h"
      73             : #include "llvm/IR/Instructions.h"
      74             : #include "llvm/IR/IntrinsicInst.h"
      75             : #include "llvm/IR/Value.h"
      76             : #include "llvm/Support/Debug.h"
      77             : #include "llvm/Support/raw_ostream.h"
      78             : #include <vector>
      79             : using namespace llvm;
      80             : 
      81             : namespace {
      82             : 
      83      119352 : class DivergencePropagator {
      84             : public:
      85             :   DivergencePropagator(Function &F, TargetTransformInfo &TTI, DominatorTree &DT,
      86             :                        PostDominatorTree &PDT, DenseSet<const Value *> &DV)
      87      119352 :       : F(F), TTI(TTI), DT(DT), PDT(PDT), DV(DV) {}
      88             :   void populateWithSourcesOfDivergence();
      89             :   void propagate();
      90             : 
      91             : private:
      92             :   // A helper function that explores data dependents of V.
      93             :   void exploreDataDependency(Value *V);
      94             :   // A helper function that explores sync dependents of TI.
      95             :   void exploreSyncDependency(TerminatorInst *TI);
      96             :   // Computes the influence region from Start to End. This region includes all
      97             :   // basic blocks on any simple path from Start to End.
      98             :   void computeInfluenceRegion(BasicBlock *Start, BasicBlock *End,
      99             :                               DenseSet<BasicBlock *> &InfluenceRegion);
     100             :   // Finds all users of I that are outside the influence region, and add these
     101             :   // users to Worklist.
     102             :   void findUsersOutsideInfluenceRegion(
     103             :       Instruction &I, const DenseSet<BasicBlock *> &InfluenceRegion);
     104             : 
     105             :   Function &F;
     106             :   TargetTransformInfo &TTI;
     107             :   DominatorTree &DT;
     108             :   PostDominatorTree &PDT;
     109             :   std::vector<Value *> Worklist; // Stack for DFS.
     110             :   DenseSet<const Value *> &DV;   // Stores all divergent values.
     111             : };
     112             : 
     113       59676 : void DivergencePropagator::populateWithSourcesOfDivergence() {
     114      119352 :   Worklist.clear();
     115      119352 :   DV.clear();
     116      940058 :   for (auto &I : instructions(F)) {
     117      380515 :     if (TTI.isSourceOfDivergence(&I)) {
     118       56876 :       Worklist.push_back(&I);
     119       56876 :       DV.insert(&I);
     120             :     }
     121             :   }
     122      189445 :   for (auto &Arg : F.args()) {
     123      129769 :     if (TTI.isSourceOfDivergence(&Arg)) {
     124       21464 :       Worklist.push_back(&Arg);
     125       21464 :       DV.insert(&Arg);
     126             :     }
     127             :   }
     128       59676 : }
     129             : 
     130        1589 : void DivergencePropagator::exploreSyncDependency(TerminatorInst *TI) {
     131             :   // Propagation rule 1: if branch TI is divergent, all PHINodes in TI's
     132             :   // immediate post dominator are divergent. This rule handles if-then-else
     133             :   // patterns. For example,
     134             :   //
     135             :   // if (tid < 5)
     136             :   //   a1 = 1;
     137             :   // else
     138             :   //   a2 = 2;
     139             :   // a = phi(a1, a2); // sync dependent on (tid < 5)
     140        1589 :   BasicBlock *ThisBB = TI->getParent();
     141             : 
     142             :   // Unreachable blocks may not be in the dominator tree.
     143        1589 :   if (!DT.isReachableFromEntry(ThisBB))
     144         131 :     return;
     145             : 
     146             :   // If the function has no exit blocks or doesn't reach any exit blocks, the
     147             :   // post dominator may be null.
     148        3176 :   DomTreeNode *ThisNode = PDT.getNode(ThisBB);
     149        1588 :   if (!ThisNode)
     150             :     return;
     151             : 
     152        1588 :   BasicBlock *IPostDom = ThisNode->getIDom()->getBlock();
     153        1588 :   if (IPostDom == nullptr)
     154             :     return;
     155             : 
     156        5112 :   for (auto I = IPostDom->begin(); isa<PHINode>(I); ++I) {
     157             :     // A PHINode is uniform if it returns the same value no matter which path is
     158             :     // taken.
     159        4384 :     if (!cast<PHINode>(I)->hasConstantOrUndefValue() && DV.insert(&*I).second)
     160        2340 :       Worklist.push_back(&*I);
     161             :   }
     162             : 
     163             :   // Propagation rule 2: if a value defined in a loop is used outside, the user
     164             :   // is sync dependent on the condition of the loop exits that dominate the
     165             :   // user. For example,
     166             :   //
     167             :   // int i = 0;
     168             :   // do {
     169             :   //   i++;
     170             :   //   if (foo(i)) ... // uniform
     171             :   // } while (i < tid);
     172             :   // if (bar(i)) ...   // divergent
     173             :   //
     174             :   // A program may contain unstructured loops. Therefore, we cannot leverage
     175             :   // LoopInfo, which only recognizes natural loops.
     176             :   //
     177             :   // The algorithm used here handles both natural and unstructured loops.  Given
     178             :   // a branch TI, we first compute its influence region, the union of all simple
     179             :   // paths from TI to its immediate post dominator (IPostDom). Then, we search
     180             :   // for all the values defined in the influence region but used outside. All
     181             :   // these users are sync dependent on TI.
     182        2916 :   DenseSet<BasicBlock *> InfluenceRegion;
     183        1458 :   computeInfluenceRegion(ThisBB, IPostDom, InfluenceRegion);
     184             :   // An insight that can speed up the search process is that all the in-region
     185             :   // values that are used outside must dominate TI. Therefore, instead of
     186             :   // searching every basic blocks in the influence region, we search all the
     187             :   // dominators of TI until it is outside the influence region.
     188        1458 :   BasicBlock *InfluencedBB = ThisBB;
     189        2054 :   while (InfluenceRegion.count(InfluencedBB)) {
     190        8049 :     for (auto &I : *InfluencedBB)
     191        7155 :       findUsersOutsideInfluenceRegion(I, InfluenceRegion);
     192         596 :     DomTreeNode *IDomNode = DT.getNode(InfluencedBB)->getIDom();
     193         298 :     if (IDomNode == nullptr)
     194             :       break;
     195         298 :     InfluencedBB = IDomNode->getBlock();
     196             :   }
     197             : }
     198             : 
     199        7155 : void DivergencePropagator::findUsersOutsideInfluenceRegion(
     200             :     Instruction &I, const DenseSet<BasicBlock *> &InfluenceRegion) {
     201       44481 :   for (User *U : I.users()) {
     202       11508 :     Instruction *UserInst = cast<Instruction>(U);
     203       23016 :     if (!InfluenceRegion.count(UserInst->getParent())) {
     204        8400 :       if (DV.insert(UserInst).second)
     205        5866 :         Worklist.push_back(UserInst);
     206             :     }
     207             :   }
     208        7155 : }
     209             : 
     210             : // A helper function for computeInfluenceRegion that adds successors of "ThisBB"
     211             : // to the influence region.
     212             : static void
     213        4023 : addSuccessorsToInfluenceRegion(BasicBlock *ThisBB, BasicBlock *End,
     214             :                                DenseSet<BasicBlock *> &InfluenceRegion,
     215             :                                std::vector<BasicBlock *> &InfluenceStack) {
     216       20736 :   for (BasicBlock *Succ : successors(ThisBB)) {
     217        9613 :     if (Succ != End && InfluenceRegion.insert(Succ).second)
     218        2565 :       InfluenceStack.push_back(Succ);
     219             :   }
     220        4023 : }
     221             : 
     222        1458 : void DivergencePropagator::computeInfluenceRegion(
     223             :     BasicBlock *Start, BasicBlock *End,
     224             :     DenseSet<BasicBlock *> &InfluenceRegion) {
     225             :   assert(PDT.properlyDominates(End, Start) &&
     226             :          "End does not properly dominate Start");
     227             : 
     228             :   // The influence region starts from the end of "Start" to the beginning of
     229             :   // "End". Therefore, "Start" should not be in the region unless "Start" is in
     230             :   // a loop that doesn't contain "End".
     231        2916 :   std::vector<BasicBlock *> InfluenceStack;
     232        1458 :   addSuccessorsToInfluenceRegion(Start, End, InfluenceRegion, InfluenceStack);
     233        4023 :   while (!InfluenceStack.empty()) {
     234        2565 :     BasicBlock *BB = InfluenceStack.back();
     235        2565 :     InfluenceStack.pop_back();
     236        2565 :     addSuccessorsToInfluenceRegion(BB, End, InfluenceRegion, InfluenceStack);
     237             :   }
     238        1458 : }
     239             : 
     240      153525 : void DivergencePropagator::exploreDataDependency(Value *V) {
     241             :   // Follow def-use chains of V.
     242      589996 :   for (User *U : V->users()) {
     243      141473 :     Instruction *UserInst = cast<Instruction>(U);
     244      313751 :     if (!TTI.isAlwaysUniform(U) && DV.insert(UserInst).second)
     245      221284 :       Worklist.push_back(UserInst);
     246             :   }
     247      153525 : }
     248             : 
     249       59676 : void DivergencePropagator::propagate() {
     250             :   // Traverse the dependency graph using DFS.
     251      579927 :   while (!Worklist.empty()) {
     252      307050 :     Value *V = Worklist.back();
     253      307050 :     Worklist.pop_back();
     254        2922 :     if (TerminatorInst *TI = dyn_cast<TerminatorInst>(V)) {
     255             :       // Terminators with less than two successors won't introduce sync
     256             :       // dependency. Ignore them.
     257        2922 :       if (TI->getNumSuccessors() > 1)
     258        1589 :         exploreSyncDependency(TI);
     259             :     }
     260      153525 :     exploreDataDependency(V);
     261             :   }
     262       59676 : }
     263             : 
     264             : } /// end namespace anonymous
     265             : 
     266             : // Register this pass.
     267             : char DivergenceAnalysis::ID = 0;
     268       53053 : INITIALIZE_PASS_BEGIN(DivergenceAnalysis, "divergence", "Divergence Analysis",
     269             :                       false, true)
     270       53053 : INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
     271       53053 : INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
     272      944090 : INITIALIZE_PASS_END(DivergenceAnalysis, "divergence", "Divergence Analysis",
     273             :                     false, true)
     274             : 
     275           0 : FunctionPass *llvm::createDivergenceAnalysisPass() {
     276           0 :   return new DivergenceAnalysis();
     277             : }
     278             : 
     279        5917 : void DivergenceAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
     280        5917 :   AU.addRequired<DominatorTreeWrapperPass>();
     281        5917 :   AU.addRequired<PostDominatorTreeWrapperPass>();
     282        5917 :   AU.setPreservesAll();
     283        5917 : }
     284             : 
     285       59866 : bool DivergenceAnalysis::runOnFunction(Function &F) {
     286       59866 :   auto *TTIWP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>();
     287       59866 :   if (TTIWP == nullptr)
     288             :     return false;
     289             : 
     290       59684 :   TargetTransformInfo &TTI = TTIWP->getTTI(F);
     291             :   // Fast path: if the target does not have branch divergence, we do not mark
     292             :   // any branch as divergent.
     293       59684 :   if (!TTI.hasBranchDivergence())
     294             :     return false;
     295             : 
     296      119352 :   DivergentValues.clear();
     297      119352 :   auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
     298             :   DivergencePropagator DP(F, TTI,
     299       59676 :                           getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
     300      179028 :                           PDT, DivergentValues);
     301       59676 :   DP.populateWithSourcesOfDivergence();
     302       59676 :   DP.propagate();
     303       59676 :   return false;
     304             : }
     305             : 
     306          50 : void DivergenceAnalysis::print(raw_ostream &OS, const Module *) const {
     307         100 :   if (DivergentValues.empty())
     308             :     return;
     309         147 :   const Value *FirstDivergentValue = *DivergentValues.begin();
     310             :   const Function *F;
     311          13 :   if (const Argument *Arg = dyn_cast<Argument>(FirstDivergentValue)) {
     312          13 :     F = Arg->getParent();
     313             :   } else if (const Instruction *I =
     314          36 :                  dyn_cast<Instruction>(FirstDivergentValue)) {
     315          36 :     F = I->getParent()->getParent();
     316             :   } else {
     317           0 :     llvm_unreachable("Only arguments and instructions can be divergent");
     318             :   }
     319             : 
     320             :   // Dumps all divergent values in F, arguments and then instructions.
     321         158 :   for (auto &Arg : F->args()) {
     322         218 :     if (DivergentValues.count(&Arg))
     323         182 :       OS << "DIVERGENT:  " << Arg << "\n";
     324             :   }
     325             :   // Iterate instructions using instructions() to ensure a deterministic order.
     326         508 :   for (auto &I : instructions(F)) {
     327         410 :     if (DivergentValues.count(&I))
     328         296 :       OS << "DIVERGENT:" << I << "\n";
     329             :   }
     330             : }

Generated by: LCOV version 1.13