42#define DEBUG_TYPE "amdgpu-rewrite-agpr-copy-mfma"
45 "Controls which MFMA chains are rewritten to AGPR form");
50 "Number of MFMA instructions rewritten to use AGPR form");
55class AMDGPURewriteAGPRCopyMFMAImpl {
76 TRI(*ST.getRegisterInfo()), MRI(MF.getRegInfo()), VRM(VRM), LRM(LRM),
77 LIS(LIS), LSS(LSS), RegClassInfo(RegClassInfo) {}
79 bool isRewriteCandidate(
const MachineInstr &
MI)
const {
88 MCRegister getAssignedAGPR(
Register VReg)
const {
89 MCRegister PhysReg = VRM.getPhys(VReg);
95 const TargetRegisterClass *AssignedRC = TRI.getPhysRegBaseClass(PhysReg);
96 return TRI.isAGPRClass(AssignedRC) ? PhysReg : MCRegister();
99 bool tryReassigningMFMAChain(MachineInstr &
MFMA,
Register MFMAHintReg,
112 bool recomputeRegClassExceptRewritable(
113 Register Reg, SmallVectorImpl<MachineInstr *> &RewriteCandidates,
114 SmallSetVector<Register, 4> &RewriteRegs)
const;
116 bool tryFoldCopiesToAGPR(
Register VReg, MCRegister AssignedAGPR)
const;
117 bool tryFoldCopiesFromAGPR(
Register VReg, MCRegister AssignedAGPR)
const;
121 void replaceSpillWithCopyToVReg(MachineInstr &SpillMI,
int SpillFI,
128 SpillReferenceMap &Map)
const;
132 void eliminateSpillsOfReassignedVGPRs()
const;
134 bool run(MachineFunction &MF)
const;
137bool AMDGPURewriteAGPRCopyMFMAImpl::recomputeRegClassExceptRewritable(
143 while (!Worklist.
empty()) {
148 const TargetRegisterClass *NewRC =
TRI.getLargestLegalSuperClass(OldRC, MF);
155 MachineInstr *
MI = MO.getParent();
161 if (isRewriteCandidate(*
MI)) {
163 const MCInstrDesc &AGPRDesc =
TII.get(AGPROp);
164 const TargetRegisterClass *NewRC =
165 TII.getRegClass(AGPRDesc, MO.getOperandNo());
166 if (!
TRI.hasAGPRs(NewRC))
169 const MachineOperand *VDst =
170 TII.getNamedOperand(*
MI, AMDGPU::OpName::vdst);
171 const MachineOperand *Src2 =
172 TII.getNamedOperand(*
MI, AMDGPU::OpName::src2);
173 for (
const MachineOperand *
Op : {VDst, Src2}) {
181 if (OtherReg !=
Reg && RewriteRegs.
insert(OtherReg))
188 dbgs() <<
"Attempting to replace VGPR MFMA with AGPR version:"
207 unsigned OpNo = &MO - &
MI->getOperand(0);
208 NewRC =
MI->getRegClassConstraintEffect(OpNo, NewRC, &
TII, &
TRI);
209 if (!NewRC || NewRC == OldRC) {
211 <<
" cannot be reassigned to "
212 << (NewRC ?
TRI.getRegClassName(NewRC) :
"NULL")
222bool AMDGPURewriteAGPRCopyMFMAImpl::tryReassigningMFMAChain(
226 SmallVector<MachineInstr *, 4> RewriteCandidates = {&
MFMA};
227 SmallSetVector<Register, 4> RewriteRegs;
231 RewriteRegs.
insert(MFMAHintReg);
242 if (!recomputeRegClassExceptRewritable(MFMAHintReg, RewriteCandidates,
244 LLVM_DEBUG(
dbgs() <<
"Could not recompute the regclass of dst reg "
265 using RecoloringStack =
267 RecoloringStack TentativeReassignments;
269 for (
Register RewriteReg : RewriteRegs) {
271 TentativeReassignments.push_back({&LI, VRM.
getPhys(RewriteReg)});
276 !attemptReassignmentsToAGPR(RewriteRegs, PhysRegHint)) {
278 for (
auto [LI, OldAssign] : TentativeReassignments) {
281 LRM.
assign(*LI, OldAssign);
289 for (
Register InterferingReg : RewriteRegs) {
290 const TargetRegisterClass *EquivalentAGPRRegClass =
292 MRI.
setRegClass(InterferingReg, EquivalentAGPRRegClass);
295 for (MachineInstr *RewriteCandidate : RewriteCandidates) {
298 RewriteCandidate->setDesc(
TII.get(NewMFMAOp));
299 ++NumMFMAsRewrittenToAGPR;
308bool AMDGPURewriteAGPRCopyMFMAImpl::attemptReassignmentsToAGPR(
309 SmallSetVector<Register, 4> &InterferingRegs,
MCPhysReg PrefPhysReg)
const {
314 for (
Register InterferingReg : InterferingRegs) {
315 LiveInterval &ReassignLI = LIS.
getInterval(InterferingReg);
316 const TargetRegisterClass *EquivalentAGPRRegClass =
319 MCPhysReg Assignable = AMDGPU::NoRegister;
320 if (EquivalentAGPRRegClass->
contains(PrefPhysReg) &&
330 Assignable = PrefPhysReg;
333 RegClassInfo.
getOrder(EquivalentAGPRRegClass);
345 <<
" to a free AGPR\n");
351 LRM.
assign(ReassignLI, Assignable);
363bool AMDGPURewriteAGPRCopyMFMAImpl::tryFoldCopiesToAGPR(
364 Register VReg, MCRegister AssignedAGPR)
const {
365 bool MadeChange =
false;
388 if (isRewriteCandidate(CopySrcDefMI) &&
389 tryReassigningMFMAChain(
390 CopySrcDefMI, CopySrcDefMI.getOperand(0).getReg(), AssignedAGPR))
405bool AMDGPURewriteAGPRCopyMFMAImpl::tryFoldCopiesFromAGPR(
406 Register VReg, MCRegister AssignedAGPR)
const {
407 bool MadeChange =
false;
416 if (!CopyUseMO.readsReg())
419 MachineInstr &CopyUseMI = *CopyUseMO.getParent();
420 if (isRewriteCandidate(CopyUseMI)) {
421 if (tryReassigningMFMAChain(CopyUseMI, CopyDstReg,
431void AMDGPURewriteAGPRCopyMFMAImpl::replaceSpillWithCopyToVReg(
432 MachineInstr &SpillMI,
int SpillFI,
Register VReg)
const {
435 MachineInstr *NewCopy;
449void AMDGPURewriteAGPRCopyMFMAImpl::collectSpillIndexUses(
452 SmallSet<int, 4> NeededFrameIndexes;
453 for (
const LiveInterval *LI : StackIntervals)
456 for (MachineBasicBlock &
MBB : MF) {
457 for (MachineInstr &
MI :
MBB) {
458 for (MachineOperand &MO :
MI.operands()) {
459 if (!MO.isFI() || !NeededFrameIndexes.
count(MO.getIndex()))
462 if (
TII.isVGPRSpill(
MI)) {
463 SmallVector<MachineInstr *, 4> &References =
Map[MO.getIndex()];
472 NeededFrameIndexes.
erase(MO.getIndex());
473 Map.erase(MO.getIndex());
479void AMDGPURewriteAGPRCopyMFMAImpl::eliminateSpillsOfReassignedVGPRs()
const {
484 MachineFrameInfo &MFI = MF.getFrameInfo();
487 StackIntervals.
reserve(NumSlots);
489 for (
auto &[Slot, LI] : LSS) {
493 const TargetRegisterClass *RC = LSS.getIntervalRegClass(Slot);
494 if (
TRI.hasVGPRs(RC))
498 sort(StackIntervals, [](
const LiveInterval *
A,
const LiveInterval *
B) {
501 if (
A->weight() !=
B->weight())
502 return A->weight() >
B->weight();
504 if (
A->getSize() !=
B->getSize())
505 return A->getSize() >
B->getSize();
508 return A->reg().stackSlotIndex() <
B->reg().stackSlotIndex();
523 DenseMap<int, SmallVector<MachineInstr *, 4>> SpillSlotReferences;
524 collectSpillIndexUses(StackIntervals, SpillSlotReferences);
526 for (LiveInterval *LI : StackIntervals) {
528 auto SpillReferences = SpillSlotReferences.find(Slot);
529 if (SpillReferences == SpillSlotReferences.end())
532 const TargetRegisterClass *RC = LSS.getIntervalRegClass(Slot);
535 <<
" by reassigning\n");
546 const TargetRegisterClass *RC = LSS.getIntervalRegClass(Slot);
549 for (MachineInstr *SpillMI : SpillReferences->second)
550 replaceSpillWithCopyToVReg(*SpillMI, Slot, NewVReg);
557 LRM.
assign(NewLI, PhysReg);
564bool AMDGPURewriteAGPRCopyMFMAImpl::run(MachineFunction &MF)
const {
567 if (!
ST.hasGFX90AInsts())
572 LLVM_DEBUG(
dbgs() <<
"skipping function that did not allocate AGPRs\n");
576 bool MadeChange =
false;
579 Register VReg = Register::index2VirtReg(
I);
580 MCRegister AssignedAGPR = getAssignedAGPR(VReg);
584 if (tryFoldCopiesToAGPR(VReg, AssignedAGPR))
586 if (tryFoldCopiesFromAGPR(VReg, AssignedAGPR))
594 eliminateSpillsOfReassignedVGPRs();
599class AMDGPURewriteAGPRCopyMFMALegacy :
public MachineFunctionPass {
602 RegisterClassInfo RegClassInfo;
604 AMDGPURewriteAGPRCopyMFMALegacy() : MachineFunctionPass(
ID) {}
606 bool runOnMachineFunction(MachineFunction &MF)
override;
608 StringRef getPassName()
const override {
609 return "AMDGPU Rewrite AGPR-Copy-MFMA";
612 void getAnalysisUsage(AnalysisUsage &AU)
const override {
631 "AMDGPU Rewrite AGPR-Copy-MFMA",
false,
false)
639char AMDGPURewriteAGPRCopyMFMALegacy::
ID = 0;
642 AMDGPURewriteAGPRCopyMFMALegacy::
ID;
644bool AMDGPURewriteAGPRCopyMFMALegacy::runOnMachineFunction(
646 if (skipFunction(MF.getFunction()))
651 auto &VRM = getAnalysis<VirtRegMapWrapperLegacy>().getVRM();
652 auto &LRM = getAnalysis<LiveRegMatrixWrapperLegacy>().getLRM();
653 auto &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS();
654 auto &LSS = getAnalysis<LiveStacksWrapperLegacy>().getLS();
655 AMDGPURewriteAGPRCopyMFMAImpl Impl(MF, VRM, LRM, LIS, LSS, RegClassInfo);
669 AMDGPURewriteAGPRCopyMFMAImpl Impl(MF, VRM, LRM, LIS, LSS, RegClassInfo);
674 .preserve<LiveStacksAnalysis>()
676 .preserve<SlotIndexesAnalysis>()
678 .preserve<LiveRegMatrixAnalysis>();
MachineInstrBuilder & UseMI
AMDGPU Rewrite AGPR Copy MFMA
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
This file provides an implementation of debug counters.
#define DEBUG_COUNTER(VARNAME, COUNTERNAME, DESC)
AMD GCN specific subclass of TargetSubtarget.
const HexagonInstrInfo * TII
Register const TargetRegisterInfo * TRI
Promote Memory to Register
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Interface definition for SIRegisterInfo.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
PreservedAnalyses run(MachineFunction &MF, MachineFunctionAnalysisManager &MFAM)
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
Represents analyses that only rely on functions' control flow.
static bool shouldExecute(CounterInfo &Counter)
LiveInterval & getInterval(Register Reg)
LiveInterval & createAndComputeVirtRegInterval(Register Reg)
SlotIndex ReplaceMachineInstrInMaps(MachineInstr &MI, MachineInstr &NewMI)
bool isPhysRegUsed(MCRegister PhysReg) const
Returns true if the given PhysReg has any live intervals assigned.
void unassign(const LiveInterval &VirtReg, bool ClearAllReferencingSegments=false)
Unassign VirtReg from its PhysReg.
@ IK_Free
No interference, go ahead and assign.
void assign(const LiveInterval &VirtReg, MCRegister PhysReg)
Assign VirtReg to PhysReg.
InterferenceKind checkInterference(const LiveInterval &VirtReg, MCRegister PhysReg)
Check for interference before assigning VirtReg to PhysReg.
unsigned getNumIntervals() const
bool isSpillSlotObjectIndex(int ObjectIdx) const
Returns true if the specified index corresponds to a spill slot.
void RemoveStackObject(int ObjectIdx)
Remove or mark dead a statically sized stack object.
bool isDeadObjectIndex(int ObjectIdx) const
Returns true if the specified index corresponds to a dead object.
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - Subclasses that override getAnalysisUsage must call this.
Register getReg(unsigned Idx) const
Get the register for the operand index.
const MachineInstrBuilder & addReg(Register RegNo, RegState Flags={}, unsigned SubReg=0) const
Add a new virtual register operand.
const MachineInstrBuilder & add(const MachineOperand &MO) const
const MachineBasicBlock * getParent() const
bool mayStore(QueryType Type=AnyInBundle) const
Return true if this instruction could possibly modify memory.
const DebugLoc & getDebugLoc() const
Returns the debug location id of this MachineInstr.
const MachineOperand & getOperand(unsigned i) const
LLVM_ABI MachineInstrBundleIterator< MachineInstr > eraseFromParent()
Unlink 'this' from the containing basic block and delete it.
bool isReg() const
isReg - Tests if this is a MO_Register operand.
Register getReg() const
getReg - Returns the register number.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
const TargetRegisterClass * getRegClass(Register Reg) const
Return the register class of the specified virtual register.
iterator_range< def_instr_iterator > def_instructions(Register Reg) const
LLVM_ABI Register createVirtualRegister(const TargetRegisterClass *RegClass, StringRef Name="")
createVirtualRegister - Create and return a new virtual register in the function with the specified r...
LLVM_ABI void setRegClass(Register Reg, const TargetRegisterClass *RC)
setRegClass - Set the register class of the specified virtual register.
iterator_range< use_instr_iterator > use_instructions(Register Reg) const
iterator_range< reg_nodbg_iterator > reg_nodbg_operands(Register Reg) const
unsigned getNumVirtRegs() const
getNumVirtRegs - Return the number of virtual registers created.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
LLVM_ABI void runOnMachineFunction(const MachineFunction &MF, bool Rev=false)
runOnFunction - Prepare to answer questions about MF.
ArrayRef< MCPhysReg > getOrder(const TargetRegisterClass *RC) const
getOrder - Returns the preferred allocation order for RC.
Wrapper class representing virtual and physical registers.
int stackSlotIndex() const
Compute the frame index from a register value representing a stack slot.
constexpr bool isVirtual() const
Return true if the specified register number is in the virtual register namespace.
constexpr bool isPhysical() const
Return true if the specified register number is in the physical register namespace.
bool insert(const value_type &X)
Insert a new element into the SetVector.
A SetVector that performs no allocations if smaller than a certain size.
size_type count(const T &V) const
count - Return 1 if the element is in the set, 0 otherwise.
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void reserve(size_type N)
void push_back(const T &Elt)
bool contains(Register Reg) const
Return true if the specified register is included in this register class.
MCRegister getPhys(Register virtReg) const
returns the physical register mapped to the specified virtual register
bool hasPhys(Register virtReg) const
returns true if the specified virtual register is mapped to a physical register
LLVM_READONLY int32_t getMFMASrcCVDstAGPROp(uint32_t Opcode)
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
PointerTypeMap run(const Module &M)
Compute the PointerTypeMap for the module M.
This is an optimization pass for GlobalISel generic memory operations.
MachineInstrBuilder BuildMI(MachineFunction &MF, const MIMetadata &MIMD, const MCInstrDesc &MCID)
Builder interface. Specify how to create the initial instruction itself.
AnalysisManager< MachineFunction > MachineFunctionAnalysisManager
LLVM_ABI PreservedAnalyses getMachineFunctionPassPreservedAnalyses()
Returns the minimum set of Analyses that all machine function passes must preserve.
void sort(IteratorTy Start, IteratorTy End)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
uint16_t MCPhysReg
An unsigned integer type large enough to represent all physical registers, but not necessarily virtua...
DWARFExpression::Operation Op
ArrayRef(const T &OneElt) -> ArrayRef< T >
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
char & AMDGPURewriteAGPRCopyMFMALegacyID
LLVM_ABI Printable printReg(Register Reg, const TargetRegisterInfo *TRI=nullptr, unsigned SubIdx=0, const MachineRegisterInfo *MRI=nullptr)
Prints virtual and physical registers with or without a TRI instance.