LLVM 20.0.0git
SMEPeepholeOpt.cpp
Go to the documentation of this file.
1//===- SMEPeepholeOpt.cpp - SME peephole optimization pass-----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This pass tries to remove back-to-back (smstart, smstop) and
9// (smstop, smstart) sequences. The pass is conservative when it cannot
10// determine that it is safe to remove these sequences.
11//===----------------------------------------------------------------------===//
12
13#include "AArch64InstrInfo.h"
15#include "AArch64Subtarget.h"
21
22using namespace llvm;
23
24#define DEBUG_TYPE "aarch64-sme-peephole-opt"
25
26namespace {
27
28struct SMEPeepholeOpt : public MachineFunctionPass {
29 static char ID;
30
31 SMEPeepholeOpt() : MachineFunctionPass(ID) {
33 }
34
35 bool runOnMachineFunction(MachineFunction &MF) override;
36
37 StringRef getPassName() const override {
38 return "SME Peephole Optimization pass";
39 }
40
41 void getAnalysisUsage(AnalysisUsage &AU) const override {
42 AU.setPreservesCFG();
44 }
45
46 bool optimizeStartStopPairs(MachineBasicBlock &MBB,
47 bool &HasRemovedAllSMChanges) const;
48};
49
50char SMEPeepholeOpt::ID = 0;
51
52} // end anonymous namespace
53
55 return MI->getOpcode() == AArch64::MSRpstatePseudo;
56}
57
59 const MachineInstr *MI2) {
60 // We only consider the same type of streaming mode change here, i.e.
61 // start/stop SM, or start/stop ZA pairs.
62 if (MI1->getOperand(0).getImm() != MI2->getOperand(0).getImm())
63 return false;
64
65 // One must be 'start', the other must be 'stop'
66 if (MI1->getOperand(1).getImm() == MI2->getOperand(1).getImm())
67 return false;
68
69 bool IsConditional = isConditionalStartStop(MI2);
70 if (isConditionalStartStop(MI1) != IsConditional)
71 return false;
72
73 if (!IsConditional)
74 return true;
75
76 // Check to make sure the conditional start/stop pairs are identical.
77 if (MI1->getOperand(2).getImm() != MI2->getOperand(2).getImm())
78 return false;
79
80 // Ensure reg masks are identical.
81 if (MI1->getOperand(4).getRegMask() != MI2->getOperand(4).getRegMask())
82 return false;
83
84 // This optimisation is unlikely to happen in practice for conditional
85 // smstart/smstop pairs as the virtual registers for pstate.sm will always
86 // be different.
87 // TODO: For this optimisation to apply to conditional smstart/smstop,
88 // this pass will need to do more work to remove redundant calls to
89 // __arm_sme_state.
90
91 // Only consider conditional start/stop pairs which read the same register
92 // holding the original value of pstate.sm, as some conditional start/stops
93 // require the state on entry to the function.
94 if (MI1->getOperand(3).isReg() && MI2->getOperand(3).isReg()) {
95 Register Reg1 = MI1->getOperand(3).getReg();
96 Register Reg2 = MI2->getOperand(3).getReg();
97 if (Reg1.isPhysical() || Reg2.isPhysical() || Reg1 != Reg2)
98 return false;
99 }
100
101 return true;
102}
103
105 assert((MI->getOpcode() == AArch64::MSRpstatesvcrImm1 ||
106 MI->getOpcode() == AArch64::MSRpstatePseudo) &&
107 "Expected MI to be a smstart/smstop instruction");
108 return MI->getOperand(0).getImm() == AArch64SVCR::SVCRSM ||
109 MI->getOperand(0).getImm() == AArch64SVCR::SVCRSMZA;
110}
111
114 const MachineOperand &MO) {
115 if (!MO.isReg())
116 return false;
117
118 Register R = MO.getReg();
119 if (R.isPhysical())
120 return llvm::any_of(TRI.subregs_inclusive(R), [](const MCPhysReg &SR) {
121 return AArch64::ZPRRegClass.contains(SR) ||
122 AArch64::PPRRegClass.contains(SR);
123 });
124
125 const TargetRegisterClass *RC = MRI.getRegClass(R);
126 return TRI.getCommonSubClass(&AArch64::ZPRRegClass, RC) ||
127 TRI.getCommonSubClass(&AArch64::PPRRegClass, RC);
128}
129
130bool SMEPeepholeOpt::optimizeStartStopPairs(
131 MachineBasicBlock &MBB, bool &HasRemovedAllSMChanges) const {
133 const TargetRegisterInfo &TRI =
135
136 bool Changed = false;
137 MachineInstr *Prev = nullptr;
139
140 // Convenience function to reset the matching of a sequence.
141 auto Reset = [&]() {
142 Prev = nullptr;
143 ToBeRemoved.clear();
144 };
145
146 // Walk through instructions in the block trying to find pairs of smstart
147 // and smstop nodes that cancel each other out. We only permit a limited
148 // set of instructions to appear between them, otherwise we reset our
149 // tracking.
150 unsigned NumSMChanges = 0;
151 unsigned NumSMChangesRemoved = 0;
153 switch (MI.getOpcode()) {
154 case AArch64::MSRpstatesvcrImm1:
155 case AArch64::MSRpstatePseudo: {
157 NumSMChanges++;
158
159 if (!Prev)
160 Prev = &MI;
161 else if (isMatchingStartStopPair(Prev, &MI)) {
162 // If they match, we can remove them, and possibly any instructions
163 // that we marked for deletion in between.
164 Prev->eraseFromParent();
165 MI.eraseFromParent();
166 for (MachineInstr *TBR : ToBeRemoved)
167 TBR->eraseFromParent();
168 ToBeRemoved.clear();
169 Prev = nullptr;
170 Changed = true;
171 NumSMChangesRemoved += 2;
172 } else {
173 Reset();
174 Prev = &MI;
175 }
176 continue;
177 }
178 default:
179 if (!Prev)
180 // Avoid doing expensive checks when Prev is nullptr.
181 continue;
182 break;
183 }
184
185 // Test if the instructions in between the start/stop sequence are agnostic
186 // of streaming mode. If not, the algorithm should reset.
187 switch (MI.getOpcode()) {
188 default:
189 Reset();
190 break;
191 case AArch64::COALESCER_BARRIER_FPR16:
192 case AArch64::COALESCER_BARRIER_FPR32:
193 case AArch64::COALESCER_BARRIER_FPR64:
194 case AArch64::COALESCER_BARRIER_FPR128:
195 case AArch64::COPY:
196 // These instructions should be safe when executed on their own, but
197 // the code remains conservative when SVE registers are used. There may
198 // exist subtle cases where executing a COPY in a different mode results
199 // in different behaviour, even if we can't yet come up with any
200 // concrete example/test-case.
201 if (isSVERegOp(TRI, MRI, MI.getOperand(0)) ||
202 isSVERegOp(TRI, MRI, MI.getOperand(1)))
203 Reset();
204 break;
205 case AArch64::ADJCALLSTACKDOWN:
206 case AArch64::ADJCALLSTACKUP:
207 case AArch64::ANDXri:
208 case AArch64::ADDXri:
209 // We permit these as they don't generate SVE/NEON instructions.
210 break;
211 case AArch64::VGRestorePseudo:
212 case AArch64::VGSavePseudo:
213 // When the smstart/smstop are removed, we should also remove
214 // the pseudos that save/restore the VG value for CFI info.
215 ToBeRemoved.push_back(&MI);
216 break;
217 case AArch64::MSRpstatesvcrImm1:
218 case AArch64::MSRpstatePseudo:
219 llvm_unreachable("Should have been handled");
220 }
221 }
222
223 HasRemovedAllSMChanges =
224 NumSMChanges && (NumSMChanges == NumSMChangesRemoved);
225 return Changed;
226}
227
228INITIALIZE_PASS(SMEPeepholeOpt, "aarch64-sme-peephole-opt",
229 "SME Peephole Optimization", false, false)
230
231bool SMEPeepholeOpt::runOnMachineFunction(MachineFunction &MF) {
232 if (skipFunction(MF.getFunction()))
233 return false;
234
235 if (!MF.getSubtarget<AArch64Subtarget>().hasSME())
236 return false;
237
238 assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!");
239
240 bool Changed = false;
241 bool FunctionHasAllSMChangesRemoved = false;
242
243 // Even if the block lives in a function with no SME attributes attached we
244 // still have to analyze all the blocks because we may call a streaming
245 // function that requires smstart/smstop pairs.
246 for (MachineBasicBlock &MBB : MF) {
247 bool BlockHasAllSMChangesRemoved;
248 Changed |= optimizeStartStopPairs(MBB, BlockHasAllSMChangesRemoved);
249 FunctionHasAllSMChangesRemoved |= BlockHasAllSMChangesRemoved;
250 }
251
252 AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
253 if (FunctionHasAllSMChangesRemoved)
254 AFI->setHasStreamingModeChanges(false);
255
256 return Changed;
257}
258
259FunctionPass *llvm::createSMEPeepholeOptPass() { return new SMEPeepholeOpt(); }
unsigned const MachineRegisterInfo * MRI
MachineBasicBlock & MBB
IRTranslator LLVM IR MI
unsigned const TargetRegisterInfo * TRI
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:38
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static bool isSVERegOp(const TargetRegisterInfo &TRI, const MachineRegisterInfo &MRI, const MachineOperand &MO)
static bool isMatchingStartStopPair(const MachineInstr *MI1, const MachineInstr *MI2)
static bool isConditionalStartStop(const MachineInstr *MI)
static bool ChangesStreamingMode(const MachineInstr *MI)
This file defines the SmallVector class.
AArch64FunctionInfo - This class is derived from MachineFunctionInfo and contains private AArch64-spe...
void setHasStreamingModeChanges(bool HasChanges)
Represent the analysis usage information of a pass.
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:256
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:310
const MachineFunction * getParent() const
Return the MachineFunction containing this basic block.
MachineFunctionPass - This class adapts the FunctionPass interface to allow convenient creation of pa...
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - Subclasses that override getAnalysisUsage must call this.
virtual bool runOnMachineFunction(MachineFunction &MF)=0
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
const TargetSubtargetInfo & getSubtarget() const
getSubtarget - Return the subtarget for which this machine code is being compiled.
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
Representation of each machine instruction.
Definition: MachineInstr.h:69
void eraseFromParent()
Unlink 'this' from the containing basic block and delete it.
const MachineOperand & getOperand(unsigned i) const
Definition: MachineInstr.h:585
MachineOperand class - Representation of each machine instruction operand.
int64_t getImm() const
bool isReg() const
isReg - Tests if this is a MO_Register operand.
Register getReg() const
getReg - Returns the register number.
const uint32_t * getRegMask() const
getRegMask - Returns a bit mask of registers preserved by this RegMask operand.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
Definition: Pass.cpp:81
Wrapper class representing virtual and physical registers.
Definition: Register.h:19
constexpr bool isPhysical() const
Return true if the specified register number is in the physical register namespace.
Definition: Register.h:95
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1196
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:51
TargetRegisterInfo base class - We assume that the target defines a static array of TargetRegisterDes...
virtual const TargetRegisterInfo * getRegisterInfo() const
getRegisterInfo - If register information is available, return it.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:657
FunctionPass * createSMEPeepholeOptPass()
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1746
void initializeSMEPeepholeOptPass(PassRegistry &)