LLVM 20.0.0git
GCNRegPressure.cpp
Go to the documentation of this file.
1//===- GCNRegPressure.cpp -------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file implements the GCNRegPressure class.
11///
12//===----------------------------------------------------------------------===//
13
14#include "GCNRegPressure.h"
15#include "AMDGPU.h"
17
18using namespace llvm;
19
20#define DEBUG_TYPE "machine-scheduler"
21
23 const GCNRPTracker::LiveRegSet &S2) {
24 if (S1.size() != S2.size())
25 return false;
26
27 for (const auto &P : S1) {
28 auto I = S2.find(P.first);
29 if (I == S2.end() || I->second != P.second)
30 return false;
31 }
32 return true;
33}
34
35///////////////////////////////////////////////////////////////////////////////
36// GCNRegPressure
37
38unsigned GCNRegPressure::getRegKind(Register Reg,
39 const MachineRegisterInfo &MRI) {
40 assert(Reg.isVirtual());
41 const auto *const RC = MRI.getRegClass(Reg);
42 const auto *STI =
43 static_cast<const SIRegisterInfo *>(MRI.getTargetRegisterInfo());
44 return STI->isSGPRClass(RC)
45 ? (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE)
46 : STI->isAGPRClass(RC)
47 ? (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE)
48 : (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE);
49}
50
51void GCNRegPressure::inc(unsigned Reg,
52 LaneBitmask PrevMask,
53 LaneBitmask NewMask,
54 const MachineRegisterInfo &MRI) {
57 return;
58
59 int Sign = 1;
60 if (NewMask < PrevMask) {
61 std::swap(NewMask, PrevMask);
62 Sign = -1;
63 }
64
65 switch (auto Kind = getRegKind(Reg, MRI)) {
66 case SGPR32:
67 case VGPR32:
68 case AGPR32:
69 Value[Kind] += Sign;
70 break;
71
72 case SGPR_TUPLE:
73 case VGPR_TUPLE:
74 case AGPR_TUPLE:
75 assert(PrevMask < NewMask);
76
77 Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] +=
78 Sign * SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask);
79
80 if (PrevMask.none()) {
81 assert(NewMask.any());
82 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
83 Value[Kind] +=
84 Sign * TRI->getRegClassWeight(MRI.getRegClass(Reg)).RegWeight;
85 }
86 break;
87
88 default: llvm_unreachable("Unknown register kind");
89 }
90}
91
93 unsigned MaxOccupancy) const {
94 const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
95
96 const auto SGPROcc = std::min(MaxOccupancy,
97 ST.getOccupancyWithNumSGPRs(getSGPRNum()));
98 const auto VGPROcc =
99 std::min(MaxOccupancy,
100 ST.getOccupancyWithNumVGPRs(getVGPRNum(ST.hasGFX90AInsts())));
101 const auto OtherSGPROcc = std::min(MaxOccupancy,
102 ST.getOccupancyWithNumSGPRs(O.getSGPRNum()));
103 const auto OtherVGPROcc =
104 std::min(MaxOccupancy,
105 ST.getOccupancyWithNumVGPRs(O.getVGPRNum(ST.hasGFX90AInsts())));
106
107 const auto Occ = std::min(SGPROcc, VGPROcc);
108 const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc);
109
110 // Give first precedence to the better occupancy.
111 if (Occ != OtherOcc)
112 return Occ > OtherOcc;
113
114 unsigned MaxVGPRs = ST.getMaxNumVGPRs(MF);
115 unsigned MaxSGPRs = ST.getMaxNumSGPRs(MF);
116
117 // SGPR excess pressure conditions
118 unsigned ExcessSGPR = std::max(static_cast<int>(getSGPRNum() - MaxSGPRs), 0);
119 unsigned OtherExcessSGPR =
120 std::max(static_cast<int>(O.getSGPRNum() - MaxSGPRs), 0);
121
122 auto WaveSize = ST.getWavefrontSize();
123 // The number of virtual VGPRs required to handle excess SGPR
124 unsigned VGPRForSGPRSpills = (ExcessSGPR + (WaveSize - 1)) / WaveSize;
125 unsigned OtherVGPRForSGPRSpills =
126 (OtherExcessSGPR + (WaveSize - 1)) / WaveSize;
127
128 unsigned MaxArchVGPRs = ST.getAddressableNumArchVGPRs();
129
130 // Unified excess pressure conditions, accounting for VGPRs used for SGPR
131 // spills
132 unsigned ExcessVGPR =
133 std::max(static_cast<int>(getVGPRNum(ST.hasGFX90AInsts()) +
134 VGPRForSGPRSpills - MaxVGPRs),
135 0);
136 unsigned OtherExcessVGPR =
137 std::max(static_cast<int>(O.getVGPRNum(ST.hasGFX90AInsts()) +
138 OtherVGPRForSGPRSpills - MaxVGPRs),
139 0);
140 // Arch VGPR excess pressure conditions, accounting for VGPRs used for SGPR
141 // spills
142 unsigned ExcessArchVGPR = std::max(
143 static_cast<int>(getVGPRNum(false) + VGPRForSGPRSpills - MaxArchVGPRs),
144 0);
145 unsigned OtherExcessArchVGPR =
146 std::max(static_cast<int>(O.getVGPRNum(false) + OtherVGPRForSGPRSpills -
147 MaxArchVGPRs),
148 0);
149 // AGPR excess pressure conditions
150 unsigned ExcessAGPR = std::max(
151 static_cast<int>(ST.hasGFX90AInsts() ? (getAGPRNum() - MaxArchVGPRs)
152 : (getAGPRNum() - MaxVGPRs)),
153 0);
154 unsigned OtherExcessAGPR = std::max(
155 static_cast<int>(ST.hasGFX90AInsts() ? (O.getAGPRNum() - MaxArchVGPRs)
156 : (O.getAGPRNum() - MaxVGPRs)),
157 0);
158
159 bool ExcessRP = ExcessSGPR || ExcessVGPR || ExcessArchVGPR || ExcessAGPR;
160 bool OtherExcessRP = OtherExcessSGPR || OtherExcessVGPR ||
161 OtherExcessArchVGPR || OtherExcessAGPR;
162
163 // Give second precedence to the reduced number of spills to hold the register
164 // pressure.
165 if (ExcessRP || OtherExcessRP) {
166 // The difference in excess VGPR pressure, after including VGPRs used for
167 // SGPR spills
168 int VGPRDiff = ((OtherExcessVGPR + OtherExcessArchVGPR + OtherExcessAGPR) -
169 (ExcessVGPR + ExcessArchVGPR + ExcessAGPR));
170
171 int SGPRDiff = OtherExcessSGPR - ExcessSGPR;
172
173 if (VGPRDiff != 0)
174 return VGPRDiff > 0;
175 if (SGPRDiff != 0) {
176 unsigned PureExcessVGPR =
177 std::max(static_cast<int>(getVGPRNum(ST.hasGFX90AInsts()) - MaxVGPRs),
178 0) +
179 std::max(static_cast<int>(getVGPRNum(false) - MaxArchVGPRs), 0);
180 unsigned OtherPureExcessVGPR =
181 std::max(
182 static_cast<int>(O.getVGPRNum(ST.hasGFX90AInsts()) - MaxVGPRs),
183 0) +
184 std::max(static_cast<int>(O.getVGPRNum(false) - MaxArchVGPRs), 0);
185
186 // If we have a special case where there is a tie in excess VGPR, but one
187 // of the pressures has VGPR usage from SGPR spills, prefer the pressure
188 // with SGPR spills.
189 if (PureExcessVGPR != OtherPureExcessVGPR)
190 return SGPRDiff < 0;
191 // If both pressures have the same excess pressure before and after
192 // accounting for SGPR spills, prefer fewer SGPR spills.
193 return SGPRDiff > 0;
194 }
195 }
196
197 bool SGPRImportant = SGPROcc < VGPROcc;
198 const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc;
199
200 // If both pressures disagree on what is more important compare vgprs.
201 if (SGPRImportant != OtherSGPRImportant) {
202 SGPRImportant = false;
203 }
204
205 // Give third precedence to lower register tuple pressure.
206 bool SGPRFirst = SGPRImportant;
207 for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) {
208 if (SGPRFirst) {
209 auto SW = getSGPRTuplesWeight();
210 auto OtherSW = O.getSGPRTuplesWeight();
211 if (SW != OtherSW)
212 return SW < OtherSW;
213 } else {
214 auto VW = getVGPRTuplesWeight();
215 auto OtherVW = O.getVGPRTuplesWeight();
216 if (VW != OtherVW)
217 return VW < OtherVW;
218 }
219 }
220
221 // Give final precedence to lower general RP.
222 return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()):
223 (getVGPRNum(ST.hasGFX90AInsts()) <
224 O.getVGPRNum(ST.hasGFX90AInsts()));
225}
226
228 return Printable([&RP, ST](raw_ostream &OS) {
229 OS << "VGPRs: " << RP.Value[GCNRegPressure::VGPR32] << ' '
230 << "AGPRs: " << RP.getAGPRNum();
231 if (ST)
232 OS << "(O"
233 << ST->getOccupancyWithNumVGPRs(RP.getVGPRNum(ST->hasGFX90AInsts()))
234 << ')';
235 OS << ", SGPRs: " << RP.getSGPRNum();
236 if (ST)
237 OS << "(O" << ST->getOccupancyWithNumSGPRs(RP.getSGPRNum()) << ')';
238 OS << ", LVGPR WT: " << RP.getVGPRTuplesWeight()
239 << ", LSGPR WT: " << RP.getSGPRTuplesWeight();
240 if (ST)
241 OS << " -> Occ: " << RP.getOccupancy(*ST);
242 OS << '\n';
243 });
244}
245
247 const MachineRegisterInfo &MRI) {
248 assert(MO.isDef() && MO.isReg() && MO.getReg().isVirtual());
249
250 // We don't rely on read-undef flag because in case of tentative schedule
251 // tracking it isn't set correctly yet. This works correctly however since
252 // use mask has been tracked before using LIS.
253 return MO.getSubReg() == 0 ?
254 MRI.getMaxLaneMaskForVReg(MO.getReg()) :
255 MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg());
256}
257
258static void
260 const MachineInstr &MI, const LiveIntervals &LIS,
261 const MachineRegisterInfo &MRI) {
262
263 auto &TRI = *MRI.getTargetRegisterInfo();
264 for (const auto &MO : MI.operands()) {
265 if (!MO.isReg() || !MO.getReg().isVirtual())
266 continue;
267 if (!MO.isUse() || !MO.readsReg())
268 continue;
269
270 Register Reg = MO.getReg();
271 auto I = llvm::find_if(RegMaskPairs, [Reg](const RegisterMaskPair &RM) {
272 return RM.RegUnit == Reg;
273 });
274
275 auto &P = I == RegMaskPairs.end()
276 ? RegMaskPairs.emplace_back(Reg, LaneBitmask::getNone())
277 : *I;
278
279 P.LaneMask |= MO.getSubReg() ? TRI.getSubRegIndexLaneMask(MO.getSubReg())
280 : MRI.getMaxLaneMaskForVReg(Reg);
281 }
282
283 SlotIndex InstrSI;
284 for (auto &P : RegMaskPairs) {
285 auto &LI = LIS.getInterval(P.RegUnit);
286 if (!LI.hasSubRanges())
287 continue;
288
289 // For a tentative schedule LIS isn't updated yet but livemask should
290 // remain the same on any schedule. Subreg defs can be reordered but they
291 // all must dominate uses anyway.
292 if (!InstrSI)
293 InstrSI = LIS.getInstructionIndex(MI).getBaseIndex();
294
295 P.LaneMask = getLiveLaneMask(LI, InstrSI, MRI, P.LaneMask);
296 }
297}
298
299/// Mostly copy/paste from CodeGen/RegisterPressure.cpp
301 const LiveIntervals &LIS, const MachineRegisterInfo &MRI,
302 bool TrackLaneMasks, Register RegUnit, SlotIndex Pos,
303 LaneBitmask SafeDefault,
304 function_ref<bool(const LiveRange &LR, SlotIndex Pos)> Property) {
305 if (RegUnit.isVirtual()) {
306 const LiveInterval &LI = LIS.getInterval(RegUnit);
307 LaneBitmask Result;
308 if (TrackLaneMasks && LI.hasSubRanges()) {
309 for (const LiveInterval::SubRange &SR : LI.subranges()) {
310 if (Property(SR, Pos))
311 Result |= SR.LaneMask;
312 }
313 } else if (Property(LI, Pos)) {
314 Result = TrackLaneMasks ? MRI.getMaxLaneMaskForVReg(RegUnit)
316 }
317
318 return Result;
319 }
320
321 const LiveRange *LR = LIS.getCachedRegUnit(RegUnit);
322 if (LR == nullptr)
323 return SafeDefault;
324 return Property(*LR, Pos) ? LaneBitmask::getAll() : LaneBitmask::getNone();
325}
326
327/// Mostly copy/paste from CodeGen/RegisterPressure.cpp
328/// Helper to find a vreg use between two indices {PriorUseIdx, NextUseIdx}.
329/// The query starts with a lane bitmask which gets lanes/bits removed for every
330/// use we find.
331static LaneBitmask findUseBetween(unsigned Reg, LaneBitmask LastUseMask,
332 SlotIndex PriorUseIdx, SlotIndex NextUseIdx,
334 const SIRegisterInfo *TRI,
335 const LiveIntervals *LIS,
336 bool Upward = false) {
337 for (const MachineOperand &MO : MRI.use_nodbg_operands(Reg)) {
338 if (MO.isUndef())
339 continue;
340 const MachineInstr *MI = MO.getParent();
341 SlotIndex InstSlot = LIS->getInstructionIndex(*MI).getRegSlot();
342 bool InRange = Upward ? (InstSlot > PriorUseIdx && InstSlot <= NextUseIdx)
343 : (InstSlot >= PriorUseIdx && InstSlot < NextUseIdx);
344 if (!InRange)
345 continue;
346
347 unsigned SubRegIdx = MO.getSubReg();
348 LaneBitmask UseMask = TRI->getSubRegIndexLaneMask(SubRegIdx);
349 LastUseMask &= ~UseMask;
350 if (LastUseMask.none())
351 return LaneBitmask::getNone();
352 }
353 return LastUseMask;
354}
355
356///////////////////////////////////////////////////////////////////////////////
357// GCNRPTracker
358
360 const LiveIntervals &LIS,
362 LaneBitmask LaneMaskFilter) {
363 return getLiveLaneMask(LIS.getInterval(Reg), SI, MRI, LaneMaskFilter);
364}
365
368 LaneBitmask LaneMaskFilter) {
369 LaneBitmask LiveMask;
370 if (LI.hasSubRanges()) {
371 for (const auto &S : LI.subranges())
372 if ((S.LaneMask & LaneMaskFilter).any() && S.liveAt(SI)) {
373 LiveMask |= S.LaneMask;
374 assert(LiveMask == (LiveMask & MRI.getMaxLaneMaskForVReg(LI.reg())));
375 }
376 } else if (LI.liveAt(SI)) {
377 LiveMask = MRI.getMaxLaneMaskForVReg(LI.reg());
378 }
379 LiveMask &= LaneMaskFilter;
380 return LiveMask;
381}
382
384 const LiveIntervals &LIS,
385 const MachineRegisterInfo &MRI) {
387 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
388 auto Reg = Register::index2VirtReg(I);
389 if (!LIS.hasInterval(Reg))
390 continue;
391 auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI);
392 if (LiveMask.any())
393 LiveRegs[Reg] = LiveMask;
394 }
395 return LiveRegs;
396}
397
399 const LiveRegSet *LiveRegsCopy,
400 bool After) {
401 const MachineFunction &MF = *MI.getMF();
402 MRI = &MF.getRegInfo();
403 if (LiveRegsCopy) {
404 if (&LiveRegs != LiveRegsCopy)
405 LiveRegs = *LiveRegsCopy;
406 } else {
409 }
410
412}
413
415 const LiveRegSet &LiveRegs_) {
416 MRI = &MRI_;
417 LiveRegs = LiveRegs_;
418 LastTrackedMI = nullptr;
419 MaxPressure = CurPressure = getRegPressure(MRI_, LiveRegs_);
420}
421
422/// Mostly copy/paste from CodeGen/RegisterPressure.cpp
424 SlotIndex Pos) const {
426 LIS, *MRI, true, RegUnit, Pos.getBaseIndex(), LaneBitmask::getNone(),
427 [](const LiveRange &LR, SlotIndex Pos) {
428 const LiveRange::Segment *S = LR.getSegmentContaining(Pos);
429 return S != nullptr && S->end == Pos.getRegSlot();
430 });
431}
432
433////////////////////////////////////////////////////////////////////////////////
434// GCNUpwardRPTracker
435
437 assert(MRI && "call reset first");
438
439 LastTrackedMI = &MI;
440
441 if (MI.isDebugInstr())
442 return;
443
444 // Kill all defs.
445 GCNRegPressure DefPressure, ECDefPressure;
446 bool HasECDefs = false;
447 for (const MachineOperand &MO : MI.all_defs()) {
448 if (!MO.getReg().isVirtual())
449 continue;
450
451 Register Reg = MO.getReg();
452 LaneBitmask DefMask = getDefRegMask(MO, *MRI);
453
454 // Treat a def as fully live at the moment of definition: keep a record.
455 if (MO.isEarlyClobber()) {
456 ECDefPressure.inc(Reg, LaneBitmask::getNone(), DefMask, *MRI);
457 HasECDefs = true;
458 } else
459 DefPressure.inc(Reg, LaneBitmask::getNone(), DefMask, *MRI);
460
461 auto I = LiveRegs.find(Reg);
462 if (I == LiveRegs.end())
463 continue;
464
465 LaneBitmask &LiveMask = I->second;
466 LaneBitmask PrevMask = LiveMask;
467 LiveMask &= ~DefMask;
468 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
469 if (LiveMask.none())
471 }
472
473 // Update MaxPressure with defs pressure.
474 DefPressure += CurPressure;
475 if (HasECDefs)
476 DefPressure += ECDefPressure;
477 MaxPressure = max(DefPressure, MaxPressure);
478
479 // Make uses alive.
481 collectVirtualRegUses(RegUses, MI, LIS, *MRI);
482 for (const RegisterMaskPair &U : RegUses) {
483 LaneBitmask &LiveMask = LiveRegs[U.RegUnit];
484 LaneBitmask PrevMask = LiveMask;
485 LiveMask |= U.LaneMask;
486 CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
487 }
488
489 // Update MaxPressure with uses plus early-clobber defs pressure.
490 MaxPressure = HasECDefs ? max(CurPressure + ECDefPressure, MaxPressure)
492
494}
495
496////////////////////////////////////////////////////////////////////////////////
497// GCNDownwardRPTracker
498
500 const LiveRegSet *LiveRegsCopy) {
501 MRI = &MI.getParent()->getParent()->getRegInfo();
502 LastTrackedMI = nullptr;
503 MBBEnd = MI.getParent()->end();
504 NextMI = &MI;
505 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
506 if (NextMI == MBBEnd)
507 return false;
508 GCNRPTracker::reset(*NextMI, LiveRegsCopy, false);
509 return true;
510}
511
513 bool UseInternalIterator) {
514 assert(MRI && "call reset first");
515 SlotIndex SI;
516 const MachineInstr *CurrMI;
517 if (UseInternalIterator) {
518 if (!LastTrackedMI)
519 return NextMI == MBBEnd;
520
521 assert(NextMI == MBBEnd || !NextMI->isDebugInstr());
522 CurrMI = LastTrackedMI;
523
524 SI = NextMI == MBBEnd
525 ? LIS.getInstructionIndex(*LastTrackedMI).getDeadSlot()
527 } else { //! UseInternalIterator
529 CurrMI = MI;
530 }
531
532 assert(SI.isValid());
533
534 // Remove dead registers or mask bits.
535 SmallSet<Register, 8> SeenRegs;
536 for (auto &MO : CurrMI->operands()) {
537 if (!MO.isReg() || !MO.getReg().isVirtual())
538 continue;
539 if (MO.isUse() && !MO.readsReg())
540 continue;
541 if (!UseInternalIterator && MO.isDef())
542 continue;
543 if (!SeenRegs.insert(MO.getReg()).second)
544 continue;
545 const LiveInterval &LI = LIS.getInterval(MO.getReg());
546 if (LI.hasSubRanges()) {
547 auto It = LiveRegs.end();
548 for (const auto &S : LI.subranges()) {
549 if (!S.liveAt(SI)) {
550 if (It == LiveRegs.end()) {
551 It = LiveRegs.find(MO.getReg());
552 if (It == LiveRegs.end())
553 llvm_unreachable("register isn't live");
554 }
555 auto PrevMask = It->second;
556 It->second &= ~S.LaneMask;
557 CurPressure.inc(MO.getReg(), PrevMask, It->second, *MRI);
558 }
559 }
560 if (It != LiveRegs.end() && It->second.none())
561 LiveRegs.erase(It);
562 } else if (!LI.liveAt(SI)) {
563 auto It = LiveRegs.find(MO.getReg());
564 if (It == LiveRegs.end())
565 llvm_unreachable("register isn't live");
566 CurPressure.inc(MO.getReg(), It->second, LaneBitmask::getNone(), *MRI);
567 LiveRegs.erase(It);
568 }
569 }
570
572
573 LastTrackedMI = nullptr;
574
575 return UseInternalIterator && (NextMI == MBBEnd);
576}
577
579 bool UseInternalIterator) {
580 if (UseInternalIterator) {
581 LastTrackedMI = &*NextMI++;
582 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
583 } else {
585 }
586
587 const MachineInstr *CurrMI = LastTrackedMI;
588
589 // Add new registers or mask bits.
590 for (const auto &MO : CurrMI->all_defs()) {
591 Register Reg = MO.getReg();
592 if (!Reg.isVirtual())
593 continue;
594 auto &LiveMask = LiveRegs[Reg];
595 auto PrevMask = LiveMask;
596 LiveMask |= getDefRegMask(MO, *MRI);
597 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
598 }
599
601}
602
603bool GCNDownwardRPTracker::advance(MachineInstr *MI, bool UseInternalIterator) {
604 if (UseInternalIterator && NextMI == MBBEnd)
605 return false;
606
607 advanceBeforeNext(MI, UseInternalIterator);
608 advanceToNext(MI, UseInternalIterator);
609 if (!UseInternalIterator) {
610 // We must remove any dead def lanes from the current RP
611 advanceBeforeNext(MI, true);
612 }
613 return true;
614}
615
617 while (NextMI != End)
618 if (!advance()) return false;
619 return true;
620}
621
624 const LiveRegSet *LiveRegsCopy) {
625 reset(*Begin, LiveRegsCopy);
626 return advance(End);
627}
628
630 const GCNRPTracker::LiveRegSet &TrackedLR,
631 const TargetRegisterInfo *TRI, StringRef Pfx) {
632 return Printable([&LISLR, &TrackedLR, TRI, Pfx](raw_ostream &OS) {
633 for (auto const &P : TrackedLR) {
634 auto I = LISLR.find(P.first);
635 if (I == LISLR.end()) {
636 OS << Pfx << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second)
637 << " isn't found in LIS reported set\n";
638 } else if (I->second != P.second) {
639 OS << Pfx << printReg(P.first, TRI)
640 << " masks doesn't match: LIS reported " << PrintLaneMask(I->second)
641 << ", tracked " << PrintLaneMask(P.second) << '\n';
642 }
643 }
644 for (auto const &P : LISLR) {
645 auto I = TrackedLR.find(P.first);
646 if (I == TrackedLR.end()) {
647 OS << Pfx << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second)
648 << " isn't found in tracked set\n";
649 }
650 }
651 });
652}
653
656 const SIRegisterInfo *TRI) const {
657 assert(!MI->isDebugOrPseudoInstr() && "Expect a nondebug instruction.");
658
659 SlotIndex SlotIdx;
660 SlotIdx = LIS.getInstructionIndex(*MI).getRegSlot();
661
662 // Account for register pressure similar to RegPressureTracker::recede().
663 RegisterOperands RegOpers;
664 RegOpers.collect(*MI, *TRI, *MRI, true, /*IgnoreDead=*/false);
665 RegOpers.adjustLaneLiveness(LIS, *MRI, SlotIdx);
666 GCNRegPressure TempPressure = CurPressure;
667
668 for (const RegisterMaskPair &Use : RegOpers.Uses) {
669 Register Reg = Use.RegUnit;
670 if (!Reg.isVirtual())
671 continue;
672 LaneBitmask LastUseMask = getLastUsedLanes(Reg, SlotIdx);
673 if (LastUseMask.none())
674 continue;
675 // The LastUseMask is queried from the liveness information of instruction
676 // which may be further down the schedule. Some lanes may actually not be
677 // last uses for the current position.
678 // FIXME: allow the caller to pass in the list of vreg uses that remain
679 // to be bottom-scheduled to avoid searching uses at each query.
680 SlotIndex CurrIdx;
681 const MachineBasicBlock *MBB = MI->getParent();
684 if (IdxPos == MBB->end()) {
685 CurrIdx = LIS.getMBBEndIdx(MBB);
686 } else {
687 CurrIdx = LIS.getInstructionIndex(*IdxPos).getRegSlot();
688 }
689
690 LastUseMask =
691 findUseBetween(Reg, LastUseMask, CurrIdx, SlotIdx, *MRI, TRI, &LIS);
692 if (LastUseMask.none())
693 continue;
694
695 LaneBitmask LiveMask =
696 LiveRegs.contains(Reg) ? LiveRegs.at(Reg) : LaneBitmask(0);
697 LaneBitmask NewMask = LiveMask & ~LastUseMask;
698 TempPressure.inc(Reg, LiveMask, NewMask, *MRI);
699 }
700
701 // Generate liveness for defs.
702 for (const RegisterMaskPair &Def : RegOpers.Defs) {
703 Register Reg = Def.RegUnit;
704 if (!Reg.isVirtual())
705 continue;
706 LaneBitmask LiveMask =
707 LiveRegs.contains(Reg) ? LiveRegs.at(Reg) : LaneBitmask(0);
708 LaneBitmask NewMask = LiveMask | Def.LaneMask;
709 TempPressure.inc(Reg, LiveMask, NewMask, *MRI);
710 }
711
712 return TempPressure;
713}
714
716 const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
717 const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);
718 const auto &TrackedLR = LiveRegs;
719
720 if (!isEqual(LISLR, TrackedLR)) {
721 dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
722 " LIS reported livesets mismatch:\n"
723 << print(LISLR, *MRI);
724 reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo());
725 return false;
726 }
727
728 auto LISPressure = getRegPressure(*MRI, LISLR);
729 if (LISPressure != CurPressure) {
730 dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: "
731 << print(CurPressure) << "LIS rpt: " << print(LISPressure);
732 return false;
733 }
734 return true;
735}
736
738 const MachineRegisterInfo &MRI) {
739 return Printable([&LiveRegs, &MRI](raw_ostream &OS) {
740 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
741 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
742 Register Reg = Register::index2VirtReg(I);
743 auto It = LiveRegs.find(Reg);
744 if (It != LiveRegs.end() && It->second.any())
745 OS << ' ' << printVRegOrUnit(Reg, TRI) << ':'
746 << PrintLaneMask(It->second);
747 }
748 OS << '\n';
749 });
750}
751
752void GCNRegPressure::dump() const { dbgs() << print(*this); }
753
755 "amdgpu-print-rp-downward",
756 cl::desc("Use GCNDownwardRPTracker for GCNRegPressurePrinter pass"),
757 cl::init(false), cl::Hidden);
758
761
762INITIALIZE_PASS(GCNRegPressurePrinter, "amdgpu-print-rp", "", true, true)
763
764// Return lanemask of Reg's subregs that are live-through at [Begin, End] and
765// are fully covered by Mask.
766static LaneBitmask
768 Register Reg, SlotIndex Begin, SlotIndex End,
769 LaneBitmask Mask = LaneBitmask::getAll()) {
770
771 auto IsInOneSegment = [Begin, End](const LiveRange &LR) -> bool {
772 auto *Segment = LR.getSegmentContaining(Begin);
773 return Segment && Segment->contains(End);
774 };
775
776 LaneBitmask LiveThroughMask;
777 const LiveInterval &LI = LIS.getInterval(Reg);
778 if (LI.hasSubRanges()) {
779 for (auto &SR : LI.subranges()) {
780 if ((SR.LaneMask & Mask) == SR.LaneMask && IsInOneSegment(SR))
781 LiveThroughMask |= SR.LaneMask;
782 }
783 } else {
784 LaneBitmask RegMask = MRI.getMaxLaneMaskForVReg(Reg);
785 if ((RegMask & Mask) == RegMask && IsInOneSegment(LI))
786 LiveThroughMask = RegMask;
787 }
788
789 return LiveThroughMask;
790}
791
793 const MachineRegisterInfo &MRI = MF.getRegInfo();
794 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
795 const LiveIntervals &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS();
796
797 auto &OS = dbgs();
798
799// Leading spaces are important for YAML syntax.
800#define PFX " "
801
802 OS << "---\nname: " << MF.getName() << "\nbody: |\n";
803
804 auto printRP = [](const GCNRegPressure &RP) {
805 return Printable([&RP](raw_ostream &OS) {
806 OS << format(PFX " %-5d", RP.getSGPRNum())
807 << format(" %-5d", RP.getVGPRNum(false));
808 });
809 };
810
811 auto ReportLISMismatchIfAny = [&](const GCNRPTracker::LiveRegSet &TrackedLR,
812 const GCNRPTracker::LiveRegSet &LISLR) {
813 if (LISLR != TrackedLR) {
814 OS << PFX " mis LIS: " << llvm::print(LISLR, MRI)
815 << reportMismatch(LISLR, TrackedLR, TRI, PFX " ");
816 }
817 };
818
819 // Register pressure before and at an instruction (in program order).
821
822 for (auto &MBB : MF) {
823 RP.clear();
824 RP.reserve(MBB.size());
825
826 OS << PFX;
828 OS << ":\n";
829
830 SlotIndex MBBStartSlot = LIS.getSlotIndexes()->getMBBStartIdx(&MBB);
831 SlotIndex MBBEndSlot = LIS.getSlotIndexes()->getMBBEndIdx(&MBB);
832
833 GCNRPTracker::LiveRegSet LiveIn, LiveOut;
834 GCNRegPressure RPAtMBBEnd;
835
836 if (UseDownwardTracker) {
837 if (MBB.empty()) {
838 LiveIn = LiveOut = getLiveRegs(MBBStartSlot, LIS, MRI);
839 RPAtMBBEnd = getRegPressure(MRI, LiveIn);
840 } else {
841 GCNDownwardRPTracker RPT(LIS);
842 RPT.reset(MBB.front());
843
844 LiveIn = RPT.getLiveRegs();
845
846 while (!RPT.advanceBeforeNext()) {
847 GCNRegPressure RPBeforeMI = RPT.getPressure();
848 RPT.advanceToNext();
849 RP.emplace_back(RPBeforeMI, RPT.getPressure());
850 }
851
852 LiveOut = RPT.getLiveRegs();
853 RPAtMBBEnd = RPT.getPressure();
854 }
855 } else {
856 GCNUpwardRPTracker RPT(LIS);
857 RPT.reset(MRI, MBBEndSlot);
858
859 LiveOut = RPT.getLiveRegs();
860 RPAtMBBEnd = RPT.getPressure();
861
862 for (auto &MI : reverse(MBB)) {
863 RPT.resetMaxPressure();
864 RPT.recede(MI);
865 if (!MI.isDebugInstr())
866 RP.emplace_back(RPT.getPressure(), RPT.getMaxPressure());
867 }
868
869 LiveIn = RPT.getLiveRegs();
870 }
871
872 OS << PFX " Live-in: " << llvm::print(LiveIn, MRI);
874 ReportLISMismatchIfAny(LiveIn, getLiveRegs(MBBStartSlot, LIS, MRI));
875
876 OS << PFX " SGPR VGPR\n";
877 int I = 0;
878 for (auto &MI : MBB) {
879 if (!MI.isDebugInstr()) {
880 auto &[RPBeforeInstr, RPAtInstr] =
881 RP[UseDownwardTracker ? I : (RP.size() - 1 - I)];
882 ++I;
883 OS << printRP(RPBeforeInstr) << '\n' << printRP(RPAtInstr) << " ";
884 } else
885 OS << PFX " ";
886 MI.print(OS);
887 }
888 OS << printRP(RPAtMBBEnd) << '\n';
889
890 OS << PFX " Live-out:" << llvm::print(LiveOut, MRI);
892 ReportLISMismatchIfAny(LiveOut, getLiveRegs(MBBEndSlot, LIS, MRI));
893
894 GCNRPTracker::LiveRegSet LiveThrough;
895 for (auto [Reg, Mask] : LiveIn) {
896 LaneBitmask MaskIntersection = Mask & LiveOut.lookup(Reg);
897 if (MaskIntersection.any()) {
899 MRI, LIS, Reg, MBBStartSlot, MBBEndSlot, MaskIntersection);
900 if (LTMask.any())
901 LiveThrough[Reg] = LTMask;
902 }
903 }
904 OS << PFX " Live-thr:" << llvm::print(LiveThrough, MRI);
905 OS << printRP(getRegPressure(MRI, LiveThrough)) << '\n';
906 }
907 OS << "...\n";
908 return false;
909
910#undef PFX
911}
unsigned const MachineRegisterInfo * MRI
aarch64 promote const
static const LLT S1
MachineBasicBlock & MBB
bool End
Definition: ELF_riscv.cpp:480
#define PFX
static cl::opt< bool > UseDownwardTracker("amdgpu-print-rp-downward", cl::desc("Use GCNDownwardRPTracker for GCNRegPressurePrinter pass"), cl::init(false), cl::Hidden)
static void collectVirtualRegUses(SmallVectorImpl< RegisterMaskPair > &RegMaskPairs, const MachineInstr &MI, const LiveIntervals &LIS, const MachineRegisterInfo &MRI)
static LaneBitmask getDefRegMask(const MachineOperand &MO, const MachineRegisterInfo &MRI)
static LaneBitmask getRegLiveThroughMask(const MachineRegisterInfo &MRI, const LiveIntervals &LIS, Register Reg, SlotIndex Begin, SlotIndex End, LaneBitmask Mask=LaneBitmask::getAll())
This file defines the GCNRegPressure class, which tracks registry pressure by bookkeeping number of S...
IRTranslator LLVM IR MI
#define I(x, y, z)
Definition: MD5.cpp:58
unsigned const TargetRegisterInfo * TRI
static bool InRange(int64_t Value, unsigned short Shift, int LBound, int HBound)
#define P(N)
if(PassOpts->AAPipeline)
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:38
static LaneBitmask findUseBetween(unsigned Reg, LaneBitmask LastUseMask, SlotIndex PriorUseIdx, SlotIndex NextUseIdx, const MachineRegisterInfo &MRI, const LiveIntervals *LIS)
Helper to find a vreg use between two indices [PriorUseIdx, NextUseIdx).
static LaneBitmask getLanesWithProperty(const LiveIntervals &LIS, const MachineRegisterInfo &MRI, bool TrackLaneMasks, Register RegUnit, SlotIndex Pos, LaneBitmask SafeDefault, bool(*Property)(const LiveRange &LR, SlotIndex Pos))
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
raw_pwrite_stream & OS
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition: DenseMap.h:194
iterator find(const_arg_type_t< KeyT > Val)
Definition: DenseMap.h:156
bool erase(const KeyT &Val)
Definition: DenseMap.h:321
unsigned size() const
Definition: DenseMap.h:99
iterator end()
Definition: DenseMap.h:84
const ValueT & at(const_arg_type_t< KeyT > Val) const
at - Return the entry for the specified key, or abort if no such entry exists.
Definition: DenseMap.h:202
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition: DenseMap.h:147
bool advanceBeforeNext(MachineInstr *MI=nullptr, bool UseInternalIterator=true)
Move to the state right before the next MI or after the end of MBB.
bool advance(MachineInstr *MI=nullptr, bool UseInternalIterator=true)
Move to the state at the next MI.
GCNRegPressure bumpDownwardPressure(const MachineInstr *MI, const SIRegisterInfo *TRI) const
Mostly copy/paste from CodeGen/RegisterPressure.cpp Calculate the impact MI will have on CurPressure ...
bool reset(const MachineInstr &MI, const LiveRegSet *LiveRegs=nullptr)
Reset tracker to the point before the MI filling LiveRegs upon this point using LIS.
void advanceToNext(MachineInstr *MI=nullptr, bool UseInternalIterator=true)
Move to the state at the MI, advanceBeforeNext has to be called first.
GCNRegPressure getPressure() const
const decltype(LiveRegs) & getLiveRegs() const
const MachineInstr * LastTrackedMI
GCNRegPressure CurPressure
GCNRegPressure MaxPressure
void reset(const MachineInstr &MI, const LiveRegSet *LiveRegsCopy, bool After)
LaneBitmask getLastUsedLanes(Register RegUnit, SlotIndex Pos) const
Mostly copy/paste from CodeGen/RegisterPressure.cpp.
const MachineRegisterInfo * MRI
const LiveIntervals & LIS
void reset(const MachineRegisterInfo &MRI, SlotIndex SI)
reset tracker at the specified slot index SI.
void recede(const MachineInstr &MI)
Move to the state of RP just before the MI .
const GCNRegPressure & getMaxPressure() const
bool isValid() const
returns whether the tracker's state after receding MI corresponds to reported by LIS.
A live range for subregisters.
Definition: LiveInterval.h:694
LiveInterval - This class represents the liveness of a register, or stack slot.
Definition: LiveInterval.h:687
Register reg() const
Definition: LiveInterval.h:718
bool hasSubRanges() const
Returns true if subregister liveness information is available.
Definition: LiveInterval.h:810
iterator_range< subrange_iterator > subranges()
Definition: LiveInterval.h:782
bool hasInterval(Register Reg) const
SlotIndexes * getSlotIndexes() const
SlotIndex getInstructionIndex(const MachineInstr &Instr) const
Returns the base index of the given instruction.
SlotIndex getMBBEndIdx(const MachineBasicBlock *mbb) const
Return the last index in the given basic block.
LiveRange * getCachedRegUnit(unsigned Unit)
Return the live range for register unit Unit if it has already been computed, or nullptr if it hasn't...
LiveInterval & getInterval(Register Reg)
This class represents the liveness of a register, stack slot, etc.
Definition: LiveInterval.h:157
bool liveAt(SlotIndex index) const
Definition: LiveInterval.h:401
void printName(raw_ostream &os, unsigned printNameFlags=PrintNameIr, ModuleSlotTracker *moduleSlotTracker=nullptr) const
Print the basic block's name as:
const TargetSubtargetInfo & getSubtarget() const
getSubtarget - Return the subtarget for which this machine code is being compiled.
StringRef getName() const
getName - Return the name of the corresponding LLVM function.
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
Representation of each machine instruction.
Definition: MachineInstr.h:69
iterator_range< mop_iterator > operands()
Definition: MachineInstr.h:691
iterator_range< filtered_mop_iterator > all_defs()
Returns an iterator range over all operands that are (explicit or implicit) register defs.
Definition: MachineInstr.h:762
MachineOperand class - Representation of each machine instruction operand.
unsigned getSubReg() const
bool isReg() const
isReg - Tests if this is a MO_Register operand.
Register getReg() const
getReg - Returns the register number.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
const TargetRegisterInfo * getTargetRegisterInfo() const
virtual void print(raw_ostream &OS, const Module *M) const
print - Print out the internal state of the pass.
Definition: Pass.cpp:130
Simple wrapper around std::function<void(raw_ostream&)>.
Definition: Printable.h:38
List of registers defined and used by a machine instruction.
void collect(const MachineInstr &MI, const TargetRegisterInfo &TRI, const MachineRegisterInfo &MRI, bool TrackLaneMasks, bool IgnoreDead)
Analyze the given instruction MI and fill in the Uses, Defs and DeadDefs list based on the MachineOpe...
void adjustLaneLiveness(const LiveIntervals &LIS, const MachineRegisterInfo &MRI, SlotIndex Pos, MachineInstr *AddFlagsMI=nullptr)
Use liveness information to find out which uses/defs are partially undefined/dead and adjust the Regi...
SmallVector< RegisterMaskPair, 8 > Uses
List of virtual registers and register units read by the instruction.
SmallVector< RegisterMaskPair, 8 > Defs
List of virtual registers and register units defined by the instruction which are not dead.
Wrapper class representing virtual and physical registers.
Definition: Register.h:19
static Register index2VirtReg(unsigned Index)
Convert a 0-based index to a virtual register number.
Definition: Register.h:84
constexpr bool isVirtual() const
Return true if the specified register number is in the virtual register namespace.
Definition: Register.h:91
static unsigned getNumCoveredRegs(LaneBitmask LM)
static bool isSGPRClass(const TargetRegisterClass *RC)
SlotIndex - An opaque wrapper around machine indexes.
Definition: SlotIndexes.h:65
SlotIndex getDeadSlot() const
Returns the dead def kill slot for the current instruction.
Definition: SlotIndexes.h:242
SlotIndex getBaseIndex() const
Returns the base index for associated with this index.
Definition: SlotIndexes.h:224
SlotIndex getRegSlot(bool EC=false) const
Returns the register use/def slot in the current instruction for a normal or early-clobber def.
Definition: SlotIndexes.h:237
SlotIndex getMBBEndIdx(unsigned Num) const
Returns the last index in the given basic block number.
Definition: SlotIndexes.h:470
SlotIndex getMBBStartIdx(unsigned Num) const
Returns the first index in the given basic block number.
Definition: SlotIndexes.h:460
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition: SmallSet.h:132
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition: SmallSet.h:181
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:573
reference emplace_back(ArgTypes &&... Args)
Definition: SmallVector.h:937
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1196
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:51
TargetRegisterInfo base class - We assume that the target defines a static array of TargetRegisterDes...
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
LLVM Value Representation.
Definition: Value.h:74
An efficient, type-erasing, non-owning reference to a callable.
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:443
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
GCNRegPressure max(const GCNRegPressure &P1, const GCNRegPressure &P2)
LaneBitmask getLiveLaneMask(unsigned Reg, SlotIndex SI, const LiveIntervals &LIS, const MachineRegisterInfo &MRI, LaneBitmask LaneMaskFilter=LaneBitmask::getAll())
bool isEqual(const GCNRPTracker::LiveRegSet &S1, const GCNRPTracker::LiveRegSet &S2)
GCNRegPressure getRegPressure(const MachineRegisterInfo &MRI, Range &&LiveRegs)
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr)
IterT skipDebugInstructionsForward(IterT It, IterT End, bool SkipPseudoOp=true)
Increment It until it points to a non-debug instruction or to End and return the resulting iterator.
GCNRPTracker::LiveRegSet getLiveRegs(SlotIndex SI, const LiveIntervals &LIS, const MachineRegisterInfo &MRI)
GCNRPTracker::LiveRegSet getLiveRegsAfter(const MachineInstr &MI, const LiveIntervals &LIS)
auto reverse(ContainerTy &&C)
Definition: STLExtras.h:420
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
format_object< Ts... > format(const char *Fmt, const Ts &... Vals)
These are helper functions used to produce formatted output.
Definition: Format.h:125
char & GCNRegPressurePrinterID
auto find_if(R &&Range, UnaryPredicate P)
Provide wrappers to std::find_if which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1766
GCNRPTracker::LiveRegSet getLiveRegsBefore(const MachineInstr &MI, const LiveIntervals &LIS)
Printable reportMismatch(const GCNRPTracker::LiveRegSet &LISLR, const GCNRPTracker::LiveRegSet &TrackedL, const TargetRegisterInfo *TRI, StringRef Pfx=" ")
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:860
bool runOnMachineFunction(MachineFunction &MF) override
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
unsigned getVGPRTuplesWeight() const
unsigned getVGPRNum(bool UnifiedVGPRFile) const
void inc(unsigned Reg, LaneBitmask PrevMask, LaneBitmask NewMask, const MachineRegisterInfo &MRI)
unsigned getAGPRNum() const
unsigned getSGPRNum() const
unsigned getSGPRTuplesWeight() const
friend Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST)
bool less(const MachineFunction &MF, const GCNRegPressure &O, unsigned MaxOccupancy=std::numeric_limits< unsigned >::max()) const
Compares this GCNRegpressure to O, returning true if this is less.
static constexpr LaneBitmask getAll()
Definition: LaneBitmask.h:82
constexpr bool none() const
Definition: LaneBitmask.h:52
constexpr bool any() const
Definition: LaneBitmask.h:53
static constexpr LaneBitmask getNone()
Definition: LaneBitmask.h:81