Line data Source code
1 : //===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===//
2 : //
3 : // The LLVM Compiler Infrastructure
4 : //
5 : // This file is distributed under the University of Illinois Open Source
6 : // License. See LICENSE.TXT for details.
7 : //
8 : //===----------------------------------------------------------------------===//
9 : // This file contains the AArch64 / Cortex-A57 specific register allocation
10 : // constraints for use by the PBQP register allocator.
11 : //
12 : // It is essentially a transcription of what is contained in
13 : // AArch64A57FPLoadBalancing, which tries to use a balanced
14 : // mix of odd and even D-registers when performing a critical sequence of
15 : // independent, non-quadword FP/ASIMD floating-point multiply-accumulates.
16 : //===----------------------------------------------------------------------===//
17 :
18 : #define DEBUG_TYPE "aarch64-pbqp"
19 :
20 : #include "AArch64PBQPRegAlloc.h"
21 : #include "AArch64.h"
22 : #include "AArch64RegisterInfo.h"
23 : #include "llvm/CodeGen/LiveIntervals.h"
24 : #include "llvm/CodeGen/MachineBasicBlock.h"
25 : #include "llvm/CodeGen/MachineFunction.h"
26 : #include "llvm/CodeGen/MachineRegisterInfo.h"
27 : #include "llvm/CodeGen/RegAllocPBQP.h"
28 : #include "llvm/Support/Debug.h"
29 : #include "llvm/Support/ErrorHandling.h"
30 : #include "llvm/Support/raw_ostream.h"
31 :
32 : using namespace llvm;
33 :
34 : namespace {
35 :
36 : #ifndef NDEBUG
37 : bool isFPReg(unsigned reg) {
38 : return AArch64::FPR32RegClass.contains(reg) ||
39 : AArch64::FPR64RegClass.contains(reg) ||
40 : AArch64::FPR128RegClass.contains(reg);
41 : }
42 : #endif
43 :
44 163840 : bool isOdd(unsigned reg) {
45 163840 : switch (reg) {
46 0 : default:
47 0 : llvm_unreachable("Register is not from the expected class !");
48 : case AArch64::S1:
49 : case AArch64::S3:
50 : case AArch64::S5:
51 : case AArch64::S7:
52 : case AArch64::S9:
53 : case AArch64::S11:
54 : case AArch64::S13:
55 : case AArch64::S15:
56 : case AArch64::S17:
57 : case AArch64::S19:
58 : case AArch64::S21:
59 : case AArch64::S23:
60 : case AArch64::S25:
61 : case AArch64::S27:
62 : case AArch64::S29:
63 : case AArch64::S31:
64 : case AArch64::D1:
65 : case AArch64::D3:
66 : case AArch64::D5:
67 : case AArch64::D7:
68 : case AArch64::D9:
69 : case AArch64::D11:
70 : case AArch64::D13:
71 : case AArch64::D15:
72 : case AArch64::D17:
73 : case AArch64::D19:
74 : case AArch64::D21:
75 : case AArch64::D23:
76 : case AArch64::D25:
77 : case AArch64::D27:
78 : case AArch64::D29:
79 : case AArch64::D31:
80 : case AArch64::Q1:
81 : case AArch64::Q3:
82 : case AArch64::Q5:
83 : case AArch64::Q7:
84 : case AArch64::Q9:
85 : case AArch64::Q11:
86 : case AArch64::Q13:
87 : case AArch64::Q15:
88 : case AArch64::Q17:
89 : case AArch64::Q19:
90 : case AArch64::Q21:
91 : case AArch64::Q23:
92 : case AArch64::Q25:
93 : case AArch64::Q27:
94 : case AArch64::Q29:
95 : case AArch64::Q31:
96 : return true;
97 81920 : case AArch64::S0:
98 : case AArch64::S2:
99 : case AArch64::S4:
100 : case AArch64::S6:
101 : case AArch64::S8:
102 : case AArch64::S10:
103 : case AArch64::S12:
104 : case AArch64::S14:
105 : case AArch64::S16:
106 : case AArch64::S18:
107 : case AArch64::S20:
108 : case AArch64::S22:
109 : case AArch64::S24:
110 : case AArch64::S26:
111 : case AArch64::S28:
112 : case AArch64::S30:
113 : case AArch64::D0:
114 : case AArch64::D2:
115 : case AArch64::D4:
116 : case AArch64::D6:
117 : case AArch64::D8:
118 : case AArch64::D10:
119 : case AArch64::D12:
120 : case AArch64::D14:
121 : case AArch64::D16:
122 : case AArch64::D18:
123 : case AArch64::D20:
124 : case AArch64::D22:
125 : case AArch64::D24:
126 : case AArch64::D26:
127 : case AArch64::D28:
128 : case AArch64::D30:
129 : case AArch64::Q0:
130 : case AArch64::Q2:
131 : case AArch64::Q4:
132 : case AArch64::Q6:
133 : case AArch64::Q8:
134 : case AArch64::Q10:
135 : case AArch64::Q12:
136 : case AArch64::Q14:
137 : case AArch64::Q16:
138 : case AArch64::Q18:
139 : case AArch64::Q20:
140 : case AArch64::Q22:
141 : case AArch64::Q24:
142 : case AArch64::Q26:
143 : case AArch64::Q28:
144 : case AArch64::Q30:
145 81920 : return false;
146 :
147 : }
148 : }
149 :
150 : bool haveSameParity(unsigned reg1, unsigned reg2) {
151 : assert(isFPReg(reg1) && "Expecting an FP register for reg1");
152 : assert(isFPReg(reg2) && "Expecting an FP register for reg2");
153 :
154 81920 : return isOdd(reg1) == isOdd(reg2);
155 : }
156 :
157 : }
158 :
159 28 : bool A57ChainingConstraint::addIntraChainConstraint(PBQPRAGraph &G, unsigned Rd,
160 : unsigned Ra) {
161 28 : if (Rd == Ra)
162 : return false;
163 :
164 28 : LiveIntervals &LIs = G.getMetadata().LIS;
165 :
166 28 : if (TRI->isPhysicalRegister(Rd) || TRI->isPhysicalRegister(Ra)) {
167 : LLVM_DEBUG(dbgs() << "Rd is a physical reg:" << TRI->isPhysicalRegister(Rd)
168 : << '\n');
169 : LLVM_DEBUG(dbgs() << "Ra is a physical reg:" << TRI->isPhysicalRegister(Ra)
170 : << '\n');
171 : return false;
172 : }
173 :
174 28 : PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
175 28 : PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(Ra);
176 :
177 : const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
178 : &G.getNodeMetadata(node1).getAllowedRegs();
179 : const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRaAllowed =
180 : &G.getNodeMetadata(node2).getAllowedRegs();
181 :
182 : PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
183 :
184 : // The edge does not exist. Create one with the appropriate interference
185 : // costs.
186 28 : if (edge == G.invalidEdgeId()) {
187 28 : const LiveInterval &ld = LIs.getInterval(Rd);
188 28 : const LiveInterval &la = LIs.getInterval(Ra);
189 28 : bool livesOverlap = ld.overlaps(la);
190 :
191 28 : PBQPRAGraph::RawMatrix costs(vRdAllowed->size() + 1,
192 28 : vRaAllowed->size() + 1, 0);
193 924 : for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
194 896 : unsigned pRd = (*vRdAllowed)[i];
195 29568 : for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
196 28672 : unsigned pRa = (*vRaAllowed)[j];
197 28672 : if (livesOverlap && TRI->regsOverlap(pRd, pRa))
198 0 : costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity();
199 : else
200 57344 : costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0;
201 : }
202 : }
203 84 : G.addEdge(node1, node2, std::move(costs));
204 : return true;
205 : }
206 :
207 0 : if (G.getEdgeNode1Id(edge) == node2) {
208 : std::swap(node1, node2);
209 : std::swap(vRdAllowed, vRaAllowed);
210 : }
211 :
212 : // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass))
213 0 : PBQPRAGraph::RawMatrix costs(G.getEdgeCosts(edge));
214 0 : for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
215 0 : unsigned pRd = (*vRdAllowed)[i];
216 :
217 : // Get the maximum cost (excluding unallocatable reg) for same parity
218 : // registers
219 : PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
220 0 : for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
221 0 : unsigned pRa = (*vRaAllowed)[j];
222 0 : if (haveSameParity(pRd, pRa))
223 0 : if (costs[i + 1][j + 1] !=
224 0 : std::numeric_limits<PBQP::PBQPNum>::infinity() &&
225 : costs[i + 1][j + 1] > sameParityMax)
226 : sameParityMax = costs[i + 1][j + 1];
227 : }
228 :
229 : // Ensure all registers with a different parity have a higher cost
230 : // than sameParityMax
231 0 : for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
232 0 : unsigned pRa = (*vRaAllowed)[j];
233 0 : if (!haveSameParity(pRd, pRa))
234 0 : if (sameParityMax > costs[i + 1][j + 1])
235 0 : costs[i + 1][j + 1] = sameParityMax + 1.0;
236 : }
237 : }
238 0 : G.updateEdgeCosts(edge, std::move(costs));
239 :
240 : return true;
241 : }
242 :
243 28 : void A57ChainingConstraint::addInterChainConstraint(PBQPRAGraph &G, unsigned Rd,
244 : unsigned Ra) {
245 28 : LiveIntervals &LIs = G.getMetadata().LIS;
246 :
247 : // Do some Chain management
248 28 : if (Chains.count(Ra)) {
249 24 : if (Rd != Ra) {
250 : LLVM_DEBUG(dbgs() << "Moving acc chain from " << printReg(Ra, TRI)
251 : << " to " << printReg(Rd, TRI) << '\n';);
252 24 : Chains.remove(Ra);
253 24 : Chains.insert(Rd);
254 : }
255 : } else {
256 : LLVM_DEBUG(dbgs() << "Creating new acc chain for " << printReg(Rd, TRI)
257 : << '\n';);
258 4 : Chains.insert(Rd);
259 : }
260 :
261 28 : PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
262 :
263 28 : const LiveInterval &ld = LIs.getInterval(Rd);
264 82 : for (auto r : Chains) {
265 : // Skip self
266 54 : if (r == Rd)
267 : continue;
268 :
269 26 : const LiveInterval &lr = LIs.getInterval(r);
270 52 : if (ld.overlaps(lr)) {
271 : const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
272 : &G.getNodeMetadata(node1).getAllowedRegs();
273 :
274 26 : PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(r);
275 : const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRrAllowed =
276 : &G.getNodeMetadata(node2).getAllowedRegs();
277 :
278 : PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
279 : assert(edge != G.invalidEdgeId() &&
280 : "PBQP error ! The edge should exist !");
281 :
282 : LLVM_DEBUG(dbgs() << "Refining constraint !\n";);
283 :
284 26 : if (G.getEdgeNode1Id(edge) == node2) {
285 : std::swap(node1, node2);
286 : std::swap(vRdAllowed, vRrAllowed);
287 : }
288 :
289 : // Enforce that cost is higher with all other Chains of the same parity
290 26 : PBQP::Matrix costs(G.getEdgeCosts(edge));
291 858 : for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
292 832 : unsigned pRd = (*vRdAllowed)[i];
293 :
294 : // Get the maximum cost (excluding unallocatable reg) for all other
295 : // parity registers
296 : PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
297 27456 : for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
298 26624 : unsigned pRa = (*vRrAllowed)[j];
299 26624 : if (!haveSameParity(pRd, pRa))
300 13312 : if (costs[i + 1][j + 1] !=
301 13312 : std::numeric_limits<PBQP::PBQPNum>::infinity() &&
302 : costs[i + 1][j + 1] > sameParityMax)
303 : sameParityMax = costs[i + 1][j + 1];
304 : }
305 :
306 : // Ensure all registers with same parity have a higher cost
307 : // than sameParityMax
308 27456 : for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
309 26624 : unsigned pRa = (*vRrAllowed)[j];
310 26624 : if (haveSameParity(pRd, pRa))
311 26624 : if (sameParityMax > costs[i + 1][j + 1])
312 12480 : costs[i + 1][j + 1] = sameParityMax + 1.0;
313 : }
314 : }
315 78 : G.updateEdgeCosts(edge, std::move(costs));
316 : }
317 : }
318 28 : }
319 :
320 132 : static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg,
321 : const MachineInstr &MI) {
322 : const LiveInterval &LI = LIs.getInterval(reg);
323 132 : SlotIndex SI = LIs.getInstructionIndex(MI);
324 132 : return LI.expiredAt(SI);
325 : }
326 :
327 5 : void A57ChainingConstraint::apply(PBQPRAGraph &G) {
328 5 : const MachineFunction &MF = G.getMetadata().MF;
329 5 : LiveIntervals &LIs = G.getMetadata().LIS;
330 :
331 5 : TRI = MF.getSubtarget().getRegisterInfo();
332 : LLVM_DEBUG(MF.dump());
333 :
334 13 : for (const auto &MBB: MF) {
335 8 : Chains.clear(); // FIXME: really needed ? Could not work at MF level ?
336 :
337 170 : for (const auto &MI: MBB) {
338 :
339 : // Forget Chains which have expired
340 294 : for (auto r : Chains) {
341 : SmallVector<unsigned, 8> toDel;
342 132 : if(regJustKilledBefore(LIs, r, MI)) {
343 : LLVM_DEBUG(dbgs() << "Killing chain " << printReg(r, TRI) << " at ";
344 : MI.print(dbgs()););
345 4 : toDel.push_back(r);
346 : }
347 :
348 136 : while (!toDel.empty()) {
349 4 : Chains.remove(toDel.back());
350 : toDel.pop_back();
351 : }
352 : }
353 :
354 324 : switch (MI.getOpcode()) {
355 28 : case AArch64::FMSUBSrrr:
356 : case AArch64::FMADDSrrr:
357 : case AArch64::FNMSUBSrrr:
358 : case AArch64::FNMADDSrrr:
359 : case AArch64::FMSUBDrrr:
360 : case AArch64::FMADDDrrr:
361 : case AArch64::FNMSUBDrrr:
362 : case AArch64::FNMADDDrrr: {
363 28 : unsigned Rd = MI.getOperand(0).getReg();
364 28 : unsigned Ra = MI.getOperand(3).getReg();
365 :
366 28 : if (addIntraChainConstraint(G, Rd, Ra))
367 28 : addInterChainConstraint(G, Rd, Ra);
368 : break;
369 : }
370 :
371 0 : case AArch64::FMLAv2f32:
372 : case AArch64::FMLSv2f32: {
373 0 : unsigned Rd = MI.getOperand(0).getReg();
374 0 : addInterChainConstraint(G, Rd, Rd);
375 0 : break;
376 : }
377 :
378 : default:
379 : break;
380 : }
381 : }
382 : }
383 5 : }
|