LLVM 20.0.0git
SMEPeepholeOpt.cpp
Go to the documentation of this file.
1//===- SMEPeepholeOpt.cpp - SME peephole optimization pass-----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This pass tries to remove back-to-back (smstart, smstop) and
9// (smstop, smstart) sequences. The pass is conservative when it cannot
10// determine that it is safe to remove these sequences.
11//===----------------------------------------------------------------------===//
12
13#include "AArch64InstrInfo.h"
15#include "AArch64Subtarget.h"
22
23using namespace llvm;
24
25#define DEBUG_TYPE "aarch64-sme-peephole-opt"
26
27namespace {
28
29struct SMEPeepholeOpt : public MachineFunctionPass {
30 static char ID;
31
32 SMEPeepholeOpt() : MachineFunctionPass(ID) {
34 }
35
36 bool runOnMachineFunction(MachineFunction &MF) override;
37
38 StringRef getPassName() const override {
39 return "SME Peephole Optimization pass";
40 }
41
42 void getAnalysisUsage(AnalysisUsage &AU) const override {
43 AU.setPreservesCFG();
45 }
46
47 bool optimizeStartStopPairs(MachineBasicBlock &MBB,
48 bool &HasRemovedAllSMChanges) const;
49};
50
51char SMEPeepholeOpt::ID = 0;
52
53} // end anonymous namespace
54
56 return MI->getOpcode() == AArch64::MSRpstatePseudo;
57}
58
60 const MachineInstr *MI2) {
61 // We only consider the same type of streaming mode change here, i.e.
62 // start/stop SM, or start/stop ZA pairs.
63 if (MI1->getOperand(0).getImm() != MI2->getOperand(0).getImm())
64 return false;
65
66 // One must be 'start', the other must be 'stop'
67 if (MI1->getOperand(1).getImm() == MI2->getOperand(1).getImm())
68 return false;
69
70 bool IsConditional = isConditionalStartStop(MI2);
71 if (isConditionalStartStop(MI1) != IsConditional)
72 return false;
73
74 if (!IsConditional)
75 return true;
76
77 // Check to make sure the conditional start/stop pairs are identical.
78 if (MI1->getOperand(2).getImm() != MI2->getOperand(2).getImm())
79 return false;
80
81 // Ensure reg masks are identical.
82 if (MI1->getOperand(4).getRegMask() != MI2->getOperand(4).getRegMask())
83 return false;
84
85 // This optimisation is unlikely to happen in practice for conditional
86 // smstart/smstop pairs as the virtual registers for pstate.sm will always
87 // be different.
88 // TODO: For this optimisation to apply to conditional smstart/smstop,
89 // this pass will need to do more work to remove redundant calls to
90 // __arm_sme_state.
91
92 // Only consider conditional start/stop pairs which read the same register
93 // holding the original value of pstate.sm, as some conditional start/stops
94 // require the state on entry to the function.
95 if (MI1->getOperand(3).isReg() && MI2->getOperand(3).isReg()) {
96 Register Reg1 = MI1->getOperand(3).getReg();
97 Register Reg2 = MI2->getOperand(3).getReg();
98 if (Reg1.isPhysical() || Reg2.isPhysical() || Reg1 != Reg2)
99 return false;
100 }
101
102 return true;
103}
104
106 assert((MI->getOpcode() == AArch64::MSRpstatesvcrImm1 ||
107 MI->getOpcode() == AArch64::MSRpstatePseudo) &&
108 "Expected MI to be a smstart/smstop instruction");
109 return MI->getOperand(0).getImm() == AArch64SVCR::SVCRSM ||
110 MI->getOperand(0).getImm() == AArch64SVCR::SVCRSMZA;
111}
112
115 const MachineOperand &MO) {
116 if (!MO.isReg())
117 return false;
118
119 Register R = MO.getReg();
120 if (R.isPhysical())
121 return llvm::any_of(TRI.subregs_inclusive(R), [](const MCPhysReg &SR) {
122 return AArch64::ZPRRegClass.contains(SR) ||
123 AArch64::PPRRegClass.contains(SR);
124 });
125
126 const TargetRegisterClass *RC = MRI.getRegClass(R);
127 return TRI.getCommonSubClass(&AArch64::ZPRRegClass, RC) ||
128 TRI.getCommonSubClass(&AArch64::PPRRegClass, RC);
129}
130
131bool SMEPeepholeOpt::optimizeStartStopPairs(
132 MachineBasicBlock &MBB, bool &HasRemovedAllSMChanges) const {
134 const TargetRegisterInfo &TRI =
136
137 bool Changed = false;
138 MachineInstr *Prev = nullptr;
140
141 // Convenience function to reset the matching of a sequence.
142 auto Reset = [&]() {
143 Prev = nullptr;
144 ToBeRemoved.clear();
145 };
146
147 // Walk through instructions in the block trying to find pairs of smstart
148 // and smstop nodes that cancel each other out. We only permit a limited
149 // set of instructions to appear between them, otherwise we reset our
150 // tracking.
151 unsigned NumSMChanges = 0;
152 unsigned NumSMChangesRemoved = 0;
154 switch (MI.getOpcode()) {
155 case AArch64::MSRpstatesvcrImm1:
156 case AArch64::MSRpstatePseudo: {
158 NumSMChanges++;
159
160 if (!Prev)
161 Prev = &MI;
162 else if (isMatchingStartStopPair(Prev, &MI)) {
163 // If they match, we can remove them, and possibly any instructions
164 // that we marked for deletion in between.
165 Prev->eraseFromParent();
166 MI.eraseFromParent();
167 for (MachineInstr *TBR : ToBeRemoved)
168 TBR->eraseFromParent();
169 ToBeRemoved.clear();
170 Prev = nullptr;
171 Changed = true;
172 NumSMChangesRemoved += 2;
173 } else {
174 Reset();
175 Prev = &MI;
176 }
177 continue;
178 }
179 default:
180 if (!Prev)
181 // Avoid doing expensive checks when Prev is nullptr.
182 continue;
183 break;
184 }
185
186 // Test if the instructions in between the start/stop sequence are agnostic
187 // of streaming mode. If not, the algorithm should reset.
188 switch (MI.getOpcode()) {
189 default:
190 Reset();
191 break;
192 case AArch64::COALESCER_BARRIER_FPR16:
193 case AArch64::COALESCER_BARRIER_FPR32:
194 case AArch64::COALESCER_BARRIER_FPR64:
195 case AArch64::COALESCER_BARRIER_FPR128:
196 case AArch64::COPY:
197 // These instructions should be safe when executed on their own, but
198 // the code remains conservative when SVE registers are used. There may
199 // exist subtle cases where executing a COPY in a different mode results
200 // in different behaviour, even if we can't yet come up with any
201 // concrete example/test-case.
202 if (isSVERegOp(TRI, MRI, MI.getOperand(0)) ||
203 isSVERegOp(TRI, MRI, MI.getOperand(1)))
204 Reset();
205 break;
206 case AArch64::ADJCALLSTACKDOWN:
207 case AArch64::ADJCALLSTACKUP:
208 case AArch64::ANDXri:
209 case AArch64::ADDXri:
210 // We permit these as they don't generate SVE/NEON instructions.
211 break;
212 case AArch64::VGRestorePseudo:
213 case AArch64::VGSavePseudo:
214 // When the smstart/smstop are removed, we should also remove
215 // the pseudos that save/restore the VG value for CFI info.
216 ToBeRemoved.push_back(&MI);
217 break;
218 case AArch64::MSRpstatesvcrImm1:
219 case AArch64::MSRpstatePseudo:
220 llvm_unreachable("Should have been handled");
221 }
222 }
223
224 HasRemovedAllSMChanges =
225 NumSMChanges && (NumSMChanges == NumSMChangesRemoved);
226 return Changed;
227}
228
229INITIALIZE_PASS(SMEPeepholeOpt, "aarch64-sme-peephole-opt",
230 "SME Peephole Optimization", false, false)
231
232bool SMEPeepholeOpt::runOnMachineFunction(MachineFunction &MF) {
233 if (skipFunction(MF.getFunction()))
234 return false;
235
236 if (!MF.getSubtarget<AArch64Subtarget>().hasSME())
237 return false;
238
239 assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!");
240
241 bool Changed = false;
242 bool FunctionHasAllSMChangesRemoved = false;
243
244 // Even if the block lives in a function with no SME attributes attached we
245 // still have to analyze all the blocks because we may call a streaming
246 // function that requires smstart/smstop pairs.
247 for (MachineBasicBlock &MBB : MF) {
248 bool BlockHasAllSMChangesRemoved;
249 Changed |= optimizeStartStopPairs(MBB, BlockHasAllSMChangesRemoved);
250 FunctionHasAllSMChangesRemoved |= BlockHasAllSMChangesRemoved;
251 }
252
253 AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
254 if (FunctionHasAllSMChangesRemoved)
255 AFI->setHasStreamingModeChanges(false);
256
257 return Changed;
258}
259
260FunctionPass *llvm::createSMEPeepholeOptPass() { return new SMEPeepholeOpt(); }
unsigned const MachineRegisterInfo * MRI
MachineBasicBlock & MBB
IRTranslator LLVM IR MI
unsigned const TargetRegisterInfo * TRI
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:38
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static bool isSVERegOp(const TargetRegisterInfo &TRI, const MachineRegisterInfo &MRI, const MachineOperand &MO)
static bool isMatchingStartStopPair(const MachineInstr *MI1, const MachineInstr *MI2)
static bool isConditionalStartStop(const MachineInstr *MI)
static bool ChangesStreamingMode(const MachineInstr *MI)
This file defines the SmallVector class.
AArch64FunctionInfo - This class is derived from MachineFunctionInfo and contains private AArch64-spe...
void setHasStreamingModeChanges(bool HasChanges)
Represent the analysis usage information of a pass.
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:256
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:310
const MachineFunction * getParent() const
Return the MachineFunction containing this basic block.
MachineFunctionPass - This class adapts the FunctionPass interface to allow convenient creation of pa...
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - Subclasses that override getAnalysisUsage must call this.
virtual bool runOnMachineFunction(MachineFunction &MF)=0
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
const TargetSubtargetInfo & getSubtarget() const
getSubtarget - Return the subtarget for which this machine code is being compiled.
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
Representation of each machine instruction.
Definition: MachineInstr.h:69
void eraseFromParent()
Unlink 'this' from the containing basic block and delete it.
const MachineOperand & getOperand(unsigned i) const
Definition: MachineInstr.h:579
MachineOperand class - Representation of each machine instruction operand.
int64_t getImm() const
bool isReg() const
isReg - Tests if this is a MO_Register operand.
Register getReg() const
getReg - Returns the register number.
const uint32_t * getRegMask() const
getRegMask - Returns a bit mask of registers preserved by this RegMask operand.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
Definition: Pass.cpp:81
Wrapper class representing virtual and physical registers.
Definition: Register.h:19
constexpr bool isPhysical() const
Return true if the specified register number is in the physical register namespace.
Definition: Register.h:95
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1209
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
TargetRegisterInfo base class - We assume that the target defines a static array of TargetRegisterDes...
virtual const TargetRegisterInfo * getRegisterInfo() const
getRegisterInfo - If register information is available, return it.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:656
FunctionPass * createSMEPeepholeOptPass()
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1729
void initializeSMEPeepholeOptPass(PassRegistry &)