33#define DEBUG_TYPE "si-lower-sgpr-spills"
41struct LaneVGPRInsertPt {
52 "amdgpu-num-vgprs-for-wwm-alloc",
53 cl::desc(
"Max num VGPRs for whole-wave register allocation."),
56class SILowerSGPRSpills {
75 : LIS(LIS), Indexes(Indexes), MDT(MDT), MCI(MCI) {}
80 void updateLaneVGPRDomInstr(
109char SILowerSGPRSpillsLegacy::ID = 0;
112 "SI lower SGPR spill instructions",
false,
false)
131 assert(
Success &&
"spillCalleeSavedRegisters should always succeed");
136 Indexes->repairIndexesInRange(&SaveBlock, SaveBlock.begin(),
I);
140 LIS->removeAllRegUnitsForPhysReg(CS.getReg());
155 I == RestoreBlock.
begin() ?
I : std::prev(
I);
186void SILowerSGPRSpills::calculateSaveRestoreBlocks(
MachineFunction &MF) {
196 "Multiple save points not yet supported!");
200 "Multiple restore points not yet supported!");
202 MachineBasicBlock *RestoreBlock = RestorePoint.first;
213 for (MachineBasicBlock &
MBB : MF) {
231bool SILowerSGPRSpills::spillCalleeSavedRegs(
232 MachineFunction &MF, SmallVectorImpl<int> &CalleeSavedFIs) {
236 const SIFrameLowering *TFI =
ST.getFrameLowering();
238 RegScavenger *
RS =
nullptr;
242 TFI->determineCalleeSavesSGPR(MF, SavedRegs, RS);
245 if (!
F.hasFnAttribute(Attribute::Naked)) {
250 std::vector<CalleeSavedInfo> CSI;
252 MCRegister RetAddrReg =
TRI->getReturnAddressReg(MF);
253 MCRegister RetAddrRegSub0 =
TRI->getSubReg(RetAddrReg, AMDGPU::sub0);
254 MCRegister RetAddrRegSub1 =
TRI->getSubReg(RetAddrReg, AMDGPU::sub1);
255 bool SpillRetAddrReg =
false;
257 for (
unsigned I = 0; CSRegs[
I]; ++
I) {
258 MCRegister
Reg = CSRegs[
I];
261 if (
Reg == RetAddrRegSub0 ||
Reg == RetAddrRegSub1) {
262 SpillRetAddrReg =
true;
266 const TargetRegisterClass *RC =
TRI->getMinimalPhysRegClass(
Reg);
268 TRI->getSpillAlign(*RC),
true,
269 nullptr,
TRI->getSpillStackID(*RC));
271 CSI.emplace_back(
Reg, JunkFI);
279 if (SpillRetAddrReg) {
280 const TargetRegisterClass *RC =
TRI->getMinimalPhysRegClass(RetAddrReg);
283 true,
nullptr,
TRI->getSpillStackID(*RC));
284 CSI.push_back(CalleeSavedInfo(RetAddrReg, JunkFI));
289 for (MachineBasicBlock *SaveBlock : SaveBlocks)
293 assert(SaveBlocks.size() == 1 &&
"shrink wrapping not fully implemented");
296 for (MachineBasicBlock *RestoreBlock : RestoreBlocks)
305MachineBasicBlock *SILowerSGPRSpills::getCycleDomBB(
MachineCycle *
C) {
308 if (
C->isReducible()) {
309 if (
auto *IDom = MDT->
getNode(
C->getHeader())->getIDom())
310 return IDom->getBlock();
315 const SmallVectorImpl<MachineBasicBlock *> &Entries =
C->getEntries();
316 assert(!Entries.
empty() &&
"Expected cycle to have at least one entry.");
317 MachineBasicBlock *EntryBB = Entries[0];
318 for (
unsigned I = 1;
I < Entries.
size(); ++
I)
323void SILowerSGPRSpills::updateLaneVGPRDomInstr(
325 DenseMap<Register, LaneVGPRInsertPt> &LaneVGPRDomInstr) {
332 SIMachineFunctionInfo *FuncInfo =
337 for (
auto &Spill : VGPRSpills) {
338 if (PrevLaneVGPR ==
Spill.VGPR)
341 PrevLaneVGPR =
Spill.VGPR;
343 if (
Spill.Lane == 0 &&
I == LaneVGPRDomInstr.
end()) {
344 LaneVGPRDomInstr[
Spill.VGPR] = insertPt(
MBB, InsertPt);
347 LaneVGPRInsertPt Prev =
I->second;
348 MachineBasicBlock *PrevInsertMBB = Prev.MBB;
350 MachineBasicBlock *DomMBB = PrevInsertMBB;
356 if (PrevInsertPt ==
MBB->
end() ||
357 MDT->
dominates(&*InsertPt, &*PrevInsertPt))
358 I->second = insertPt(
MBB, InsertPt);
368 I->second = insertPt(
MBB, InsertPt);
369 else if (DomMBB != PrevInsertMBB)
375void SILowerSGPRSpills::determineRegsForWWMAllocation(MachineFunction &MF,
376 BitVector &RegMask) {
379 SIMachineFunctionInfo *MFI = MF.
getInfo<SIMachineFunctionInfo>();
381 BitVector ReservedRegs =
TRI->getReservedRegs(MF);
382 BitVector NonWwmAllocMask(
TRI->getNumRegs());
388 unsigned NumRegs = MaxNumVGPRsForWwmAllocation;
392 auto [MaxNumVGPRs, MaxNumAGPRs] =
ST.getMaxNumVectorRegs(MF.
getFunction());
396 for (
unsigned Reg = AMDGPU::VGPR0 + MaxNumVGPRs - 1;
397 (
I < NumRegs) && (
Reg >= AMDGPU::VGPR0); --
Reg) {
399 !MRI.isPhysRegUsed(
Reg,
true)) {
400 TRI->markSuperRegs(RegMask,
Reg);
407 TRI->markSuperRegs(RegMask, AMDGPU::VGPR0);
409 "cannot find enough VGPRs for wwm-regalloc");
413bool SILowerSGPRSpillsLegacy::runOnMachineFunction(MachineFunction &MF) {
414 auto *LISWrapper = getAnalysisIfAvailable<LiveIntervalsWrapperPass>();
415 LiveIntervals *LIS = LISWrapper ? &LISWrapper->getLIS() :
nullptr;
416 auto *SIWrapper = getAnalysisIfAvailable<SlotIndexesWrapperPass>();
417 SlotIndexes *Indexes = SIWrapper ? &SIWrapper->getSI() :
nullptr;
418 MachineDominatorTree *MDT =
419 &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
420 MachineCycleInfo *MCI =
421 &getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo();
422 return SILowerSGPRSpills(LIS, Indexes, MDT, MCI).run(MF);
425bool SILowerSGPRSpills::run(MachineFunction &MF) {
427 TII =
ST.getInstrInfo();
434 calculateSaveRestoreBlocks(MF);
435 SmallVector<int> CalleeSavedFIs;
436 bool HasCSRs = spillCalleeSavedRegs(MF, CalleeSavedFIs);
440 SIMachineFunctionInfo *FuncInfo = MF.
getInfo<SIMachineFunctionInfo>();
444 RestoreBlocks.
clear();
448 bool MadeChange =
false;
449 bool SpilledToVirtVGPRLanes =
false;
453 const bool HasSGPRSpillToVGPR =
TRI->spillSGPRToVGPR() &&
455 if (HasSGPRSpillToVGPR) {
466 DenseMap<Register, LaneVGPRInsertPt> LaneVGPRDomInstr;
468 for (MachineBasicBlock &
MBB : MF) {
470 if (!
TII->isSGPRSpill(
MI))
473 if (
MI.getOperand(0).isUndef()) {
476 MI.eraseFromParent();
480 int FI =
TII->getNamedOperand(
MI, AMDGPU::OpName::addr)->getIndex();
484 if (IsCalleeSaveSGPRSpill) {
497 bool Spilled =
TRI->eliminateSGPRToVGPRSpillFrameIndex(
498 MI, FI,
nullptr, Indexes, LIS,
true);
501 "failed to spill SGPR to physical VGPR lane when allocated");
504 MachineInstrSpan MIS(&
MI, &
MBB);
506 bool Spilled =
TRI->eliminateSGPRToVGPRSpillFrameIndex(
507 MI, FI,
nullptr, Indexes, LIS);
510 "failed to spill SGPR to virtual VGPR lane when allocated");
512 updateLaneVGPRDomInstr(FI, &
MBB, MIS.
begin(), LaneVGPRDomInstr);
513 SpilledToVirtVGPRLanes =
true;
520 LaneVGPRInsertPt IP = LaneVGPRDomInstr[
Reg];
522 MachineBasicBlock *AdjMBB = getCycleDomBB(
C);
526 MachineBasicBlock &
Block = *IP.MBB;
544 BitVector WwmRegMask(
TRI->getNumRegs());
546 determineRegsForWWMAllocation(MF, WwmRegMask);
548 BitVector NonWwmRegMask(WwmRegMask);
549 NonWwmRegMask.flip().clearBitsNotInMask(
TRI->getAllVGPRRegMask());
556 for (MachineBasicBlock &
MBB : MF)
569 if (SpilledToVirtVGPRLanes) {
570 const TargetRegisterClass *RC =
TRI->getWaveMaskRegClass();
574 Register UnusedLowSGPR =
TRI->findUnusedRegister(MRI, RC, MF);
575 if (UnusedLowSGPR &&
TRI->getHWRegIndex(UnusedLowSGPR) <
585 RestoreBlocks.
clear();
598 SILowerSGPRSpills(LIS, Indexes, MDT, &MCI).
run(MF);
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
Provides AMDGPU specific target descriptions.
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
AMD GCN specific subclass of TargetSubtarget.
const HexagonInstrInfo * TII
Register const TargetRegisterInfo * TRI
Promote Memory to Register
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
static void insertCSRRestores(MachineBasicBlock &RestoreBlock, std::vector< CalleeSavedInfo > &CSI)
Insert restore code for the callee-saved registers used in the function.
SmallVector< MachineBasicBlock *, 4 > MBBVector
static void insertCSRSaves(MachineBasicBlock &SaveBlock, ArrayRef< CalleeSavedInfo > CSI)
Insert spill code for the callee-saved registers used in the function.
static void updateLiveness(MachineFunction &MF)
Helper function to update the liveness information for the callee-saved registers.
This file declares the machine register scavenger class.
static void insertCSRRestores(MachineBasicBlock &RestoreBlock, MutableArrayRef< CalleeSavedInfo > CSI, SlotIndexes *Indexes, LiveIntervals *LIS)
Insert restore code for the callee-saved registers used in the function.
PassT::Result * getCachedResult(IRUnitT &IR) const
Get the cached result of an analysis pass for a given IR unit.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesAll()
Set by analyses that do not transform their input at all.
Represent a constant reference to an array (0 or more elements consecutively in memory),...
bool test(unsigned Idx) const
Returns true if bit Idx is set.
The CalleeSavedInfo class tracks the information need to locate where a callee saved register is in t...
iterator find(const_arg_type_t< KeyT > Val)
NodeT * findNearestCommonDominator(NodeT *A, NodeT *B) const
Find nearest common dominator basic block for basic block A and B.
DomTreeNodeBase< NodeT > * getNode(const NodeT *BB) const
getNode - return the (Post)DominatorTree node for the specified basic block.
LLVMContext & getContext() const
getContext - Return a reference to the LLVMContext associated with this function.
CycleT * getTopLevelParentCycle(const BlockT *Block) const
const HexagonRegisterInfo & getRegisterInfo() const
LLVM_ABI void emitError(const Instruction *I, const Twine &ErrorStr)
emitError - Emit an error message to the currently installed error handler with optional location inf...
void removeAllRegUnitsForPhysReg(MCRegister Reg)
Remove associated live ranges for the register units associated with Reg.
SlotIndex InsertMachineInstrInMaps(MachineInstr &MI)
LiveInterval & createAndComputeVirtRegInterval(Register Reg)
An RAII based helper class to modify MachineFunctionProperties when running pass.
bool isEHFuncletEntry() const
Returns true if this is the entry block of an EH funclet.
LLVM_ABI iterator getFirstTerminator()
Returns an iterator to the first terminator instruction of this basic block.
bool isReturnBlock() const
Convenience function that returns true if the block ends in a return instruction.
LLVM_ABI void sortUniqueLiveIns()
Sorts and uniques the LiveIns vector.
LLVM_ABI DebugLoc findDebugLoc(instr_iterator MBBI)
Find the next valid DebugLoc starting at MBBI, skipping any debug instructions.
void addLiveIn(MCRegister PhysReg, LaneBitmask LaneMask=LaneBitmask::getAll())
Adds the specified register as a live in.
const MachineFunction * getParent() const
Return the MachineFunction containing this basic block.
MachineInstrBundleIterator< MachineInstr > iterator
LLVM_ABI Result run(MachineFunction &MF, MachineFunctionAnalysisManager &MFAM)
Legacy analysis pass which computes a MachineCycleInfo.
Analysis pass which computes a MachineDominatorTree.
Analysis pass which computes a MachineDominatorTree.
DominatorTree Class - Concrete subclass of DominatorTreeBase that is used to compute a normal dominat...
bool dominates(const MachineInstr *A, const MachineInstr *B) const
LLVM_ABI int CreateStackObject(uint64_t Size, Align Alignment, bool isSpillSlot, const AllocaInst *Alloca=nullptr, uint8_t ID=0)
Create a new statically sized stack object, returning a nonnegative identifier to represent it.
void setCalleeSavedInfoValid(bool v)
int getObjectIndexEnd() const
Return one past the maximum frame object index.
bool hasStackObjects() const
Return true if there are any stack objects in this function.
uint8_t getStackID(int ObjectIdx) const
const SaveRestorePoints & getRestorePoints() const
const SaveRestorePoints & getSavePoints() const
MachineFunctionPass - This class adapts the FunctionPass interface to allow convenient creation of pa...
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - Subclasses that override getAnalysisUsage must call this.
Properties which a MachineFunction may have at a given point in time.
const TargetSubtargetInfo & getSubtarget() const
getSubtarget - Return the subtarget for which this machine code is being compiled.
MachineFrameInfo & getFrameInfo()
getFrameInfo - Return the frame info object for the current function.
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
Function & getFunction()
Return the LLVM function that this machine code represents.
Ty * getInfo()
getInfo - Keep track of various per-function pieces of information for backends that would like to do...
const MachineBasicBlock & front() const
MachineInstrSpan provides an interface to get an iteration range containing the instruction it was in...
Representation of each machine instruction.
LLVM_ABI const MCPhysReg * getCalleeSavedRegs() const
Returns list of callee saved registers.
Represent a mutable reference to an array (0 or more elements consecutively in memory),...
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
PreservedAnalyses run(MachineFunction &MF, MachineFunctionAnalysisManager &MFAM)
void setSGPRForEXECCopy(Register Reg)
void setFlag(Register Reg, uint8_t Flag)
ArrayRef< SIRegisterInfo::SpilledReg > getSGPRSpillToVirtualVGPRLanes(int FrameIndex) const
Register getSGPRForEXECCopy() const
bool allocateSGPRSpillToVGPRLane(MachineFunction &MF, int FI, bool SpillToPhysVGPRLane=false, bool IsPrologEpilog=false)
bool removeDeadFrameIndices(MachineFrameInfo &MFI, bool ResetSGPRSpillStackIDs)
If ResetSGPRSpillStackIDs is true, reset the stack ID from sgpr-spill to the default stack.
void updateNonWWMRegMask(BitVector &RegMask)
bool hasSpilledSGPRs() const
ArrayRef< Register > getSGPRSpillVGPRs() const
SlotIndex insertMachineInstrInMaps(MachineInstr &MI, bool Late=false)
Insert the given machine instruction into the mapping.
LLVM_ABI void removeMachineInstrFromMaps(MachineInstr &MI, bool AllowBundled=false)
Removes machine instruction (bundle) MI from the mapping.
LLVM_ABI void repairIndexesInRange(MachineBasicBlock *MBB, MachineBasicBlock::iterator Begin, MachineBasicBlock::iterator End)
Repair indexes after adding and removing instructions.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Information about stack frame layout on the target.
void restoreCalleeSavedRegister(MachineBasicBlock &MBB, MachineBasicBlock::iterator MI, const CalleeSavedInfo &CS, const TargetInstrInfo *TII, const TargetRegisterInfo *TRI) const
virtual bool spillCalleeSavedRegisters(MachineBasicBlock &MBB, MachineBasicBlock::iterator MI, ArrayRef< CalleeSavedInfo > CSI, const TargetRegisterInfo *TRI) const
spillCalleeSavedRegisters - Issues instruction(s) to spill all callee saved registers and returns tru...
virtual bool restoreCalleeSavedRegisters(MachineBasicBlock &MBB, MachineBasicBlock::iterator MI, MutableArrayRef< CalleeSavedInfo > CSI, const TargetRegisterInfo *TRI) const
restoreCalleeSavedRegisters - Issues instruction(s) to restore all callee saved registers and returns...
TargetInstrInfo - Interface to description of machine instruction set.
TargetRegisterInfo base class - We assume that the target defines a static array of TargetRegisterDes...
virtual const TargetFrameLowering * getFrameLowering() const
virtual const TargetInstrInfo * getInstrInfo() const
virtual const TargetRegisterInfo * getRegisterInfo() const =0
Return the target's register information.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ C
The default llvm calling convention, compatible with C.
initializer< Ty > init(const Ty &Val)
This is an optimization pass for GlobalISel generic memory operations.
MachineInstrBuilder BuildMI(MachineFunction &MF, const MIMetadata &MIMD, const MCInstrDesc &MCID)
Builder interface. Specify how to create the initial instruction itself.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
AnalysisManager< MachineFunction > MachineFunctionAnalysisManager
void clearDebugInfoForSpillFIs(MachineFrameInfo &MFI, MachineBasicBlock &MBB, const BitVector &SpillFIs)
Replace frame index operands with null registers in debug value instructions for the specified spill ...
auto reverse(ContainerTy &&C)
char & SILowerSGPRSpillsLegacyID
uint16_t MCPhysReg
An unsigned integer type large enough to represent all physical registers, but not necessarily virtua...
ArrayRef(const T &OneElt) -> ArrayRef< T >
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
MachineCycleInfo::CycleT MachineCycle