20#define DEBUG_TYPE "machine-scheduler"
24 if (
S1.size() != S2.
size())
27 for (
const auto &
P :
S1) {
29 if (
I == S2.
end() ||
I->second !=
P.second)
38unsigned GCNRegPressure::getRegKind(
Register Reg,
41 const auto *
const RC =
MRI.getRegClass(Reg);
46 : STI->isAGPRClass(RC)
60 if (NewMask < PrevMask) {
65 switch (
auto Kind = getRegKind(Reg,
MRI)) {
75 assert(PrevMask < NewMask);
80 if (PrevMask.
none()) {
84 Sign *
TRI->getRegClassWeight(
MRI.getRegClass(Reg)).RegWeight;
93 unsigned MaxOccupancy)
const {
96 const auto SGPROcc = std::min(MaxOccupancy,
99 std::min(MaxOccupancy,
100 ST.getOccupancyWithNumVGPRs(
getVGPRNum(ST.hasGFX90AInsts())));
101 const auto OtherSGPROcc = std::min(MaxOccupancy,
102 ST.getOccupancyWithNumSGPRs(O.getSGPRNum()));
103 const auto OtherVGPROcc =
104 std::min(MaxOccupancy,
105 ST.getOccupancyWithNumVGPRs(O.getVGPRNum(ST.hasGFX90AInsts())));
107 const auto Occ = std::min(SGPROcc, VGPROcc);
108 const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc);
112 return Occ > OtherOcc;
114 unsigned MaxVGPRs = ST.getMaxNumVGPRs(MF);
115 unsigned MaxSGPRs = ST.getMaxNumSGPRs(MF);
118 unsigned ExcessSGPR = std::max(
static_cast<int>(
getSGPRNum() - MaxSGPRs), 0);
119 unsigned OtherExcessSGPR =
120 std::max(
static_cast<int>(O.getSGPRNum() - MaxSGPRs), 0);
122 auto WaveSize = ST.getWavefrontSize();
124 unsigned VGPRForSGPRSpills = (ExcessSGPR + (WaveSize - 1)) / WaveSize;
125 unsigned OtherVGPRForSGPRSpills =
126 (OtherExcessSGPR + (WaveSize - 1)) / WaveSize;
128 unsigned MaxArchVGPRs = ST.getAddressableNumArchVGPRs();
132 unsigned ExcessVGPR =
133 std::max(
static_cast<int>(
getVGPRNum(ST.hasGFX90AInsts()) +
134 VGPRForSGPRSpills - MaxVGPRs),
136 unsigned OtherExcessVGPR =
137 std::max(
static_cast<int>(O.getVGPRNum(ST.hasGFX90AInsts()) +
138 OtherVGPRForSGPRSpills - MaxVGPRs),
142 unsigned ExcessArchVGPR = std::max(
143 static_cast<int>(
getVGPRNum(
false) + VGPRForSGPRSpills - MaxArchVGPRs),
145 unsigned OtherExcessArchVGPR =
146 std::max(
static_cast<int>(O.getVGPRNum(
false) + OtherVGPRForSGPRSpills -
150 unsigned ExcessAGPR = std::max(
151 static_cast<int>(ST.hasGFX90AInsts() ? (
getAGPRNum() - MaxArchVGPRs)
154 unsigned OtherExcessAGPR = std::max(
155 static_cast<int>(ST.hasGFX90AInsts() ? (O.getAGPRNum() - MaxArchVGPRs)
156 : (O.getAGPRNum() - MaxVGPRs)),
159 bool ExcessRP = ExcessSGPR || ExcessVGPR || ExcessArchVGPR || ExcessAGPR;
160 bool OtherExcessRP = OtherExcessSGPR || OtherExcessVGPR ||
161 OtherExcessArchVGPR || OtherExcessAGPR;
165 if (ExcessRP || OtherExcessRP) {
168 int VGPRDiff = ((OtherExcessVGPR + OtherExcessArchVGPR + OtherExcessAGPR) -
169 (ExcessVGPR + ExcessArchVGPR + ExcessAGPR));
171 int SGPRDiff = OtherExcessSGPR - ExcessSGPR;
176 unsigned PureExcessVGPR =
177 std::max(
static_cast<int>(
getVGPRNum(ST.hasGFX90AInsts()) - MaxVGPRs),
179 std::max(
static_cast<int>(
getVGPRNum(
false) - MaxArchVGPRs), 0);
180 unsigned OtherPureExcessVGPR =
182 static_cast<int>(O.getVGPRNum(ST.hasGFX90AInsts()) - MaxVGPRs),
184 std::max(
static_cast<int>(O.getVGPRNum(
false) - MaxArchVGPRs), 0);
189 if (PureExcessVGPR != OtherPureExcessVGPR)
197 bool SGPRImportant = SGPROcc < VGPROcc;
198 const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc;
201 if (SGPRImportant != OtherSGPRImportant) {
202 SGPRImportant =
false;
206 bool SGPRFirst = SGPRImportant;
207 for (
int I = 2;
I > 0; --
I, SGPRFirst = !SGPRFirst) {
210 auto OtherSW =
O.getSGPRTuplesWeight();
215 auto OtherVW =
O.getVGPRTuplesWeight();
222 return SGPRImportant ? (
getSGPRNum() <
O.getSGPRNum()):
224 O.getVGPRNum(
ST.hasGFX90AInsts()));
230 <<
"AGPRs: " << RP.getAGPRNum();
233 << ST->getOccupancyWithNumVGPRs(RP.getVGPRNum(ST->hasGFX90AInsts()))
235 OS <<
", SGPRs: " << RP.getSGPRNum();
237 OS <<
"(O" << ST->getOccupancyWithNumSGPRs(RP.getSGPRNum()) <<
')';
238 OS <<
", LVGPR WT: " << RP.getVGPRTuplesWeight()
239 <<
", LSGPR WT: " << RP.getSGPRTuplesWeight();
241 OS <<
" -> Occ: " << RP.getOccupancy(*ST);
255 MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.
getSubReg());
263 auto &
TRI = *
MRI.getTargetRegisterInfo();
264 for (
const auto &MO :
MI.operands()) {
265 if (!MO.isReg() || !MO.getReg().isVirtual())
267 if (!MO.isUse() || !MO.readsReg())
272 return RM.RegUnit == Reg;
275 auto &
P =
I == RegMaskPairs.
end()
279 P.LaneMask |= MO.getSubReg() ?
TRI.getSubRegIndexLaneMask(MO.getSubReg())
280 :
MRI.getMaxLaneMaskForVReg(Reg);
284 for (
auto &
P : RegMaskPairs) {
286 if (!LI.hasSubRanges())
310 if (Property(SR, Pos))
311 Result |= SR.LaneMask;
313 }
else if (Property(LI, Pos)) {
314 Result = TrackLaneMasks ?
MRI.getMaxLaneMaskForVReg(RegUnit)
336 bool Upward =
false) {
342 bool InRange = Upward ? (InstSlot > PriorUseIdx && InstSlot <= NextUseIdx)
343 : (InstSlot >= PriorUseIdx && InstSlot < NextUseIdx);
347 unsigned SubRegIdx = MO.getSubReg();
349 LastUseMask &= ~UseMask;
350 if (LastUseMask.
none())
372 if ((S.LaneMask & LaneMaskFilter).any() && S.liveAt(SI)) {
373 LiveMask |= S.LaneMask;
374 assert(LiveMask == (LiveMask &
MRI.getMaxLaneMaskForVReg(LI.
reg())));
376 }
else if (LI.
liveAt(SI)) {
377 LiveMask =
MRI.getMaxLaneMaskForVReg(LI.
reg());
379 LiveMask &= LaneMaskFilter;
387 for (
unsigned I = 0, E =
MRI.getNumVirtRegs();
I != E; ++
I) {
393 LiveRegs[Reg] = LiveMask;
428 const LiveRange::Segment *S = LR.getSegmentContaining(Pos);
429 return S != nullptr && S->end == Pos.getRegSlot();
441 if (
MI.isDebugInstr())
446 bool HasECDefs =
false;
448 if (!MO.getReg().isVirtual())
455 if (MO.isEarlyClobber()) {
467 LiveMask &= ~DefMask;
476 DefPressure += ECDefPressure;
485 LiveMask |= U.LaneMask;
501 MRI = &
MI.getParent()->getParent()->getRegInfo();
503 MBBEnd =
MI.getParent()->end();
506 if (NextMI == MBBEnd)
513 bool UseInternalIterator) {
517 if (UseInternalIterator) {
519 return NextMI == MBBEnd;
521 assert(NextMI == MBBEnd || !NextMI->isDebugInstr());
524 SI = NextMI == MBBEnd
536 for (
auto &MO : CurrMI->
operands()) {
537 if (!MO.isReg() || !MO.getReg().isVirtual())
539 if (MO.isUse() && !MO.readsReg())
541 if (!UseInternalIterator && MO.isDef())
543 if (!SeenRegs.
insert(MO.getReg()).second)
555 auto PrevMask = It->second;
556 It->second &= ~S.LaneMask;
562 }
else if (!LI.
liveAt(SI)) {
575 return UseInternalIterator && (NextMI == MBBEnd);
579 bool UseInternalIterator) {
580 if (UseInternalIterator) {
590 for (
const auto &MO : CurrMI->
all_defs()) {
592 if (!Reg.isVirtual())
595 auto PrevMask = LiveMask;
604 if (UseInternalIterator && NextMI == MBBEnd)
609 if (!UseInternalIterator) {
617 while (NextMI !=
End)
625 reset(*Begin, LiveRegsCopy);
633 for (
auto const &
P : TrackedLR) {
634 auto I = LISLR.
find(
P.first);
636 OS << Pfx << printReg(P.first, TRI) <<
":L" << PrintLaneMask(P.second)
637 <<
" isn't found in LIS reported set\n";
638 }
else if (
I->second !=
P.second) {
639 OS << Pfx << printReg(P.first, TRI)
640 <<
" masks doesn't match: LIS reported " << PrintLaneMask(I->second)
641 <<
", tracked " << PrintLaneMask(P.second) <<
'\n';
644 for (
auto const &
P : LISLR) {
645 auto I = TrackedLR.find(
P.first);
646 if (
I == TrackedLR.end()) {
647 OS << Pfx << printReg(P.first, TRI) <<
":L" << PrintLaneMask(P.second)
648 <<
" isn't found in tracked set\n";
657 assert(!
MI->isDebugOrPseudoInstr() &&
"Expect a nondebug instruction.");
670 if (!Reg.isVirtual())
673 if (LastUseMask.
none())
684 if (IdxPos ==
MBB->
end()) {
692 if (LastUseMask.
none())
698 TempPressure.
inc(Reg, LiveMask, NewMask, *
MRI);
704 if (!Reg.isVirtual())
709 TempPressure.
inc(Reg, LiveMask, NewMask, *
MRI);
720 if (!
isEqual(LISLR, TrackedLR)) {
721 dbgs() <<
"\nGCNUpwardRPTracker error: Tracked and"
722 " LIS reported livesets mismatch:\n"
730 dbgs() <<
"GCNUpwardRPTracker error: Pressure sets different\nTracked: "
741 for (
unsigned I = 0, E =
MRI.getNumVirtRegs();
I != E; ++
I) {
742 Register Reg = Register::index2VirtReg(I);
743 auto It = LiveRegs.find(Reg);
744 if (It != LiveRegs.end() && It->second.any())
745 OS <<
' ' << printVRegOrUnit(Reg, TRI) <<
':'
746 << PrintLaneMask(It->second);
755 "amdgpu-print-rp-downward",
756 cl::desc(
"Use GCNDownwardRPTracker for GCNRegPressurePrinter pass"),
771 auto IsInOneSegment = [Begin,
End](
const LiveRange &LR) ->
bool {
772 auto *Segment = LR.getSegmentContaining(Begin);
773 return Segment && Segment->contains(
End);
780 if ((SR.LaneMask & Mask) == SR.LaneMask && IsInOneSegment(SR))
781 LiveThroughMask |= SR.LaneMask;
785 if ((RegMask & Mask) == RegMask && IsInOneSegment(LI))
786 LiveThroughMask = RegMask;
789 return LiveThroughMask;
795 const LiveIntervals &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS();
802 OS <<
"---\nname: " << MF.
getName() <<
"\nbody: |\n";
807 <<
format(
" %-5d", RP.getVGPRNum(
false));
813 if (LISLR != TrackedLR) {
822 for (
auto &
MBB : MF) {
865 if (!
MI.isDebugInstr())
874 ReportLISMismatchIfAny(LiveIn,
getLiveRegs(MBBStartSlot, LIS,
MRI));
876 OS <<
PFX " SGPR VGPR\n";
878 for (
auto &
MI :
MBB) {
879 if (!
MI.isDebugInstr()) {
880 auto &[RPBeforeInstr, RPAtInstr] =
883 OS << printRP(RPBeforeInstr) <<
'\n' << printRP(RPAtInstr) <<
" ";
888 OS << printRP(RPAtMBBEnd) <<
'\n';
892 ReportLISMismatchIfAny(LiveOut,
getLiveRegs(MBBEndSlot, LIS,
MRI));
895 for (
auto [Reg, Mask] : LiveIn) {
897 if (MaskIntersection.
any()) {
899 MRI, LIS, Reg, MBBStartSlot, MBBEndSlot, MaskIntersection);
901 LiveThrough[Reg] = LTMask;
unsigned const MachineRegisterInfo * MRI
static cl::opt< bool > UseDownwardTracker("amdgpu-print-rp-downward", cl::desc("Use GCNDownwardRPTracker for GCNRegPressurePrinter pass"), cl::init(false), cl::Hidden)
static void collectVirtualRegUses(SmallVectorImpl< RegisterMaskPair > &RegMaskPairs, const MachineInstr &MI, const LiveIntervals &LIS, const MachineRegisterInfo &MRI)
static LaneBitmask getDefRegMask(const MachineOperand &MO, const MachineRegisterInfo &MRI)
static LaneBitmask getRegLiveThroughMask(const MachineRegisterInfo &MRI, const LiveIntervals &LIS, Register Reg, SlotIndex Begin, SlotIndex End, LaneBitmask Mask=LaneBitmask::getAll())
This file defines the GCNRegPressure class, which tracks registry pressure by bookkeeping number of S...
unsigned const TargetRegisterInfo * TRI
static bool InRange(int64_t Value, unsigned short Shift, int LBound, int HBound)
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
static LaneBitmask findUseBetween(unsigned Reg, LaneBitmask LastUseMask, SlotIndex PriorUseIdx, SlotIndex NextUseIdx, const MachineRegisterInfo &MRI, const LiveIntervals *LIS)
Helper to find a vreg use between two indices [PriorUseIdx, NextUseIdx).
static LaneBitmask getLanesWithProperty(const LiveIntervals &LIS, const MachineRegisterInfo &MRI, bool TrackLaneMasks, Register RegUnit, SlotIndex Pos, LaneBitmask SafeDefault, bool(*Property)(const LiveRange &LR, SlotIndex Pos))
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
iterator find(const_arg_type_t< KeyT > Val)
bool erase(const KeyT &Val)
const ValueT & at(const_arg_type_t< KeyT > Val) const
at - Return the entry for the specified key, or abort if no such entry exists.
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
bool advanceBeforeNext(MachineInstr *MI=nullptr, bool UseInternalIterator=true)
Move to the state right before the next MI or after the end of MBB.
bool advance(MachineInstr *MI=nullptr, bool UseInternalIterator=true)
Move to the state at the next MI.
GCNRegPressure bumpDownwardPressure(const MachineInstr *MI, const SIRegisterInfo *TRI) const
Mostly copy/paste from CodeGen/RegisterPressure.cpp Calculate the impact MI will have on CurPressure ...
bool reset(const MachineInstr &MI, const LiveRegSet *LiveRegs=nullptr)
Reset tracker to the point before the MI filling LiveRegs upon this point using LIS.
void advanceToNext(MachineInstr *MI=nullptr, bool UseInternalIterator=true)
Move to the state at the MI, advanceBeforeNext has to be called first.
GCNRegPressure getPressure() const
const decltype(LiveRegs) & getLiveRegs() const
const MachineInstr * LastTrackedMI
GCNRegPressure CurPressure
GCNRegPressure MaxPressure
void reset(const MachineInstr &MI, const LiveRegSet *LiveRegsCopy, bool After)
LaneBitmask getLastUsedLanes(Register RegUnit, SlotIndex Pos) const
Mostly copy/paste from CodeGen/RegisterPressure.cpp.
const MachineRegisterInfo * MRI
const LiveIntervals & LIS
void reset(const MachineRegisterInfo &MRI, SlotIndex SI)
reset tracker at the specified slot index SI.
void recede(const MachineInstr &MI)
Move to the state of RP just before the MI .
const GCNRegPressure & getMaxPressure() const
bool isValid() const
returns whether the tracker's state after receding MI corresponds to reported by LIS.
A live range for subregisters.
LiveInterval - This class represents the liveness of a register, or stack slot.
bool hasSubRanges() const
Returns true if subregister liveness information is available.
iterator_range< subrange_iterator > subranges()
bool hasInterval(Register Reg) const
SlotIndexes * getSlotIndexes() const
SlotIndex getInstructionIndex(const MachineInstr &Instr) const
Returns the base index of the given instruction.
SlotIndex getMBBEndIdx(const MachineBasicBlock *mbb) const
Return the last index in the given basic block.
LiveRange * getCachedRegUnit(unsigned Unit)
Return the live range for register unit Unit if it has already been computed, or nullptr if it hasn't...
LiveInterval & getInterval(Register Reg)
This class represents the liveness of a register, stack slot, etc.
bool liveAt(SlotIndex index) const
void printName(raw_ostream &os, unsigned printNameFlags=PrintNameIr, ModuleSlotTracker *moduleSlotTracker=nullptr) const
Print the basic block's name as:
const TargetSubtargetInfo & getSubtarget() const
getSubtarget - Return the subtarget for which this machine code is being compiled.
StringRef getName() const
getName - Return the name of the corresponding LLVM function.
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
Representation of each machine instruction.
iterator_range< mop_iterator > operands()
iterator_range< filtered_mop_iterator > all_defs()
Returns an iterator range over all operands that are (explicit or implicit) register defs.
MachineOperand class - Representation of each machine instruction operand.
unsigned getSubReg() const
bool isReg() const
isReg - Tests if this is a MO_Register operand.
Register getReg() const
getReg - Returns the register number.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
const TargetRegisterInfo * getTargetRegisterInfo() const
virtual void print(raw_ostream &OS, const Module *M) const
print - Print out the internal state of the pass.
Simple wrapper around std::function<void(raw_ostream&)>.
List of registers defined and used by a machine instruction.
void collect(const MachineInstr &MI, const TargetRegisterInfo &TRI, const MachineRegisterInfo &MRI, bool TrackLaneMasks, bool IgnoreDead)
Analyze the given instruction MI and fill in the Uses, Defs and DeadDefs list based on the MachineOpe...
void adjustLaneLiveness(const LiveIntervals &LIS, const MachineRegisterInfo &MRI, SlotIndex Pos, MachineInstr *AddFlagsMI=nullptr)
Use liveness information to find out which uses/defs are partially undefined/dead and adjust the Regi...
SmallVector< RegisterMaskPair, 8 > Uses
List of virtual registers and register units read by the instruction.
SmallVector< RegisterMaskPair, 8 > Defs
List of virtual registers and register units defined by the instruction which are not dead.
Wrapper class representing virtual and physical registers.
static Register index2VirtReg(unsigned Index)
Convert a 0-based index to a virtual register number.
constexpr bool isVirtual() const
Return true if the specified register number is in the virtual register namespace.
static unsigned getNumCoveredRegs(LaneBitmask LM)
static bool isSGPRClass(const TargetRegisterClass *RC)
SlotIndex - An opaque wrapper around machine indexes.
SlotIndex getDeadSlot() const
Returns the dead def kill slot for the current instruction.
SlotIndex getBaseIndex() const
Returns the base index for associated with this index.
SlotIndex getRegSlot(bool EC=false) const
Returns the register use/def slot in the current instruction for a normal or early-clobber def.
SlotIndex getMBBEndIdx(unsigned Num) const
Returns the last index in the given basic block number.
SlotIndex getMBBStartIdx(unsigned Num) const
Returns the first index in the given basic block number.
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
reference emplace_back(ArgTypes &&... Args)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
TargetRegisterInfo base class - We assume that the target defines a static array of TargetRegisterDes...
A Use represents the edge between a Value definition and its users.
LLVM Value Representation.
An efficient, type-erasing, non-owning reference to a callable.
This class implements an extremely fast bulk output stream that can only output to a stream.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
initializer< Ty > init(const Ty &Val)
This is an optimization pass for GlobalISel generic memory operations.
GCNRegPressure max(const GCNRegPressure &P1, const GCNRegPressure &P2)
LaneBitmask getLiveLaneMask(unsigned Reg, SlotIndex SI, const LiveIntervals &LIS, const MachineRegisterInfo &MRI, LaneBitmask LaneMaskFilter=LaneBitmask::getAll())
bool isEqual(const GCNRPTracker::LiveRegSet &S1, const GCNRPTracker::LiveRegSet &S2)
GCNRegPressure getRegPressure(const MachineRegisterInfo &MRI, Range &&LiveRegs)
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr)
IterT skipDebugInstructionsForward(IterT It, IterT End, bool SkipPseudoOp=true)
Increment It until it points to a non-debug instruction or to End and return the resulting iterator.
GCNRPTracker::LiveRegSet getLiveRegs(SlotIndex SI, const LiveIntervals &LIS, const MachineRegisterInfo &MRI)
GCNRPTracker::LiveRegSet getLiveRegsAfter(const MachineInstr &MI, const LiveIntervals &LIS)
auto reverse(ContainerTy &&C)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
format_object< Ts... > format(const char *Fmt, const Ts &... Vals)
These are helper functions used to produce formatted output.
char & GCNRegPressurePrinterID
auto find_if(R &&Range, UnaryPredicate P)
Provide wrappers to std::find_if which take ranges instead of having to pass begin/end explicitly.
GCNRPTracker::LiveRegSet getLiveRegsBefore(const MachineInstr &MI, const LiveIntervals &LIS)
Printable reportMismatch(const GCNRPTracker::LiveRegSet &LISLR, const GCNRPTracker::LiveRegSet &TrackedL, const TargetRegisterInfo *TRI, StringRef Pfx=" ")
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
bool runOnMachineFunction(MachineFunction &MF) override
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
unsigned getVGPRTuplesWeight() const
unsigned getVGPRNum(bool UnifiedVGPRFile) const
void inc(unsigned Reg, LaneBitmask PrevMask, LaneBitmask NewMask, const MachineRegisterInfo &MRI)
unsigned getAGPRNum() const
unsigned getSGPRNum() const
unsigned getSGPRTuplesWeight() const
friend Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST)
bool less(const MachineFunction &MF, const GCNRegPressure &O, unsigned MaxOccupancy=std::numeric_limits< unsigned >::max()) const
Compares this GCNRegpressure to O, returning true if this is less.
static constexpr LaneBitmask getAll()
constexpr bool none() const
constexpr bool any() const
static constexpr LaneBitmask getNone()