LLVM  4.0.0
AArch64PBQPRegAlloc.cpp
Go to the documentation of this file.
1 //===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 // This file contains the AArch64 / Cortex-A57 specific register allocation
10 // constraints for use by the PBQP register allocator.
11 //
12 // It is essentially a transcription of what is contained in
13 // AArch64A57FPLoadBalancing, which tries to use a balanced
14 // mix of odd and even D-registers when performing a critical sequence of
15 // independent, non-quadword FP/ASIMD floating-point multiply-accumulates.
16 //===----------------------------------------------------------------------===//
17 
18 #define DEBUG_TYPE "aarch64-pbqp"
19 
20 #include "AArch64.h"
21 #include "AArch64PBQPRegAlloc.h"
22 #include "AArch64RegisterInfo.h"
28 #include "llvm/Support/Debug.h"
31 
32 using namespace llvm;
33 
34 namespace {
35 
36 #ifndef NDEBUG
37 bool isFPReg(unsigned reg) {
38  return AArch64::FPR32RegClass.contains(reg) ||
39  AArch64::FPR64RegClass.contains(reg) ||
40  AArch64::FPR128RegClass.contains(reg);
41 }
42 #endif
43 
44 bool isOdd(unsigned reg) {
45  switch (reg) {
46  default:
47  llvm_unreachable("Register is not from the expected class !");
48  case AArch64::S1:
49  case AArch64::S3:
50  case AArch64::S5:
51  case AArch64::S7:
52  case AArch64::S9:
53  case AArch64::S11:
54  case AArch64::S13:
55  case AArch64::S15:
56  case AArch64::S17:
57  case AArch64::S19:
58  case AArch64::S21:
59  case AArch64::S23:
60  case AArch64::S25:
61  case AArch64::S27:
62  case AArch64::S29:
63  case AArch64::S31:
64  case AArch64::D1:
65  case AArch64::D3:
66  case AArch64::D5:
67  case AArch64::D7:
68  case AArch64::D9:
69  case AArch64::D11:
70  case AArch64::D13:
71  case AArch64::D15:
72  case AArch64::D17:
73  case AArch64::D19:
74  case AArch64::D21:
75  case AArch64::D23:
76  case AArch64::D25:
77  case AArch64::D27:
78  case AArch64::D29:
79  case AArch64::D31:
80  case AArch64::Q1:
81  case AArch64::Q3:
82  case AArch64::Q5:
83  case AArch64::Q7:
84  case AArch64::Q9:
85  case AArch64::Q11:
86  case AArch64::Q13:
87  case AArch64::Q15:
88  case AArch64::Q17:
89  case AArch64::Q19:
90  case AArch64::Q21:
91  case AArch64::Q23:
92  case AArch64::Q25:
93  case AArch64::Q27:
94  case AArch64::Q29:
95  case AArch64::Q31:
96  return true;
97  case AArch64::S0:
98  case AArch64::S2:
99  case AArch64::S4:
100  case AArch64::S6:
101  case AArch64::S8:
102  case AArch64::S10:
103  case AArch64::S12:
104  case AArch64::S14:
105  case AArch64::S16:
106  case AArch64::S18:
107  case AArch64::S20:
108  case AArch64::S22:
109  case AArch64::S24:
110  case AArch64::S26:
111  case AArch64::S28:
112  case AArch64::S30:
113  case AArch64::D0:
114  case AArch64::D2:
115  case AArch64::D4:
116  case AArch64::D6:
117  case AArch64::D8:
118  case AArch64::D10:
119  case AArch64::D12:
120  case AArch64::D14:
121  case AArch64::D16:
122  case AArch64::D18:
123  case AArch64::D20:
124  case AArch64::D22:
125  case AArch64::D24:
126  case AArch64::D26:
127  case AArch64::D28:
128  case AArch64::D30:
129  case AArch64::Q0:
130  case AArch64::Q2:
131  case AArch64::Q4:
132  case AArch64::Q6:
133  case AArch64::Q8:
134  case AArch64::Q10:
135  case AArch64::Q12:
136  case AArch64::Q14:
137  case AArch64::Q16:
138  case AArch64::Q18:
139  case AArch64::Q20:
140  case AArch64::Q22:
141  case AArch64::Q24:
142  case AArch64::Q26:
143  case AArch64::Q28:
144  case AArch64::Q30:
145  return false;
146 
147  }
148 }
149 
150 bool haveSameParity(unsigned reg1, unsigned reg2) {
151  assert(isFPReg(reg1) && "Expecting an FP register for reg1");
152  assert(isFPReg(reg2) && "Expecting an FP register for reg2");
153 
154  return isOdd(reg1) == isOdd(reg2);
155 }
156 
157 }
158 
159 bool A57ChainingConstraint::addIntraChainConstraint(PBQPRAGraph &G, unsigned Rd,
160  unsigned Ra) {
161  if (Rd == Ra)
162  return false;
163 
164  LiveIntervals &LIs = G.getMetadata().LIS;
165 
166  if (TRI->isPhysicalRegister(Rd) || TRI->isPhysicalRegister(Ra)) {
167  DEBUG(dbgs() << "Rd is a physical reg:" << TRI->isPhysicalRegister(Rd)
168  << '\n');
169  DEBUG(dbgs() << "Ra is a physical reg:" << TRI->isPhysicalRegister(Ra)
170  << '\n');
171  return false;
172  }
173 
174  PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
175  PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(Ra);
176 
177  const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
178  &G.getNodeMetadata(node1).getAllowedRegs();
179  const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRaAllowed =
180  &G.getNodeMetadata(node2).getAllowedRegs();
181 
182  PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
183 
184  // The edge does not exist. Create one with the appropriate interference
185  // costs.
186  if (edge == G.invalidEdgeId()) {
187  const LiveInterval &ld = LIs.getInterval(Rd);
188  const LiveInterval &la = LIs.getInterval(Ra);
189  bool livesOverlap = ld.overlaps(la);
190 
191  PBQPRAGraph::RawMatrix costs(vRdAllowed->size() + 1,
192  vRaAllowed->size() + 1, 0);
193  for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
194  unsigned pRd = (*vRdAllowed)[i];
195  for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
196  unsigned pRa = (*vRaAllowed)[j];
197  if (livesOverlap && TRI->regsOverlap(pRd, pRa))
198  costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity();
199  else
200  costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0;
201  }
202  }
203  G.addEdge(node1, node2, std::move(costs));
204  return true;
205  }
206 
207  if (G.getEdgeNode1Id(edge) == node2) {
208  std::swap(node1, node2);
209  std::swap(vRdAllowed, vRaAllowed);
210  }
211 
212  // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass))
213  PBQPRAGraph::RawMatrix costs(G.getEdgeCosts(edge));
214  for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
215  unsigned pRd = (*vRdAllowed)[i];
216 
217  // Get the maximum cost (excluding unallocatable reg) for same parity
218  // registers
220  for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
221  unsigned pRa = (*vRaAllowed)[j];
222  if (haveSameParity(pRd, pRa))
223  if (costs[i + 1][j + 1] !=
224  std::numeric_limits<PBQP::PBQPNum>::infinity() &&
225  costs[i + 1][j + 1] > sameParityMax)
226  sameParityMax = costs[i + 1][j + 1];
227  }
228 
229  // Ensure all registers with a different parity have a higher cost
230  // than sameParityMax
231  for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
232  unsigned pRa = (*vRaAllowed)[j];
233  if (!haveSameParity(pRd, pRa))
234  if (sameParityMax > costs[i + 1][j + 1])
235  costs[i + 1][j + 1] = sameParityMax + 1.0;
236  }
237  }
238  G.updateEdgeCosts(edge, std::move(costs));
239 
240  return true;
241 }
242 
243 void A57ChainingConstraint::addInterChainConstraint(PBQPRAGraph &G, unsigned Rd,
244  unsigned Ra) {
245  LiveIntervals &LIs = G.getMetadata().LIS;
246 
247  // Do some Chain management
248  if (Chains.count(Ra)) {
249  if (Rd != Ra) {
250  DEBUG(dbgs() << "Moving acc chain from " << PrintReg(Ra, TRI) << " to "
251  << PrintReg(Rd, TRI) << '\n';);
252  Chains.remove(Ra);
253  Chains.insert(Rd);
254  }
255  } else {
256  DEBUG(dbgs() << "Creating new acc chain for " << PrintReg(Rd, TRI)
257  << '\n';);
258  Chains.insert(Rd);
259  }
260 
261  PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
262 
263  const LiveInterval &ld = LIs.getInterval(Rd);
264  for (auto r : Chains) {
265  // Skip self
266  if (r == Rd)
267  continue;
268 
269  const LiveInterval &lr = LIs.getInterval(r);
270  if (ld.overlaps(lr)) {
271  const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
272  &G.getNodeMetadata(node1).getAllowedRegs();
273 
274  PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(r);
275  const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRrAllowed =
276  &G.getNodeMetadata(node2).getAllowedRegs();
277 
278  PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
279  assert(edge != G.invalidEdgeId() &&
280  "PBQP error ! The edge should exist !");
281 
282  DEBUG(dbgs() << "Refining constraint !\n";);
283 
284  if (G.getEdgeNode1Id(edge) == node2) {
285  std::swap(node1, node2);
286  std::swap(vRdAllowed, vRrAllowed);
287  }
288 
289  // Enforce that cost is higher with all other Chains of the same parity
290  PBQP::Matrix costs(G.getEdgeCosts(edge));
291  for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
292  unsigned pRd = (*vRdAllowed)[i];
293 
294  // Get the maximum cost (excluding unallocatable reg) for all other
295  // parity registers
297  for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
298  unsigned pRa = (*vRrAllowed)[j];
299  if (!haveSameParity(pRd, pRa))
300  if (costs[i + 1][j + 1] !=
301  std::numeric_limits<PBQP::PBQPNum>::infinity() &&
302  costs[i + 1][j + 1] > sameParityMax)
303  sameParityMax = costs[i + 1][j + 1];
304  }
305 
306  // Ensure all registers with same parity have a higher cost
307  // than sameParityMax
308  for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
309  unsigned pRa = (*vRrAllowed)[j];
310  if (haveSameParity(pRd, pRa))
311  if (sameParityMax > costs[i + 1][j + 1])
312  costs[i + 1][j + 1] = sameParityMax + 1.0;
313  }
314  }
315  G.updateEdgeCosts(edge, std::move(costs));
316  }
317  }
318 }
319 
320 static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg,
321  const MachineInstr &MI) {
322  const LiveInterval &LI = LIs.getInterval(reg);
323  SlotIndex SI = LIs.getInstructionIndex(MI);
324  return LI.expiredAt(SI);
325 }
326 
328  const MachineFunction &MF = G.getMetadata().MF;
329  LiveIntervals &LIs = G.getMetadata().LIS;
330 
331  TRI = MF.getSubtarget().getRegisterInfo();
332  DEBUG(MF.dump());
333 
334  for (const auto &MBB: MF) {
335  Chains.clear(); // FIXME: really needed ? Could not work at MF level ?
336 
337  for (const auto &MI: MBB) {
338 
339  // Forget Chains which have expired
340  for (auto r : Chains) {
342  if(regJustKilledBefore(LIs, r, MI)) {
343  DEBUG(dbgs() << "Killing chain " << PrintReg(r, TRI) << " at ";
344  MI.print(dbgs()););
345  toDel.push_back(r);
346  }
347 
348  while (!toDel.empty()) {
349  Chains.remove(toDel.back());
350  toDel.pop_back();
351  }
352  }
353 
354  switch (MI.getOpcode()) {
355  case AArch64::FMSUBSrrr:
356  case AArch64::FMADDSrrr:
357  case AArch64::FNMSUBSrrr:
358  case AArch64::FNMADDSrrr:
359  case AArch64::FMSUBDrrr:
360  case AArch64::FMADDDrrr:
361  case AArch64::FNMSUBDrrr:
362  case AArch64::FNMADDDrrr: {
363  unsigned Rd = MI.getOperand(0).getReg();
364  unsigned Ra = MI.getOperand(3).getReg();
365 
366  if (addIntraChainConstraint(G, Rd, Ra))
367  addInterChainConstraint(G, Rd, Ra);
368  break;
369  }
370 
371  case AArch64::FMLAv2f32:
372  case AArch64::FMLSv2f32: {
373  unsigned Rd = MI.getOperand(0).getReg();
374  addInterChainConstraint(G, Rd, Rd);
375  break;
376  }
377 
378  default:
379  break;
380  }
381  }
382  }
383 }
void push_back(const T &Elt)
Definition: SmallVector.h:211
NodeId getEdgeNode1Id(EdgeId EId) const
Get the first node connected to this edge.
Definition: Graph.h:533
size_t i
EdgeId addEdge(NodeId N1Id, NodeId N2Id, OtherVectorT Costs)
Add an edge between the given nodes with the given costs.
Definition: Graph.h:397
EdgeId findEdge(NodeId N1Id, NodeId N2Id)
Get the edge connecting two nodes.
Definition: Graph.h:561
LiveInterval - This class represents the liveness of a register, or stack slot.
Definition: LiveInterval.h:625
SlotIndex getInstructionIndex(const MachineInstr &Instr) const
Returns the base index of the given instruction.
void updateEdgeCosts(EdgeId EId, OtherMatrixT Costs)
Update an edge's cost matrix.
Definition: Graph.h:496
RegAllocSolverImpl::RawMatrix RawMatrix
Definition: Graph.h:54
NodeMetadata & getNodeMetadata(NodeId NId)
Definition: Graph.h:480
const TargetSubtargetInfo & getSubtarget() const
getSubtarget - Return the subtarget for which this machine code is being compiled.
bool remove(const value_type &X)
Remove an item from the set vector.
Definition: SetVector.h:152
LLVM_NODISCARD bool empty() const
Definition: SmallVector.h:60
void apply(PBQPRAGraph &G) override
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:136
MachineBasicBlock * MBB
bool expiredAt(SlotIndex index) const
Definition: LiveInterval.h:372
Printable PrintReg(unsigned Reg, const TargetRegisterInfo *TRI=nullptr, unsigned SubRegIdx=0)
Prints virtual and physical registers with or without a TRI instance.
Maximum length of the test input libFuzzer tries to guess a good value based on the corpus and reports it always prefer smaller inputs during the corpus shuffle When libFuzzer itself reports a bug this exit code will be used If indicates the maximal total time in seconds to run the fuzzer minimizes the provided crash input Use with etc Experimental Use value profile to guide fuzzing Number of simultaneous worker processes to run the jobs If min(jobs, NumberOfCpuCores()/2)\" is used.") FUZZER_FLAG_INT(reload
bool regsOverlap(unsigned regA, unsigned regB) const
Returns true if the two registers are equal or alias each other.
PBQP Matrix class.
Definition: Math.h:120
bool overlaps(const LiveRange &other) const
overlaps - Return true if the intersection of the two live ranges is not empty.
Definition: LiveInterval.h:423
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
void dump() const
dump - Print the current MachineFunction to cerr, useful for debugger use.
float PBQPNum
Definition: Math.h:21
GraphMetadata & getMetadata()
Get a reference to the graph metadata.
Definition: Graph.h:335
const DataFlowGraph & G
Definition: RDFGraph.cpp:206
const Matrix & getEdgeCosts(EdgeId EId) const
Get an edge's cost matrix.
Definition: Graph.h:518
LiveInterval & getInterval(unsigned Reg)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:132
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:586
static EdgeId invalidEdgeId()
Returns a value representing an invalid (non-existent) edge.
Definition: Graph.h:40
Representation of each machine instruction.
Definition: MachineInstr.h:52
static bool isPhysicalRegister(unsigned Reg)
Return true if the specified register number is in the physical register namespace.
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
Definition: SetVector.h:205
static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg, const MachineInstr &MI)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
#define DEBUG(X)
Definition: Debug.h:100
IRTranslator LLVM IR MI
virtual const TargetRegisterInfo * getRegisterInfo() const
getRegisterInfo - If register information is available, return it.
virtual void print(raw_ostream &O, const Module *M) const
print - Print out the internal state of the pass.
Definition: Pass.cpp:117
SlotIndex - An opaque wrapper around machine indexes.
Definition: SlotIndexes.h:76