LLVM 20.0.0git
AArch64PBQPRegAlloc.cpp
Go to the documentation of this file.
1//===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This file contains the AArch64 / Cortex-A57 specific register allocation
9// constraints for use by the PBQP register allocator.
10//
11// It is essentially a transcription of what is contained in
12// AArch64A57FPLoadBalancing, which tries to use a balanced
13// mix of odd and even D-registers when performing a critical sequence of
14// independent, non-quadword FP/ASIMD floating-point multiply-accumulates.
15//===----------------------------------------------------------------------===//
16
17#include "AArch64PBQPRegAlloc.h"
18#include "AArch64InstrInfo.h"
19#include "AArch64RegisterInfo.h"
24#include "llvm/Support/Debug.h"
27
28#define DEBUG_TYPE "aarch64-pbqp"
29
30using namespace llvm;
31
32namespace {
33
34bool isOdd(unsigned reg) {
35 switch (reg) {
36 default:
37 llvm_unreachable("Register is not from the expected class !");
38 case AArch64::S1:
39 case AArch64::S3:
40 case AArch64::S5:
41 case AArch64::S7:
42 case AArch64::S9:
43 case AArch64::S11:
44 case AArch64::S13:
45 case AArch64::S15:
46 case AArch64::S17:
47 case AArch64::S19:
48 case AArch64::S21:
49 case AArch64::S23:
50 case AArch64::S25:
51 case AArch64::S27:
52 case AArch64::S29:
53 case AArch64::S31:
54 case AArch64::D1:
55 case AArch64::D3:
56 case AArch64::D5:
57 case AArch64::D7:
58 case AArch64::D9:
59 case AArch64::D11:
60 case AArch64::D13:
61 case AArch64::D15:
62 case AArch64::D17:
63 case AArch64::D19:
64 case AArch64::D21:
65 case AArch64::D23:
66 case AArch64::D25:
67 case AArch64::D27:
68 case AArch64::D29:
69 case AArch64::D31:
70 case AArch64::Q1:
71 case AArch64::Q3:
72 case AArch64::Q5:
73 case AArch64::Q7:
74 case AArch64::Q9:
75 case AArch64::Q11:
76 case AArch64::Q13:
77 case AArch64::Q15:
78 case AArch64::Q17:
79 case AArch64::Q19:
80 case AArch64::Q21:
81 case AArch64::Q23:
82 case AArch64::Q25:
83 case AArch64::Q27:
84 case AArch64::Q29:
85 case AArch64::Q31:
86 return true;
87 case AArch64::S0:
88 case AArch64::S2:
89 case AArch64::S4:
90 case AArch64::S6:
91 case AArch64::S8:
92 case AArch64::S10:
93 case AArch64::S12:
94 case AArch64::S14:
95 case AArch64::S16:
96 case AArch64::S18:
97 case AArch64::S20:
98 case AArch64::S22:
99 case AArch64::S24:
100 case AArch64::S26:
101 case AArch64::S28:
102 case AArch64::S30:
103 case AArch64::D0:
104 case AArch64::D2:
105 case AArch64::D4:
106 case AArch64::D6:
107 case AArch64::D8:
108 case AArch64::D10:
109 case AArch64::D12:
110 case AArch64::D14:
111 case AArch64::D16:
112 case AArch64::D18:
113 case AArch64::D20:
114 case AArch64::D22:
115 case AArch64::D24:
116 case AArch64::D26:
117 case AArch64::D28:
118 case AArch64::D30:
119 case AArch64::Q0:
120 case AArch64::Q2:
121 case AArch64::Q4:
122 case AArch64::Q6:
123 case AArch64::Q8:
124 case AArch64::Q10:
125 case AArch64::Q12:
126 case AArch64::Q14:
127 case AArch64::Q16:
128 case AArch64::Q18:
129 case AArch64::Q20:
130 case AArch64::Q22:
131 case AArch64::Q24:
132 case AArch64::Q26:
133 case AArch64::Q28:
134 case AArch64::Q30:
135 return false;
136
137 }
138}
139
140bool haveSameParity(unsigned reg1, unsigned reg2) {
142 "Expecting an FP register for reg1");
144 "Expecting an FP register for reg2");
145
146 return isOdd(reg1) == isOdd(reg2);
147}
148
149}
150
151bool A57ChainingConstraint::addIntraChainConstraint(PBQPRAGraph &G, unsigned Rd,
152 unsigned Ra) {
153 if (Rd == Ra)
154 return false;
155
156 LiveIntervals &LIs = G.getMetadata().LIS;
157
159 LLVM_DEBUG(dbgs() << "Rd is a physical reg:"
160 << Register::isPhysicalRegister(Rd) << '\n');
161 LLVM_DEBUG(dbgs() << "Ra is a physical reg:"
162 << Register::isPhysicalRegister(Ra) << '\n');
163 return false;
164 }
165
166 PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
167 PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(Ra);
168
169 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
170 &G.getNodeMetadata(node1).getAllowedRegs();
171 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRaAllowed =
172 &G.getNodeMetadata(node2).getAllowedRegs();
173
174 PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
175
176 // The edge does not exist. Create one with the appropriate interference
177 // costs.
178 if (edge == G.invalidEdgeId()) {
179 const LiveInterval &ld = LIs.getInterval(Rd);
180 const LiveInterval &la = LIs.getInterval(Ra);
181 bool livesOverlap = ld.overlaps(la);
182
183 PBQPRAGraph::RawMatrix costs(vRdAllowed->size() + 1,
184 vRaAllowed->size() + 1, 0);
185 for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
186 unsigned pRd = (*vRdAllowed)[i];
187 for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
188 unsigned pRa = (*vRaAllowed)[j];
189 if (livesOverlap && TRI->regsOverlap(pRd, pRa))
190 costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity();
191 else
192 costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0;
193 }
194 }
195 G.addEdge(node1, node2, std::move(costs));
196 return true;
197 }
198
199 if (G.getEdgeNode1Id(edge) == node2) {
200 std::swap(node1, node2);
201 std::swap(vRdAllowed, vRaAllowed);
202 }
203
204 // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass))
205 PBQPRAGraph::RawMatrix costs(G.getEdgeCosts(edge));
206 for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
207 unsigned pRd = (*vRdAllowed)[i];
208
209 // Get the maximum cost (excluding unallocatable reg) for same parity
210 // registers
211 PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
212 for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
213 unsigned pRa = (*vRaAllowed)[j];
214 if (haveSameParity(pRd, pRa))
215 if (costs[i + 1][j + 1] !=
216 std::numeric_limits<PBQP::PBQPNum>::infinity() &&
217 costs[i + 1][j + 1] > sameParityMax)
218 sameParityMax = costs[i + 1][j + 1];
219 }
220
221 // Ensure all registers with a different parity have a higher cost
222 // than sameParityMax
223 for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
224 unsigned pRa = (*vRaAllowed)[j];
225 if (!haveSameParity(pRd, pRa))
226 if (sameParityMax > costs[i + 1][j + 1])
227 costs[i + 1][j + 1] = sameParityMax + 1.0;
228 }
229 }
230 G.updateEdgeCosts(edge, std::move(costs));
231
232 return true;
233}
234
235void A57ChainingConstraint::addInterChainConstraint(PBQPRAGraph &G, unsigned Rd,
236 unsigned Ra) {
237 LiveIntervals &LIs = G.getMetadata().LIS;
238
239 // Do some Chain management
240 if (Chains.count(Ra)) {
241 if (Rd != Ra) {
242 LLVM_DEBUG(dbgs() << "Moving acc chain from " << printReg(Ra, TRI)
243 << " to " << printReg(Rd, TRI) << '\n');
244 Chains.remove(Ra);
245 Chains.insert(Rd);
246 }
247 } else {
248 LLVM_DEBUG(dbgs() << "Creating new acc chain for " << printReg(Rd, TRI)
249 << '\n');
250 Chains.insert(Rd);
251 }
252
253 PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
254
255 const LiveInterval &ld = LIs.getInterval(Rd);
256 for (auto r : Chains) {
257 // Skip self
258 if (r == Rd)
259 continue;
260
261 const LiveInterval &lr = LIs.getInterval(r);
262 if (ld.overlaps(lr)) {
263 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
264 &G.getNodeMetadata(node1).getAllowedRegs();
265
266 PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(r);
267 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRrAllowed =
268 &G.getNodeMetadata(node2).getAllowedRegs();
269
270 PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
271 assert(edge != G.invalidEdgeId() &&
272 "PBQP error ! The edge should exist !");
273
274 LLVM_DEBUG(dbgs() << "Refining constraint !\n");
275
276 if (G.getEdgeNode1Id(edge) == node2) {
277 std::swap(node1, node2);
278 std::swap(vRdAllowed, vRrAllowed);
279 }
280
281 // Enforce that cost is higher with all other Chains of the same parity
282 PBQP::Matrix costs(G.getEdgeCosts(edge));
283 for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
284 unsigned pRd = (*vRdAllowed)[i];
285
286 // Get the maximum cost (excluding unallocatable reg) for all other
287 // parity registers
288 PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
289 for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
290 unsigned pRa = (*vRrAllowed)[j];
291 if (!haveSameParity(pRd, pRa))
292 if (costs[i + 1][j + 1] !=
293 std::numeric_limits<PBQP::PBQPNum>::infinity() &&
294 costs[i + 1][j + 1] > sameParityMax)
295 sameParityMax = costs[i + 1][j + 1];
296 }
297
298 // Ensure all registers with same parity have a higher cost
299 // than sameParityMax
300 for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
301 unsigned pRa = (*vRrAllowed)[j];
302 if (haveSameParity(pRd, pRa))
303 if (sameParityMax > costs[i + 1][j + 1])
304 costs[i + 1][j + 1] = sameParityMax + 1.0;
305 }
306 }
307 G.updateEdgeCosts(edge, std::move(costs));
308 }
309 }
310}
311
312static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg,
313 const MachineInstr &MI) {
314 const LiveInterval &LI = LIs.getInterval(reg);
316 return LI.expiredAt(SI);
317}
318
320 const MachineFunction &MF = G.getMetadata().MF;
321 LiveIntervals &LIs = G.getMetadata().LIS;
322
323 TRI = MF.getSubtarget().getRegisterInfo();
324 LLVM_DEBUG(MF.dump());
325
326 for (const auto &MBB: MF) {
327 Chains.clear(); // FIXME: really needed ? Could not work at MF level ?
328
329 for (const auto &MI: MBB) {
330
331 // Forget Chains which have expired
332 for (auto r : Chains) {
334 if(regJustKilledBefore(LIs, r, MI)) {
335 LLVM_DEBUG(dbgs() << "Killing chain " << printReg(r, TRI) << " at ";
336 MI.print(dbgs()));
337 toDel.push_back(r);
338 }
339
340 while (!toDel.empty()) {
341 Chains.remove(toDel.back());
342 toDel.pop_back();
343 }
344 }
345
346 switch (MI.getOpcode()) {
347 case AArch64::FMSUBSrrr:
348 case AArch64::FMADDSrrr:
349 case AArch64::FNMSUBSrrr:
350 case AArch64::FNMADDSrrr:
351 case AArch64::FMSUBDrrr:
352 case AArch64::FMADDDrrr:
353 case AArch64::FNMSUBDrrr:
354 case AArch64::FNMADDDrrr: {
355 Register Rd = MI.getOperand(0).getReg();
356 Register Ra = MI.getOperand(3).getReg();
357
358 if (addIntraChainConstraint(G, Rd, Ra))
359 addInterChainConstraint(G, Rd, Ra);
360 break;
361 }
362
363 case AArch64::FMLAv2f32:
364 case AArch64::FMLSv2f32: {
365 Register Rd = MI.getOperand(0).getReg();
366 addInterChainConstraint(G, Rd, Rd);
367 break;
368 }
369
370 default:
371 break;
372 }
373 }
374 }
375}
static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg, const MachineInstr &MI)
MachineBasicBlock & MBB
#define LLVM_DEBUG(...)
Definition: Debug.h:106
IRTranslator LLVM IR MI
#define G(x, y, z)
Definition: MD5.cpp:56
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
void apply(PBQPRAGraph &G) override
static bool isFpOrNEON(Register Reg)
Returns whether the physical register is FP or NEON.
LiveInterval - This class represents the liveness of a register, or stack slot.
Definition: LiveInterval.h:687
SlotIndex getInstructionIndex(const MachineInstr &Instr) const
Returns the base index of the given instruction.
LiveInterval & getInterval(Register Reg)
bool overlaps(const LiveRange &other) const
overlaps - Return true if the intersection of the two live ranges is not empty.
Definition: LiveInterval.h:448
bool expiredAt(SlotIndex index) const
Definition: LiveInterval.h:397
const TargetSubtargetInfo & getSubtarget() const
getSubtarget - Return the subtarget for which this machine code is being compiled.
void dump() const
dump - Print the current MachineFunction to cerr, useful for debugger use.
Representation of each machine instruction.
Definition: MachineInstr.h:69
typename SolverT::RawMatrix RawMatrix
Definition: Graph.h:52
PBQP Matrix class.
Definition: Math.h:121
virtual void print(raw_ostream &OS, const Module *M) const
print - Print out the internal state of the pass.
Definition: Pass.cpp:130
Wrapper class representing virtual and physical registers.
Definition: Register.h:19
static constexpr bool isPhysicalRegister(unsigned Reg)
Return true if the specified register number is in the physical register namespace.
Definition: Register.h:65
bool remove(const value_type &X)
Remove an item from the set vector.
Definition: SetVector.h:188
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
Definition: SetVector.h:264
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:162
SlotIndex - An opaque wrapper around machine indexes.
Definition: SlotIndexes.h:65
bool empty() const
Definition: SmallVector.h:81
void push_back(const T &Elt)
Definition: SmallVector.h:413
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1196
bool regsOverlap(Register RegA, Register RegB) const
Returns true if the two registers are equal or alias each other.
virtual const TargetRegisterInfo * getRegisterInfo() const
getRegisterInfo - If register information is available, return it.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
float PBQPNum
Definition: Math.h:22
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
Printable printReg(Register Reg, const TargetRegisterInfo *TRI=nullptr, unsigned SubIdx=0, const MachineRegisterInfo *MRI=nullptr)
Prints virtual and physical registers with or without a TRI instance.
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:860