LCOV - code coverage report
Current view: top level - lib/Target/AArch64 - AArch64PBQPRegAlloc.cpp (source / functions) Hit Total Coverage
Test: llvm-toolchain.info Lines: 76 98 77.6 %
Date: 2018-10-20 13:21:21 Functions: 5 5 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===//
       2             : //
       3             : //                     The LLVM Compiler Infrastructure
       4             : //
       5             : // This file is distributed under the University of Illinois Open Source
       6             : // License. See LICENSE.TXT for details.
       7             : //
       8             : //===----------------------------------------------------------------------===//
       9             : // This file contains the AArch64 / Cortex-A57 specific register allocation
      10             : // constraints for use by the PBQP register allocator.
      11             : //
      12             : // It is essentially a transcription of what is contained in
      13             : // AArch64A57FPLoadBalancing, which tries to use a balanced
      14             : // mix of odd and even D-registers when performing a critical sequence of
      15             : // independent, non-quadword FP/ASIMD floating-point multiply-accumulates.
      16             : //===----------------------------------------------------------------------===//
      17             : 
      18             : #define DEBUG_TYPE "aarch64-pbqp"
      19             : 
      20             : #include "AArch64PBQPRegAlloc.h"
      21             : #include "AArch64.h"
      22             : #include "AArch64RegisterInfo.h"
      23             : #include "llvm/CodeGen/LiveIntervals.h"
      24             : #include "llvm/CodeGen/MachineBasicBlock.h"
      25             : #include "llvm/CodeGen/MachineFunction.h"
      26             : #include "llvm/CodeGen/MachineRegisterInfo.h"
      27             : #include "llvm/CodeGen/RegAllocPBQP.h"
      28             : #include "llvm/Support/Debug.h"
      29             : #include "llvm/Support/ErrorHandling.h"
      30             : #include "llvm/Support/raw_ostream.h"
      31             : 
      32             : using namespace llvm;
      33             : 
      34             : namespace {
      35             : 
      36             : #ifndef NDEBUG
      37             : bool isFPReg(unsigned reg) {
      38             :   return AArch64::FPR32RegClass.contains(reg) ||
      39             :          AArch64::FPR64RegClass.contains(reg) ||
      40             :          AArch64::FPR128RegClass.contains(reg);
      41             : }
      42             : #endif
      43             : 
      44      163840 : bool isOdd(unsigned reg) {
      45      163840 :   switch (reg) {
      46           0 :   default:
      47           0 :     llvm_unreachable("Register is not from the expected class !");
      48             :   case AArch64::S1:
      49             :   case AArch64::S3:
      50             :   case AArch64::S5:
      51             :   case AArch64::S7:
      52             :   case AArch64::S9:
      53             :   case AArch64::S11:
      54             :   case AArch64::S13:
      55             :   case AArch64::S15:
      56             :   case AArch64::S17:
      57             :   case AArch64::S19:
      58             :   case AArch64::S21:
      59             :   case AArch64::S23:
      60             :   case AArch64::S25:
      61             :   case AArch64::S27:
      62             :   case AArch64::S29:
      63             :   case AArch64::S31:
      64             :   case AArch64::D1:
      65             :   case AArch64::D3:
      66             :   case AArch64::D5:
      67             :   case AArch64::D7:
      68             :   case AArch64::D9:
      69             :   case AArch64::D11:
      70             :   case AArch64::D13:
      71             :   case AArch64::D15:
      72             :   case AArch64::D17:
      73             :   case AArch64::D19:
      74             :   case AArch64::D21:
      75             :   case AArch64::D23:
      76             :   case AArch64::D25:
      77             :   case AArch64::D27:
      78             :   case AArch64::D29:
      79             :   case AArch64::D31:
      80             :   case AArch64::Q1:
      81             :   case AArch64::Q3:
      82             :   case AArch64::Q5:
      83             :   case AArch64::Q7:
      84             :   case AArch64::Q9:
      85             :   case AArch64::Q11:
      86             :   case AArch64::Q13:
      87             :   case AArch64::Q15:
      88             :   case AArch64::Q17:
      89             :   case AArch64::Q19:
      90             :   case AArch64::Q21:
      91             :   case AArch64::Q23:
      92             :   case AArch64::Q25:
      93             :   case AArch64::Q27:
      94             :   case AArch64::Q29:
      95             :   case AArch64::Q31:
      96             :     return true;
      97       81920 :   case AArch64::S0:
      98             :   case AArch64::S2:
      99             :   case AArch64::S4:
     100             :   case AArch64::S6:
     101             :   case AArch64::S8:
     102             :   case AArch64::S10:
     103             :   case AArch64::S12:
     104             :   case AArch64::S14:
     105             :   case AArch64::S16:
     106             :   case AArch64::S18:
     107             :   case AArch64::S20:
     108             :   case AArch64::S22:
     109             :   case AArch64::S24:
     110             :   case AArch64::S26:
     111             :   case AArch64::S28:
     112             :   case AArch64::S30:
     113             :   case AArch64::D0:
     114             :   case AArch64::D2:
     115             :   case AArch64::D4:
     116             :   case AArch64::D6:
     117             :   case AArch64::D8:
     118             :   case AArch64::D10:
     119             :   case AArch64::D12:
     120             :   case AArch64::D14:
     121             :   case AArch64::D16:
     122             :   case AArch64::D18:
     123             :   case AArch64::D20:
     124             :   case AArch64::D22:
     125             :   case AArch64::D24:
     126             :   case AArch64::D26:
     127             :   case AArch64::D28:
     128             :   case AArch64::D30:
     129             :   case AArch64::Q0:
     130             :   case AArch64::Q2:
     131             :   case AArch64::Q4:
     132             :   case AArch64::Q6:
     133             :   case AArch64::Q8:
     134             :   case AArch64::Q10:
     135             :   case AArch64::Q12:
     136             :   case AArch64::Q14:
     137             :   case AArch64::Q16:
     138             :   case AArch64::Q18:
     139             :   case AArch64::Q20:
     140             :   case AArch64::Q22:
     141             :   case AArch64::Q24:
     142             :   case AArch64::Q26:
     143             :   case AArch64::Q28:
     144             :   case AArch64::Q30:
     145       81920 :     return false;
     146             : 
     147             :   }
     148             : }
     149             : 
     150             : bool haveSameParity(unsigned reg1, unsigned reg2) {
     151             :   assert(isFPReg(reg1) && "Expecting an FP register for reg1");
     152             :   assert(isFPReg(reg2) && "Expecting an FP register for reg2");
     153             : 
     154       81920 :   return isOdd(reg1) == isOdd(reg2);
     155             : }
     156             : 
     157             : }
     158             : 
     159          28 : bool A57ChainingConstraint::addIntraChainConstraint(PBQPRAGraph &G, unsigned Rd,
     160             :                                                  unsigned Ra) {
     161          28 :   if (Rd == Ra)
     162             :     return false;
     163             : 
     164          28 :   LiveIntervals &LIs = G.getMetadata().LIS;
     165             : 
     166          28 :   if (TRI->isPhysicalRegister(Rd) || TRI->isPhysicalRegister(Ra)) {
     167             :     LLVM_DEBUG(dbgs() << "Rd is a physical reg:" << TRI->isPhysicalRegister(Rd)
     168             :                       << '\n');
     169             :     LLVM_DEBUG(dbgs() << "Ra is a physical reg:" << TRI->isPhysicalRegister(Ra)
     170             :                       << '\n');
     171             :     return false;
     172             :   }
     173             : 
     174          28 :   PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
     175          28 :   PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(Ra);
     176             : 
     177             :   const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
     178             :     &G.getNodeMetadata(node1).getAllowedRegs();
     179             :   const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRaAllowed =
     180             :     &G.getNodeMetadata(node2).getAllowedRegs();
     181             : 
     182             :   PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
     183             : 
     184             :   // The edge does not exist. Create one with the appropriate interference
     185             :   // costs.
     186          28 :   if (edge == G.invalidEdgeId()) {
     187          28 :     const LiveInterval &ld = LIs.getInterval(Rd);
     188          28 :     const LiveInterval &la = LIs.getInterval(Ra);
     189          28 :     bool livesOverlap = ld.overlaps(la);
     190             : 
     191          28 :     PBQPRAGraph::RawMatrix costs(vRdAllowed->size() + 1,
     192          28 :                                  vRaAllowed->size() + 1, 0);
     193         924 :     for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
     194         896 :       unsigned pRd = (*vRdAllowed)[i];
     195       29568 :       for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
     196       28672 :         unsigned pRa = (*vRaAllowed)[j];
     197       28672 :         if (livesOverlap && TRI->regsOverlap(pRd, pRa))
     198           0 :           costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity();
     199             :         else
     200       57344 :           costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0;
     201             :       }
     202             :     }
     203          84 :     G.addEdge(node1, node2, std::move(costs));
     204             :     return true;
     205             :   }
     206             : 
     207           0 :   if (G.getEdgeNode1Id(edge) == node2) {
     208             :     std::swap(node1, node2);
     209             :     std::swap(vRdAllowed, vRaAllowed);
     210             :   }
     211             : 
     212             :   // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass))
     213           0 :   PBQPRAGraph::RawMatrix costs(G.getEdgeCosts(edge));
     214           0 :   for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
     215           0 :     unsigned pRd = (*vRdAllowed)[i];
     216             : 
     217             :     // Get the maximum cost (excluding unallocatable reg) for same parity
     218             :     // registers
     219             :     PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
     220           0 :     for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
     221           0 :       unsigned pRa = (*vRaAllowed)[j];
     222           0 :       if (haveSameParity(pRd, pRa))
     223           0 :         if (costs[i + 1][j + 1] !=
     224           0 :                 std::numeric_limits<PBQP::PBQPNum>::infinity() &&
     225             :             costs[i + 1][j + 1] > sameParityMax)
     226             :           sameParityMax = costs[i + 1][j + 1];
     227             :     }
     228             : 
     229             :     // Ensure all registers with a different parity have a higher cost
     230             :     // than sameParityMax
     231           0 :     for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
     232           0 :       unsigned pRa = (*vRaAllowed)[j];
     233           0 :       if (!haveSameParity(pRd, pRa))
     234           0 :         if (sameParityMax > costs[i + 1][j + 1])
     235           0 :           costs[i + 1][j + 1] = sameParityMax + 1.0;
     236             :     }
     237             :   }
     238           0 :   G.updateEdgeCosts(edge, std::move(costs));
     239             : 
     240             :   return true;
     241             : }
     242             : 
     243          28 : void A57ChainingConstraint::addInterChainConstraint(PBQPRAGraph &G, unsigned Rd,
     244             :                                                  unsigned Ra) {
     245          28 :   LiveIntervals &LIs = G.getMetadata().LIS;
     246             : 
     247             :   // Do some Chain management
     248          28 :   if (Chains.count(Ra)) {
     249          24 :     if (Rd != Ra) {
     250             :       LLVM_DEBUG(dbgs() << "Moving acc chain from " << printReg(Ra, TRI)
     251             :                         << " to " << printReg(Rd, TRI) << '\n';);
     252          24 :       Chains.remove(Ra);
     253          24 :       Chains.insert(Rd);
     254             :     }
     255             :   } else {
     256             :     LLVM_DEBUG(dbgs() << "Creating new acc chain for " << printReg(Rd, TRI)
     257             :                       << '\n';);
     258           4 :     Chains.insert(Rd);
     259             :   }
     260             : 
     261          28 :   PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
     262             : 
     263          28 :   const LiveInterval &ld = LIs.getInterval(Rd);
     264          82 :   for (auto r : Chains) {
     265             :     // Skip self
     266          54 :     if (r == Rd)
     267             :       continue;
     268             : 
     269          26 :     const LiveInterval &lr = LIs.getInterval(r);
     270          52 :     if (ld.overlaps(lr)) {
     271             :       const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
     272             :         &G.getNodeMetadata(node1).getAllowedRegs();
     273             : 
     274          26 :       PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(r);
     275             :       const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRrAllowed =
     276             :         &G.getNodeMetadata(node2).getAllowedRegs();
     277             : 
     278             :       PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
     279             :       assert(edge != G.invalidEdgeId() &&
     280             :              "PBQP error ! The edge should exist !");
     281             : 
     282             :       LLVM_DEBUG(dbgs() << "Refining constraint !\n";);
     283             : 
     284          26 :       if (G.getEdgeNode1Id(edge) == node2) {
     285             :         std::swap(node1, node2);
     286             :         std::swap(vRdAllowed, vRrAllowed);
     287             :       }
     288             : 
     289             :       // Enforce that cost is higher with all other Chains of the same parity
     290          26 :       PBQP::Matrix costs(G.getEdgeCosts(edge));
     291         858 :       for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
     292         832 :         unsigned pRd = (*vRdAllowed)[i];
     293             : 
     294             :         // Get the maximum cost (excluding unallocatable reg) for all other
     295             :         // parity registers
     296             :         PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
     297       27456 :         for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
     298       26624 :           unsigned pRa = (*vRrAllowed)[j];
     299       26624 :           if (!haveSameParity(pRd, pRa))
     300       13312 :             if (costs[i + 1][j + 1] !=
     301       13312 :                     std::numeric_limits<PBQP::PBQPNum>::infinity() &&
     302             :                 costs[i + 1][j + 1] > sameParityMax)
     303             :               sameParityMax = costs[i + 1][j + 1];
     304             :         }
     305             : 
     306             :         // Ensure all registers with same parity have a higher cost
     307             :         // than sameParityMax
     308       27456 :         for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
     309       26624 :           unsigned pRa = (*vRrAllowed)[j];
     310       26624 :           if (haveSameParity(pRd, pRa))
     311       26624 :             if (sameParityMax > costs[i + 1][j + 1])
     312       12480 :               costs[i + 1][j + 1] = sameParityMax + 1.0;
     313             :         }
     314             :       }
     315          78 :       G.updateEdgeCosts(edge, std::move(costs));
     316             :     }
     317             :   }
     318          28 : }
     319             : 
     320         132 : static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg,
     321             :                                 const MachineInstr &MI) {
     322             :   const LiveInterval &LI = LIs.getInterval(reg);
     323         132 :   SlotIndex SI = LIs.getInstructionIndex(MI);
     324         132 :   return LI.expiredAt(SI);
     325             : }
     326             : 
     327           5 : void A57ChainingConstraint::apply(PBQPRAGraph &G) {
     328           5 :   const MachineFunction &MF = G.getMetadata().MF;
     329           5 :   LiveIntervals &LIs = G.getMetadata().LIS;
     330             : 
     331           5 :   TRI = MF.getSubtarget().getRegisterInfo();
     332             :   LLVM_DEBUG(MF.dump());
     333             : 
     334          13 :   for (const auto &MBB: MF) {
     335           8 :     Chains.clear(); // FIXME: really needed ? Could not work at MF level ?
     336             : 
     337         170 :     for (const auto &MI: MBB) {
     338             : 
     339             :       // Forget Chains which have expired
     340         294 :       for (auto r : Chains) {
     341             :         SmallVector<unsigned, 8> toDel;
     342         132 :         if(regJustKilledBefore(LIs, r, MI)) {
     343             :           LLVM_DEBUG(dbgs() << "Killing chain " << printReg(r, TRI) << " at ";
     344             :                      MI.print(dbgs()););
     345           4 :           toDel.push_back(r);
     346             :         }
     347             : 
     348         136 :         while (!toDel.empty()) {
     349           4 :           Chains.remove(toDel.back());
     350             :           toDel.pop_back();
     351             :         }
     352             :       }
     353             : 
     354         324 :       switch (MI.getOpcode()) {
     355          28 :       case AArch64::FMSUBSrrr:
     356             :       case AArch64::FMADDSrrr:
     357             :       case AArch64::FNMSUBSrrr:
     358             :       case AArch64::FNMADDSrrr:
     359             :       case AArch64::FMSUBDrrr:
     360             :       case AArch64::FMADDDrrr:
     361             :       case AArch64::FNMSUBDrrr:
     362             :       case AArch64::FNMADDDrrr: {
     363          28 :         unsigned Rd = MI.getOperand(0).getReg();
     364          28 :         unsigned Ra = MI.getOperand(3).getReg();
     365             : 
     366          28 :         if (addIntraChainConstraint(G, Rd, Ra))
     367          28 :           addInterChainConstraint(G, Rd, Ra);
     368             :         break;
     369             :       }
     370             : 
     371           0 :       case AArch64::FMLAv2f32:
     372             :       case AArch64::FMLSv2f32: {
     373           0 :         unsigned Rd = MI.getOperand(0).getReg();
     374           0 :         addInterChainConstraint(G, Rd, Rd);
     375           0 :         break;
     376             :       }
     377             : 
     378             :       default:
     379             :         break;
     380             :       }
     381             :     }
     382             :   }
     383           5 : }

Generated by: LCOV version 1.13