39#define DEBUG_TYPE "arm64eccalllowering"
41STATISTIC(Arm64ECCallsLowered,
"Number of Arm64EC calls lowered");
50enum ThunkArgTranslation : uint8_t {
59 ThunkArgTranslation Translation;
62class AArch64Arm64ECCallLowering :
public ModulePass {
77 int cfguard_module_flag = 0;
104 ThunkArgInfo canonicalizeThunkType(
Type *
T,
Align Alignment,
bool Ret,
110void AArch64Arm64ECCallLowering::getThunkType(
114 Out << (
TT == Arm64ECThunkType::Entry ?
"$ientry_thunk$cdecl$"
115 :
"$iexit_thunk$cdecl$");
126 if (TT == Arm64ECThunkType::Exit)
130 bool HasSretPtr =
false;
131 getThunkRetType(FT, AttrList, Out, Arm64RetTy, X64RetTy, Arm64ArgTypes,
132 X64ArgTypes, ArgTranslations, HasSretPtr);
134 getThunkArgTypes(FT, AttrList, TT, Out, Arm64ArgTypes, X64ArgTypes,
135 ArgTranslations, HasSretPtr);
137 Arm64Ty = FunctionType::get(Arm64RetTy, Arm64ArgTypes,
false);
139 X64Ty = FunctionType::get(X64RetTy, X64ArgTypes,
false);
142void AArch64Arm64ECCallLowering::getThunkArgTypes(
149 if (FT->isVarArg()) {
172 for (
int i = HasSretPtr ? 1 : 0; i < 4; i++) {
175 ArgTranslations.
push_back(ThunkArgTranslation::Direct);
181 ArgTranslations.
push_back(ThunkArgTranslation::Direct);
184 if (TT != Arm64ECThunkType::Entry) {
188 ArgTranslations.
push_back(ThunkArgTranslation::Direct);
197 if (
I == FT->getNumParams()) {
202 for (
unsigned E = FT->getNumParams();
I != E; ++
I) {
206 uint64_t ArgSizeBytes = AttrList.getParamArm64ECArgSizeBytes(
I);
212 auto [Arm64Ty, X64Ty, ArgTranslation] =
213 canonicalizeThunkType(FT->getParamType(
I), ParamAlign,
214 false, ArgSizeBytes, Out);
217 ArgTranslations.
push_back(ArgTranslation);
221void AArch64Arm64ECCallLowering::getThunkRetType(
226 Type *
T = FT->getReturnType();
230 uint64_t ArgSizeBytes = AttrList.getRetArm64ECArgSizeBytes();
232 int64_t ArgSizeBytes = 0;
235 if (FT->getNumParams()) {
239 if (FT->getNumParams() > 1) {
243 SRetAttr1 = AttrList.
getParamAttr(1, Attribute::StructRet);
244 InRegAttr1 = AttrList.
getParamAttr(1, Attribute::InReg);
267 canonicalizeThunkType(SRetType, SRetAlign,
true, ArgSizeBytes,
271 Arm64ArgTypes.
push_back(FT->getParamType(0));
272 X64ArgTypes.
push_back(FT->getParamType(0));
273 ArgTranslations.
push_back(ThunkArgTranslation::Direct);
286 canonicalizeThunkType(
T,
Align(),
true, ArgSizeBytes, Out);
287 Arm64RetTy =
info.Arm64Ty;
288 X64RetTy =
info.X64Ty;
298ThunkArgInfo AArch64Arm64ECCallLowering::canonicalizeThunkType(
302 auto direct = [](
Type *
T) {
303 return ThunkArgInfo{
T,
T, ThunkArgTranslation::Direct};
306 auto bitcast = [
this](
Type *Arm64Ty,
uint64_t SizeInBytes) {
307 return ThunkArgInfo{Arm64Ty,
309 ThunkArgTranslation::Bitcast};
312 auto pointerIndirection = [
this](
Type *Arm64Ty) {
313 return ThunkArgInfo{Arm64Ty, PtrTy,
314 ThunkArgTranslation::PointerIndirection};
317 if (
T->isFloatTy()) {
322 if (
T->isDoubleTy()) {
327 if (
T->isFloatingPointTy()) {
329 "Only 32 and 64 bit floating points are supported for ARM64EC thunks");
332 auto &
DL =
M->getDataLayout();
334 if (
auto *StructTy = dyn_cast<StructType>(
T))
335 if (StructTy->getNumElements() == 1)
336 T = StructTy->getElementType(0);
338 if (
T->isArrayTy()) {
339 Type *ElementTy =
T->getArrayElementType();
340 uint64_t ElementCnt =
T->getArrayNumElements();
341 uint64_t ElementSizePerBytes =
DL.getTypeSizeInBits(ElementTy) / 8;
342 uint64_t TotalSizeBytes = ElementCnt * ElementSizePerBytes;
344 Out << (ElementTy->
isFloatTy() ?
"F" :
"D") << TotalSizeBytes;
345 if (Alignment.
value() >= 16 && !
Ret)
346 Out <<
"a" << Alignment.
value();
347 if (TotalSizeBytes <= 8) {
350 return bitcast(
T, TotalSizeBytes);
353 return pointerIndirection(
T);
355 }
else if (
T->isFloatingPointTy()) {
361 if ((
T->isIntegerTy() ||
T->isPointerTy()) &&
DL.getTypeSizeInBits(
T) <= 64) {
363 return direct(I64Ty);
372 if (Alignment.
value() >= 16 && !Ret)
373 Out <<
"a" << Alignment.
value();
380 return pointerIndirection(
T);
392 getThunkType(FT, Attrs, Arm64ECThunkType::Exit, ExitThunkStream, Arm64Ty,
393 X64Ty, ArgTranslations);
394 if (
Function *
F =
M->getFunction(ExitThunkName))
400 F->setSection(
".wowthk$aa");
401 F->setComdat(
M->getOrInsertComdat(ExitThunkName));
403 F->addFnAttr(
"frame-pointer",
"all");
407 if (FT->getNumParams()) {
408 auto SRet =
Attrs.getParamAttr(0, Attribute::StructRet);
409 auto InReg =
Attrs.getParamAttr(0, Attribute::InReg);
410 if (SRet.isValid() && !InReg.isValid())
411 F->addParamAttr(1, SRet);
418 M->getOrInsertGlobal(
"__os_arm64x_dispatch_call_no_redirect", PtrTy);
420 auto &
DL =
M->getDataLayout();
424 auto X64TyOffset = 1;
425 Args.push_back(
F->arg_begin());
428 if (
RetTy != X64Ty->getReturnType()) {
432 if (
DL.getTypeStoreSize(
RetTy) > 8) {
440 make_range(X64Ty->param_begin() + X64TyOffset, X64Ty->param_end()),
456 if (ArgTranslation != ThunkArgTranslation::Direct) {
457 Value *Mem = IRB.CreateAlloca(Arg.getType());
458 IRB.CreateStore(&Arg, Mem);
459 if (ArgTranslation == ThunkArgTranslation::Bitcast) {
461 Args.push_back(IRB.CreateLoad(IntTy, IRB.CreateBitCast(Mem, PtrTy)));
463 assert(ArgTranslation == ThunkArgTranslation::PointerIndirection);
467 Args.push_back(&Arg);
473 Callee = IRB.CreateBitCast(Callee, PtrTy);
478 if (
RetTy != X64Ty->getReturnType()) {
481 if (
DL.getTypeStoreSize(
RetTy) > 8) {
482 RetVal = IRB.CreateLoad(
RetTy, Args[1]);
485 IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy));
486 RetVal = IRB.CreateLoad(
RetTy, CastAlloca);
490 if (
RetTy->isVoidTy())
493 IRB.CreateRet(RetVal);
504 getThunkType(
F->getFunctionType(),
F->getAttributes(),
505 Arm64ECThunkType::Entry, EntryThunkStream, Arm64Ty, X64Ty,
507 if (
Function *
F =
M->getFunction(EntryThunkName))
513 Thunk->setSection(
".wowthk$aa");
514 Thunk->setComdat(
M->getOrInsertComdat(EntryThunkName));
516 Thunk->addFnAttr(
"frame-pointer",
"all");
522 Type *X64RetType = X64Ty->getReturnType();
524 bool TransformDirectToSRet = X64RetType->
isVoidTy() && !
RetTy->isVoidTy();
525 unsigned ThunkArgOffset = TransformDirectToSRet ? 2 : 1;
526 unsigned PassthroughArgSize =
527 (
F->isVarArg() ? 5 :
Thunk->arg_size()) - ThunkArgOffset;
528 assert(ArgTranslations.
size() == (
F->isVarArg() ? 5 : PassthroughArgSize));
532 for (
unsigned i = 0; i != PassthroughArgSize; ++i) {
533 Value *Arg =
Thunk->getArg(i + ThunkArgOffset);
534 Type *ArgTy = Arm64Ty->getParamType(i);
535 ThunkArgTranslation ArgTranslation = ArgTranslations[i];
536 if (ArgTranslation != ThunkArgTranslation::Direct) {
538 if (ArgTranslation == ThunkArgTranslation::Bitcast) {
539 Value *CastAlloca = IRB.CreateAlloca(ArgTy);
540 IRB.CreateStore(Arg, IRB.CreateBitCast(CastAlloca, PtrTy));
541 Arg = IRB.CreateLoad(ArgTy, CastAlloca);
543 assert(ArgTranslation == ThunkArgTranslation::PointerIndirection);
544 Arg = IRB.CreateLoad(ArgTy, IRB.CreateBitCast(Arg, PtrTy));
558 Thunk->addParamAttr(5, Attribute::InReg);
560 Arg = IRB.CreatePtrAdd(Arg, IRB.getInt64(0x20));
564 Args.push_back(IRB.getInt64(0));
569 Callee = IRB.CreateBitCast(Callee, PtrTy);
572 auto SRetAttr =
F->getAttributes().getParamAttr(0, Attribute::StructRet);
573 auto InRegAttr =
F->getAttributes().getParamAttr(0, Attribute::InReg);
574 if (SRetAttr.isValid() && !InRegAttr.isValid()) {
575 Thunk->addParamAttr(1, SRetAttr);
576 Call->addParamAttr(0, SRetAttr);
580 if (TransformDirectToSRet) {
581 IRB.CreateStore(RetVal, IRB.CreateBitCast(
Thunk->getArg(1), PtrTy));
582 }
else if (X64RetType !=
RetTy) {
583 Value *CastAlloca = IRB.CreateAlloca(X64RetType);
584 IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy));
585 RetVal = IRB.CreateLoad(X64RetType, CastAlloca);
595 IRB.CreateRet(RetVal);
608 getThunkType(
F->getFunctionType(),
F->getAttributes(),
609 Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty,
612 assert(MangledName &&
"Can't guest exit to function that's already native");
613 std::string ThunkName = *MangledName;
614 if (ThunkName[0] ==
'?' && ThunkName.find(
"@") != std::string::npos) {
615 ThunkName.insert(ThunkName.find(
"@"),
"$exit_thunk");
617 ThunkName.append(
"$exit_thunk");
621 GuestExit->setComdat(
M->getOrInsertComdat(ThunkName));
624 "arm64ec_unmangled_name",
628 "arm64ec_ecmangled_name",
631 F->setMetadata(
"arm64ec_hasguestexit",
MDNode::get(
M->getContext(), {}));
637 if (cfguard_module_flag == 2 && !
F->hasFnAttribute(
"guard_nocf"))
638 GuardFn = GuardFnCFGlobal;
640 GuardFn = GuardFnGlobal;
641 LoadInst *GuardCheckLoad =
B.CreateLoad(GuardFnPtrType, GuardFn);
645 Function *
Thunk = buildExitThunk(
F->getFunctionType(),
F->getAttributes());
647 GuardFnType, GuardCheckLoad,
648 {
B.CreateBitCast(
F,
B.getPtrTy()),
B.CreateBitCast(Thunk,
B.getPtrTy())});
653 Value *GuardRetVal =
B.CreateBitCast(GuardCheck, PtrTy);
656 Args.push_back(&Arg);
660 if (
Call->getType()->isVoidTy())
665 auto SRetAttr =
F->getAttributes().getParamAttr(0, Attribute::StructRet);
666 auto InRegAttr =
F->getAttributes().getParamAttr(0, Attribute::InReg);
667 if (SRetAttr.isValid() && !InRegAttr.isValid()) {
669 Call->addParamAttr(0, SRetAttr);
676void AArch64Arm64ECCallLowering::lowerCall(
CallBase *CB) {
678 "Only applicable for Windows targets");
691 if (cfguard_module_flag == 2 && !CB->
hasFnAttr(
"guard_nocf"))
692 GuardFn = GuardFnCFGlobal;
694 GuardFn = GuardFnGlobal;
695 LoadInst *GuardCheckLoad =
B.CreateLoad(GuardFnPtrType, GuardFn);
701 B.CreateCall(GuardFnType, GuardCheckLoad,
702 {
B.CreateBitCast(CalledOperand,
B.getPtrTy()),
703 B.CreateBitCast(Thunk,
B.getPtrTy())},
709 Value *GuardRetVal =
B.CreateBitCast(GuardCheck, CalledOperand->
getType());
713bool AArch64Arm64ECCallLowering::runOnModule(
Module &
Mod) {
721 mdconst::extract_or_null<ConstantInt>(
M->getModuleFlag(
"cfguard")))
722 cfguard_module_flag = MD->getZExtValue();
724 PtrTy = PointerType::getUnqual(
M->getContext());
728 GuardFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy},
false);
729 GuardFnPtrType = PointerType::get(GuardFnType, 0);
731 M->getOrInsertGlobal(
"__os_arm64x_check_icall_cfg", GuardFnPtrType);
733 M->getOrInsertGlobal(
"__os_arm64x_check_icall", GuardFnPtrType);
737 if (!
F.isDeclaration() &&
740 processFunction(
F, DirectCalledFns);
749 if (!
F.isDeclaration() && (!
F.hasLocalLinkage() ||
F.hasAddressTaken()) &&
755 {&
F, buildEntryThunk(&
F), Arm64ECThunkType::Entry});
760 {
F, buildExitThunk(
F->getFunctionType(),
F->getAttributes()),
761 Arm64ECThunkType::Exit});
762 if (!
F->hasDLLImportStorageClass())
764 {buildGuestExitThunk(
F),
F, Arm64ECThunkType::GuestExit});
767 if (!ThunkMapping.
empty()) {
769 for (ThunkInfo &Thunk : ThunkMapping) {
773 ConstantInt::get(
M->getContext(),
APInt(32, uint8_t(
Thunk.Kind)))}));
777 ThunkMappingArrayElems.
size()),
778 ThunkMappingArrayElems);
781 "llvm.arm64ec.symbolmap");
787bool AArch64Arm64ECCallLowering::processFunction(
798 if (!
F.hasLocalLinkage() ||
F.hasAddressTaken()) {
799 if (std::optional<std::string> MangledName =
801 F.setMetadata(
"arm64ec_unmangled_name",
804 if (
F.hasComdat() &&
F.getComdat()->getName() ==
F.getName()) {
805 Comdat *MangledComdat =
M->getOrInsertComdat(MangledName.value());
809 User->setComdat(MangledComdat);
811 F.setName(MangledName.value());
821 auto *CB = dyn_cast<CallBase>(&
I);
833 F->isIntrinsic() || !
F->isDeclaration())
841 ++Arm64ECCallsLowered;
845 if (IndirectCalls.
empty())
854char AArch64Arm64ECCallLowering::ID = 0;
856 "AArch64Arm64ECCallLowering",
false,
false)
859 return new AArch64Arm64ECCallLowering;
static cl::opt< bool > LowerDirectToIndirect("arm64ec-lower-direct-to-indirect", cl::Hidden, cl::init(true))
static cl::opt< bool > GenerateThunks("arm64ec-generate-thunks", cl::Hidden, cl::init(true))
OperandBundleDefT< Value * > OperandBundleDef
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
Module.h This file contains the declarations for the Module class.
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file implements a set that has insertion order iteration characteristics.
This file defines the SmallString class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
static SymbolRef::Type getType(const Symbol *Sym)
Class for arbitrary precision integers.
This class represents an incoming formal argument to a Function.
static ArrayType * get(Type *ElementType, uint64_t NumElements)
This static method is the primary way to construct an ArrayType.
Attribute getParamAttr(unsigned ArgNo, Attribute::AttrKind Kind) const
Return the attribute object that exists at the arg index.
MaybeAlign getParamAlignment(unsigned ArgNo) const
Return the alignment for the specified function parameter.
bool isValid() const
Return true if the attribute is any kind of attribute.
Type * getValueAsType() const
Return the attribute's value as a Type.
LLVM Basic Block Representation.
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
bool isInlineAsm() const
Check if this call is an inline asm statement.
void setCallingConv(CallingConv::ID CC)
std::optional< OperandBundleUse > getOperandBundle(StringRef Name) const
Return an operand bundle by name, if present.
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
bool hasFnAttr(Attribute::AttrKind Kind) const
Determine whether this call has the given attribute.
CallingConv::ID getCallingConv() const
Value * getCalledOperand() const
FunctionType * getFunctionType() const
void setCalledOperand(Value *V)
AttributeList getAttributes() const
Return the parameter attributes for this call.
This class represents a function call, abstracting a target machine's calling convention.
static Constant * get(ArrayType *T, ArrayRef< Constant * > V)
static Constant * getBitCast(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static Constant * getAnon(ArrayRef< Constant * > V, bool Packed=false)
Return an anonymous struct that has the specified elements.
This is an important base class in LLVM.
static Function * Create(FunctionType *Ty, LinkageTypes Linkage, unsigned AddrSpace, const Twine &N="", Module *M=nullptr)
@ WeakODRLinkage
Same, but only replaced by something equivalent.
@ ExternalLinkage
Externally visible function.
@ LinkOnceODRLinkage
Same, but only replaced by something equivalent.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
const Module * getModule() const
Return the module owning the function this instruction belongs to or nullptr it the function does not...
An instruction for reading from memory.
static MDTuple * get(LLVMContext &Context, ArrayRef< Metadata * > MDs)
static MDString * get(LLVMContext &Context, StringRef Str)
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
virtual bool runOnModule(Module &M)=0
runOnModule - Virtual method overriden by subclasses to process the module being operated on.
A Module instance is used to store all the information related to an LLVM module.
const std::string & getTargetTriple() const
Get the target triple which is a string describing the target host.
Comdat * getOrInsertComdat(StringRef Name)
Return the Comdat in the module with the specified name.
A container for an operand bundle being viewed as a set of values rather than a set of uses.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
A vector that has set insertion semantics.
bool insert(const value_type &X)
Insert a new element into the SetVector.
SmallString - A SmallString is just a SmallVector with methods and accessors that make it work better...
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Triple - Helper class for working with autoconf configuration names.
bool isOSWindows() const
Tests whether the OS is Windows.
The instances of the Type class are immutable: once they are created, they are never changed.
bool isPointerTy() const
True if this is an instance of PointerType.
bool isFloatTy() const
Return true if this is 'float', a 32-bit IEEE fp type.
static IntegerType * getIntNTy(LLVMContext &C, unsigned N)
static Type * getVoidTy(LLVMContext &C)
bool isDoubleTy() const
Return true if this is 'double', a 64-bit IEEE fp type.
static IntegerType * getInt64Ty(LLVMContext &C)
bool isVoidTy() const
Return true if this is 'void'.
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
A raw_ostream that discards all output.
This class implements an extremely fast bulk output stream that can only output to a stream.
A raw_ostream that writes to an SmallVector or SmallString.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
constexpr char Attrs[]
Key for Kernel::Metadata::mAttrs.
@ ARM64EC_Thunk_Native
Calling convention used in the ARM64EC ABI to implement calls between ARM64 code and thunks.
@ CFGuard_Check
Special calling convention on Windows for calling the Control Guard Check ICall funtion.
@ ARM64EC_Thunk_X64
Calling convention used in the ARM64EC ABI to implement calls between x64 code and thunks.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
initializer< Ty > init(const Ty &Val)
This is an optimization pass for GlobalISel generic memory operations.
std::optional< std::string > getArm64ECMangledFunctionName(StringRef Name)
detail::zippy< detail::zip_first, T, U, Args... > zip_equal(T &&t, U &&u, Args &&...args)
zip iterator that assumes that all iteratees have the same length.
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
void initializeAArch64Arm64ECCallLoweringPass(PassRegistry &)
ModulePass * createAArch64Arm64ECCallLoweringPass()
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
SmallVector< ValueTypeFromRangeType< R >, Size > to_vector(R &&Range)
Given a range of type R, iterate the entire range and return a SmallVector with elements of the vecto...
This struct is a compact representation of a valid (non-zero power of two) alignment.
uint64_t value() const
This is a hole in the type system and should not be abused.
Align valueOrOne() const
For convenience, returns a valid alignment or 1 if undefined.