LLVM 22.0.0git
MachineSMEABIPass.cpp
Go to the documentation of this file.
1//===- MachineSMEABIPass.cpp ----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements the SME ABI requirements for ZA state. This includes
10// implementing the lazy (and agnostic) ZA state save schemes around calls.
11//
12//===----------------------------------------------------------------------===//
13//
14// This pass works by collecting instructions that require ZA to be in a
15// specific state (e.g., "ACTIVE" or "SAVED") and inserting the necessary state
16// transitions to ensure ZA is in the required state before instructions. State
17// transitions represent actions such as setting up or restoring a lazy save.
18// Certain points within a function may also have predefined states independent
19// of any instructions, for example, a "shared_za" function is always entered
20// and exited in the "ACTIVE" state.
21//
22// To handle ZA state across control flow, we make use of edge bundling. This
23// assigns each block an "incoming" and "outgoing" edge bundle (representing
24// incoming and outgoing edges). Initially, these are unique to each block;
25// then, in the process of forming bundles, the outgoing bundle of a block is
26// joined with the incoming bundle of all successors. The result is that each
27// bundle can be assigned a single ZA state, which ensures the state required by
28// all a blocks' successors is the same, and that each basic block will always
29// be entered with the same ZA state. This eliminates the need for splitting
30// edges to insert state transitions or "phi" nodes for ZA states.
31//
32// See below for a simple example of edge bundling.
33//
34// The following shows a conditionally executed basic block (BB1):
35//
36// if (cond)
37// BB1
38// BB2
39//
40// Initial Bundles Joined Bundles
41//
42// ┌──0──┐ ┌──0──┐
43// │ BB0 │ │ BB0 │
44// └──1──┘ └──1──┘
45// ├───────┐ ├───────┐
46// ▼ │ ▼ │
47// ┌──2──┐ │ ─────► ┌──1──┐ │
48// │ BB1 │ ▼ │ BB1 │ ▼
49// └──3──┘ ┌──4──┐ └──1──┘ ┌──1──┐
50// └───►4 BB2 │ └───►1 BB2 │
51// └──5──┘ └──2──┘
52//
53// On the left are the initial per-block bundles, and on the right are the
54// joined bundles (which are the result of the EdgeBundles analysis).
55
56#include "AArch64InstrInfo.h"
58#include "AArch64Subtarget.h"
69
70using namespace llvm;
71
72#define DEBUG_TYPE "aarch64-machine-sme-abi"
73
74namespace {
75
76// Note: For agnostic ZA, we assume the function is always entered/exited in the
77// "ACTIVE" state -- this _may_ not be the case (since OFF is also a
78// possibility, but for the purpose of placing ZA saves/restores, that does not
79// matter).
80enum ZAState : uint8_t {
81 // Any/unknown state (not valid)
82 ANY = 0,
83
84 // ZA is in use and active (i.e. within the accumulator)
85 ACTIVE,
86
87 // ZA is active, but ZT0 has been saved.
88 // This handles the edge case of sharedZA && !sharesZT0.
89 ACTIVE_ZT0_SAVED,
90
91 // A ZA save has been set up or committed (i.e. ZA is dormant or off)
92 // If the function uses ZT0 it must also be saved.
93 LOCAL_SAVED,
94
95 // ZA has been committed to the lazy save buffer of the current function.
96 // If the function uses ZT0 it must also be saved.
97 // ZA is off.
98 LOCAL_COMMITTED,
99
100 // The ZA/ZT0 state on entry to the function.
101 ENTRY,
102
103 // ZA is off.
104 OFF,
105
106 // The number of ZA states (not a valid state)
107 NUM_ZA_STATE
108};
109
110/// A bitmask enum to record live physical registers that the "emit*" routines
111/// may need to preserve. Note: This only tracks registers we may clobber.
112enum LiveRegs : uint8_t {
113 None = 0,
114 NZCV = 1 << 0,
115 W0 = 1 << 1,
116 W0_HI = 1 << 2,
117 X0 = W0 | W0_HI,
118 LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ W0_HI)
119};
120
121/// Holds the virtual registers live physical registers have been saved to.
122struct PhysRegSave {
123 LiveRegs PhysLiveRegs;
124 Register StatusFlags = AArch64::NoRegister;
125 Register X0Save = AArch64::NoRegister;
126};
127
128/// Contains the needed ZA state (and live registers) at an instruction. That is
129/// the state ZA must be in _before_ "InsertPt".
130struct InstInfo {
131 ZAState NeededState{ZAState::ANY};
133 LiveRegs PhysLiveRegs = LiveRegs::None;
134};
135
136/// Contains the needed ZA state for each instruction in a block. Instructions
137/// that do not require a ZA state are not recorded.
138struct BlockInfo {
140 ZAState FixedEntryState{ZAState::ANY};
141 ZAState DesiredIncomingState{ZAState::ANY};
142 ZAState DesiredOutgoingState{ZAState::ANY};
143 LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
144 LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
145};
146
147/// Contains the needed ZA state information for all blocks within a function.
148struct FunctionInfo {
150 std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
151 LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
152};
153
154/// State/helpers that is only needed when emitting code to handle
155/// saving/restoring ZA.
156class EmitContext {
157public:
158 EmitContext() = default;
159
160 /// Get or create a TPIDR2 block in \p MF.
161 int getTPIDR2Block(MachineFunction &MF) {
162 if (TPIDR2BlockFI)
163 return *TPIDR2BlockFI;
164 MachineFrameInfo &MFI = MF.getFrameInfo();
165 TPIDR2BlockFI = MFI.CreateStackObject(16, Align(16), false);
166 return *TPIDR2BlockFI;
167 }
168
169 /// Get or create agnostic ZA buffer pointer in \p MF.
170 Register getAgnosticZABufferPtr(MachineFunction &MF) {
171 if (AgnosticZABufferPtr != AArch64::NoRegister)
172 return AgnosticZABufferPtr;
173 Register BufferPtr =
174 MF.getInfo<AArch64FunctionInfo>()->getEarlyAllocSMESaveBuffer();
175 AgnosticZABufferPtr =
176 BufferPtr != AArch64::NoRegister
177 ? BufferPtr
178 : MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
179 return AgnosticZABufferPtr;
180 }
181
182 int getZT0SaveSlot(MachineFunction &MF) {
183 if (ZT0SaveFI)
184 return *ZT0SaveFI;
185 MachineFrameInfo &MFI = MF.getFrameInfo();
186 ZT0SaveFI = MFI.CreateSpillStackObject(64, Align(16));
187 return *ZT0SaveFI;
188 }
189
190 /// Returns true if the function must allocate a ZA save buffer on entry. This
191 /// will be the case if, at any point in the function, a ZA save was emitted.
192 bool needsSaveBuffer() const {
193 assert(!(TPIDR2BlockFI && AgnosticZABufferPtr) &&
194 "Cannot have both a TPIDR2 block and agnostic ZA buffer");
195 return TPIDR2BlockFI || AgnosticZABufferPtr != AArch64::NoRegister;
196 }
197
198private:
199 std::optional<int> ZT0SaveFI;
200 std::optional<int> TPIDR2BlockFI;
201 Register AgnosticZABufferPtr = AArch64::NoRegister;
202};
203
204/// Checks if \p State is a legal edge bundle state. For a state to be a legal
205/// bundle state, it must be possible to transition from it to any other bundle
206/// state without losing any ZA state. This is the case for ACTIVE/LOCAL_SAVED,
207/// as you can transition between those states by saving/restoring ZA. The OFF
208/// state would not be legal, as transitioning to it drops the content of ZA.
209static bool isLegalEdgeBundleZAState(ZAState State) {
210 switch (State) {
211 case ZAState::ACTIVE: // ZA state within the accumulator/ZT0.
212 case ZAState::ACTIVE_ZT0_SAVED: // ZT0 is saved (ZA is active).
213 case ZAState::LOCAL_SAVED: // ZA state may be saved on the stack.
214 case ZAState::LOCAL_COMMITTED: // ZA state is saved on the stack.
215 return true;
216 default:
217 return false;
218 }
219}
220
221StringRef getZAStateString(ZAState State) {
222#define MAKE_CASE(V) \
223 case V: \
224 return #V;
225 switch (State) {
226 MAKE_CASE(ZAState::ANY)
227 MAKE_CASE(ZAState::ACTIVE)
228 MAKE_CASE(ZAState::ACTIVE_ZT0_SAVED)
229 MAKE_CASE(ZAState::LOCAL_SAVED)
230 MAKE_CASE(ZAState::LOCAL_COMMITTED)
231 MAKE_CASE(ZAState::ENTRY)
232 MAKE_CASE(ZAState::OFF)
233 default:
234 llvm_unreachable("Unexpected ZAState");
235 }
236#undef MAKE_CASE
237}
238
239static bool isZAorZTRegOp(const TargetRegisterInfo &TRI,
240 const MachineOperand &MO) {
241 if (!MO.isReg() || !MO.getReg().isPhysical())
242 return false;
243 return any_of(TRI.subregs_inclusive(MO.getReg()), [](const MCPhysReg &SR) {
244 return AArch64::MPR128RegClass.contains(SR) ||
245 AArch64::ZTRRegClass.contains(SR);
246 });
247}
248
249/// Returns the required ZA state needed before \p MI and an iterator pointing
250/// to where any code required to change the ZA state should be inserted.
251static std::pair<ZAState, MachineBasicBlock::iterator>
252getInstNeededZAState(const TargetRegisterInfo &TRI, MachineInstr &MI,
253 SMEAttrs SMEFnAttrs) {
255
256 // Note: InOutZAUsePseudo, RequiresZASavePseudo, and RequiresZT0SavePseudo are
257 // intended to mark the position immediately before a call. Due to
258 // SelectionDAG constraints, these markers occur after the ADJCALLSTACKDOWN,
259 // so we use std::prev(InsertPt) to get the position before the call.
260
261 if (MI.getOpcode() == AArch64::InOutZAUsePseudo)
262 return {ZAState::ACTIVE, std::prev(InsertPt)};
263
264 // Note: If we need to save both ZA and ZT0 we use RequiresZASavePseudo.
265 if (MI.getOpcode() == AArch64::RequiresZASavePseudo)
266 return {ZAState::LOCAL_SAVED, std::prev(InsertPt)};
267
268 // If we only need to save ZT0 there's two cases to consider:
269 // 1. The function has ZA state (that we don't need to save).
270 // - In this case we switch to the "ACTIVE_ZT0_SAVED" state.
271 // This only saves ZT0.
272 // 2. The function does not have ZA state
273 // - In this case we switch to "LOCAL_COMMITTED" state.
274 // This saves ZT0 and turns ZA off.
275 if (MI.getOpcode() == AArch64::RequiresZT0SavePseudo) {
276 return {SMEFnAttrs.hasZAState() ? ZAState::ACTIVE_ZT0_SAVED
277 : ZAState::LOCAL_COMMITTED,
278 std::prev(InsertPt)};
279 }
280
281 if (MI.isReturn()) {
282 bool ZAOffAtReturn = SMEFnAttrs.hasPrivateZAInterface();
283 return {ZAOffAtReturn ? ZAState::OFF : ZAState::ACTIVE, InsertPt};
284 }
285
286 for (auto &MO : MI.operands()) {
287 if (isZAorZTRegOp(TRI, MO))
288 return {ZAState::ACTIVE, InsertPt};
289 }
290
291 return {ZAState::ANY, InsertPt};
292}
293
294struct MachineSMEABI : public MachineFunctionPass {
295 inline static char ID = 0;
296
297 MachineSMEABI(CodeGenOptLevel OptLevel = CodeGenOptLevel::Default)
298 : MachineFunctionPass(ID), OptLevel(OptLevel) {}
299
300 bool runOnMachineFunction(MachineFunction &MF) override;
301
302 StringRef getPassName() const override { return "Machine SME ABI pass"; }
303
304 void getAnalysisUsage(AnalysisUsage &AU) const override {
305 AU.setPreservesCFG();
311 }
312
313 /// Collects the needed ZA state (and live registers) before each instruction
314 /// within the machine function.
315 FunctionInfo collectNeededZAStates(SMEAttrs SMEFnAttrs);
316
317 /// Assigns each edge bundle a ZA state based on the needed states of blocks
318 /// that have incoming or outgoing edges in that bundle.
319 SmallVector<ZAState> assignBundleZAStates(const EdgeBundles &Bundles,
320 const FunctionInfo &FnInfo);
321
322 /// Inserts code to handle changes between ZA states within the function.
323 /// E.g., ACTIVE -> LOCAL_SAVED will insert code required to save ZA.
324 void insertStateChanges(EmitContext &, const FunctionInfo &FnInfo,
325 const EdgeBundles &Bundles,
326 ArrayRef<ZAState> BundleStates);
327
328 /// Propagates desired states forwards (from predecessors -> successors) if
329 /// \p Forwards, otherwise, propagates backwards (from successors ->
330 /// predecessors).
331 void propagateDesiredStates(FunctionInfo &FnInfo, bool Forwards = true);
332
333 void emitZT0SaveRestore(EmitContext &, MachineBasicBlock &MBB,
334 MachineBasicBlock::iterator MBBI, bool IsSave);
335
336 // Emission routines for private and shared ZA functions (using lazy saves).
337 void emitSMEPrologue(MachineBasicBlock &MBB,
339 void emitRestoreLazySave(EmitContext &, MachineBasicBlock &MBB,
341 LiveRegs PhysLiveRegs);
342 void emitSetupLazySave(EmitContext &, MachineBasicBlock &MBB,
344 void emitAllocateLazySaveBuffer(EmitContext &, MachineBasicBlock &MBB,
347 bool ClearTPIDR2, bool On);
348
349 // Emission routines for agnostic ZA functions.
350 void emitSetupFullZASave(MachineBasicBlock &MBB,
352 LiveRegs PhysLiveRegs);
353 // Emit a "full" ZA save or restore. It is "full" in the sense that this
354 // function will emit a call to __arm_sme_save or __arm_sme_restore, which
355 // handles saving and restoring both ZA and ZT0.
356 void emitFullZASaveRestore(EmitContext &, MachineBasicBlock &MBB,
358 LiveRegs PhysLiveRegs, bool IsSave);
359 void emitAllocateFullZASaveBuffer(EmitContext &, MachineBasicBlock &MBB,
361 LiveRegs PhysLiveRegs);
362
363 /// Attempts to find an insertion point before \p Inst where the status flags
364 /// are not live. If \p Inst is `Block.Insts.end()` a point before the end of
365 /// the block is found.
366 std::pair<MachineBasicBlock::iterator, LiveRegs>
367 findStateChangeInsertionPoint(MachineBasicBlock &MBB, const BlockInfo &Block,
369 void emitStateChange(EmitContext &, MachineBasicBlock &MBB,
370 MachineBasicBlock::iterator MBBI, ZAState From,
371 ZAState To, LiveRegs PhysLiveRegs);
372
373 // Helpers for switching between lazy/full ZA save/restore routines.
374 void emitZASave(EmitContext &Context, MachineBasicBlock &MBB,
376 if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
377 return emitFullZASaveRestore(Context, MBB, MBBI, PhysLiveRegs,
378 /*IsSave=*/true);
379 return emitSetupLazySave(Context, MBB, MBBI);
380 }
381 void emitZARestore(EmitContext &Context, MachineBasicBlock &MBB,
383 if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
384 return emitFullZASaveRestore(Context, MBB, MBBI, PhysLiveRegs,
385 /*IsSave=*/false);
386 return emitRestoreLazySave(Context, MBB, MBBI, PhysLiveRegs);
387 }
388 void emitAllocateZASaveBuffer(EmitContext &Context, MachineBasicBlock &MBB,
390 LiveRegs PhysLiveRegs) {
391 if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
392 return emitAllocateFullZASaveBuffer(Context, MBB, MBBI, PhysLiveRegs);
393 return emitAllocateLazySaveBuffer(Context, MBB, MBBI);
394 }
395
396 /// Collects the reachable calls from \p MBBI marked with \p Marker. This is
397 /// intended to be used to emit lazy save remarks. Note: This stops at the
398 /// first marked call along any path.
399 void collectReachableMarkedCalls(const MachineBasicBlock &MBB,
402 unsigned Marker) const;
403
404 void emitCallSaveRemarks(const MachineBasicBlock &MBB,
406 unsigned Marker, StringRef RemarkName,
407 StringRef SaveName) const;
408
409 /// Save live physical registers to virtual registers.
410 PhysRegSave createPhysRegSave(LiveRegs PhysLiveRegs, MachineBasicBlock &MBB,
412 /// Restore physical registers from a save of their previous values.
413 void restorePhyRegSave(const PhysRegSave &RegSave, MachineBasicBlock &MBB,
415
416private:
418
419 MachineFunction *MF = nullptr;
420 const AArch64Subtarget *Subtarget = nullptr;
421 const AArch64RegisterInfo *TRI = nullptr;
422 const AArch64FunctionInfo *AFI = nullptr;
423 const TargetInstrInfo *TII = nullptr;
424
426 MachineRegisterInfo *MRI = nullptr;
427 MachineLoopInfo *MLI = nullptr;
428};
429
430static LiveRegs getPhysLiveRegs(LiveRegUnits const &LiveUnits) {
431 LiveRegs PhysLiveRegs = LiveRegs::None;
432 if (!LiveUnits.available(AArch64::NZCV))
433 PhysLiveRegs |= LiveRegs::NZCV;
434 // We have to track W0 and X0 separately as otherwise things can get
435 // confused if we attempt to preserve X0 but only W0 was defined.
436 if (!LiveUnits.available(AArch64::W0))
437 PhysLiveRegs |= LiveRegs::W0;
438 if (!LiveUnits.available(AArch64::W0_HI))
439 PhysLiveRegs |= LiveRegs::W0_HI;
440 return PhysLiveRegs;
441}
442
443static void setPhysLiveRegs(LiveRegUnits &LiveUnits, LiveRegs PhysLiveRegs) {
444 if (PhysLiveRegs & LiveRegs::NZCV)
445 LiveUnits.addReg(AArch64::NZCV);
446 if (PhysLiveRegs & LiveRegs::W0)
447 LiveUnits.addReg(AArch64::W0);
448 if (PhysLiveRegs & LiveRegs::W0_HI)
449 LiveUnits.addReg(AArch64::W0_HI);
450}
451
452[[maybe_unused]] bool isCallStartOpcode(unsigned Opc) {
453 switch (Opc) {
454 case AArch64::TLSDESC_CALLSEQ:
455 case AArch64::TLSDESC_AUTH_CALLSEQ:
456 case AArch64::ADJCALLSTACKDOWN:
457 return true;
458 default:
459 return false;
460 }
461}
462
463FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
464 assert((SMEFnAttrs.hasAgnosticZAInterface() || SMEFnAttrs.hasZT0State() ||
465 SMEFnAttrs.hasZAState()) &&
466 "Expected function to have ZA/ZT0 state!");
467
469 LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
470 std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
471
472 for (MachineBasicBlock &MBB : *MF) {
473 BlockInfo &Block = Blocks[MBB.getNumber()];
474
475 if (MBB.isEntryBlock()) {
476 // Entry block:
477 Block.FixedEntryState = ZAState::ENTRY;
478 } else if (MBB.isEHPad()) {
479 // EH entry block:
480 Block.FixedEntryState = ZAState::LOCAL_COMMITTED;
481 }
482
483 LiveRegUnits LiveUnits(*TRI);
484 LiveUnits.addLiveOuts(MBB);
485
486 Block.PhysLiveRegsAtExit = getPhysLiveRegs(LiveUnits);
487 auto FirstTerminatorInsertPt = MBB.getFirstTerminator();
488 auto FirstNonPhiInsertPt = MBB.getFirstNonPHI();
489 for (MachineInstr &MI : reverse(MBB)) {
491 LiveUnits.stepBackward(MI);
492 LiveRegs PhysLiveRegs = getPhysLiveRegs(LiveUnits);
493 // The SMEStateAllocPseudo marker is added to a function if the save
494 // buffer was allocated in SelectionDAG. It marks the end of the
495 // allocation -- which is a safe point for this pass to insert any TPIDR2
496 // block setup.
497 if (MI.getOpcode() == AArch64::SMEStateAllocPseudo) {
498 AfterSMEProloguePt = MBBI;
499 PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
500 }
501 // Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
502 auto [NeededState, InsertPt] = getInstNeededZAState(*TRI, MI, SMEFnAttrs);
503 assert((InsertPt == MBBI || isCallStartOpcode(InsertPt->getOpcode())) &&
504 "Unexpected state change insertion point!");
505 // TODO: Do something to avoid state changes where NZCV is live.
506 if (MBBI == FirstTerminatorInsertPt)
507 Block.PhysLiveRegsAtExit = PhysLiveRegs;
508 if (MBBI == FirstNonPhiInsertPt)
509 Block.PhysLiveRegsAtEntry = PhysLiveRegs;
510 if (NeededState != ZAState::ANY)
511 Block.Insts.push_back({NeededState, InsertPt, PhysLiveRegs});
512 }
513
514 // Reverse vector (as we had to iterate backwards for liveness).
515 std::reverse(Block.Insts.begin(), Block.Insts.end());
516
517 // Record the desired states on entry/exit of this block. These are the
518 // states that would not incur a state transition.
519 if (!Block.Insts.empty()) {
520 Block.DesiredIncomingState = Block.Insts.front().NeededState;
521 Block.DesiredOutgoingState = Block.Insts.back().NeededState;
522 }
523 }
524
525 return FunctionInfo{std::move(Blocks), AfterSMEProloguePt,
526 PhysLiveRegsAfterSMEPrologue};
527}
528
529void MachineSMEABI::propagateDesiredStates(FunctionInfo &FnInfo,
530 bool Forwards) {
531 // If `Forwards`, this propagates desired states from predecessors to
532 // successors, otherwise, this propagates states from successors to
533 // predecessors.
534 auto GetBlockState = [](BlockInfo &Block, bool Incoming) -> ZAState & {
535 return Incoming ? Block.DesiredIncomingState : Block.DesiredOutgoingState;
536 };
537
539 for (auto [BlockID, BlockInfo] : enumerate(FnInfo.Blocks)) {
540 if (!isLegalEdgeBundleZAState(GetBlockState(BlockInfo, Forwards)))
541 Worklist.push_back(MF->getBlockNumbered(BlockID));
542 }
543
544 while (!Worklist.empty()) {
545 MachineBasicBlock *MBB = Worklist.pop_back_val();
546 BlockInfo &Block = FnInfo.Blocks[MBB->getNumber()];
547
548 // Pick a legal edge bundle state that matches the majority of
549 // predecessors/successors.
550 int StateCounts[ZAState::NUM_ZA_STATE] = {0};
551 for (MachineBasicBlock *PredOrSucc :
552 Forwards ? predecessors(MBB) : successors(MBB)) {
553 BlockInfo &PredOrSuccBlock = FnInfo.Blocks[PredOrSucc->getNumber()];
554 ZAState ZAState = GetBlockState(PredOrSuccBlock, !Forwards);
555 if (isLegalEdgeBundleZAState(ZAState))
556 StateCounts[ZAState]++;
557 }
558
559 ZAState PropagatedState = ZAState(max_element(StateCounts) - StateCounts);
560 ZAState &CurrentState = GetBlockState(Block, Forwards);
561 if (PropagatedState != CurrentState) {
562 CurrentState = PropagatedState;
563 ZAState &OtherState = GetBlockState(Block, !Forwards);
564 // Propagate to the incoming/outgoing state if that is also "ANY".
565 if (OtherState == ZAState::ANY)
566 OtherState = PropagatedState;
567 // Push any successors/predecessors that may need updating to the
568 // worklist.
569 for (MachineBasicBlock *SuccOrPred :
570 Forwards ? successors(MBB) : predecessors(MBB)) {
571 BlockInfo &SuccOrPredBlock = FnInfo.Blocks[SuccOrPred->getNumber()];
572 if (!isLegalEdgeBundleZAState(GetBlockState(SuccOrPredBlock, Forwards)))
573 Worklist.push_back(SuccOrPred);
574 }
575 }
576 }
577}
578
579/// Assigns each edge bundle a ZA state based on the needed states of blocks
580/// that have incoming or outgoing edges in that bundle.
582MachineSMEABI::assignBundleZAStates(const EdgeBundles &Bundles,
583 const FunctionInfo &FnInfo) {
584 SmallVector<ZAState> BundleStates(Bundles.getNumBundles());
585 for (unsigned I = 0, E = Bundles.getNumBundles(); I != E; ++I) {
586 LLVM_DEBUG(dbgs() << "Assigning ZA state for edge bundle: " << I << '\n');
587
588 // Attempt to assign a ZA state for this bundle that minimizes state
589 // transitions. Edges within loops are given a higher weight as we assume
590 // they will be executed more than once.
591 int EdgeStateCounts[ZAState::NUM_ZA_STATE] = {0};
592 for (unsigned BlockID : Bundles.getBlocks(I)) {
593 LLVM_DEBUG(dbgs() << "- bb." << BlockID);
594
595 const BlockInfo &Block = FnInfo.Blocks[BlockID];
596 bool InEdge = Bundles.getBundle(BlockID, /*Out=*/false) == I;
597 bool OutEdge = Bundles.getBundle(BlockID, /*Out=*/true) == I;
598
599 bool LegalInEdge =
600 InEdge && isLegalEdgeBundleZAState(Block.DesiredIncomingState);
601 bool LegalOutEgde =
602 OutEdge && isLegalEdgeBundleZAState(Block.DesiredOutgoingState);
603 if (LegalInEdge) {
604 LLVM_DEBUG(dbgs() << " DesiredIncomingState: "
605 << getZAStateString(Block.DesiredIncomingState));
606 EdgeStateCounts[Block.DesiredIncomingState]++;
607 }
608 if (LegalOutEgde) {
609 LLVM_DEBUG(dbgs() << " DesiredOutgoingState: "
610 << getZAStateString(Block.DesiredOutgoingState));
611 EdgeStateCounts[Block.DesiredOutgoingState]++;
612 }
613 if (!LegalInEdge && !LegalOutEgde)
614 LLVM_DEBUG(dbgs() << " (no state preference)");
615 LLVM_DEBUG(dbgs() << '\n');
616 }
617
618 ZAState BundleState =
619 ZAState(max_element(EdgeStateCounts) - EdgeStateCounts);
620
621 if (BundleState == ZAState::ANY)
622 BundleState = ZAState::ACTIVE;
623
624 LLVM_DEBUG({
625 dbgs() << "Chosen ZA state: " << getZAStateString(BundleState) << '\n'
626 << "Edge counts:";
627 for (auto [State, Count] : enumerate(EdgeStateCounts))
628 dbgs() << " " << getZAStateString(ZAState(State)) << ": " << Count;
629 dbgs() << "\n\n";
630 });
631
632 BundleStates[I] = BundleState;
633 }
634
635 return BundleStates;
636}
637
638std::pair<MachineBasicBlock::iterator, LiveRegs>
639MachineSMEABI::findStateChangeInsertionPoint(
640 MachineBasicBlock &MBB, const BlockInfo &Block,
642 LiveRegs PhysLiveRegs;
644 if (Inst != Block.Insts.end()) {
645 InsertPt = Inst->InsertPt;
646 PhysLiveRegs = Inst->PhysLiveRegs;
647 } else {
648 InsertPt = MBB.getFirstTerminator();
649 PhysLiveRegs = Block.PhysLiveRegsAtExit;
650 }
651
652 if (PhysLiveRegs == LiveRegs::None)
653 return {InsertPt, PhysLiveRegs}; // Nothing to do (no live regs).
654
655 // Find the previous state change. We can not move before this point.
656 MachineBasicBlock::iterator PrevStateChangeI;
657 if (Inst == Block.Insts.begin()) {
658 PrevStateChangeI = MBB.begin();
659 } else {
660 // Note: `std::prev(Inst)` is the previous InstInfo. We only create an
661 // InstInfo object for instructions that require a specific ZA state, so the
662 // InstInfo is the site of the previous state change in the block (which can
663 // be several MIs earlier).
664 PrevStateChangeI = std::prev(Inst)->InsertPt;
665 }
666
667 // Note: LiveUnits will only accurately track X0 and NZCV.
668 LiveRegUnits LiveUnits(*TRI);
669 setPhysLiveRegs(LiveUnits, PhysLiveRegs);
670 auto BestCandidate = std::make_pair(InsertPt, PhysLiveRegs);
671 for (MachineBasicBlock::iterator I = InsertPt; I != PrevStateChangeI; --I) {
672 // Don't move before/into a call (which may have a state change before it).
673 if (I->getOpcode() == TII->getCallFrameDestroyOpcode() || I->isCall())
674 break;
675 LiveUnits.stepBackward(*I);
676 LiveRegs CurrentPhysLiveRegs = getPhysLiveRegs(LiveUnits);
677 // Find places where NZCV is available, but keep looking for locations where
678 // both NZCV and X0 are available, which can avoid some copies.
679 if (!(CurrentPhysLiveRegs & LiveRegs::NZCV))
680 BestCandidate = {I, CurrentPhysLiveRegs};
681 if (CurrentPhysLiveRegs == LiveRegs::None)
682 break;
683 }
684 return BestCandidate;
685}
686
687void MachineSMEABI::insertStateChanges(EmitContext &Context,
688 const FunctionInfo &FnInfo,
689 const EdgeBundles &Bundles,
690 ArrayRef<ZAState> BundleStates) {
691 for (MachineBasicBlock &MBB : *MF) {
692 const BlockInfo &Block = FnInfo.Blocks[MBB.getNumber()];
693 ZAState InState = BundleStates[Bundles.getBundle(MBB.getNumber(),
694 /*Out=*/false)];
695
696 ZAState CurrentState = Block.FixedEntryState;
697 if (CurrentState == ZAState::ANY)
698 CurrentState = InState;
699
700 for (auto &Inst : Block.Insts) {
701 if (CurrentState != Inst.NeededState) {
702 auto [InsertPt, PhysLiveRegs] =
703 findStateChangeInsertionPoint(MBB, Block, &Inst);
704 emitStateChange(Context, MBB, InsertPt, CurrentState, Inst.NeededState,
705 PhysLiveRegs);
706 CurrentState = Inst.NeededState;
707 }
708 }
709
710 if (MBB.succ_empty())
711 continue;
712
713 ZAState OutState =
714 BundleStates[Bundles.getBundle(MBB.getNumber(), /*Out=*/true)];
715 if (CurrentState != OutState) {
716 auto [InsertPt, PhysLiveRegs] =
717 findStateChangeInsertionPoint(MBB, Block, Block.Insts.end());
718 emitStateChange(Context, MBB, InsertPt, CurrentState, OutState,
719 PhysLiveRegs);
720 }
721 }
722}
723
726 if (MBB.empty())
727 return DebugLoc();
728 return MBBI != MBB.end() ? MBBI->getDebugLoc() : MBB.back().getDebugLoc();
729}
730
731/// Finds the first call (as determined by MachineInstr::isCall()) starting from
732/// \p MBBI in \p MBB marked with \p Marker (which is a marker opcode such as
733/// RequiresZASavePseudo). If a marked call is found, it is pushed to \p Calls
734/// and the function returns true.
735static bool findMarkedCall(const MachineBasicBlock &MBB,
738 unsigned Marker, unsigned CallDestroyOpcode) {
739 auto IsMarker = [&](auto &MI) { return MI.getOpcode() == Marker; };
740 auto MarkerInst = std::find_if(MBBI, MBB.end(), IsMarker);
741 if (MarkerInst == MBB.end())
742 return false;
744 while (++I != MBB.end()) {
745 if (I->isCall() || I->getOpcode() == CallDestroyOpcode)
746 break;
747 }
748 if (I != MBB.end() && I->isCall())
749 Calls.push_back(&*I);
750 // Note: This function always returns true if a "Marker" was found.
751 return true;
752}
753
754void MachineSMEABI::collectReachableMarkedCalls(
755 const MachineBasicBlock &StartMBB,
757 SmallVectorImpl<const MachineInstr *> &Calls, unsigned Marker) const {
758 assert(Marker == AArch64::InOutZAUsePseudo ||
759 Marker == AArch64::RequiresZASavePseudo ||
760 Marker == AArch64::RequiresZT0SavePseudo);
761 unsigned CallDestroyOpcode = TII->getCallFrameDestroyOpcode();
762 if (findMarkedCall(StartMBB, StartInst, Calls, Marker, CallDestroyOpcode))
763 return;
764
767 StartMBB.succ_rend());
768 while (!Worklist.empty()) {
769 const MachineBasicBlock *MBB = Worklist.pop_back_val();
770 auto [_, Inserted] = Visited.insert(MBB);
771 if (!Inserted)
772 continue;
773
774 if (!findMarkedCall(*MBB, MBB->begin(), Calls, Marker, CallDestroyOpcode))
775 Worklist.append(MBB->succ_rbegin(), MBB->succ_rend());
776 }
777}
778
779static StringRef getCalleeName(const MachineInstr &CallInst) {
780 assert(CallInst.isCall() && "expected a call");
781 for (const MachineOperand &MO : CallInst.operands()) {
782 if (MO.isSymbol())
783 return MO.getSymbolName();
784 if (MO.isGlobal())
785 return MO.getGlobal()->getName();
786 }
787 return {};
788}
789
790void MachineSMEABI::emitCallSaveRemarks(const MachineBasicBlock &MBB,
792 DebugLoc DL, unsigned Marker,
793 StringRef RemarkName,
794 StringRef SaveName) const {
795 auto SaveRemark = [&](DebugLoc DL, const MachineBasicBlock &MBB) {
796 return MachineOptimizationRemarkAnalysis("sme", RemarkName, DL, &MBB);
797 };
798 StringRef StateName = Marker == AArch64::RequiresZT0SavePseudo ? "ZT0" : "ZA";
799 ORE->emit([&] {
800 return SaveRemark(DL, MBB) << SaveName << " of " << StateName
801 << " emitted in '" << MF->getName() << "'";
802 });
803 if (!ORE->allowExtraAnalysis("sme"))
804 return;
805 SmallVector<const MachineInstr *> CallsRequiringSaves;
806 collectReachableMarkedCalls(MBB, MBBI, CallsRequiringSaves, Marker);
807 for (const MachineInstr *CallInst : CallsRequiringSaves) {
808 auto R = SaveRemark(CallInst->getDebugLoc(), *CallInst->getParent());
809 R << "call";
810 if (StringRef CalleeName = getCalleeName(*CallInst); !CalleeName.empty())
811 R << " to '" << CalleeName << "'";
812 R << " requires " << StateName << " save";
813 ORE->emit(R);
814 }
815}
816
817void MachineSMEABI::emitSetupLazySave(EmitContext &Context,
821
822 emitCallSaveRemarks(MBB, MBBI, DL, AArch64::RequiresZASavePseudo,
823 "SMELazySaveZA", "lazy save");
824
825 // Get pointer to TPIDR2 block.
826 Register TPIDR2 = MRI->createVirtualRegister(&AArch64::GPR64spRegClass);
827 Register TPIDR2Ptr = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
828 BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXri), TPIDR2)
829 .addFrameIndex(Context.getTPIDR2Block(*MF))
830 .addImm(0)
831 .addImm(0);
832 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), TPIDR2Ptr)
833 .addReg(TPIDR2);
834 // Set TPIDR2_EL0 to point to TPIDR2 block.
835 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSR))
836 .addImm(AArch64SysReg::TPIDR2_EL0)
837 .addReg(TPIDR2Ptr);
838}
839
840PhysRegSave MachineSMEABI::createPhysRegSave(LiveRegs PhysLiveRegs,
843 DebugLoc DL) {
844 PhysRegSave RegSave{PhysLiveRegs};
845 if (PhysLiveRegs & LiveRegs::NZCV) {
846 RegSave.StatusFlags = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
847 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS), RegSave.StatusFlags)
848 .addImm(AArch64SysReg::NZCV)
849 .addReg(AArch64::NZCV, RegState::Implicit);
850 }
851 // Note: Preserving X0 is "free" as this is before register allocation, so
852 // the register allocator is still able to optimize these copies.
853 if (PhysLiveRegs & LiveRegs::W0) {
854 RegSave.X0Save = MRI->createVirtualRegister(PhysLiveRegs & LiveRegs::W0_HI
855 ? &AArch64::GPR64RegClass
856 : &AArch64::GPR32RegClass);
857 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), RegSave.X0Save)
858 .addReg(PhysLiveRegs & LiveRegs::W0_HI ? AArch64::X0 : AArch64::W0);
859 }
860 return RegSave;
861}
862
863void MachineSMEABI::restorePhyRegSave(const PhysRegSave &RegSave,
866 DebugLoc DL) {
867 if (RegSave.StatusFlags != AArch64::NoRegister)
868 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSR))
869 .addImm(AArch64SysReg::NZCV)
870 .addReg(RegSave.StatusFlags)
871 .addReg(AArch64::NZCV, RegState::ImplicitDefine);
872
873 if (RegSave.X0Save != AArch64::NoRegister)
874 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY),
875 RegSave.PhysLiveRegs & LiveRegs::W0_HI ? AArch64::X0 : AArch64::W0)
876 .addReg(RegSave.X0Save);
877}
878
879void MachineSMEABI::emitRestoreLazySave(EmitContext &Context,
882 LiveRegs PhysLiveRegs) {
883 auto *TLI = Subtarget->getTargetLowering();
885 Register TPIDR2EL0 = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
886 Register TPIDR2 = AArch64::X0;
887
888 // TODO: Emit these within the restore MBB to prevent unnecessary saves.
889 PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
890
891 // Enable ZA.
892 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1))
893 .addImm(AArch64SVCR::SVCRZA)
894 .addImm(1);
895 // Get current TPIDR2_EL0.
896 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS), TPIDR2EL0)
897 .addImm(AArch64SysReg::TPIDR2_EL0);
898 // Get pointer to TPIDR2 block.
899 BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXri), TPIDR2)
900 .addFrameIndex(Context.getTPIDR2Block(*MF))
901 .addImm(0)
902 .addImm(0);
903 // (Conditionally) restore ZA state.
904 BuildMI(MBB, MBBI, DL, TII->get(AArch64::RestoreZAPseudo))
905 .addReg(TPIDR2EL0)
906 .addReg(TPIDR2)
907 .addExternalSymbol(TLI->getLibcallName(RTLIB::SMEABI_TPIDR2_RESTORE))
908 .addRegMask(TRI->SMEABISupportRoutinesCallPreservedMaskFromX0());
909 // Zero TPIDR2_EL0.
910 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSR))
911 .addImm(AArch64SysReg::TPIDR2_EL0)
912 .addReg(AArch64::XZR);
913
914 restorePhyRegSave(RegSave, MBB, MBBI, DL);
915}
916
917void MachineSMEABI::emitZAMode(MachineBasicBlock &MBB,
919 bool ClearTPIDR2, bool On) {
921
922 if (ClearTPIDR2)
923 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSR))
924 .addImm(AArch64SysReg::TPIDR2_EL0)
925 .addReg(AArch64::XZR);
926
927 // Disable ZA.
928 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1))
929 .addImm(AArch64SVCR::SVCRZA)
930 .addImm(On ? 1 : 0);
931}
932
933void MachineSMEABI::emitAllocateLazySaveBuffer(
934 EmitContext &Context, MachineBasicBlock &MBB,
936 MachineFrameInfo &MFI = MF->getFrameInfo();
938 Register SP = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
939 Register SVL = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
940 Register Buffer = AFI->getEarlyAllocSMESaveBuffer();
941
942 // Calculate SVL.
943 BuildMI(MBB, MBBI, DL, TII->get(AArch64::RDSVLI_XI), SVL).addImm(1);
944
945 // 1. Allocate the lazy save buffer.
946 if (Buffer == AArch64::NoRegister) {
947 // TODO: On Windows, we allocate the lazy save buffer in SelectionDAG (so
948 // Buffer != AArch64::NoRegister). This is done to reuse the existing
949 // expansions (which can insert stack checks). This works, but it means we
950 // will always allocate the lazy save buffer (even if the function contains
951 // no lazy saves). If we want to handle Windows here, we'll need to
952 // implement something similar to LowerWindowsDYNAMIC_STACKALLOC.
953 assert(!Subtarget->isTargetWindows() &&
954 "Lazy ZA save is not yet supported on Windows");
955 Buffer = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
956 // Get original stack pointer.
957 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), SP)
958 .addReg(AArch64::SP);
959 // Allocate a lazy-save buffer object of the size given, normally SVL * SVL
960 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSUBXrrr), Buffer)
961 .addReg(SVL)
962 .addReg(SVL)
963 .addReg(SP);
964 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), AArch64::SP)
965 .addReg(Buffer);
966 // We have just allocated a variable sized object, tell this to PEI.
967 MFI.CreateVariableSizedObject(Align(16), nullptr);
968 }
969
970 // 2. Setup the TPIDR2 block.
971 {
972 // Note: This case just needs to do `SVL << 48`. It is not implemented as we
973 // generally don't support big-endian SVE/SME.
974 if (!Subtarget->isLittleEndian())
976 "TPIDR2 block initialization is not supported on big-endian targets");
977
978 // Store buffer pointer and num_za_save_slices.
979 // Bytes 10-15 are implicitly zeroed.
980 BuildMI(MBB, MBBI, DL, TII->get(AArch64::STPXi))
981 .addReg(Buffer)
982 .addReg(SVL)
983 .addFrameIndex(Context.getTPIDR2Block(*MF))
984 .addImm(0);
985 }
986}
987
988static constexpr unsigned ZERO_ALL_ZA_MASK = 0b11111111;
989
990void MachineSMEABI::emitSMEPrologue(MachineBasicBlock &MBB,
992 auto *TLI = Subtarget->getTargetLowering();
994
995 bool ZeroZA = AFI->getSMEFnAttrs().isNewZA();
996 bool ZeroZT0 = AFI->getSMEFnAttrs().isNewZT0();
998 // Get current TPIDR2_EL0.
999 Register TPIDR2EL0 = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
1000 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS))
1001 .addReg(TPIDR2EL0, RegState::Define)
1002 .addImm(AArch64SysReg::TPIDR2_EL0);
1003 // If TPIDR2_EL0 is non-zero, commit the lazy save.
1004 // NOTE: Functions that only use ZT0 don't need to zero ZA.
1005 auto CommitZASave =
1006 BuildMI(MBB, MBBI, DL, TII->get(AArch64::CommitZASavePseudo))
1007 .addReg(TPIDR2EL0)
1008 .addImm(ZeroZA)
1009 .addImm(ZeroZT0)
1010 .addExternalSymbol(TLI->getLibcallName(RTLIB::SMEABI_TPIDR2_SAVE))
1011 .addRegMask(TRI->SMEABISupportRoutinesCallPreservedMaskFromX0());
1012 if (ZeroZA)
1013 CommitZASave.addDef(AArch64::ZAB0, RegState::ImplicitDefine);
1014 if (ZeroZT0)
1015 CommitZASave.addDef(AArch64::ZT0, RegState::ImplicitDefine);
1016 // Enable ZA (as ZA could have previously been in the OFF state).
1017 BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1))
1018 .addImm(AArch64SVCR::SVCRZA)
1019 .addImm(1);
1020 } else if (AFI->getSMEFnAttrs().hasSharedZAInterface()) {
1021 if (ZeroZA)
1022 BuildMI(MBB, MBBI, DL, TII->get(AArch64::ZERO_M))
1024 .addDef(AArch64::ZAB0, RegState::ImplicitDefine);
1025 if (ZeroZT0)
1026 BuildMI(MBB, MBBI, DL, TII->get(AArch64::ZERO_T)).addDef(AArch64::ZT0);
1027 }
1028}
1029
1030void MachineSMEABI::emitFullZASaveRestore(EmitContext &Context,
1033 LiveRegs PhysLiveRegs, bool IsSave) {
1034 auto *TLI = Subtarget->getTargetLowering();
1035
1037
1038 if (IsSave)
1039 emitCallSaveRemarks(MBB, MBBI, DL, AArch64::RequiresZASavePseudo,
1040 "SMEFullZASave", "full save");
1041
1042 PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
1043
1044 // Copy the buffer pointer into X0.
1045 Register BufferPtr = AArch64::X0;
1046 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferPtr)
1047 .addReg(Context.getAgnosticZABufferPtr(*MF));
1048
1049 // Call __arm_sme_save/__arm_sme_restore.
1050 BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL))
1051 .addReg(BufferPtr, RegState::Implicit)
1052 .addExternalSymbol(TLI->getLibcallName(
1053 IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE))
1054 .addRegMask(TRI->getCallPreservedMask(
1055 *MF,
1057
1058 restorePhyRegSave(RegSave, MBB, MBBI, DL);
1059}
1060
1061void MachineSMEABI::emitZT0SaveRestore(EmitContext &Context,
1064 bool IsSave) {
1066
1067 // Note: This will report calls that _only_ need ZT0 saved. Call that save
1068 // both ZA and ZT0 will be under the SMELazySaveZA remark. This prevents
1069 // reporting the same calls twice.
1070 if (IsSave)
1071 emitCallSaveRemarks(MBB, MBBI, DL, AArch64::RequiresZT0SavePseudo,
1072 "SMEZT0Save", "spill");
1073
1074 Register ZT0Save = MRI->createVirtualRegister(&AArch64::GPR64spRegClass);
1075
1076 BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXri), ZT0Save)
1077 .addFrameIndex(Context.getZT0SaveSlot(*MF))
1078 .addImm(0)
1079 .addImm(0);
1080
1081 if (IsSave) {
1082 BuildMI(MBB, MBBI, DL, TII->get(AArch64::STR_TX))
1083 .addReg(AArch64::ZT0)
1084 .addReg(ZT0Save);
1085 } else {
1086 BuildMI(MBB, MBBI, DL, TII->get(AArch64::LDR_TX), AArch64::ZT0)
1087 .addReg(ZT0Save);
1088 }
1089}
1090
1091void MachineSMEABI::emitAllocateFullZASaveBuffer(
1092 EmitContext &Context, MachineBasicBlock &MBB,
1094 // Buffer already allocated in SelectionDAG.
1095 if (AFI->getEarlyAllocSMESaveBuffer())
1096 return;
1097
1099 Register BufferPtr = Context.getAgnosticZABufferPtr(*MF);
1100 Register BufferSize = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
1101
1102 PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
1103
1104 // Calculate the SME state size.
1105 {
1106 auto *TLI = Subtarget->getTargetLowering();
1107 const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
1108 BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL))
1109 .addExternalSymbol(TLI->getLibcallName(RTLIB::SMEABI_SME_STATE_SIZE))
1110 .addReg(AArch64::X0, RegState::ImplicitDefine)
1111 .addRegMask(TRI->getCallPreservedMask(
1112 *MF, CallingConv::
1114 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferSize)
1115 .addReg(AArch64::X0);
1116 }
1117
1118 // Allocate a buffer object of the size given __arm_sme_state_size.
1119 {
1120 MachineFrameInfo &MFI = MF->getFrameInfo();
1121 BuildMI(MBB, MBBI, DL, TII->get(AArch64::SUBXrx64), AArch64::SP)
1122 .addReg(AArch64::SP)
1123 .addReg(BufferSize)
1125 BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferPtr)
1126 .addReg(AArch64::SP);
1127
1128 // We have just allocated a variable sized object, tell this to PEI.
1129 MFI.CreateVariableSizedObject(Align(16), nullptr);
1130 }
1131
1132 restorePhyRegSave(RegSave, MBB, MBBI, DL);
1133}
1134
1135struct FromState {
1136 ZAState From;
1137
1138 constexpr uint8_t to(ZAState To) const {
1139 static_assert(NUM_ZA_STATE < 16, "expected ZAState to fit in 4-bits");
1140 return uint8_t(From) << 4 | uint8_t(To);
1141 }
1142};
1143
1144constexpr FromState transitionFrom(ZAState From) { return FromState{From}; }
1145
1146void MachineSMEABI::emitStateChange(EmitContext &Context,
1149 ZAState From, ZAState To,
1150 LiveRegs PhysLiveRegs) {
1151 // ZA not used.
1152 if (From == ZAState::ANY || To == ZAState::ANY)
1153 return;
1154
1155 // If we're exiting from the ENTRY state that means that the function has not
1156 // used ZA, so in the case of private ZA/ZT0 functions we can omit any set up.
1157 if (From == ZAState::ENTRY && To == ZAState::OFF)
1158 return;
1159
1160 // TODO: Avoid setting up the save buffer if there's no transition to
1161 // LOCAL_SAVED.
1162 if (From == ZAState::ENTRY) {
1163 assert(&MBB == &MBB.getParent()->front() &&
1164 "ENTRY state only valid in entry block");
1165 emitSMEPrologue(MBB, MBB.getFirstNonPHI());
1166 if (To == ZAState::ACTIVE)
1167 return; // Nothing more to do (ZA is active after the prologue).
1168
1169 // Note: "emitNewZAPrologue" zeros ZA, so we may need to setup a lazy save
1170 // if "To" is "ZAState::LOCAL_SAVED". It may be possible to improve this
1171 // case by changing the placement of the zero instruction.
1172 From = ZAState::ACTIVE;
1173 }
1174
1175 SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs();
1176 bool IsAgnosticZA = SMEFnAttrs.hasAgnosticZAInterface();
1177 bool HasZT0State = SMEFnAttrs.hasZT0State();
1178 bool HasZAState = IsAgnosticZA || SMEFnAttrs.hasZAState();
1179
1180 switch (transitionFrom(From).to(To)) {
1181 // This section handles: ACTIVE <-> ACTIVE_ZT0_SAVED
1182 case transitionFrom(ZAState::ACTIVE).to(ZAState::ACTIVE_ZT0_SAVED):
1183 emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true);
1184 break;
1185 case transitionFrom(ZAState::ACTIVE_ZT0_SAVED).to(ZAState::ACTIVE):
1186 emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/false);
1187 break;
1188
1189 // This section handles: ACTIVE[_ZT0_SAVED] -> LOCAL_SAVED
1190 case transitionFrom(ZAState::ACTIVE).to(ZAState::LOCAL_SAVED):
1191 case transitionFrom(ZAState::ACTIVE_ZT0_SAVED).to(ZAState::LOCAL_SAVED):
1192 if (HasZT0State && From == ZAState::ACTIVE)
1193 emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true);
1194 if (HasZAState)
1195 emitZASave(Context, MBB, InsertPt, PhysLiveRegs);
1196 break;
1197
1198 // This section handles: ACTIVE -> LOCAL_COMMITTED
1199 case transitionFrom(ZAState::ACTIVE).to(ZAState::LOCAL_COMMITTED):
1200 // TODO: We could support ZA state here, but this transition is currently
1201 // only possible when we _don't_ have ZA state.
1202 assert(HasZT0State && !HasZAState && "Expect to only have ZT0 state.");
1203 emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true);
1204 emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/false, /*On=*/false);
1205 break;
1206
1207 // This section handles: LOCAL_COMMITTED -> (OFF|LOCAL_SAVED)
1208 case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::OFF):
1209 case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::LOCAL_SAVED):
1210 // These transistions are a no-op.
1211 break;
1212
1213 // This section handles: LOCAL_(SAVED|COMMITTED) -> ACTIVE[_ZT0_SAVED]
1214 case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::ACTIVE):
1215 case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::ACTIVE_ZT0_SAVED):
1216 case transitionFrom(ZAState::LOCAL_SAVED).to(ZAState::ACTIVE):
1217 if (HasZAState)
1218 emitZARestore(Context, MBB, InsertPt, PhysLiveRegs);
1219 else
1220 emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/false, /*On=*/true);
1221 if (HasZT0State && To == ZAState::ACTIVE)
1222 emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/false);
1223 break;
1224
1225 // This section handles transistions to OFF (not previously covered)
1226 case transitionFrom(ZAState::ACTIVE).to(ZAState::OFF):
1227 case transitionFrom(ZAState::ACTIVE_ZT0_SAVED).to(ZAState::OFF):
1228 case transitionFrom(ZAState::LOCAL_SAVED).to(ZAState::OFF):
1229 assert(SMEFnAttrs.hasPrivateZAInterface() &&
1230 "Did not expect to turn ZA off in shared/agnostic ZA function");
1231 emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED,
1232 /*On=*/false);
1233 break;
1234
1235 default:
1236 dbgs() << "Error: Transition from " << getZAStateString(From) << " to "
1237 << getZAStateString(To) << '\n';
1238 llvm_unreachable("Unimplemented state transition");
1239 }
1240}
1241
1242} // end anonymous namespace
1243
1244INITIALIZE_PASS(MachineSMEABI, "aarch64-machine-sme-abi", "Machine SME ABI",
1245 false, false)
1246
1247bool MachineSMEABI::runOnMachineFunction(MachineFunction &MF) {
1248 if (!MF.getSubtarget<AArch64Subtarget>().hasSME())
1249 return false;
1250
1251 AFI = MF.getInfo<AArch64FunctionInfo>();
1252 SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs();
1253 if (!SMEFnAttrs.hasZAState() && !SMEFnAttrs.hasZT0State() &&
1254 !SMEFnAttrs.hasAgnosticZAInterface())
1255 return false;
1256
1257 assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!");
1258
1259 this->MF = &MF;
1260 Subtarget = &MF.getSubtarget<AArch64Subtarget>();
1261 ORE = &getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE();
1262 TII = Subtarget->getInstrInfo();
1263 TRI = Subtarget->getRegisterInfo();
1264 MRI = &MF.getRegInfo();
1265
1266 const EdgeBundles &Bundles =
1267 getAnalysis<EdgeBundlesWrapperLegacy>().getEdgeBundles();
1268
1269 FunctionInfo FnInfo = collectNeededZAStates(SMEFnAttrs);
1270
1271 if (OptLevel != CodeGenOptLevel::None) {
1272 // Propagate desired states forward, then backwards. Most of the propagation
1273 // should be done in the forward step, and backwards propagation is then
1274 // used to fill in the gaps. Note: Doing both in one step can give poor
1275 // results. For example, consider this subgraph:
1276 //
1277 // ┌─────┐
1278 // ┌─┤ BB0 ◄───┐
1279 // │ └─┬───┘ │
1280 // │ ┌─▼───◄──┐│
1281 // │ │ BB1 │ ││
1282 // │ └─┬┬──┘ ││
1283 // │ │└─────┘│
1284 // │ ┌─▼───┐ │
1285 // │ │ BB2 ├───┘
1286 // │ └─┬───┘
1287 // │ ┌─▼───┐
1288 // └─► BB3 │
1289 // └─────┘
1290 //
1291 // If:
1292 // - "BB0" and "BB2" (outer loop) has no state preference
1293 // - "BB1" (inner loop) desires the ACTIVE state on entry/exit
1294 // - "BB3" desires the LOCAL_SAVED state on entry
1295 //
1296 // If we propagate forwards first, ACTIVE is propagated from BB1 to BB2,
1297 // then from BB2 to BB0. Which results in the inner and outer loops having
1298 // the "ACTIVE" state. This avoids any state changes in the loops.
1299 //
1300 // If we propagate backwards first, we _could_ propagate LOCAL_SAVED from
1301 // BB3 to BB0, which would result in a transition from ACTIVE -> LOCAL_SAVED
1302 // in the outer loop.
1303 for (bool Forwards : {true, false})
1304 propagateDesiredStates(FnInfo, Forwards);
1305 }
1306
1307 SmallVector<ZAState> BundleStates = assignBundleZAStates(Bundles, FnInfo);
1308
1309 EmitContext Context;
1310 insertStateChanges(Context, FnInfo, Bundles, BundleStates);
1311
1312 if (Context.needsSaveBuffer()) {
1313 if (FnInfo.AfterSMEProloguePt) {
1314 // Note: With inline stack probes the AfterSMEProloguePt may not be in the
1315 // entry block (due to the probing loop).
1316 MachineBasicBlock::iterator MBBI = *FnInfo.AfterSMEProloguePt;
1317 emitAllocateZASaveBuffer(Context, *MBBI->getParent(), MBBI,
1318 FnInfo.PhysLiveRegsAfterSMEPrologue);
1319 } else {
1320 MachineBasicBlock &EntryBlock = MF.front();
1321 emitAllocateZASaveBuffer(
1322 Context, EntryBlock, EntryBlock.getFirstNonPHI(),
1323 FnInfo.Blocks[EntryBlock.getNumber()].PhysLiveRegsAtEntry);
1324 }
1325 }
1326
1327 return true;
1328}
1329
1331 return new MachineSMEABI(OptLevel);
1332}
unsigned const MachineRegisterInfo * MRI
static constexpr unsigned ZERO_ALL_ZA_MASK
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock & MBB
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
MachineBasicBlock MachineBasicBlock::iterator MBBI
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
const HexagonInstrInfo * TII
#define _
IRTranslator LLVM IR MI
This file implements the LivePhysRegs utility for tracking liveness of physical registers.
#define ENTRY(ASMNAME, ENUM)
#define I(x, y, z)
Definition MD5.cpp:57
static DebugLoc getDebugLoc(MachineBasicBlock::instr_iterator FirstMI, MachineBasicBlock::instr_iterator LastMI)
Return the first DebugLoc that has line number information, given a range of instructions.
===- MachineOptimizationRemarkEmitter.h - Opt Diagnostics -*- C++ -*-—===//
#define MAKE_CASE(V)
Register const TargetRegisterInfo * TRI
if(PassOpts->AAPipeline)
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition PassSupport.h:56
This file defines the SmallVector class.
#define LLVM_DEBUG(...)
Definition Debug.h:114
AArch64FunctionInfo - This class is derived from MachineFunctionInfo and contains private AArch64-spe...
const AArch64RegisterInfo * getRegisterInfo() const override
const AArch64TargetLowering * getTargetLowering() const override
Represent the analysis usage information of a pass.
AnalysisUsage & addPreservedID(const void *ID)
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition Pass.cpp:270
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
This class represents a function call, abstracting a target machine's calling convention.
A debug info location.
Definition DebugLoc.h:123
ArrayRef< unsigned > getBlocks(unsigned Bundle) const
getBlocks - Return an array of blocks that are connected to Bundle.
Definition EdgeBundles.h:53
unsigned getBundle(unsigned N, bool Out) const
getBundle - Return the ingoing (Out = false) or outgoing (Out = true) bundle number for basic block N
Definition EdgeBundles.h:47
unsigned getNumBundles() const
getNumBundles - Return the total number of bundles in the CFG.
Definition EdgeBundles.h:50
FunctionPass class - This class is used to implement most global optimizations.
Definition Pass.h:314
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
A set of register units used to track register liveness.
bool available(MCRegister Reg) const
Returns true if no part of physical register Reg is live.
void addReg(MCRegister Reg)
Adds register units covered by physical register Reg.
LLVM_ABI void stepBackward(const MachineInstr &MI)
Updates liveness when stepping backwards over the instruction MI.
LLVM_ABI void addLiveOuts(const MachineBasicBlock &MBB)
Adds registers living out of block MBB.
MachineInstrBundleIterator< const MachineInstr > const_iterator
int getNumber() const
MachineBasicBlocks are uniquely numbered at the function level, unless they're not in a MachineFuncti...
LLVM_ABI iterator getFirstNonPHI()
Returns a pointer to the first instruction in this block that is not a PHINode instruction.
succ_reverse_iterator succ_rbegin()
MachineInstrBundleIterator< MachineInstr > iterator
succ_reverse_iterator succ_rend()
The MachineFrameInfo class represents an abstract stack frame until prolog/epilog code is inserted.
LLVM_ABI int CreateStackObject(uint64_t Size, Align Alignment, bool isSpillSlot, const AllocaInst *Alloca=nullptr, uint8_t ID=0)
Create a new statically sized stack object, returning a nonnegative identifier to represent it.
LLVM_ABI int CreateSpillStackObject(uint64_t Size, Align Alignment)
Create a new statically sized stack object that represents a spill slot, returning a nonnegative iden...
LLVM_ABI int CreateVariableSizedObject(Align Alignment, const AllocaInst *Alloca)
Notify the MachineFrameInfo object that a variable sized object has been created.
MachineFunctionPass - This class adapts the FunctionPass interface to allow convenient creation of pa...
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - Subclasses that override getAnalysisUsage must call this.
StringRef getName() const
getName - Return the name of the corresponding LLVM function.
MachineFrameInfo & getFrameInfo()
getFrameInfo - Return the frame info object for the current function.
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
MachineBasicBlock * getBlockNumbered(unsigned N) const
getBlockNumbered - MachineBasicBlocks are automatically numbered when they are inserted into the mach...
unsigned getNumBlockIDs() const
getNumBlockIDs - Return the number of MBB ID's allocated.
Ty * getInfo()
getInfo - Keep track of various per-function pieces of information for backends that would like to do...
const MachineInstrBuilder & addExternalSymbol(const char *FnName, unsigned TargetFlags=0) const
const MachineInstrBuilder & addImm(int64_t Val) const
Add a new immediate operand.
const MachineInstrBuilder & addFrameIndex(int Idx) const
const MachineInstrBuilder & addRegMask(const uint32_t *Mask) const
const MachineInstrBuilder & addReg(Register RegNo, unsigned flags=0, unsigned SubReg=0) const
Add a new virtual register operand.
const MachineInstrBuilder & addDef(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register definition operand.
Representation of each machine instruction.
MachineOperand class - Representation of each machine instruction operand.
const GlobalValue * getGlobal() const
bool isReg() const
isReg - Tests if this is a MO_Register operand.
bool isSymbol() const
isSymbol - Tests if this is a MO_ExternalSymbol operand.
bool isGlobal() const
isGlobal - Tests if this is a MO_GlobalAddress operand.
const char * getSymbolName() const
Register getReg() const
getReg - Returns the register number.
Diagnostic information for optimization analysis remarks.
LLVM_ABI void emit(DiagnosticInfoOptimizationBase &OptDiag)
Emit an optimization remark.
bool allowExtraAnalysis(StringRef PassName) const
Whether we allow for extra compile-time budget to perform more analysis to be more informative.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
LLVM_ABI Register createVirtualRegister(const TargetRegisterClass *RegClass, StringRef Name="")
createVirtualRegister - Create and return a new virtual register in the function with the specified r...
Wrapper class representing virtual and physical registers.
Definition Register.h:20
constexpr bool isPhysical() const
Return true if the specified register number is in the physical register namespace.
Definition Register.h:83
SMEAttrs is a utility class to parse the SME ACLE attributes on functions.
bool hasAgnosticZAInterface() const
bool hasPrivateZAInterface() const
bool hasSharedZAInterface() const
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
typename SuperClass::const_iterator const_iterator
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
constexpr bool empty() const
empty - Check if the string is empty.
Definition StringRef.h:143
TargetInstrInfo - Interface to description of machine instruction set.
TargetRegisterInfo base class - We assume that the target defines a static array of TargetRegisterDes...
op_range operands()
Definition User.h:292
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:322
const ParentTy * getParent() const
Definition ilist_node.h:34
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
static unsigned getArithExtendImm(AArch64_AM::ShiftExtendType ET, unsigned Imm)
getArithExtendImm - Encode the extend type and shift amount for an arithmetic instruction: imm: 3-bit...
CallingConv Namespace - This namespace contains an enum with a value for the well-known calling conve...
Definition CallingConv.h:21
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1
Preserve X1-X15, X19-X29, SP, Z0-Z31, P0-P15.
@ Implicit
Not emitted register (e.g. carry, or temporary result).
@ Define
Register definition.
This is an optimization pass for GlobalISel generic memory operations.
MachineInstrBuilder BuildMI(MachineFunction &MF, const MIMetadata &MIMD, const MCInstrDesc &MCID)
Builder interface. Specify how to create the initial instruction itself.
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are tuples (A, B,...
Definition STLExtras.h:2530
auto successors(const MachineBasicBlock *BB)
FunctionPass * createMachineSMEABIPass(CodeGenOptLevel)
LLVM_ABI char & MachineDominatorsID
MachineDominators - This pass is a machine dominators analysis pass.
LLVM_ABI void reportFatalInternalError(Error Err)
Report a fatal error that indicates a bug in LLVM.
Definition Error.cpp:177
LLVM_ABI char & MachineLoopInfoID
MachineLoopInfo - This pass is a loop analysis pass.
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1744
auto reverse(ContainerTy &&C)
Definition STLExtras.h:406
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
FunctionAddr VTableAddr Count
Definition InstrProf.h:139
CodeGenOptLevel
Code generation optimization level.
Definition CodeGen.h:82
@ Default
-O2, -Os, -Oz
Definition CodeGen.h:85
@ LLVM_MARK_AS_BITMASK_ENUM
Definition ModRef.h:37
uint16_t MCPhysReg
An unsigned integer type large enough to represent all physical registers, but not necessarily virtua...
Definition MCRegister.h:21
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
Definition STLExtras.h:2078
auto predecessors(const MachineBasicBlock *BB)
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition Alignment.h:39
Incoming for lane maks phi as machine instruction, incoming register Reg and incoming block Block are...