Line data Source code
1 : //===- GCNRegPressure.cpp -------------------------------------------------===//
2 : //
3 : // The LLVM Compiler Infrastructure
4 : //
5 : // This file is distributed under the University of Illinois Open Source
6 : // License. See LICENSE.TXT for details.
7 : //
8 : //===----------------------------------------------------------------------===//
9 :
10 : #include "GCNRegPressure.h"
11 : #include "AMDGPUSubtarget.h"
12 : #include "SIRegisterInfo.h"
13 : #include "llvm/ADT/SmallVector.h"
14 : #include "llvm/CodeGen/LiveInterval.h"
15 : #include "llvm/CodeGen/LiveIntervals.h"
16 : #include "llvm/CodeGen/MachineInstr.h"
17 : #include "llvm/CodeGen/MachineOperand.h"
18 : #include "llvm/CodeGen/MachineRegisterInfo.h"
19 : #include "llvm/CodeGen/RegisterPressure.h"
20 : #include "llvm/CodeGen/SlotIndexes.h"
21 : #include "llvm/CodeGen/TargetRegisterInfo.h"
22 : #include "llvm/Config/llvm-config.h"
23 : #include "llvm/MC/LaneBitmask.h"
24 : #include "llvm/Support/Compiler.h"
25 : #include "llvm/Support/Debug.h"
26 : #include "llvm/Support/ErrorHandling.h"
27 : #include "llvm/Support/raw_ostream.h"
28 : #include <algorithm>
29 : #include <cassert>
30 :
31 : using namespace llvm;
32 :
33 : #define DEBUG_TYPE "machine-scheduler"
34 :
35 : #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
36 : LLVM_DUMP_METHOD
37 : void llvm::printLivesAt(SlotIndex SI,
38 : const LiveIntervals &LIS,
39 : const MachineRegisterInfo &MRI) {
40 : dbgs() << "Live regs at " << SI << ": "
41 : << *LIS.getInstructionFromIndex(SI);
42 : unsigned Num = 0;
43 : for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
44 : const unsigned Reg = TargetRegisterInfo::index2VirtReg(I);
45 : if (!LIS.hasInterval(Reg))
46 : continue;
47 : const auto &LI = LIS.getInterval(Reg);
48 : if (LI.hasSubRanges()) {
49 : bool firstTime = true;
50 : for (const auto &S : LI.subranges()) {
51 : if (!S.liveAt(SI)) continue;
52 : if (firstTime) {
53 : dbgs() << " " << printReg(Reg, MRI.getTargetRegisterInfo())
54 : << '\n';
55 : firstTime = false;
56 : }
57 : dbgs() << " " << S << '\n';
58 : ++Num;
59 : }
60 : } else if (LI.liveAt(SI)) {
61 : dbgs() << " " << LI << '\n';
62 : ++Num;
63 : }
64 : }
65 : if (!Num) dbgs() << " <none>\n";
66 : }
67 :
68 : static bool isEqual(const GCNRPTracker::LiveRegSet &S1,
69 : const GCNRPTracker::LiveRegSet &S2) {
70 : if (S1.size() != S2.size())
71 : return false;
72 :
73 : for (const auto &P : S1) {
74 : auto I = S2.find(P.first);
75 : if (I == S2.end() || I->second != P.second)
76 : return false;
77 : }
78 : return true;
79 : }
80 : #endif
81 :
82 : ///////////////////////////////////////////////////////////////////////////////
83 : // GCNRegPressure
84 :
85 1232427 : unsigned GCNRegPressure::getRegKind(unsigned Reg,
86 : const MachineRegisterInfo &MRI) {
87 : assert(TargetRegisterInfo::isVirtualRegister(Reg));
88 : const auto RC = MRI.getRegClass(Reg);
89 1232427 : auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo());
90 1232427 : return STI->isSGPRClass(RC) ?
91 : (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) :
92 1232427 : (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE);
93 : }
94 :
95 2454451 : void GCNRegPressure::inc(unsigned Reg,
96 : LaneBitmask PrevMask,
97 : LaneBitmask NewMask,
98 : const MachineRegisterInfo &MRI) {
99 2454451 : if (NewMask == PrevMask)
100 : return;
101 :
102 : int Sign = 1;
103 1232427 : if (NewMask < PrevMask) {
104 : std::swap(NewMask, PrevMask);
105 : Sign = -1;
106 : }
107 : #ifndef NDEBUG
108 : const auto MaxMask = MRI.getMaxLaneMaskForVReg(Reg);
109 : #endif
110 1232427 : switch (auto Kind = getRegKind(Reg, MRI)) {
111 443090 : case SGPR32:
112 : case VGPR32:
113 : assert(PrevMask.none() && NewMask == MaxMask);
114 443090 : Value[Kind] += Sign;
115 443090 : break;
116 :
117 789337 : case SGPR_TUPLE:
118 : case VGPR_TUPLE:
119 : assert(NewMask < MaxMask || NewMask == MaxMask);
120 : assert(PrevMask < NewMask);
121 :
122 1578674 : Value[Kind == SGPR_TUPLE ? SGPR32 : VGPR32] +=
123 789337 : Sign * (~PrevMask & NewMask).getNumLanes();
124 :
125 789337 : if (PrevMask.none()) {
126 : assert(NewMask.any());
127 409764 : Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight();
128 : }
129 : break;
130 :
131 0 : default: llvm_unreachable("Unknown register kind");
132 : }
133 : }
134 :
135 0 : bool GCNRegPressure::less(const GCNSubtarget &ST,
136 : const GCNRegPressure& O,
137 : unsigned MaxOccupancy) const {
138 0 : const auto SGPROcc = std::min(MaxOccupancy,
139 0 : ST.getOccupancyWithNumSGPRs(getSGPRNum()));
140 0 : const auto VGPROcc = std::min(MaxOccupancy,
141 0 : ST.getOccupancyWithNumVGPRs(getVGPRNum()));
142 0 : const auto OtherSGPROcc = std::min(MaxOccupancy,
143 0 : ST.getOccupancyWithNumSGPRs(O.getSGPRNum()));
144 0 : const auto OtherVGPROcc = std::min(MaxOccupancy,
145 0 : ST.getOccupancyWithNumVGPRs(O.getVGPRNum()));
146 :
147 0 : const auto Occ = std::min(SGPROcc, VGPROcc);
148 0 : const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc);
149 0 : if (Occ != OtherOcc)
150 0 : return Occ > OtherOcc;
151 :
152 0 : bool SGPRImportant = SGPROcc < VGPROcc;
153 0 : const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc;
154 :
155 : // if both pressures disagree on what is more important compare vgprs
156 0 : if (SGPRImportant != OtherSGPRImportant) {
157 : SGPRImportant = false;
158 : }
159 :
160 : // compare large regs pressure
161 : bool SGPRFirst = SGPRImportant;
162 0 : for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) {
163 0 : if (SGPRFirst) {
164 0 : auto SW = getSGPRTuplesWeight();
165 0 : auto OtherSW = O.getSGPRTuplesWeight();
166 0 : if (SW != OtherSW)
167 0 : return SW < OtherSW;
168 : } else {
169 0 : auto VW = getVGPRTuplesWeight();
170 0 : auto OtherVW = O.getVGPRTuplesWeight();
171 0 : if (VW != OtherVW)
172 0 : return VW < OtherVW;
173 : }
174 : }
175 0 : return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()):
176 0 : (getVGPRNum() < O.getVGPRNum());
177 : }
178 :
179 : #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
180 : LLVM_DUMP_METHOD
181 : void GCNRegPressure::print(raw_ostream &OS, const GCNSubtarget *ST) const {
182 : OS << "VGPRs: " << getVGPRNum();
183 : if (ST) OS << "(O" << ST->getOccupancyWithNumVGPRs(getVGPRNum()) << ')';
184 : OS << ", SGPRs: " << getSGPRNum();
185 : if (ST) OS << "(O" << ST->getOccupancyWithNumSGPRs(getSGPRNum()) << ')';
186 : OS << ", LVGPR WT: " << getVGPRTuplesWeight()
187 : << ", LSGPR WT: " << getSGPRTuplesWeight();
188 : if (ST) OS << " -> Occ: " << getOccupancy(*ST);
189 : OS << '\n';
190 : }
191 : #endif
192 :
193 582406 : static LaneBitmask getDefRegMask(const MachineOperand &MO,
194 : const MachineRegisterInfo &MRI) {
195 : assert(MO.isDef() && MO.isReg() &&
196 : TargetRegisterInfo::isVirtualRegister(MO.getReg()));
197 :
198 : // We don't rely on read-undef flag because in case of tentative schedule
199 : // tracking it isn't set correctly yet. This works correctly however since
200 : // use mask has been tracked before using LIS.
201 : return MO.getSubReg() == 0 ?
202 348860 : MRI.getMaxLaneMaskForVReg(MO.getReg()) :
203 582406 : MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg());
204 : }
205 :
206 8350 : static LaneBitmask getUsedRegMask(const MachineOperand &MO,
207 : const MachineRegisterInfo &MRI,
208 : const LiveIntervals &LIS) {
209 : assert(MO.isUse() && MO.isReg() &&
210 : TargetRegisterInfo::isVirtualRegister(MO.getReg()));
211 :
212 8350 : if (auto SubReg = MO.getSubReg())
213 2510 : return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg);
214 :
215 5840 : auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg());
216 5840 : if (MaxMask == LaneBitmask::getLane(0)) // cannot have subregs
217 4870 : return MaxMask;
218 :
219 : // For a tentative schedule LIS isn't updated yet but livemask should remain
220 : // the same on any schedule. Subreg defs can be reordered but they all must
221 : // dominate uses anyway.
222 970 : auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex();
223 970 : return getLiveLaneMask(MO.getReg(), SI, LIS, MRI);
224 : }
225 :
226 : static SmallVector<RegisterMaskPair, 8>
227 6080 : collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS,
228 : const MachineRegisterInfo &MRI) {
229 : SmallVector<RegisterMaskPair, 8> Res;
230 36875 : for (const auto &MO : MI.operands()) {
231 30795 : if (!MO.isReg() || !TargetRegisterInfo::isVirtualRegister(MO.getReg()))
232 22445 : continue;
233 13930 : if (!MO.isUse() || !MO.readsReg())
234 : continue;
235 :
236 8350 : auto const UsedMask = getUsedRegMask(MO, MRI, LIS);
237 :
238 8350 : auto Reg = MO.getReg();
239 : auto I = std::find_if(Res.begin(), Res.end(), [Reg](const RegisterMaskPair &RM) {
240 : return RM.RegUnit == Reg;
241 : });
242 8350 : if (I != Res.end())
243 : I->LaneMask |= UsedMask;
244 : else
245 15420 : Res.push_back(RegisterMaskPair(Reg, UsedMask));
246 : }
247 6080 : return Res;
248 : }
249 :
250 : ///////////////////////////////////////////////////////////////////////////////
251 : // GCNRPTracker
252 :
253 239970 : LaneBitmask llvm::getLiveLaneMask(unsigned Reg,
254 : SlotIndex SI,
255 : const LiveIntervals &LIS,
256 : const MachineRegisterInfo &MRI) {
257 : LaneBitmask LiveMask;
258 : const auto &LI = LIS.getInterval(Reg);
259 239970 : if (LI.hasSubRanges()) {
260 260049 : for (const auto &S : LI.subranges())
261 196263 : if (S.liveAt(SI)) {
262 : LiveMask |= S.LaneMask;
263 : assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) ||
264 : LiveMask == MRI.getMaxLaneMaskForVReg(Reg));
265 : }
266 176184 : } else if (LI.liveAt(SI)) {
267 8630 : LiveMask = MRI.getMaxLaneMaskForVReg(Reg);
268 : }
269 239970 : return LiveMask;
270 : }
271 :
272 20213 : GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI,
273 : const LiveIntervals &LIS,
274 : const MachineRegisterInfo &MRI) {
275 : GCNRPTracker::LiveRegSet LiveRegs;
276 825693 : for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
277 805480 : auto Reg = TargetRegisterInfo::index2VirtReg(I);
278 : if (!LIS.hasInterval(Reg))
279 566480 : continue;
280 239000 : auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI);
281 239000 : if (LiveMask.any())
282 9878 : LiveRegs[Reg] = LiveMask;
283 : }
284 20213 : return LiveRegs;
285 : }
286 :
287 43106 : void GCNRPTracker::reset(const MachineInstr &MI,
288 : const LiveRegSet *LiveRegsCopy,
289 : bool After) {
290 43106 : const MachineFunction &MF = *MI.getMF();
291 43106 : MRI = &MF.getRegInfo();
292 43106 : if (LiveRegsCopy) {
293 22893 : if (&LiveRegs != LiveRegsCopy)
294 : LiveRegs = *LiveRegsCopy;
295 : } else {
296 40426 : LiveRegs = After ? getLiveRegsAfter(MI, LIS)
297 20213 : : getLiveRegsBefore(MI, LIS);
298 : }
299 :
300 43106 : MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
301 43106 : }
302 :
303 20 : void GCNUpwardRPTracker::reset(const MachineInstr &MI,
304 : const LiveRegSet *LiveRegsCopy) {
305 20 : GCNRPTracker::reset(MI, LiveRegsCopy, true);
306 20 : }
307 :
308 6080 : void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
309 : assert(MRI && "call reset first");
310 :
311 6080 : LastTrackedMI = &MI;
312 :
313 : if (MI.isDebugInstr())
314 0 : return;
315 :
316 6080 : auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI);
317 :
318 : // calc pressure at the MI (defs + uses)
319 6080 : auto AtMIPressure = CurPressure;
320 13790 : for (const auto &U : RegUses) {
321 7710 : auto LiveMask = LiveRegs[U.RegUnit];
322 15420 : AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI);
323 : }
324 : // update max pressure
325 6080 : MaxPressure = max(AtMIPressure, MaxPressure);
326 :
327 11680 : for (const auto &MO : MI.defs()) {
328 5600 : if (!MO.isReg() || !TargetRegisterInfo::isVirtualRegister(MO.getReg()) ||
329 : MO.isDead())
330 20 : continue;
331 :
332 5580 : auto Reg = MO.getReg();
333 5580 : auto I = LiveRegs.find(Reg);
334 5580 : if (I == LiveRegs.end())
335 : continue;
336 5580 : auto &LiveMask = I->second;
337 5580 : auto PrevMask = LiveMask;
338 5580 : LiveMask &= ~getDefRegMask(MO, *MRI);
339 5580 : CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
340 5580 : if (LiveMask.none())
341 : LiveRegs.erase(I);
342 : }
343 13790 : for (const auto &U : RegUses) {
344 7710 : auto &LiveMask = LiveRegs[U.RegUnit];
345 7710 : auto PrevMask = LiveMask;
346 : LiveMask |= U.LaneMask;
347 7710 : CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
348 : }
349 : assert(CurPressure == getRegPressure(*MRI, LiveRegs));
350 : }
351 :
352 43086 : bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
353 : const LiveRegSet *LiveRegsCopy) {
354 43086 : MRI = &MI.getParent()->getParent()->getRegInfo();
355 43086 : LastTrackedMI = nullptr;
356 43086 : MBBEnd = MI.getParent()->end();
357 43086 : NextMI = &MI;
358 43086 : NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
359 43086 : if (NextMI == MBBEnd)
360 : return false;
361 43086 : GCNRPTracker::reset(*NextMI, LiveRegsCopy, false);
362 43086 : return true;
363 : }
364 :
365 632278 : bool GCNDownwardRPTracker::advanceBeforeNext() {
366 : assert(MRI && "call reset first");
367 :
368 632278 : NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
369 632278 : if (NextMI == MBBEnd)
370 : return false;
371 :
372 631549 : SlotIndex SI = LIS.getInstructionIndex(*NextMI).getBaseIndex();
373 : assert(SI.isValid());
374 :
375 : // Remove dead registers or mask bits.
376 7506396 : for (auto &It : LiveRegs) {
377 6874847 : const LiveInterval &LI = LIS.getInterval(It.first);
378 6874847 : if (LI.hasSubRanges()) {
379 8152091 : for (const auto &S : LI.subranges()) {
380 6246028 : if (!S.liveAt(SI)) {
381 1539419 : auto PrevMask = It.second;
382 1539419 : It.second &= ~S.LaneMask;
383 1539419 : CurPressure.inc(It.first, PrevMask, It.second, *MRI);
384 : }
385 : }
386 4968784 : } else if (!LI.liveAt(SI)) {
387 269355 : auto PrevMask = It.second;
388 269355 : It.second = LaneBitmask::getNone();
389 269355 : CurPressure.inc(It.first, PrevMask, It.second, *MRI);
390 : }
391 6874847 : if (It.second.none())
392 371955 : LiveRegs.erase(It.first);
393 : }
394 :
395 631549 : MaxPressure = max(MaxPressure, CurPressure);
396 :
397 631549 : return true;
398 : }
399 :
400 653488 : void GCNDownwardRPTracker::advanceToNext() {
401 653488 : LastTrackedMI = &*NextMI++;
402 :
403 : // Add new registers or mask bits.
404 1256388 : for (const auto &MO : LastTrackedMI->defs()) {
405 602900 : if (!MO.isReg())
406 26074 : continue;
407 602900 : unsigned Reg = MO.getReg();
408 602900 : if (!TargetRegisterInfo::isVirtualRegister(Reg))
409 : continue;
410 576826 : auto &LiveMask = LiveRegs[Reg];
411 576826 : auto PrevMask = LiveMask;
412 576826 : LiveMask |= getDefRegMask(MO, *MRI);
413 576826 : CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
414 : }
415 :
416 653488 : MaxPressure = max(MaxPressure, CurPressure);
417 653488 : }
418 :
419 340244 : bool GCNDownwardRPTracker::advance() {
420 : // If we have just called reset live set is actual.
421 340244 : if ((NextMI == MBBEnd) || (LastTrackedMI && !advanceBeforeNext()))
422 0 : return false;
423 340244 : advanceToNext();
424 340244 : return true;
425 : }
426 :
427 21621 : bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) {
428 361865 : while (NextMI != End)
429 340244 : if (!advance()) return false;
430 : return true;
431 : }
432 :
433 21448 : bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin,
434 : MachineBasicBlock::const_iterator End,
435 : const LiveRegSet *LiveRegsCopy) {
436 21448 : reset(*Begin, LiveRegsCopy);
437 21448 : return advance(End);
438 : }
439 :
440 : #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
441 : LLVM_DUMP_METHOD
442 : static void reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
443 : const GCNRPTracker::LiveRegSet &TrackedLR,
444 : const TargetRegisterInfo *TRI) {
445 : for (auto const &P : TrackedLR) {
446 : auto I = LISLR.find(P.first);
447 : if (I == LISLR.end()) {
448 : dbgs() << " " << printReg(P.first, TRI)
449 : << ":L" << PrintLaneMask(P.second)
450 : << " isn't found in LIS reported set\n";
451 : }
452 : else if (I->second != P.second) {
453 : dbgs() << " " << printReg(P.first, TRI)
454 : << " masks doesn't match: LIS reported "
455 : << PrintLaneMask(I->second)
456 : << ", tracked "
457 : << PrintLaneMask(P.second)
458 : << '\n';
459 : }
460 : }
461 : for (auto const &P : LISLR) {
462 : auto I = TrackedLR.find(P.first);
463 : if (I == TrackedLR.end()) {
464 : dbgs() << " " << printReg(P.first, TRI)
465 : << ":L" << PrintLaneMask(P.second)
466 : << " isn't found in tracked set\n";
467 : }
468 : }
469 : }
470 :
471 : bool GCNUpwardRPTracker::isValid() const {
472 : const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
473 : const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);
474 : const auto &TrackedLR = LiveRegs;
475 :
476 : if (!isEqual(LISLR, TrackedLR)) {
477 : dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
478 : " LIS reported livesets mismatch:\n";
479 : printLivesAt(SI, LIS, *MRI);
480 : reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo());
481 : return false;
482 : }
483 :
484 : auto LISPressure = getRegPressure(*MRI, LISLR);
485 : if (LISPressure != CurPressure) {
486 : dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: ";
487 : CurPressure.print(dbgs());
488 : dbgs() << "LIS rpt: ";
489 : LISPressure.print(dbgs());
490 : return false;
491 : }
492 : return true;
493 : }
494 :
495 : void GCNRPTracker::printLiveRegs(raw_ostream &OS, const LiveRegSet& LiveRegs,
496 : const MachineRegisterInfo &MRI) {
497 : const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
498 : for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
499 : unsigned Reg = TargetRegisterInfo::index2VirtReg(I);
500 : auto It = LiveRegs.find(Reg);
501 : if (It != LiveRegs.end() && It->second.any())
502 : OS << ' ' << printVRegOrUnit(Reg, TRI) << ':'
503 : << PrintLaneMask(It->second);
504 : }
505 : OS << '\n';
506 : }
507 : #endif
|