20#define DEBUG_TYPE "machine-scheduler"
24 if (
S1.size() != S2.
size())
27 for (
const auto &
P :
S1) {
29 if (
I == S2.
end() ||
I->second !=
P.second)
38unsigned GCNRegPressure::getRegKind(
Register Reg,
41 const auto RC =
MRI.getRegClass(Reg);
45 : STI->isAGPRClass(RC)
59 if (NewMask < PrevMask) {
64 switch (
auto Kind = getRegKind(Reg,
MRI)) {
74 assert(PrevMask < NewMask);
79 if (PrevMask.
none()) {
83 Sign *
TRI->getRegClassWeight(
MRI.getRegClass(Reg)).RegWeight;
92 unsigned MaxOccupancy)
const {
95 const auto SGPROcc = std::min(MaxOccupancy,
98 std::min(MaxOccupancy,
99 ST.getOccupancyWithNumVGPRs(
getVGPRNum(ST.hasGFX90AInsts())));
100 const auto OtherSGPROcc = std::min(MaxOccupancy,
101 ST.getOccupancyWithNumSGPRs(O.getSGPRNum()));
102 const auto OtherVGPROcc =
103 std::min(MaxOccupancy,
104 ST.getOccupancyWithNumVGPRs(O.getVGPRNum(ST.hasGFX90AInsts())));
106 const auto Occ = std::min(SGPROcc, VGPROcc);
107 const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc);
111 return Occ > OtherOcc;
113 unsigned MaxVGPRs = ST.getMaxNumVGPRs(MF);
114 unsigned MaxSGPRs = ST.getMaxNumSGPRs(MF);
117 unsigned ExcessSGPR = std::max(
static_cast<int>(
getSGPRNum() - MaxSGPRs), 0);
118 unsigned OtherExcessSGPR =
119 std::max(
static_cast<int>(O.getSGPRNum() - MaxSGPRs), 0);
121 auto WaveSize = ST.getWavefrontSize();
123 unsigned VGPRForSGPRSpills = (ExcessSGPR + (WaveSize - 1)) / WaveSize;
124 unsigned OtherVGPRForSGPRSpills =
125 (OtherExcessSGPR + (WaveSize - 1)) / WaveSize;
127 unsigned MaxArchVGPRs = ST.getAddressableNumArchVGPRs();
131 unsigned ExcessVGPR =
132 std::max(
static_cast<int>(
getVGPRNum(ST.hasGFX90AInsts()) +
133 VGPRForSGPRSpills - MaxVGPRs),
135 unsigned OtherExcessVGPR =
136 std::max(
static_cast<int>(O.getVGPRNum(ST.hasGFX90AInsts()) +
137 OtherVGPRForSGPRSpills - MaxVGPRs),
141 unsigned ExcessArchVGPR = std::max(
142 static_cast<int>(
getVGPRNum(
false) + VGPRForSGPRSpills - MaxArchVGPRs),
144 unsigned OtherExcessArchVGPR =
145 std::max(
static_cast<int>(O.getVGPRNum(
false) + OtherVGPRForSGPRSpills -
149 unsigned ExcessAGPR = std::max(
150 static_cast<int>(ST.hasGFX90AInsts() ? (
getAGPRNum() - MaxArchVGPRs)
153 unsigned OtherExcessAGPR = std::max(
154 static_cast<int>(ST.hasGFX90AInsts() ? (O.getAGPRNum() - MaxArchVGPRs)
155 : (O.getAGPRNum() - MaxVGPRs)),
158 bool ExcessRP = ExcessSGPR || ExcessVGPR || ExcessArchVGPR || ExcessAGPR;
159 bool OtherExcessRP = OtherExcessSGPR || OtherExcessVGPR ||
160 OtherExcessArchVGPR || OtherExcessAGPR;
164 if (ExcessRP || OtherExcessRP) {
167 int VGPRDiff = ((OtherExcessVGPR + OtherExcessArchVGPR + OtherExcessAGPR) -
168 (ExcessVGPR + ExcessArchVGPR + ExcessAGPR));
170 int SGPRDiff = OtherExcessSGPR - ExcessSGPR;
175 unsigned PureExcessVGPR =
176 std::max(
static_cast<int>(
getVGPRNum(ST.hasGFX90AInsts()) - MaxVGPRs),
178 std::max(
static_cast<int>(
getVGPRNum(
false) - MaxArchVGPRs), 0);
179 unsigned OtherPureExcessVGPR =
181 static_cast<int>(O.getVGPRNum(ST.hasGFX90AInsts()) - MaxVGPRs),
183 std::max(
static_cast<int>(O.getVGPRNum(
false) - MaxArchVGPRs), 0);
188 if (PureExcessVGPR != OtherPureExcessVGPR)
196 bool SGPRImportant = SGPROcc < VGPROcc;
197 const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc;
200 if (SGPRImportant != OtherSGPRImportant) {
201 SGPRImportant =
false;
205 bool SGPRFirst = SGPRImportant;
206 for (
int I = 2;
I > 0; --
I, SGPRFirst = !SGPRFirst) {
209 auto OtherSW = O.getSGPRTuplesWeight();
214 auto OtherVW = O.getVGPRTuplesWeight();
221 return SGPRImportant ? (
getSGPRNum() < O.getSGPRNum()):
223 O.getVGPRNum(ST.hasGFX90AInsts()));
229 <<
"AGPRs: " << RP.getAGPRNum();
232 << ST->getOccupancyWithNumVGPRs(RP.getVGPRNum(ST->hasGFX90AInsts()))
234 OS <<
", SGPRs: " << RP.getSGPRNum();
236 OS <<
"(O" << ST->getOccupancyWithNumSGPRs(RP.getSGPRNum()) <<
')';
237 OS <<
", LVGPR WT: " << RP.getVGPRTuplesWeight()
238 <<
", LSGPR WT: " << RP.getSGPRTuplesWeight();
240 OS <<
" -> Occ: " << RP.getOccupancy(*ST);
254 MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.
getSubReg());
262 for (
const auto &MO :
MI.operands()) {
263 if (!MO.isReg() || !MO.getReg().isVirtual())
265 if (!MO.isUse() || !MO.readsReg())
270 return RM.RegUnit == Reg;
276 if (!LI.hasSubRanges())
277 UseMask =
MRI.getMaxLaneMaskForVReg(Reg);
306 LiveMask |= S.LaneMask;
307 assert(LiveMask == (LiveMask &
MRI.getMaxLaneMaskForVReg(LI.
reg())));
309 }
else if (LI.
liveAt(SI)) {
310 LiveMask =
MRI.getMaxLaneMaskForVReg(LI.
reg());
319 for (
unsigned I = 0, E =
MRI.getNumVirtRegs();
I != E; ++
I) {
325 LiveRegs[Reg] = LiveMask;
362 if (
MI.isDebugInstr())
367 bool HasECDefs =
false;
369 if (!MO.getReg().isVirtual())
376 if (MO.isEarlyClobber()) {
388 LiveMask &= ~DefMask;
397 DefPressure += ECDefPressure;
406 LiveMask |= U.LaneMask;
422 MRI = &
MI.getParent()->getParent()->getRegInfo();
424 MBBEnd =
MI.getParent()->end();
427 if (NextMI == MBBEnd)
436 return NextMI == MBBEnd;
438 assert(NextMI == MBBEnd || !NextMI->isDebugInstr());
448 if (!MO.isReg() || !MO.getReg().isVirtual())
450 if (MO.isUse() && !MO.readsReg())
452 if (!SeenRegs.
insert(MO.getReg()).second)
464 auto PrevMask = It->second;
465 It->second &= ~S.LaneMask;
471 }
else if (!LI.
liveAt(SI)) {
484 return NextMI == MBBEnd;
494 if (!Reg.isVirtual())
497 auto PrevMask = LiveMask;
506 if (NextMI == MBBEnd)
514 while (NextMI !=
End)
522 reset(*Begin, LiveRegsCopy);
530 for (
auto const &
P : TrackedLR) {
531 auto I = LISLR.
find(
P.first);
533 OS << Pfx << printReg(P.first, TRI) <<
":L" << PrintLaneMask(P.second)
534 <<
" isn't found in LIS reported set\n";
535 }
else if (
I->second !=
P.second) {
536 OS << Pfx << printReg(P.first, TRI)
537 <<
" masks doesn't match: LIS reported " << PrintLaneMask(I->second)
538 <<
", tracked " << PrintLaneMask(P.second) <<
'\n';
541 for (
auto const &
P : LISLR) {
542 auto I = TrackedLR.find(
P.first);
543 if (
I == TrackedLR.end()) {
544 OS << Pfx << printReg(P.first, TRI) <<
":L" << PrintLaneMask(P.second)
545 <<
" isn't found in tracked set\n";
556 if (!
isEqual(LISLR, TrackedLR)) {
557 dbgs() <<
"\nGCNUpwardRPTracker error: Tracked and"
558 " LIS reported livesets mismatch:\n"
566 dbgs() <<
"GCNUpwardRPTracker error: Pressure sets different\nTracked: "
577 for (
unsigned I = 0, E =
MRI.getNumVirtRegs();
I != E; ++
I) {
578 Register Reg = Register::index2VirtReg(I);
579 auto It = LiveRegs.find(Reg);
580 if (It != LiveRegs.end() && It->second.any())
581 OS <<
' ' << printVRegOrUnit(Reg, TRI) <<
':'
582 << PrintLaneMask(It->second);
591 "amdgpu-print-rp-downward",
592 cl::desc(
"Use GCNDownwardRPTracker for GCNRegPressurePrinter pass"),
607 auto IsInOneSegment = [Begin,
End](
const LiveRange &LR) ->
bool {
608 auto *Segment = LR.getSegmentContaining(Begin);
609 return Segment && Segment->contains(
End);
616 if ((SR.LaneMask & Mask) == SR.LaneMask && IsInOneSegment(SR))
617 LiveThroughMask |= SR.LaneMask;
621 if ((RegMask & Mask) == RegMask && IsInOneSegment(LI))
622 LiveThroughMask = RegMask;
625 return LiveThroughMask;
631 const LiveIntervals &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS();
638 OS <<
"---\nname: " << MF.
getName() <<
"\nbody: |\n";
643 <<
format(
" %-5d", RP.getVGPRNum(
false));
649 if (LISLR != TrackedLR) {
658 for (
auto &
MBB : MF) {
701 if (!
MI.isDebugInstr())
710 ReportLISMismatchIfAny(LiveIn,
getLiveRegs(MBBStartSlot, LIS,
MRI));
712 OS <<
PFX " SGPR VGPR\n";
714 for (
auto &
MI :
MBB) {
715 if (!
MI.isDebugInstr()) {
716 auto &[RPBeforeInstr, RPAtInstr] =
719 OS << printRP(RPBeforeInstr) <<
'\n' << printRP(RPAtInstr) <<
" ";
724 OS << printRP(RPAtMBBEnd) <<
'\n';
728 ReportLISMismatchIfAny(LiveOut,
getLiveRegs(MBBEndSlot, LIS,
MRI));
731 for (
auto [Reg, Mask] : LiveIn) {
733 if (MaskIntersection.
any()) {
735 MRI, LIS, Reg, MBBStartSlot, MBBEndSlot, MaskIntersection);
737 LiveThrough[Reg] = LTMask;
unsigned const MachineRegisterInfo * MRI
static cl::opt< bool > UseDownwardTracker("amdgpu-print-rp-downward", cl::desc("Use GCNDownwardRPTracker for GCNRegPressurePrinter pass"), cl::init(false), cl::Hidden)
static void collectVirtualRegUses(SmallVectorImpl< RegisterMaskPair > &RegMaskPairs, const MachineInstr &MI, const LiveIntervals &LIS, const MachineRegisterInfo &MRI)
static LaneBitmask getDefRegMask(const MachineOperand &MO, const MachineRegisterInfo &MRI)
static LaneBitmask getRegLiveThroughMask(const MachineRegisterInfo &MRI, const LiveIntervals &LIS, Register Reg, SlotIndex Begin, SlotIndex End, LaneBitmask Mask=LaneBitmask::getAll())
This file defines the GCNRegPressure class, which tracks registry pressure by bookkeeping number of S...
unsigned const TargetRegisterInfo * TRI
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
iterator find(const_arg_type_t< KeyT > Val)
bool erase(const KeyT &Val)
bool reset(const MachineInstr &MI, const LiveRegSet *LiveRegs=nullptr)
GCNRegPressure getPressure() const
const decltype(LiveRegs) & getLiveRegs() const
const MachineInstr * LastTrackedMI
GCNRegPressure CurPressure
GCNRegPressure MaxPressure
void reset(const MachineInstr &MI, const LiveRegSet *LiveRegsCopy, bool After)
const MachineRegisterInfo * MRI
const LiveIntervals & LIS
void recede(const MachineInstr &MI)
void reset(const MachineRegisterInfo &MRI_, const LiveRegSet &LiveRegs_)
const GCNRegPressure & getMaxPressure() const
LiveInterval - This class represents the liveness of a register, or stack slot.
bool hasSubRanges() const
Returns true if subregister liveness information is available.
iterator_range< subrange_iterator > subranges()
bool hasInterval(Register Reg) const
SlotIndexes * getSlotIndexes() const
SlotIndex getInstructionIndex(const MachineInstr &Instr) const
Returns the base index of the given instruction.
LiveInterval & getInterval(Register Reg)
This class represents the liveness of a register, stack slot, etc.
bool liveAt(SlotIndex index) const
void printName(raw_ostream &os, unsigned printNameFlags=PrintNameIr, ModuleSlotTracker *moduleSlotTracker=nullptr) const
Print the basic block's name as:
const TargetSubtargetInfo & getSubtarget() const
getSubtarget - Return the subtarget for which this machine code is being compiled.
StringRef getName() const
getName - Return the name of the corresponding LLVM function.
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
Representation of each machine instruction.
iterator_range< mop_iterator > operands()
iterator_range< filtered_mop_iterator > all_defs()
Returns an iterator range over all operands that are (explicit or implicit) register defs.
MachineOperand class - Representation of each machine instruction operand.
unsigned getSubReg() const
bool isReg() const
isReg - Tests if this is a MO_Register operand.
Register getReg() const
getReg - Returns the register number.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
const TargetRegisterInfo * getTargetRegisterInfo() const
virtual void print(raw_ostream &OS, const Module *M) const
print - Print out the internal state of the pass.
Simple wrapper around std::function<void(raw_ostream&)>.
Wrapper class representing virtual and physical registers.
static Register index2VirtReg(unsigned Index)
Convert a 0-based index to a virtual register number.
constexpr bool isVirtual() const
Return true if the specified register number is in the virtual register namespace.
static unsigned getNumCoveredRegs(LaneBitmask LM)
static bool isSGPRClass(const TargetRegisterClass *RC)
SlotIndex - An opaque wrapper around machine indexes.
SlotIndex getDeadSlot() const
Returns the dead def kill slot for the current instruction.
SlotIndex getBaseIndex() const
Returns the base index for associated with this index.
SlotIndex getMBBEndIdx(unsigned Num) const
Returns the last index in the given basic block number.
SlotIndex getMBBStartIdx(unsigned Num) const
Returns the first index in the given basic block number.
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
reference emplace_back(ArgTypes &&... Args)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
TargetRegisterInfo base class - We assume that the target defines a static array of TargetRegisterDes...
LLVM Value Representation.
This class implements an extremely fast bulk output stream that can only output to a stream.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
initializer< Ty > init(const Ty &Val)
This is an optimization pass for GlobalISel generic memory operations.
GCNRegPressure max(const GCNRegPressure &P1, const GCNRegPressure &P2)
bool isEqual(const GCNRPTracker::LiveRegSet &S1, const GCNRPTracker::LiveRegSet &S2)
GCNRegPressure getRegPressure(const MachineRegisterInfo &MRI, Range &&LiveRegs)
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr)
IterT skipDebugInstructionsForward(IterT It, IterT End, bool SkipPseudoOp=true)
Increment It until it points to a non-debug instruction or to End and return the resulting iterator.
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
GCNRPTracker::LiveRegSet getLiveRegs(SlotIndex SI, const LiveIntervals &LIS, const MachineRegisterInfo &MRI)
GCNRPTracker::LiveRegSet getLiveRegsAfter(const MachineInstr &MI, const LiveIntervals &LIS)
auto reverse(ContainerTy &&C)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
LaneBitmask getLiveLaneMask(unsigned Reg, SlotIndex SI, const LiveIntervals &LIS, const MachineRegisterInfo &MRI)
format_object< Ts... > format(const char *Fmt, const Ts &... Vals)
These are helper functions used to produce formatted output.
char & GCNRegPressurePrinterID
GCNRPTracker::LiveRegSet getLiveRegsBefore(const MachineInstr &MI, const LiveIntervals &LIS)
Printable reportMismatch(const GCNRPTracker::LiveRegSet &LISLR, const GCNRPTracker::LiveRegSet &TrackedL, const TargetRegisterInfo *TRI, StringRef Pfx=" ")
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
bool runOnMachineFunction(MachineFunction &MF) override
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
unsigned getVGPRTuplesWeight() const
unsigned getVGPRNum(bool UnifiedVGPRFile) const
void inc(unsigned Reg, LaneBitmask PrevMask, LaneBitmask NewMask, const MachineRegisterInfo &MRI)
unsigned getAGPRNum() const
unsigned getSGPRNum() const
unsigned getSGPRTuplesWeight() const
friend Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST)
bool less(const MachineFunction &MF, const GCNRegPressure &O, unsigned MaxOccupancy=std::numeric_limits< unsigned >::max()) const
Compares this GCNRegpressure to O, returning true if this is less.
constexpr bool none() const
constexpr bool any() const
static constexpr LaneBitmask getNone()