112 using namespace llvm;
114 #define DEBUG_TYPE "aarch64-speculation-hardening" 116 #define AARCH64_SPECULATION_HARDENING_NAME "AArch64 speculation hardening pass" 119 cl::desc(
"Sanitize loads from memory."),
142 unsigned MisspeculatingTaintReg;
143 unsigned MisspeculatingTaintReg32Bit;
144 bool UseControlFlowSpeculationBarrier;
150 bool &UsesFullSpeculationBarrier);
160 unsigned TmpReg)
const;
170 bool UsesFullSpeculationBarrier);
173 bool UsesFullSpeculationBarrier);
185 bool AArch64SpeculationHardening::endsWithCondControlFlow(
193 if (analyzeBranchCondCode.
empty())
201 FBB = MBB.getFallThrough();
209 assert(MBB.succ_size() == 2);
211 assert(analyzeBranchCondCode.
size() == 1 &&
"unknown Cond array format");
216 void AArch64SpeculationHardening::insertFullSpeculationBarrier(
220 BuildMI(MBB, MBBI, DL,
TII->get(AArch64::DSB)).addImm(0xf);
221 BuildMI(MBB, MBBI, DL,
TII->get(AArch64::ISB)).addImm(0xf);
224 void AArch64SpeculationHardening::insertTrackingCode(
227 if (UseControlFlowSpeculationBarrier) {
228 insertFullSpeculationBarrier(SplitEdgeBB, SplitEdgeBB.
begin(), DL);
230 BuildMI(SplitEdgeBB, SplitEdgeBB.
begin(), DL,
TII->get(AArch64::CSELXr))
231 .addDef(MisspeculatingTaintReg)
232 .
addUse(MisspeculatingTaintReg)
239 bool AArch64SpeculationHardening::instrumentControlFlow(
241 LLVM_DEBUG(
dbgs() <<
"Instrument control flow tracking on MBB: " << MBB);
248 if (!endsWithCondControlFlow(MBB, TBB, FBB, CondCode)) {
260 assert(SplitEdgeTBB !=
nullptr);
261 assert(SplitEdgeFBB !=
nullptr);
267 insertTrackingCode(*SplitEdgeTBB, CondCode, DL);
268 insertTrackingCode(*SplitEdgeFBB, InvCondCode, DL);
285 bool TmpRegisterNotAvailableEverywhere =
false;
306 unsigned TmpReg = RS.
FindUnusedReg(&AArch64::GPR64commonRegClass);
308 << ((TmpReg == 0) ?
"no register " :
"register ");
310 dbgs() <<
"to be available at MI " <<
MI);
312 TmpRegisterNotAvailableEverywhere =
true;
314 ReturnInstructions.
push_back({&MI, TmpReg});
316 CallInstructions.
push_back({&MI, TmpReg});
319 if (TmpRegisterNotAvailableEverywhere) {
324 insertFullSpeculationBarrier(MBB, MBB.
begin(),
326 UsesFullSpeculationBarrier =
true;
329 for (
auto MI_Reg : ReturnInstructions) {
330 assert(MI_Reg.second != 0);
333 <<
" About to insert Reg to SP taint propagation with temp register " 335 <<
" on instruction: " << *MI_Reg.first);
336 insertRegToSPTaintPropagation(MBB, MI_Reg.first, MI_Reg.second);
340 for (
auto MI_Reg : CallInstructions) {
341 assert(MI_Reg.second != 0);
343 "propagation with temp register " 345 <<
" around instruction: " << *MI_Reg.first);
347 insertSPToRegTaintPropagation(
350 insertRegToSPTaintPropagation(MBB, MI_Reg.first, MI_Reg.second);
357 void AArch64SpeculationHardening::insertSPToRegTaintPropagation(
362 if (UseControlFlowSpeculationBarrier) {
363 insertFullSpeculationBarrier(MBB, MBBI,
DebugLoc());
369 .addDef(AArch64::XZR)
375 .addDef(MisspeculatingTaintReg)
381 void AArch64SpeculationHardening::insertRegToSPTaintPropagation(
383 unsigned TmpReg)
const {
387 if (UseControlFlowSpeculationBarrier)
410 bool AArch64SpeculationHardening::functionUsesHardeningRegister(
418 if (
MI.readsRegister(MisspeculatingTaintReg,
TRI) ||
419 MI.modifiesRegister(MisspeculatingTaintReg,
TRI))
429 bool AArch64SpeculationHardening::makeGPRSpeculationSafe(
433 AArch64::GPR64allRegClass.
contains(Reg));
440 if (Reg == AArch64::SP || Reg == AArch64::WSP)
444 if (RegsAlreadyMasked[Reg])
447 const bool Is64Bit = AArch64::GPR64allRegClass.contains(Reg);
448 LLVM_DEBUG(
dbgs() <<
"About to harden register : " << Reg <<
"\n");
450 TII->get(Is64Bit ? AArch64::SpeculationSafeValueX
451 : AArch64::SpeculationSafeValueW))
454 RegsAlreadyMasked.set(Reg);
463 RegsAlreadyMasked.reset();
467 for (; MBBI !=
E; MBBI = NextMBBI) {
469 NextMBBI = std::next(MBBI);
484 return Op.isReg() && (AArch64::GPR32allRegClass.contains(
Op.getReg()) ||
485 AArch64::GPR64allRegClass.
contains(
Op.getReg()));
491 bool HardenLoadedData = AllDefsAreGPR;
492 bool HardenAddressLoadedFrom = !HardenLoadedData;
499 RegsAlreadyMasked.reset(*AI);
507 if (HardenLoadedData)
517 Modified |= makeGPRSpeculationSafe(MBB, NextMBBI, MI,
Def.getReg());
520 if (HardenAddressLoadedFrom)
524 unsigned Reg =
Use.getReg();
535 if (!(AArch64::GPR32allRegClass.
contains(Reg) ||
536 AArch64::GPR64allRegClass.
contains(Reg)))
538 Modified |= makeGPRSpeculationSafe(MBB, MBBI, MI, Reg);
546 bool AArch64SpeculationHardening::expandSpeculationSafeValue(
548 bool UsesFullSpeculationBarrier) {
556 case AArch64::SpeculationSafeValueW:
559 case AArch64::SpeculationSafeValueX:
563 if (!UseControlFlowSpeculationBarrier && !UsesFullSpeculationBarrier) {
571 RegsNeedingCSDBBeforeUse.set(*AI);
575 Is64Bit ?
TII->get(AArch64::ANDXrs) :
TII->get(AArch64::ANDWrs))
578 .
addUse(Is64Bit ? MisspeculatingTaintReg
579 : MisspeculatingTaintReg32Bit)
591 assert(!UseControlFlowSpeculationBarrier &&
"No need to insert CSDBs when " 592 "control flow miss-speculation " 593 "is already blocked");
595 BuildMI(MBB, MBBI, DL,
TII->get(AArch64::HINT)).addImm(0x14);
596 RegsNeedingCSDBBeforeUse.reset();
600 bool AArch64SpeculationHardening::lowerSpeculationSafeValuePseudos(
604 RegsNeedingCSDBBeforeUse.reset();
626 bool NeedToEmitBarrier =
false;
628 NeedToEmitBarrier =
true;
629 if (!NeedToEmitBarrier)
631 if (
Op.isReg() && RegsNeedingCSDBBeforeUse[
Op.getReg()]) {
632 NeedToEmitBarrier =
true;
636 if (NeedToEmitBarrier && !UsesFullSpeculationBarrier)
637 Modified |= insertCSDB(MBB, MBBI, DL);
640 expandSpeculationSafeValue(MBB, MBBI, UsesFullSpeculationBarrier);
645 if (RegsNeedingCSDBBeforeUse.any() && !UsesFullSpeculationBarrier)
646 Modified |= insertCSDB(MBB, MBBI, DL);
651 bool AArch64SpeculationHardening::runOnMachineFunction(
MachineFunction &MF) {
655 MisspeculatingTaintReg = AArch64::X16;
656 MisspeculatingTaintReg32Bit = AArch64::W16;
659 RegsNeedingCSDBBeforeUse.resize(
TRI->getNumRegs());
660 RegsAlreadyMasked.resize(
TRI->getNumRegs());
661 UseControlFlowSpeculationBarrier = functionUsesHardeningRegister(MF);
668 dbgs() <<
"***** AArch64SpeculationHardening - automatic insertion of " 669 "SpeculationSafeValue intrinsics *****\n");
671 Modified |= slhLoads(MBB);
677 <<
"***** AArch64SpeculationHardening - track control flow *****\n");
682 EntryBlocks.
push_back(LPI.LandingPadBlock);
683 for (
auto Entry : EntryBlocks)
684 insertSPToRegTaintPropagation(
685 *Entry, Entry->SkipPHIsLabelsAndDebug(Entry->begin()));
688 for (
auto &MBB : MF) {
689 bool UsesFullSpeculationBarrier =
false;
690 Modified |= instrumentControlFlow(MBB, UsesFullSpeculationBarrier);
692 lowerSpeculationSafeValuePseudos(MBB, UsesFullSpeculationBarrier);
700 return new AArch64SpeculationHardening();
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
instr_iterator instr_begin()
bool isCall(QueryType Type=AnyInBundle) const
instr_iterator instr_end()
This class represents lattice values for constants.
iterator_range< mop_iterator > uses()
Returns a range that includes all operands that are register uses.
virtual const TargetRegisterInfo * getRegisterInfo() const
getRegisterInfo - If register information is available, return it.
void push_back(const T &Elt)
const DebugLoc & getDebugLoc() const
Returns the debug location id of this MachineInstr.
static CondCode getInvertedCondCode(CondCode Code)
unsigned getReg() const
getReg - Returns the register number.
bool hasFnAttribute(Attribute::AttrKind Kind) const
Return true if the function has the attribute.
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly...
unsigned const TargetRegisterInfo * TRI
return AArch64::GPR64RegClass contains(Reg)
bool analyzeBranch(MachineBasicBlock &MBB, MachineBasicBlock *&TBB, MachineBasicBlock *&FBB, SmallVectorImpl< MachineOperand > &Cond, bool AllowModify) const override
Analyze the branching code at the end of MBB, returning true if it cannot be understood (e...
MachineFunctionPass - This class adapts the FunctionPass interface to allow convenient creation of pa...
const HexagonInstrInfo * TII
Printable printReg(unsigned Reg, const TargetRegisterInfo *TRI=nullptr, unsigned SubIdx=0, const MachineRegisterInfo *MRI=nullptr)
Prints virtual and physical registers with or without a TRI instance.
A Use represents the edge between a Value definition and its users.
const MachineInstrBuilder & addUse(unsigned RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register use operand.
void eraseFromParent()
Unlink 'this' from the containing basic block and delete it.
bool isTerminator(QueryType Type=AnyInBundle) const
Returns true if this instruction part of the terminator for a basic block.
unsigned getOpcode() const
Returns the opcode of this MachineInstr.
void forward()
Move the internal MBB iterator and update register states.
This structure is used to retain landing pad info for the current function.
unsigned FindUnusedReg(const TargetRegisterClass *RC) const
Find an unused register of the specified register class.
void initializeAArch64SpeculationHardeningPass(PassRegistry &)
CondCode
ISD::CondCode enum - These are ordered carefully to make the bitfields below work out...
virtual const TargetInstrInfo * getInstrInfo() const
TargetInstrInfo - Interface to description of machine instruction set.
bool isReturn(QueryType Type=AnyInBundle) const
MachineInstrBuilder BuildMI(MachineFunction &MF, const DebugLoc &DL, const MCInstrDesc &MCID)
Builder interface. Specify how to create the initial instruction itself.
initializer< Ty > init(const Ty &Val)
void addLiveIn(MCPhysReg PhysReg, LaneBitmask LaneMask=LaneBitmask::getAll())
Adds the specified register as a live in.
This file declares the machine register scavenger class.
const TargetSubtargetInfo & getSubtarget() const
getSubtarget - Return the subtarget for which this machine code is being compiled.
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
MCRegAliasIterator enumerates all registers aliasing Reg.
iterator_range< mop_iterator > defs()
Returns a range over all explicit operands that are register definitions.
FunctionPass class - This class is used to implement most global optimizations.
TargetRegisterInfo base class - We assume that the target defines a static array of TargetRegisterDes...
const MachineBasicBlock & front() const
const std::vector< LandingPadInfo > & getLandingPads() const
Return a reference to the landing pad info for the current function.
MachineOperand class - Representation of each machine instruction operand.
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small...
FunctionPass * createAArch64SpeculationHardeningPass()
Returns an instance of the pseudo instruction expansion pass.
const Function & getFunction() const
Return the LLVM function that this machine code represents.
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
cl::opt< bool > HardenLoads("aarch64-slh-loads", cl::Hidden, cl::desc("Sanitize loads from memory."), cl::init(true))
Representation of each machine instruction.
const MachineInstrBuilder & addImm(int64_t Val) const
Add a new immediate operand.
void enterBasicBlock(MachineBasicBlock &MBB)
Start tracking liveness from the begin of basic block MBB.
static DebugLoc getDebugLoc(MachineBasicBlock::instr_iterator FirstMI, MachineBasicBlock::instr_iterator LastMI)
Return the first found DebugLoc that has a DILocation, given a range of instructions.
LLVM_NODISCARD bool empty() const
INITIALIZE_PASS(AArch64SpeculationHardening, "aarch64-speculation-hardening", AARCH64_SPECULATION_HARDENING_NAME, false, false) bool AArch64SpeculationHardening
#define AARCH64_SPECULATION_HARDENING_NAME
bool mayLoad(QueryType Type=AnyInBundle) const
Return true if this instruction could possibly read memory.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
MachineBasicBlock * SplitCriticalEdge(MachineBasicBlock *Succ, Pass &P)
Split the critical edge from this block to the given successor block, and return the newly created bl...
#define LLVM_FALLTHROUGH
LLVM_FALLTHROUGH - Mark fallthrough cases in switch statements.
StringRef - Represent a constant reference to a string, i.e.
const MachineOperand & getOperand(unsigned i) const