LCOV - code coverage report
Current view: top level - lib/Target/ARM - ARMParallelDSP.cpp (source / functions) Hit Total Coverage
Test: llvm-toolchain.info Lines: 172 207 83.1 %
Date: 2018-10-20 13:21:21 Functions: 20 25 80.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //===- ParallelDSP.cpp - Parallel DSP Pass --------------------------------===//
       2             : //
       3             : //                     The LLVM Compiler Infrastructure
       4             : //
       5             : // This file is distributed under the University of Illinois Open Source
       6             : // License. See LICENSE.TXT for details.
       7             : //
       8             : //===----------------------------------------------------------------------===//
       9             : //
      10             : /// \file
      11             : /// Armv6 introduced instructions to perform 32-bit SIMD operations. The
      12             : /// purpose of this pass is do some IR pattern matching to create ACLE
      13             : /// DSP intrinsics, which map on these 32-bit SIMD operations.
      14             : /// This pass runs only when unaligned accesses is supported/enabled.
      15             : //
      16             : //===----------------------------------------------------------------------===//
      17             : 
      18             : #include "llvm/ADT/Statistic.h"
      19             : #include "llvm/ADT/SmallPtrSet.h"
      20             : #include "llvm/Analysis/AliasAnalysis.h"
      21             : #include "llvm/Analysis/LoopAccessAnalysis.h"
      22             : #include "llvm/Analysis/LoopPass.h"
      23             : #include "llvm/Analysis/LoopInfo.h"
      24             : #include "llvm/IR/Instructions.h"
      25             : #include "llvm/IR/NoFolder.h"
      26             : #include "llvm/Transforms/Scalar.h"
      27             : #include "llvm/Transforms/Utils/BasicBlockUtils.h"
      28             : #include "llvm/Transforms/Utils/LoopUtils.h"
      29             : #include "llvm/Pass.h"
      30             : #include "llvm/PassRegistry.h"
      31             : #include "llvm/PassSupport.h"
      32             : #include "llvm/Support/Debug.h"
      33             : #include "llvm/IR/PatternMatch.h"
      34             : #include "llvm/CodeGen/TargetPassConfig.h"
      35             : #include "ARM.h"
      36             : #include "ARMSubtarget.h"
      37             : 
      38             : using namespace llvm;
      39             : using namespace PatternMatch;
      40             : 
      41             : #define DEBUG_TYPE "arm-parallel-dsp"
      42             : 
      43             : STATISTIC(NumSMLAD , "Number of smlad instructions generated");
      44             : 
      45             : static cl::opt<bool>
      46             : DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false),
      47             :                    cl::desc("Disable the ARM Parallel DSP pass"));
      48             : 
      49             : namespace {
      50             :   struct OpChain;
      51             :   struct BinOpChain;
      52             :   struct Reduction;
      53             : 
      54             :   using OpChainList     = SmallVector<std::unique_ptr<OpChain>, 8>;
      55             :   using ReductionList   = SmallVector<Reduction, 8>;
      56             :   using ValueList       = SmallVector<Value*, 8>;
      57             :   using MemInstList     = SmallVector<Instruction*, 8>;
      58             :   using PMACPair        = std::pair<BinOpChain*,BinOpChain*>;
      59             :   using PMACPairList    = SmallVector<PMACPair, 8>;
      60             :   using Instructions    = SmallVector<Instruction*,16>;
      61             :   using MemLocList      = SmallVector<MemoryLocation, 4>;
      62             : 
      63             :   struct OpChain {
      64             :     Instruction   *Root;
      65             :     ValueList     AllValues;
      66             :     MemInstList   VecLd;    // List of all load instructions.
      67             :     MemLocList    MemLocs;  // All memory locations read by this tree.
      68             :     bool          ReadOnly = true;
      69             : 
      70          61 :     OpChain(Instruction *I, ValueList &vl) : Root(I), AllValues(vl) { }
      71          61 :     virtual ~OpChain() = default;
      72             : 
      73          61 :     void SetMemoryLocations() {
      74             :       const auto Size = LocationSize::unknown();
      75         305 :       for (auto *V : AllValues) {
      76             :         if (auto *I = dyn_cast<Instruction>(V)) {
      77         244 :           if (I->mayWriteToMemory())
      78           4 :             ReadOnly = false;
      79             :           if (auto *Ld = dyn_cast<LoadInst>(V))
      80         122 :             MemLocs.push_back(MemoryLocation(Ld->getPointerOperand(), Size));
      81             :         }
      82             :       }
      83          61 :     }
      84             : 
      85          61 :     unsigned size() const { return AllValues.size(); }
      86             :   };
      87             : 
      88             :   // 'BinOpChain' and 'Reduction' are just some bookkeeping data structures.
      89             :   // 'Reduction' contains the phi-node and accumulator statement from where we
      90             :   // start pattern matching, and 'BinOpChain' the multiplication
      91             :   // instructions that are candidates for parallel execution.
      92           0 :   struct BinOpChain : public OpChain {
      93             :     ValueList     LHS;      // List of all (narrow) left hand operands.
      94             :     ValueList     RHS;      // List of all (narrow) right hand operands.
      95             :     bool Exchange = false;
      96             : 
      97          61 :     BinOpChain(Instruction *I, ValueList &lhs, ValueList &rhs) :
      98         122 :       OpChain(I, lhs), LHS(lhs), RHS(rhs) {
      99         183 :         for (auto *V : RHS)
     100         122 :           AllValues.push_back(V);
     101          61 :       }
     102             :   };
     103             : 
     104         156 :   struct Reduction {
     105             :     PHINode         *Phi;             // The Phi-node from where we start
     106             :                                       // pattern matching.
     107             :     Instruction     *AccIntAdd;       // The accumulating integer add statement,
     108             :                                       // i.e, the reduction statement.
     109             : 
     110             :     OpChainList     MACCandidates;    // The MAC candidates associated with
     111             :                                       // this reduction statement.
     112          52 :     Reduction (PHINode *P, Instruction *Acc) : Phi(P), AccIntAdd(Acc) { };
     113             :   };
     114             : 
     115             :   class ARMParallelDSP : public LoopPass {
     116             :     ScalarEvolution   *SE;
     117             :     AliasAnalysis     *AA;
     118             :     TargetLibraryInfo *TLI;
     119             :     DominatorTree     *DT;
     120             :     LoopInfo          *LI;
     121             :     Loop              *L;
     122             :     const DataLayout  *DL;
     123             :     Module            *M;
     124             : 
     125             :     bool InsertParallelMACs(Reduction &Reduction, PMACPairList &PMACPairs);
     126             :     bool AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, MemInstList &VecMem);
     127             :     PMACPairList CreateParallelMACPairs(OpChainList &Candidates);
     128             :     Instruction *CreateSMLADCall(LoadInst *VecLd0, LoadInst *VecLd1,
     129             :                                  Instruction *Acc, bool Exchange,
     130             :                                  Instruction *InsertAfter);
     131             : 
     132             :     /// Try to match and generate: SMLAD, SMLADX - Signed Multiply Accumulate
     133             :     /// Dual performs two signed 16x16-bit multiplications. It adds the
     134             :     /// products to a 32-bit accumulate operand. Optionally, the instruction can
     135             :     /// exchange the halfwords of the second operand before performing the
     136             :     /// arithmetic.
     137             :     bool MatchSMLAD(Function &F);
     138             : 
     139             :   public:
     140             :     static char ID;
     141             : 
     142        5198 :     ARMParallelDSP() : LoopPass(ID) { }
     143             : 
     144        2594 :     void getAnalysisUsage(AnalysisUsage &AU) const override {
     145        2594 :       LoopPass::getAnalysisUsage(AU);
     146             :       AU.addRequired<AssumptionCacheTracker>();
     147             :       AU.addRequired<ScalarEvolutionWrapperPass>();
     148             :       AU.addRequired<AAResultsWrapperPass>();
     149             :       AU.addRequired<TargetLibraryInfoWrapperPass>();
     150             :       AU.addRequired<LoopInfoWrapperPass>();
     151             :       AU.addRequired<DominatorTreeWrapperPass>();
     152             :       AU.addRequired<TargetPassConfig>();
     153             :       AU.addPreserved<LoopInfoWrapperPass>();
     154        2594 :       AU.setPreservesCFG();
     155        2594 :     }
     156             : 
     157         783 :     bool runOnLoop(Loop *TheLoop, LPPassManager &) override {
     158         783 :       if (DisableParallelDSP)
     159             :         return false;
     160         783 :       L = TheLoop;
     161         783 :       SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
     162         783 :       AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
     163         783 :       TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
     164         783 :       DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
     165         783 :       LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
     166         783 :       auto &TPC = getAnalysis<TargetPassConfig>();
     167             : 
     168             :       BasicBlock *Header = TheLoop->getHeader();
     169         783 :       if (!Header)
     170             :         return false;
     171             : 
     172             :       // TODO: We assume the loop header and latch to be the same block.
     173             :       // This is not a fundamental restriction, but lifting this would just
     174             :       // require more work to do the transformation and then patch up the CFG.
     175         783 :       if (Header != TheLoop->getLoopLatch()) {
     176             :         LLVM_DEBUG(dbgs() << "The loop header is not the loop latch: not "
     177             :                              "running pass ARMParallelDSP\n");
     178             :         return false;
     179             :       }
     180             : 
     181         587 :       Function &F = *Header->getParent();
     182         587 :       M = F.getParent();
     183         587 :       DL = &M->getDataLayout();
     184             : 
     185         587 :       auto &TM = TPC.getTM<TargetMachine>();
     186             :       auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
     187             : 
     188         587 :       if (!ST->allowsUnalignedMem()) {
     189             :         LLVM_DEBUG(dbgs() << "Unaligned memory access not supported: not "
     190             :                              "running pass ARMParallelDSP\n");
     191             :         return false;
     192             :       }
     193             : 
     194         552 :       if (!ST->hasDSP()) {
     195             :         LLVM_DEBUG(dbgs() << "DSP extension not enabled: not running pass "
     196             :                              "ARMParallelDSP\n");
     197             :         return false;
     198             :       }
     199             : 
     200         946 :       LoopAccessInfo LAI(L, SE, TLI, AA, DT, LI);
     201             :       bool Changes = false;
     202             : 
     203             :       LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n");
     204             :       LLVM_DEBUG(dbgs() << " - " << F.getName() << "\n\n");
     205         473 :       Changes = MatchSMLAD(F);
     206             :       return Changes;
     207             :     }
     208             :   };
     209             : }
     210             : 
     211             : // MaxBitwidth: the maximum supported bitwidth of the elements in the DSP
     212             : // instructions, which is set to 16. So here we should collect all i8 and i16
     213             : // narrow operations.
     214             : // TODO: we currently only collect i16, and will support i8 later, so that's
     215             : // why we check that types are equal to MaxBitWidth, and not <= MaxBitWidth.
     216             : template<unsigned MaxBitWidth>
     217         134 : static bool IsNarrowSequence(Value *V, ValueList &VL) {
     218             :   LLVM_DEBUG(dbgs() << "Is narrow sequence? "; V->dump());
     219             :   ConstantInt *CInt;
     220             : 
     221             :   if (match(V, m_ConstantInt(CInt))) {
     222             :     // TODO: if a constant is used, it needs to fit within the bit width.
     223             :     return false;
     224             :   }
     225             : 
     226             :   auto *I = dyn_cast<Instruction>(V);
     227             :   if (!I)
     228             :    return false;
     229             : 
     230             :   Value *Val, *LHS, *RHS;
     231         132 :   if (match(V, m_Trunc(m_Value(Val)))) {
     232           0 :     if (cast<TruncInst>(I)->getDestTy()->getIntegerBitWidth() == MaxBitWidth)
     233           0 :       return IsNarrowSequence<MaxBitWidth>(Val, VL);
     234         132 :   } else if (match(V, m_Add(m_Value(LHS), m_Value(RHS)))) {
     235             :     // TODO: we need to implement sadd16/sadd8 for this, which enables to
     236             :     // also do the rewrite for smlad8.ll, but it is unsupported for now.
     237             :     LLVM_DEBUG(dbgs() << "No, unsupported Op:\t"; I->dump());
     238             :     return false;
     239         130 :   } else if (match(V, m_ZExtOrSExt(m_Value(Val)))) {
     240         128 :     if (cast<CastInst>(I)->getSrcTy()->getIntegerBitWidth() != MaxBitWidth) {
     241             :       LLVM_DEBUG(dbgs() << "No, wrong SrcTy size: " <<
     242             :         cast<CastInst>(I)->getSrcTy()->getIntegerBitWidth() << "\n");
     243             :       return false;
     244             :     }
     245             : 
     246         124 :     if (match(Val, m_Load(m_Value()))) {
     247             :       LLVM_DEBUG(dbgs() << "Yes, found narrow Load:\t"; Val->dump());
     248         124 :       VL.push_back(Val);
     249         124 :       VL.push_back(I);
     250         124 :       return true;
     251             :     }
     252             :   }
     253             :   LLVM_DEBUG(dbgs() << "No, unsupported Op:\t"; I->dump());
     254             :   return false;
     255             : }
     256             : 
     257             : // Element-by-element comparison of Value lists returning true if they are
     258             : // instructions with the same opcode or constants with the same value.
     259          68 : static bool AreSymmetrical(const ValueList &VL0,
     260             :                            const ValueList &VL1) {
     261          68 :   if (VL0.size() != VL1.size()) {
     262             :     LLVM_DEBUG(dbgs() << "Muls are mismatching operand list lengths: "
     263             :                       << VL0.size() << " != " << VL1.size() << "\n");
     264             :     return false;
     265             :   }
     266             : 
     267             :   const unsigned Pairs = VL0.size();
     268             :   LLVM_DEBUG(dbgs() << "Number of operand pairs: " << Pairs << "\n");
     269             : 
     270         202 :   for (unsigned i = 0; i < Pairs; ++i) {
     271         272 :     const Value *V0 = VL0[i];
     272         136 :     const Value *V1 = VL1[i];
     273             :     const auto *Inst0 = dyn_cast<Instruction>(V0);
     274             :     const auto *Inst1 = dyn_cast<Instruction>(V1);
     275             : 
     276             :     LLVM_DEBUG(dbgs() << "Pair " << i << ":\n";
     277             :                dbgs() << "mul1: "; V0->dump();
     278             :                dbgs() << "mul2: "; V1->dump());
     279             : 
     280         136 :     if (!Inst0 || !Inst1)
     281           2 :       return false;
     282             : 
     283         136 :     if (Inst0->isSameOperationAs(Inst1)) {
     284             :       LLVM_DEBUG(dbgs() << "OK: same operation found!\n");
     285         134 :       continue;
     286             :     }
     287             : 
     288             :     const APInt *C0, *C1;
     289           2 :     if (!(match(V0, m_APInt(C0)) && match(V1, m_APInt(C1)) && C0 == C1))
     290           2 :       return false;
     291             :   }
     292             : 
     293             :   LLVM_DEBUG(dbgs() << "OK: found symmetrical operand lists.\n");
     294             :   return true;
     295             : }
     296             : 
     297             : template<typename MemInst>
     298          80 : static bool AreSequentialAccesses(MemInst *MemOp0, MemInst *MemOp1,
     299             :                                   MemInstList &VecMem, const DataLayout &DL,
     300             :                                   ScalarEvolution &SE) {
     301             :   if (!MemOp0->isSimple() || !MemOp1->isSimple()) {
     302             :     LLVM_DEBUG(dbgs() << "No, not touching volatile access\n");
     303           0 :     return false;
     304             :   }
     305          80 :   if (isConsecutiveAccess(MemOp0, MemOp1, DL, SE)) {
     306             :     VecMem.clear();
     307          52 :     VecMem.push_back(MemOp0);
     308          52 :     VecMem.push_back(MemOp1);
     309             :     LLVM_DEBUG(dbgs() << "OK: accesses are consecutive.\n");
     310          52 :     return true;
     311             :   }
     312             :   LLVM_DEBUG(dbgs() << "No, accesses aren't consecutive.\n");
     313             :   return false;
     314             : }
     315             : 
     316           0 : bool ARMParallelDSP::AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1,
     317             :                                         MemInstList &VecMem) {
     318           0 :   if (!Ld0 || !Ld1)
     319           0 :     return false;
     320             : 
     321             :   LLVM_DEBUG(dbgs() << "Are consecutive loads:\n";
     322             :     dbgs() << "Ld0:"; Ld0->dump();
     323             :     dbgs() << "Ld1:"; Ld1->dump();
     324             :   );
     325             : 
     326           0 :   if (!Ld0->hasOneUse() || !Ld1->hasOneUse()) {
     327             :     LLVM_DEBUG(dbgs() << "No, load has more than one use.\n");
     328           0 :     return false;
     329             :   }
     330             : 
     331           0 :   return AreSequentialAccesses<LoadInst>(Ld0, Ld1, VecMem, *DL, *SE);
     332             : }
     333             : 
     334             : PMACPairList
     335          49 : ARMParallelDSP::CreateParallelMACPairs(OpChainList &Candidates) {
     336          49 :   const unsigned Elems = Candidates.size();
     337             :   PMACPairList PMACPairs;
     338             : 
     339          49 :   if (Elems < 2)
     340             :     return PMACPairs;
     341             : 
     342             :   SmallPtrSet<const Instruction*, 4> Paired;
     343          72 :   for (unsigned i = 0; i < Elems; ++i) {
     344          52 :     BinOpChain *PMul0 = static_cast<BinOpChain*>(Candidates[i].get());
     345          52 :     if (Paired.count(PMul0->Root))
     346             :       continue;
     347             : 
     348          81 :     for (unsigned j = 0; j < Elems; ++j) {
     349          73 :       if (i == j)
     350             :         continue;
     351             : 
     352          45 :       BinOpChain *PMul1 = static_cast<BinOpChain*>(Candidates[j].get());
     353          45 :       if (Paired.count(PMul1->Root))
     354             :         continue;
     355             : 
     356          34 :       const Instruction *Mul0 = PMul0->Root;
     357          34 :       const Instruction *Mul1 = PMul1->Root;
     358          34 :       if (Mul0 == Mul1)
     359             :         continue;
     360             : 
     361             :       assert(PMul0 != PMul1 && "expected different chains");
     362             : 
     363             :       LLVM_DEBUG(dbgs() << "\nCheck parallel muls:\n";
     364             :                  dbgs() << "- "; Mul0->dump();
     365             :                  dbgs() << "- "; Mul1->dump());
     366             : 
     367          34 :       const ValueList &Mul0_LHS = PMul0->LHS;
     368          34 :       const ValueList &Mul0_RHS = PMul0->RHS;
     369          34 :       const ValueList &Mul1_LHS = PMul1->LHS;
     370          34 :       const ValueList &Mul1_RHS = PMul1->RHS;
     371             : 
     372          68 :       if (!AreSymmetrical(Mul0_LHS, Mul1_LHS) ||
     373          34 :           !AreSymmetrical(Mul0_RHS, Mul1_RHS))
     374           2 :         continue;
     375             : 
     376             :       LLVM_DEBUG(dbgs() << "OK: mul operands list match:\n");
     377             :       // The first elements of each vector should be loads with sexts. If we
     378             :       // find that its two pairs of consecutive loads, then these can be
     379             :       // transformed into two wider loads and the users can be replaced with
     380             :       // DSP intrinsics.
     381             :       bool Found = false;
     382          64 :       for (unsigned x = 0; x < Mul0_LHS.size(); x += 2) {
     383          32 :         auto *Ld0 = dyn_cast<LoadInst>(Mul0_LHS[x]);
     384          32 :         auto *Ld1 = dyn_cast<LoadInst>(Mul1_LHS[x]);
     385          32 :         auto *Ld2 = dyn_cast<LoadInst>(Mul0_RHS[x]);
     386          32 :         auto *Ld3 = dyn_cast<LoadInst>(Mul1_RHS[x]);
     387             : 
     388          32 :         if (!Ld0 || !Ld1 || !Ld2 || !Ld3)
     389             :           continue;
     390             : 
     391             :         LLVM_DEBUG(dbgs() << "Looking at operands " << x << ":\n"
     392             :                    << "\t Ld0: " << *Ld0 << "\n"
     393             :                    << "\t Ld1: " << *Ld1 << "\n"
     394             :                    << "and operands " << x + 2 << ":\n"
     395             :                    << "\t Ld2: " << *Ld2 << "\n"
     396             :                    << "\t Ld3: " << *Ld3 << "\n");
     397             : 
     398          32 :         if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) {
     399          17 :           if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
     400             :             LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
     401          12 :             PMACPairs.push_back(std::make_pair(PMul0, PMul1));
     402             :             Found = true;
     403           5 :           } else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) {
     404             :             LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
     405             :             LLVM_DEBUG(dbgs() << "    exchanging Ld2 and Ld3\n");
     406           5 :             PMul1->Exchange = true;
     407           5 :             PMACPairs.push_back(std::make_pair(PMul0, PMul1));
     408             :             Found = true;
     409             :           }
     410          15 :         } else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd)) {
     411          11 :           if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
     412             :             LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
     413             :             LLVM_DEBUG(dbgs() << "    exchanging Ld0 and Ld1\n");
     414             :             LLVM_DEBUG(dbgs() << "    and swapping muls\n");
     415           7 :             PMul0->Exchange = true;
     416             :             // Only the second operand can be exchanged, so swap the muls.
     417           7 :             PMACPairs.push_back(std::make_pair(PMul1, PMul0));
     418             :             Found = true;
     419             :           }
     420             :         }
     421             :       }
     422          32 :       if (Found) {
     423          24 :         Paired.insert(Mul0);
     424          24 :         Paired.insert(Mul1);
     425             :         break;
     426             :       }
     427             :     }
     428             :   }
     429             :   return PMACPairs;
     430             : }
     431             : 
     432           0 : bool ARMParallelDSP::InsertParallelMACs(Reduction &Reduction,
     433             :                                         PMACPairList &PMACPairs) {
     434           0 :   Instruction *Acc = Reduction.Phi;
     435           0 :   Instruction *InsertAfter = Reduction.AccIntAdd;
     436             : 
     437           0 :   for (auto &Pair : PMACPairs) {
     438           0 :     BinOpChain *PMul0 = Pair.first;
     439           0 :     BinOpChain *PMul1 = Pair.second;
     440             :     LLVM_DEBUG(dbgs() << "Found parallel MACs!!\n";
     441             :                dbgs() << "- "; PMul0->Root->dump();
     442             :                dbgs() << "- "; PMul1->Root->dump());
     443             : 
     444           0 :     auto *VecLd0 = cast<LoadInst>(PMul0->VecLd[0]);
     445           0 :     auto *VecLd1 = cast<LoadInst>(PMul1->VecLd[0]);
     446           0 :     Acc = CreateSMLADCall(VecLd0, VecLd1, Acc, PMul1->Exchange, InsertAfter);
     447             :     InsertAfter = Acc;
     448             :   }
     449             : 
     450           0 :   if (Acc != Reduction.Phi) {
     451             :     LLVM_DEBUG(dbgs() << "Replace Accumulate: "; Acc->dump());
     452           0 :     Reduction.AccIntAdd->replaceAllUsesWith(Acc);
     453           0 :     return true;
     454             :   }
     455             :   return false;
     456             : }
     457             : 
     458         473 : static void MatchReductions(Function &F, Loop *TheLoop, BasicBlock *Header,
     459             :                             ReductionList &Reductions) {
     460         470 :   RecurrenceDescriptor RecDesc;
     461             :   const bool HasFnNoNaNAttr =
     462         473 :     F.getFnAttribute("no-nans-fp-math").getValueAsString() == "true";
     463         473 :   const BasicBlock *Latch = TheLoop->getLoopLatch();
     464             : 
     465             :   // We need a preheader as getIncomingValueForBlock assumes there is one.
     466         473 :   if (!TheLoop->getLoopPreheader()) {
     467             :     LLVM_DEBUG(dbgs() << "No preheader found, bailing out\n");
     468           3 :     return;
     469             :   }
     470             : 
     471        1266 :   for (PHINode &Phi : Header->phis()) {
     472         326 :     const auto *Ty = Phi.getType();
     473         326 :     if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64))
     474             :       continue;
     475             : 
     476             :     const bool IsReduction =
     477         208 :       RecurrenceDescriptor::AddReductionVar(&Phi,
     478             :                                             RecurrenceDescriptor::RK_IntegerAdd,
     479             :                                             TheLoop, HasFnNoNaNAttr, RecDesc);
     480         208 :     if (!IsReduction)
     481             :       continue;
     482             : 
     483          52 :     Instruction *Acc = dyn_cast<Instruction>(Phi.getIncomingValueForBlock(Latch));
     484             :     if (!Acc)
     485             :       continue;
     486             : 
     487         156 :     Reductions.push_back(Reduction(&Phi, Acc));
     488             :   }
     489             : 
     490             :   LLVM_DEBUG(
     491             :     dbgs() << "\nAccumulating integer additions (reductions) found:\n";
     492             :     for (auto &R : Reductions) {
     493             :       dbgs() << "-  "; R.Phi->dump();
     494             :       dbgs() << "-> "; R.AccIntAdd->dump();
     495             :     }
     496             :   );
     497             : }
     498             : 
     499          71 : static void AddMACCandidate(OpChainList &Candidates,
     500             :                             Instruction *Mul,
     501             :                             Value *MulOp0, Value *MulOp1) {
     502             :   LLVM_DEBUG(dbgs() << "OK, found acc mul:\t"; Mul->dump());
     503             :   assert(Mul->getOpcode() == Instruction::Mul &&
     504             :          "expected mul instruction");
     505             :   ValueList LHS;
     506             :   ValueList RHS;
     507         134 :   if (IsNarrowSequence<16>(MulOp0, LHS) &&
     508          63 :       IsNarrowSequence<16>(MulOp1, RHS)) {
     509             :     LLVM_DEBUG(dbgs() << "OK, found narrow mul: "; Mul->dump());
     510         183 :     Candidates.push_back(make_unique<BinOpChain>(Mul, LHS, RHS));
     511             :   }
     512          71 : }
     513             : 
     514           0 : static void MatchParallelMACSequences(Reduction &R,
     515             :                                       OpChainList &Candidates) {
     516           0 :   Instruction *Acc = R.AccIntAdd;
     517             :   LLVM_DEBUG(dbgs() << "\n- Analysing:\t" << *Acc);
     518             : 
     519             :   // Returns false to signal the search should be stopped.
     520             :   std::function<bool(Value*)> Match =
     521             :     [&Candidates, &Match](Value *V) -> bool {
     522             : 
     523             :     auto *I = dyn_cast<Instruction>(V);
     524             :     if (!I)
     525             :       return false;
     526             : 
     527             :     Value *MulOp0, *MulOp1;
     528             : 
     529             :     switch (I->getOpcode()) {
     530             :     case Instruction::Add:
     531             :       if (Match(I->getOperand(0)) || (Match(I->getOperand(1))))
     532             :         return true;
     533             :       break;
     534             :     case Instruction::Mul:
     535             :       if (match (I, (m_Mul(m_Value(MulOp0), m_Value(MulOp1))))) {
     536             :         AddMACCandidate(Candidates, I, MulOp0, MulOp1);
     537             :         return false;
     538             :       }
     539             :       break;
     540             :     case Instruction::SExt:
     541             :       if (match (I, (m_SExt(m_Mul(m_Value(MulOp0), m_Value(MulOp1)))))) {
     542             :         Instruction *Mul = cast<Instruction>(I->getOperand(0));
     543             :         AddMACCandidate(Candidates, Mul, MulOp0, MulOp1);
     544             :         return false;
     545             :       }
     546             :       break;
     547             :     }
     548             :     return false;
     549             :   };
     550             : 
     551           0 :   while (Match (Acc));
     552             :   LLVM_DEBUG(dbgs() << "Finished matching MAC sequences, found "
     553             :              << Candidates.size() << " candidates.\n");
     554           0 : }
     555             : 
     556             : // Collects all instructions that are not part of the MAC chains, which is the
     557             : // set of instructions that can potentially alias with the MAC operands.
     558         473 : static void AliasCandidates(BasicBlock *Header, Instructions &Reads,
     559             :                             Instructions &Writes) {
     560        5649 :   for (auto &I : *Header) {
     561        5176 :     if (I.mayReadFromMemory())
     562         944 :       Reads.push_back(&I);
     563        5176 :     if (I.mayWriteToMemory())
     564         717 :       Writes.push_back(&I);
     565             :   }
     566         473 : }
     567             : 
     568             : // Check whether statements in the basic block that write to memory alias with
     569             : // the memory locations accessed by the MAC-chains.
     570             : // TODO: we need the read statements when we accept more complicated chains.
     571           0 : static bool AreAliased(AliasAnalysis *AA, Instructions &Reads,
     572             :                        Instructions &Writes, OpChainList &MACCandidates) {
     573             :   LLVM_DEBUG(dbgs() << "Alias checks:\n");
     574           0 :   for (auto &MAC : MACCandidates) {
     575             :     LLVM_DEBUG(dbgs() << "mul: "; MAC->Root->dump());
     576             : 
     577             :     // At the moment, we allow only simple chains that only consist of reads,
     578             :     // accumulate their result with an integer add, and thus that don't write
     579             :     // memory, and simply bail if they do.
     580           0 :     if (!MAC->ReadOnly)
     581           0 :       return true;
     582             : 
     583             :     // Now for all writes in the basic block, check that they don't alias with
     584             :     // the memory locations accessed by our MAC-chain:
     585           0 :     for (auto *I : Writes) {
     586             :       LLVM_DEBUG(dbgs() << "- "; I->dump());
     587             :       assert(MAC->MemLocs.size() >= 2 && "expecting at least 2 memlocs");
     588           0 :       for (auto &MemLoc : MAC->MemLocs) {
     589           0 :         if (isModOrRefSet(intersectModRef(AA->getModRefInfo(I, MemLoc),
     590             :                                           ModRefInfo::ModRef))) {
     591             :           LLVM_DEBUG(dbgs() << "Yes, aliases found\n");
     592           0 :           return true;
     593             :         }
     594             :       }
     595             :     }
     596             :   }
     597             : 
     598             :   LLVM_DEBUG(dbgs() << "OK: no aliases found!\n");
     599             :   return false;
     600             : }
     601             : 
     602          52 : static bool CheckMACMemory(OpChainList &Candidates) {
     603         113 :   for (auto &C : Candidates) {
     604             :     // A mul has 2 operands, and a narrow op consist of sext and a load; thus
     605             :     // we expect at least 4 items in this operand value list.
     606          61 :     if (C->size() < 4) {
     607             :       LLVM_DEBUG(dbgs() << "Operand list too short.\n");
     608             :       return false;
     609             :     }
     610          61 :     C->SetMemoryLocations();
     611             :     ValueList &LHS = static_cast<BinOpChain*>(C.get())->LHS;
     612             :     ValueList &RHS = static_cast<BinOpChain*>(C.get())->RHS;
     613             : 
     614             :     // Use +=2 to skip over the expected extend instructions.
     615         122 :     for (unsigned i = 0, e = LHS.size(); i < e; i += 2) {
     616         183 :       if (!isa<LoadInst>(LHS[i]) || !isa<LoadInst>(RHS[i]))
     617             :         return false;
     618             :     }
     619             :   }
     620             :   return true;
     621             : }
     622             : 
     623             : // Loop Pass that needs to identify integer add/sub reductions of 16-bit vector
     624             : // multiplications.
     625             : // To use SMLAD:
     626             : // 1) we first need to find integer add reduction PHIs,
     627             : // 2) then from the PHI, look for this pattern:
     628             : //
     629             : // acc0 = phi i32 [0, %entry], [%acc1, %loop.body]
     630             : // ld0 = load i16
     631             : // sext0 = sext i16 %ld0 to i32
     632             : // ld1 = load i16
     633             : // sext1 = sext i16 %ld1 to i32
     634             : // mul0 = mul %sext0, %sext1
     635             : // ld2 = load i16
     636             : // sext2 = sext i16 %ld2 to i32
     637             : // ld3 = load i16
     638             : // sext3 = sext i16 %ld3 to i32
     639             : // mul1 = mul i32 %sext2, %sext3
     640             : // add0 = add i32 %mul0, %acc0
     641             : // acc1 = add i32 %add0, %mul1
     642             : //
     643             : // Which can be selected to:
     644             : //
     645             : // ldr.h r0
     646             : // ldr.h r1
     647             : // smlad r2, r0, r1, r2
     648             : //
     649             : // If constants are used instead of loads, these will need to be hoisted
     650             : // out and into a register.
     651             : //
     652             : // If loop invariants are used instead of loads, these need to be packed
     653             : // before the loop begins.
     654             : //
     655         473 : bool ARMParallelDSP::MatchSMLAD(Function &F) {
     656         473 :   BasicBlock *Header = L->getHeader();
     657             :   LLVM_DEBUG(dbgs() << "= Matching SMLAD =\n";
     658             :              dbgs() << "Header block:\n"; Header->dump();
     659             :              dbgs() << "Loop info:\n\n"; L->dump());
     660             : 
     661             :   bool Changed = false;
     662         473 :   ReductionList Reductions;
     663         473 :   MatchReductions(F, L, Header, Reductions);
     664             : 
     665         525 :   for (auto &R : Reductions) {
     666          52 :     OpChainList MACCandidates;
     667          52 :     MatchParallelMACSequences(R, MACCandidates);
     668          52 :     if (!CheckMACMemory(MACCandidates))
     669           0 :       continue;
     670             : 
     671             :     R.MACCandidates = std::move(MACCandidates);
     672             : 
     673             :     LLVM_DEBUG(dbgs() << "MAC candidates:\n";
     674             :       for (auto &M : R.MACCandidates)
     675             :         M->Root->dump();
     676             :       dbgs() << "\n";);
     677             :   }
     678             : 
     679             :   // Collect all instructions that may read or write memory. Our alias
     680             :   // analysis checks bail out if any of these instructions aliases with an
     681             :   // instruction from the MAC-chain.
     682             :   Instructions Reads, Writes;
     683         473 :   AliasCandidates(Header, Reads, Writes);
     684             : 
     685         522 :   for (auto &R : Reductions) {
     686          52 :     if (AreAliased(AA, Reads, Writes, R.MACCandidates))
     687           3 :       return false;
     688          49 :     PMACPairList PMACPairs = CreateParallelMACPairs(R.MACCandidates);
     689          49 :     Changed |= InsertParallelMACs(R, PMACPairs);
     690             :   }
     691             : 
     692             :   LLVM_DEBUG(if (Changed) dbgs() << "Header block:\n"; Header->dump(););
     693             :   return Changed;
     694             : }
     695             : 
     696          48 : static LoadInst *CreateLoadIns(IRBuilder<NoFolder> &IRB, LoadInst &BaseLoad,
     697             :                                const Type *LoadTy) {
     698             :   const unsigned AddrSpace = BaseLoad.getPointerAddressSpace();
     699             : 
     700          48 :   Value *VecPtr = IRB.CreateBitCast(BaseLoad.getPointerOperand(),
     701          48 :                                     LoadTy->getPointerTo(AddrSpace));
     702          48 :   return IRB.CreateAlignedLoad(VecPtr, BaseLoad.getAlignment());
     703             : }
     704             : 
     705          24 : Instruction *ARMParallelDSP::CreateSMLADCall(LoadInst *VecLd0, LoadInst *VecLd1,
     706             :                                              Instruction *Acc, bool Exchange,
     707             :                                              Instruction *InsertAfter) {
     708             :   LLVM_DEBUG(dbgs() << "Create SMLAD intrinsic using:\n"
     709             :              << "- " << *VecLd0 << "\n"
     710             :              << "- " << *VecLd1 << "\n"
     711             :              << "- " << *Acc << "\n"
     712             :              << "Exchange: " << Exchange << "\n");
     713             : 
     714             :   IRBuilder<NoFolder> Builder(InsertAfter->getParent(),
     715          24 :                               ++BasicBlock::iterator(InsertAfter));
     716             : 
     717             :   // Replace the reduction chain with an intrinsic call
     718          24 :   const Type *Ty = IntegerType::get(M->getContext(), 32);
     719          24 :   LoadInst *NewLd0 = CreateLoadIns(Builder, VecLd0[0], Ty);
     720          24 :   LoadInst *NewLd1 = CreateLoadIns(Builder, VecLd1[0], Ty);
     721          24 :   Value* Args[] = { NewLd0, NewLd1, Acc };
     722             :   Function *SMLAD = nullptr;
     723          24 :   if (Exchange)
     724          12 :     SMLAD = Acc->getType()->isIntegerTy(32) ?
     725           4 :       Intrinsic::getDeclaration(M, Intrinsic::arm_smladx) :
     726          12 :       Intrinsic::getDeclaration(M, Intrinsic::arm_smlaldx);
     727             :   else
     728          12 :     SMLAD = Acc->getType()->isIntegerTy(32) ?
     729           4 :       Intrinsic::getDeclaration(M, Intrinsic::arm_smlad) :
     730          12 :       Intrinsic::getDeclaration(M, Intrinsic::arm_smlald);
     731          24 :   CallInst *Call = Builder.CreateCall(SMLAD, Args);
     732             :   NumSMLAD++;
     733          24 :   return Call;
     734             : }
     735             : 
     736        2569 : Pass *llvm::createARMParallelDSPPass() {
     737        2569 :   return new ARMParallelDSP();
     738             : }
     739             : 
     740             : char ARMParallelDSP::ID = 0;
     741             : 
     742       85105 : INITIALIZE_PASS_BEGIN(ARMParallelDSP, "arm-parallel-dsp",
     743             :                 "Transform loops to use DSP intrinsics", false, false)
     744      199024 : INITIALIZE_PASS_END(ARMParallelDSP, "arm-parallel-dsp",
     745             :                 "Transform loops to use DSP intrinsics", false, false)

Generated by: LCOV version 1.13