LLVM 19.0.0git
AArch64Arm64ECCallLowering.cpp
Go to the documentation of this file.
1//===-- AArch64Arm64ECCallLowering.cpp - Lower Arm64EC calls ----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file contains the IR transform to lower external or indirect calls for
11/// the ARM64EC calling convention. Such calls must go through the runtime, so
12/// we can translate the calling convention for calls into the emulator.
13///
14/// This subsumes Control Flow Guard handling.
15///
16//===----------------------------------------------------------------------===//
17
18#include "AArch64.h"
19#include "llvm/ADT/SetVector.h"
22#include "llvm/ADT/Statistic.h"
23#include "llvm/IR/CallingConv.h"
24#include "llvm/IR/IRBuilder.h"
25#include "llvm/IR/Instruction.h"
26#include "llvm/IR/Mangler.h"
27#include "llvm/IR/Module.h"
29#include "llvm/Object/COFF.h"
30#include "llvm/Pass.h"
33
34using namespace llvm;
35using namespace llvm::COFF;
36
38
39#define DEBUG_TYPE "arm64eccalllowering"
40
41STATISTIC(Arm64ECCallsLowered, "Number of Arm64EC calls lowered");
42
43static cl::opt<bool> LowerDirectToIndirect("arm64ec-lower-direct-to-indirect",
44 cl::Hidden, cl::init(true));
45static cl::opt<bool> GenerateThunks("arm64ec-generate-thunks", cl::Hidden,
46 cl::init(true));
47
48namespace {
49
50enum ThunkArgTranslation : uint8_t {
51 Direct,
52 Bitcast,
53 PointerIndirection,
54};
55
56struct ThunkArgInfo {
57 Type *Arm64Ty;
58 Type *X64Ty;
59 ThunkArgTranslation Translation;
60};
61
62class AArch64Arm64ECCallLowering : public ModulePass {
63public:
64 static char ID;
65 AArch64Arm64ECCallLowering() : ModulePass(ID) {
67 }
68
69 Function *buildExitThunk(FunctionType *FnTy, AttributeList Attrs);
70 Function *buildEntryThunk(Function *F);
71 void lowerCall(CallBase *CB);
72 Function *buildGuestExitThunk(Function *F);
73 bool processFunction(Function &F, SetVector<Function *> &DirectCalledFns);
74 bool runOnModule(Module &M) override;
75
76private:
77 int cfguard_module_flag = 0;
78 FunctionType *GuardFnType = nullptr;
79 PointerType *GuardFnPtrType = nullptr;
80 Constant *GuardFnCFGlobal = nullptr;
81 Constant *GuardFnGlobal = nullptr;
82 Module *M = nullptr;
83
84 Type *PtrTy;
85 Type *I64Ty;
86 Type *VoidTy;
87
88 void getThunkType(FunctionType *FT, AttributeList AttrList,
90 FunctionType *&Arm64Ty, FunctionType *&X64Ty,
91 SmallVector<ThunkArgTranslation> &ArgTranslations);
92 void getThunkRetType(FunctionType *FT, AttributeList AttrList,
93 raw_ostream &Out, Type *&Arm64RetTy, Type *&X64RetTy,
94 SmallVectorImpl<Type *> &Arm64ArgTypes,
95 SmallVectorImpl<Type *> &X64ArgTypes,
96 SmallVector<ThunkArgTranslation> &ArgTranslations,
97 bool &HasSretPtr);
98 void getThunkArgTypes(FunctionType *FT, AttributeList AttrList,
100 SmallVectorImpl<Type *> &Arm64ArgTypes,
101 SmallVectorImpl<Type *> &X64ArgTypes,
103 bool HasSretPtr);
104 ThunkArgInfo canonicalizeThunkType(Type *T, Align Alignment, bool Ret,
105 uint64_t ArgSizeBytes, raw_ostream &Out);
106};
107
108} // end anonymous namespace
109
110void AArch64Arm64ECCallLowering::getThunkType(
112 raw_ostream &Out, FunctionType *&Arm64Ty, FunctionType *&X64Ty,
113 SmallVector<ThunkArgTranslation> &ArgTranslations) {
114 Out << (TT == Arm64ECThunkType::Entry ? "$ientry_thunk$cdecl$"
115 : "$iexit_thunk$cdecl$");
116
117 Type *Arm64RetTy;
118 Type *X64RetTy;
119
120 SmallVector<Type *> Arm64ArgTypes;
121 SmallVector<Type *> X64ArgTypes;
122
123 // The first argument to a thunk is the called function, stored in x9.
124 // For exit thunks, we pass the called function down to the emulator;
125 // for entry/guest exit thunks, we just call the Arm64 function directly.
126 if (TT == Arm64ECThunkType::Exit)
127 Arm64ArgTypes.push_back(PtrTy);
128 X64ArgTypes.push_back(PtrTy);
129
130 bool HasSretPtr = false;
131 getThunkRetType(FT, AttrList, Out, Arm64RetTy, X64RetTy, Arm64ArgTypes,
132 X64ArgTypes, ArgTranslations, HasSretPtr);
133
134 getThunkArgTypes(FT, AttrList, TT, Out, Arm64ArgTypes, X64ArgTypes,
135 ArgTranslations, HasSretPtr);
136
137 Arm64Ty = FunctionType::get(Arm64RetTy, Arm64ArgTypes, false);
138
139 X64Ty = FunctionType::get(X64RetTy, X64ArgTypes, false);
140}
141
142void AArch64Arm64ECCallLowering::getThunkArgTypes(
144 raw_ostream &Out, SmallVectorImpl<Type *> &Arm64ArgTypes,
145 SmallVectorImpl<Type *> &X64ArgTypes,
146 SmallVectorImpl<ThunkArgTranslation> &ArgTranslations, bool HasSretPtr) {
147
148 Out << "$";
149 if (FT->isVarArg()) {
150 // We treat the variadic function's thunk as a normal function
151 // with the following type on the ARM side:
152 // rettype exitthunk(
153 // ptr x9, ptr x0, i64 x1, i64 x2, i64 x3, ptr x4, i64 x5)
154 //
155 // that can coverage all types of variadic function.
156 // x9 is similar to normal exit thunk, store the called function.
157 // x0-x3 is the arguments be stored in registers.
158 // x4 is the address of the arguments on the stack.
159 // x5 is the size of the arguments on the stack.
160 //
161 // On the x64 side, it's the same except that x5 isn't set.
162 //
163 // If both the ARM and X64 sides are sret, there are only three
164 // arguments in registers.
165 //
166 // If the X64 side is sret, but the ARM side isn't, we pass an extra value
167 // to/from the X64 side, and let SelectionDAG transform it into a memory
168 // location.
169 Out << "varargs";
170
171 // x0-x3
172 for (int i = HasSretPtr ? 1 : 0; i < 4; i++) {
173 Arm64ArgTypes.push_back(I64Ty);
174 X64ArgTypes.push_back(I64Ty);
175 ArgTranslations.push_back(ThunkArgTranslation::Direct);
176 }
177
178 // x4
179 Arm64ArgTypes.push_back(PtrTy);
180 X64ArgTypes.push_back(PtrTy);
181 ArgTranslations.push_back(ThunkArgTranslation::Direct);
182 // x5
183 Arm64ArgTypes.push_back(I64Ty);
184 if (TT != Arm64ECThunkType::Entry) {
185 // FIXME: x5 isn't actually used by the x64 side; revisit once we
186 // have proper isel for varargs
187 X64ArgTypes.push_back(I64Ty);
188 ArgTranslations.push_back(ThunkArgTranslation::Direct);
189 }
190 return;
191 }
192
193 unsigned I = 0;
194 if (HasSretPtr)
195 I++;
196
197 if (I == FT->getNumParams()) {
198 Out << "v";
199 return;
200 }
201
202 for (unsigned E = FT->getNumParams(); I != E; ++I) {
203#if 0
204 // FIXME: Need more information about argument size; see
205 // https://reviews.llvm.org/D132926
206 uint64_t ArgSizeBytes = AttrList.getParamArm64ECArgSizeBytes(I);
207 Align ParamAlign = AttrList.getParamAlignment(I).valueOrOne();
208#else
209 uint64_t ArgSizeBytes = 0;
210 Align ParamAlign = Align();
211#endif
212 auto [Arm64Ty, X64Ty, ArgTranslation] =
213 canonicalizeThunkType(FT->getParamType(I), ParamAlign,
214 /*Ret*/ false, ArgSizeBytes, Out);
215 Arm64ArgTypes.push_back(Arm64Ty);
216 X64ArgTypes.push_back(X64Ty);
217 ArgTranslations.push_back(ArgTranslation);
218 }
219}
220
221void AArch64Arm64ECCallLowering::getThunkRetType(
222 FunctionType *FT, AttributeList AttrList, raw_ostream &Out,
223 Type *&Arm64RetTy, Type *&X64RetTy, SmallVectorImpl<Type *> &Arm64ArgTypes,
224 SmallVectorImpl<Type *> &X64ArgTypes,
225 SmallVector<ThunkArgTranslation> &ArgTranslations, bool &HasSretPtr) {
226 Type *T = FT->getReturnType();
227#if 0
228 // FIXME: Need more information about argument size; see
229 // https://reviews.llvm.org/D132926
230 uint64_t ArgSizeBytes = AttrList.getRetArm64ECArgSizeBytes();
231#else
232 int64_t ArgSizeBytes = 0;
233#endif
234 if (T->isVoidTy()) {
235 if (FT->getNumParams()) {
236 Attribute SRetAttr0 = AttrList.getParamAttr(0, Attribute::StructRet);
237 Attribute InRegAttr0 = AttrList.getParamAttr(0, Attribute::InReg);
238 Attribute SRetAttr1, InRegAttr1;
239 if (FT->getNumParams() > 1) {
240 // Also check the second parameter (for class methods, the first
241 // parameter is "this", and the second parameter is the sret pointer.)
242 // It doesn't matter which one is sret.
243 SRetAttr1 = AttrList.getParamAttr(1, Attribute::StructRet);
244 InRegAttr1 = AttrList.getParamAttr(1, Attribute::InReg);
245 }
246 if ((SRetAttr0.isValid() && InRegAttr0.isValid()) ||
247 (SRetAttr1.isValid() && InRegAttr1.isValid())) {
248 // sret+inreg indicates a call that returns a C++ class value. This is
249 // actually equivalent to just passing and returning a void* pointer
250 // as the first or second argument. Translate it that way, instead of
251 // trying to model "inreg" in the thunk's calling convention; this
252 // simplfies the rest of the code, and matches MSVC mangling.
253 Out << "i8";
254 Arm64RetTy = I64Ty;
255 X64RetTy = I64Ty;
256 return;
257 }
258 if (SRetAttr0.isValid()) {
259 // FIXME: Sanity-check the sret type; if it's an integer or pointer,
260 // we'll get screwy mangling/codegen.
261 // FIXME: For large struct types, mangle as an integer argument and
262 // integer return, so we can reuse more thunks, instead of "m" syntax.
263 // (MSVC mangles this case as an integer return with no argument, but
264 // that's a miscompile.)
265 Type *SRetType = SRetAttr0.getValueAsType();
266 Align SRetAlign = AttrList.getParamAlignment(0).valueOrOne();
267 canonicalizeThunkType(SRetType, SRetAlign, /*Ret*/ true, ArgSizeBytes,
268 Out);
269 Arm64RetTy = VoidTy;
270 X64RetTy = VoidTy;
271 Arm64ArgTypes.push_back(FT->getParamType(0));
272 X64ArgTypes.push_back(FT->getParamType(0));
273 ArgTranslations.push_back(ThunkArgTranslation::Direct);
274 HasSretPtr = true;
275 return;
276 }
277 }
278
279 Out << "v";
280 Arm64RetTy = VoidTy;
281 X64RetTy = VoidTy;
282 return;
283 }
284
285 auto info =
286 canonicalizeThunkType(T, Align(), /*Ret*/ true, ArgSizeBytes, Out);
287 Arm64RetTy = info.Arm64Ty;
288 X64RetTy = info.X64Ty;
289 if (X64RetTy->isPointerTy()) {
290 // If the X64 type is canonicalized to a pointer, that means it's
291 // passed/returned indirectly. For a return value, that means it's an
292 // sret pointer.
293 X64ArgTypes.push_back(X64RetTy);
294 X64RetTy = VoidTy;
295 }
296}
297
298ThunkArgInfo AArch64Arm64ECCallLowering::canonicalizeThunkType(
299 Type *T, Align Alignment, bool Ret, uint64_t ArgSizeBytes,
300 raw_ostream &Out) {
301
302 auto direct = [](Type *T) {
303 return ThunkArgInfo{T, T, ThunkArgTranslation::Direct};
304 };
305
306 auto bitcast = [this](Type *Arm64Ty, uint64_t SizeInBytes) {
307 return ThunkArgInfo{Arm64Ty,
308 llvm::Type::getIntNTy(M->getContext(), SizeInBytes * 8),
309 ThunkArgTranslation::Bitcast};
310 };
311
312 auto pointerIndirection = [this](Type *Arm64Ty) {
313 return ThunkArgInfo{Arm64Ty, PtrTy,
314 ThunkArgTranslation::PointerIndirection};
315 };
316
317 if (T->isFloatTy()) {
318 Out << "f";
319 return direct(T);
320 }
321
322 if (T->isDoubleTy()) {
323 Out << "d";
324 return direct(T);
325 }
326
327 if (T->isFloatingPointTy()) {
329 "Only 32 and 64 bit floating points are supported for ARM64EC thunks");
330 }
331
332 auto &DL = M->getDataLayout();
333
334 if (auto *StructTy = dyn_cast<StructType>(T))
335 if (StructTy->getNumElements() == 1)
336 T = StructTy->getElementType(0);
337
338 if (T->isArrayTy()) {
339 Type *ElementTy = T->getArrayElementType();
340 uint64_t ElementCnt = T->getArrayNumElements();
341 uint64_t ElementSizePerBytes = DL.getTypeSizeInBits(ElementTy) / 8;
342 uint64_t TotalSizeBytes = ElementCnt * ElementSizePerBytes;
343 if (ElementTy->isFloatTy() || ElementTy->isDoubleTy()) {
344 Out << (ElementTy->isFloatTy() ? "F" : "D") << TotalSizeBytes;
345 if (Alignment.value() >= 16 && !Ret)
346 Out << "a" << Alignment.value();
347 if (TotalSizeBytes <= 8) {
348 // Arm64 returns small structs of float/double in float registers;
349 // X64 uses RAX.
350 return bitcast(T, TotalSizeBytes);
351 } else {
352 // Struct is passed directly on Arm64, but indirectly on X64.
353 return pointerIndirection(T);
354 }
355 } else if (T->isFloatingPointTy()) {
356 report_fatal_error("Only 32 and 64 bit floating points are supported for "
357 "ARM64EC thunks");
358 }
359 }
360
361 if ((T->isIntegerTy() || T->isPointerTy()) && DL.getTypeSizeInBits(T) <= 64) {
362 Out << "i8";
363 return direct(I64Ty);
364 }
365
366 unsigned TypeSize = ArgSizeBytes;
367 if (TypeSize == 0)
368 TypeSize = DL.getTypeSizeInBits(T) / 8;
369 Out << "m";
370 if (TypeSize != 4)
371 Out << TypeSize;
372 if (Alignment.value() >= 16 && !Ret)
373 Out << "a" << Alignment.value();
374 // FIXME: Try to canonicalize Arm64Ty more thoroughly?
375 if (TypeSize == 1 || TypeSize == 2 || TypeSize == 4 || TypeSize == 8) {
376 // Pass directly in an integer register
377 return bitcast(T, TypeSize);
378 } else {
379 // Passed directly on Arm64, but indirectly on X64.
380 return pointerIndirection(T);
381 }
382}
383
384// This function builds the "exit thunk", a function which translates
385// arguments and return values when calling x64 code from AArch64 code.
386Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT,
387 AttributeList Attrs) {
388 SmallString<256> ExitThunkName;
389 llvm::raw_svector_ostream ExitThunkStream(ExitThunkName);
390 FunctionType *Arm64Ty, *X64Ty;
391 SmallVector<ThunkArgTranslation> ArgTranslations;
392 getThunkType(FT, Attrs, Arm64ECThunkType::Exit, ExitThunkStream, Arm64Ty,
393 X64Ty, ArgTranslations);
394 if (Function *F = M->getFunction(ExitThunkName))
395 return F;
396
398 ExitThunkName, M);
399 F->setCallingConv(CallingConv::ARM64EC_Thunk_Native);
400 F->setSection(".wowthk$aa");
401 F->setComdat(M->getOrInsertComdat(ExitThunkName));
402 // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
403 F->addFnAttr("frame-pointer", "all");
404 // Only copy sret from the first argument. For C++ instance methods, clang can
405 // stick an sret marking on a later argument, but it doesn't actually affect
406 // the ABI, so we can omit it. This avoids triggering a verifier assertion.
407 if (FT->getNumParams()) {
408 auto SRet = Attrs.getParamAttr(0, Attribute::StructRet);
409 auto InReg = Attrs.getParamAttr(0, Attribute::InReg);
410 if (SRet.isValid() && !InReg.isValid())
411 F->addParamAttr(1, SRet);
412 }
413 // FIXME: Copy anything other than sret? Shouldn't be necessary for normal
414 // C ABI, but might show up in other cases.
415 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", F);
416 IRBuilder<> IRB(BB);
417 Value *CalleePtr =
418 M->getOrInsertGlobal("__os_arm64x_dispatch_call_no_redirect", PtrTy);
419 Value *Callee = IRB.CreateLoad(PtrTy, CalleePtr);
420 auto &DL = M->getDataLayout();
422
423 // Pass the called function in x9.
424 auto X64TyOffset = 1;
425 Args.push_back(F->arg_begin());
426
427 Type *RetTy = Arm64Ty->getReturnType();
428 if (RetTy != X64Ty->getReturnType()) {
429 // If the return type is an array or struct, translate it. Values of size
430 // 8 or less go into RAX; bigger values go into memory, and we pass a
431 // pointer.
432 if (DL.getTypeStoreSize(RetTy) > 8) {
433 Args.push_back(IRB.CreateAlloca(RetTy));
434 X64TyOffset++;
435 }
436 }
437
438 for (auto [Arg, X64ArgType, ArgTranslation] : llvm::zip_equal(
439 make_range(F->arg_begin() + 1, F->arg_end()),
440 make_range(X64Ty->param_begin() + X64TyOffset, X64Ty->param_end()),
441 ArgTranslations)) {
442 // Translate arguments from AArch64 calling convention to x86 calling
443 // convention.
444 //
445 // For simple types, we don't need to do any translation: they're
446 // represented the same way. (Implicit sign extension is not part of
447 // either convention.)
448 //
449 // The big thing we have to worry about is struct types... but
450 // fortunately AArch64 clang is pretty friendly here: the cases that need
451 // translation are always passed as a struct or array. (If we run into
452 // some cases where this doesn't work, we can teach clang to mark it up
453 // with an attribute.)
454 //
455 // The first argument is the called function, stored in x9.
456 if (ArgTranslation != ThunkArgTranslation::Direct) {
457 Value *Mem = IRB.CreateAlloca(Arg.getType());
458 IRB.CreateStore(&Arg, Mem);
459 if (ArgTranslation == ThunkArgTranslation::Bitcast) {
460 Type *IntTy = IRB.getIntNTy(DL.getTypeStoreSizeInBits(Arg.getType()));
461 Args.push_back(IRB.CreateLoad(IntTy, IRB.CreateBitCast(Mem, PtrTy)));
462 } else {
463 assert(ArgTranslation == ThunkArgTranslation::PointerIndirection);
464 Args.push_back(Mem);
465 }
466 } else {
467 Args.push_back(&Arg);
468 }
469 assert(Args.back()->getType() == X64ArgType);
470 }
471 // FIXME: Transfer necessary attributes? sret? anything else?
472
473 Callee = IRB.CreateBitCast(Callee, PtrTy);
474 CallInst *Call = IRB.CreateCall(X64Ty, Callee, Args);
475 Call->setCallingConv(CallingConv::ARM64EC_Thunk_X64);
476
477 Value *RetVal = Call;
478 if (RetTy != X64Ty->getReturnType()) {
479 // If we rewrote the return type earlier, convert the return value to
480 // the proper type.
481 if (DL.getTypeStoreSize(RetTy) > 8) {
482 RetVal = IRB.CreateLoad(RetTy, Args[1]);
483 } else {
484 Value *CastAlloca = IRB.CreateAlloca(RetTy);
485 IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy));
486 RetVal = IRB.CreateLoad(RetTy, CastAlloca);
487 }
488 }
489
490 if (RetTy->isVoidTy())
491 IRB.CreateRetVoid();
492 else
493 IRB.CreateRet(RetVal);
494 return F;
495}
496
497// This function builds the "entry thunk", a function which translates
498// arguments and return values when calling AArch64 code from x64 code.
499Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {
500 SmallString<256> EntryThunkName;
501 llvm::raw_svector_ostream EntryThunkStream(EntryThunkName);
502 FunctionType *Arm64Ty, *X64Ty;
503 SmallVector<ThunkArgTranslation> ArgTranslations;
504 getThunkType(F->getFunctionType(), F->getAttributes(),
505 Arm64ECThunkType::Entry, EntryThunkStream, Arm64Ty, X64Ty,
506 ArgTranslations);
507 if (Function *F = M->getFunction(EntryThunkName))
508 return F;
509
511 EntryThunkName, M);
512 Thunk->setCallingConv(CallingConv::ARM64EC_Thunk_X64);
513 Thunk->setSection(".wowthk$aa");
514 Thunk->setComdat(M->getOrInsertComdat(EntryThunkName));
515 // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
516 Thunk->addFnAttr("frame-pointer", "all");
517
518 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", Thunk);
519 IRBuilder<> IRB(BB);
520
521 Type *RetTy = Arm64Ty->getReturnType();
522 Type *X64RetType = X64Ty->getReturnType();
523
524 bool TransformDirectToSRet = X64RetType->isVoidTy() && !RetTy->isVoidTy();
525 unsigned ThunkArgOffset = TransformDirectToSRet ? 2 : 1;
526 unsigned PassthroughArgSize =
527 (F->isVarArg() ? 5 : Thunk->arg_size()) - ThunkArgOffset;
528 assert(ArgTranslations.size() == (F->isVarArg() ? 5 : PassthroughArgSize));
529
530 // Translate arguments to call.
532 for (unsigned i = 0; i != PassthroughArgSize; ++i) {
533 Value *Arg = Thunk->getArg(i + ThunkArgOffset);
534 Type *ArgTy = Arm64Ty->getParamType(i);
535 ThunkArgTranslation ArgTranslation = ArgTranslations[i];
536 if (ArgTranslation != ThunkArgTranslation::Direct) {
537 // Translate array/struct arguments to the expected type.
538 if (ArgTranslation == ThunkArgTranslation::Bitcast) {
539 Value *CastAlloca = IRB.CreateAlloca(ArgTy);
540 IRB.CreateStore(Arg, IRB.CreateBitCast(CastAlloca, PtrTy));
541 Arg = IRB.CreateLoad(ArgTy, CastAlloca);
542 } else {
543 assert(ArgTranslation == ThunkArgTranslation::PointerIndirection);
544 Arg = IRB.CreateLoad(ArgTy, IRB.CreateBitCast(Arg, PtrTy));
545 }
546 }
547 assert(Arg->getType() == ArgTy);
548 Args.push_back(Arg);
549 }
550
551 if (F->isVarArg()) {
552 // The 5th argument to variadic entry thunks is used to model the x64 sp
553 // which is passed to the thunk in x4, this can be passed to the callee as
554 // the variadic argument start address after skipping over the 32 byte
555 // shadow store.
556
557 // The EC thunk CC will assign any argument marked as InReg to x4.
558 Thunk->addParamAttr(5, Attribute::InReg);
559 Value *Arg = Thunk->getArg(5);
560 Arg = IRB.CreatePtrAdd(Arg, IRB.getInt64(0x20));
561 Args.push_back(Arg);
562
563 // Pass in a zero variadic argument size (in x5).
564 Args.push_back(IRB.getInt64(0));
565 }
566
567 // Call the function passed to the thunk.
568 Value *Callee = Thunk->getArg(0);
569 Callee = IRB.CreateBitCast(Callee, PtrTy);
570 CallInst *Call = IRB.CreateCall(Arm64Ty, Callee, Args);
571
572 auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
573 auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
574 if (SRetAttr.isValid() && !InRegAttr.isValid()) {
575 Thunk->addParamAttr(1, SRetAttr);
576 Call->addParamAttr(0, SRetAttr);
577 }
578
579 Value *RetVal = Call;
580 if (TransformDirectToSRet) {
581 IRB.CreateStore(RetVal, IRB.CreateBitCast(Thunk->getArg(1), PtrTy));
582 } else if (X64RetType != RetTy) {
583 Value *CastAlloca = IRB.CreateAlloca(X64RetType);
584 IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy));
585 RetVal = IRB.CreateLoad(X64RetType, CastAlloca);
586 }
587
588 // Return to the caller. Note that the isel has code to translate this
589 // "ret" to a tail call to __os_arm64x_dispatch_ret. (Alternatively, we
590 // could emit a tail call here, but that would require a dedicated calling
591 // convention, which seems more complicated overall.)
592 if (X64RetType->isVoidTy())
593 IRB.CreateRetVoid();
594 else
595 IRB.CreateRet(RetVal);
596
597 return Thunk;
598}
599
600// Builds the "guest exit thunk", a helper to call a function which may or may
601// not be an exit thunk. (We optimistically assume non-dllimport function
602// declarations refer to functions defined in AArch64 code; if the linker
603// can't prove that, we use this routine instead.)
604Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {
605 llvm::raw_null_ostream NullThunkName;
606 FunctionType *Arm64Ty, *X64Ty;
607 SmallVector<ThunkArgTranslation> ArgTranslations;
608 getThunkType(F->getFunctionType(), F->getAttributes(),
609 Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty,
610 ArgTranslations);
611 auto MangledName = getArm64ECMangledFunctionName(F->getName().str());
612 assert(MangledName && "Can't guest exit to function that's already native");
613 std::string ThunkName = *MangledName;
614 if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) {
615 ThunkName.insert(ThunkName.find("@"), "$exit_thunk");
616 } else {
617 ThunkName.append("$exit_thunk");
618 }
620 Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M);
621 GuestExit->setComdat(M->getOrInsertComdat(ThunkName));
622 GuestExit->setSection(".wowthk$aa");
623 GuestExit->setMetadata(
624 "arm64ec_unmangled_name",
625 MDNode::get(M->getContext(),
626 MDString::get(M->getContext(), F->getName())));
627 GuestExit->setMetadata(
628 "arm64ec_ecmangled_name",
629 MDNode::get(M->getContext(),
630 MDString::get(M->getContext(), *MangledName)));
631 F->setMetadata("arm64ec_hasguestexit", MDNode::get(M->getContext(), {}));
632 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit);
633 IRBuilder<> B(BB);
634
635 // Load the global symbol as a pointer to the check function.
636 Value *GuardFn;
637 if (cfguard_module_flag == 2 && !F->hasFnAttribute("guard_nocf"))
638 GuardFn = GuardFnCFGlobal;
639 else
640 GuardFn = GuardFnGlobal;
641 LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn);
642
643 // Create new call instruction. The CFGuard check should always be a call,
644 // even if the original CallBase is an Invoke or CallBr instruction.
645 Function *Thunk = buildExitThunk(F->getFunctionType(), F->getAttributes());
646 CallInst *GuardCheck = B.CreateCall(
647 GuardFnType, GuardCheckLoad,
648 {B.CreateBitCast(F, B.getPtrTy()), B.CreateBitCast(Thunk, B.getPtrTy())});
649
650 // Ensure that the first argument is passed in the correct register.
652
653 Value *GuardRetVal = B.CreateBitCast(GuardCheck, PtrTy);
655 for (Argument &Arg : GuestExit->args())
656 Args.push_back(&Arg);
657 CallInst *Call = B.CreateCall(Arm64Ty, GuardRetVal, Args);
658 Call->setTailCallKind(llvm::CallInst::TCK_MustTail);
659
660 if (Call->getType()->isVoidTy())
661 B.CreateRetVoid();
662 else
663 B.CreateRet(Call);
664
665 auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
666 auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
667 if (SRetAttr.isValid() && !InRegAttr.isValid()) {
668 GuestExit->addParamAttr(0, SRetAttr);
669 Call->addParamAttr(0, SRetAttr);
670 }
671
672 return GuestExit;
673}
674
675// Lower an indirect call with inline code.
676void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) {
678 "Only applicable for Windows targets");
679
680 IRBuilder<> B(CB);
681 Value *CalledOperand = CB->getCalledOperand();
682
683 // If the indirect call is called within catchpad or cleanuppad,
684 // we need to copy "funclet" bundle of the call.
686 if (auto Bundle = CB->getOperandBundle(LLVMContext::OB_funclet))
687 Bundles.push_back(OperandBundleDef(*Bundle));
688
689 // Load the global symbol as a pointer to the check function.
690 Value *GuardFn;
691 if (cfguard_module_flag == 2 && !CB->hasFnAttr("guard_nocf"))
692 GuardFn = GuardFnCFGlobal;
693 else
694 GuardFn = GuardFnGlobal;
695 LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn);
696
697 // Create new call instruction. The CFGuard check should always be a call,
698 // even if the original CallBase is an Invoke or CallBr instruction.
699 Function *Thunk = buildExitThunk(CB->getFunctionType(), CB->getAttributes());
700 CallInst *GuardCheck =
701 B.CreateCall(GuardFnType, GuardCheckLoad,
702 {B.CreateBitCast(CalledOperand, B.getPtrTy()),
703 B.CreateBitCast(Thunk, B.getPtrTy())},
704 Bundles);
705
706 // Ensure that the first argument is passed in the correct register.
708
709 Value *GuardRetVal = B.CreateBitCast(GuardCheck, CalledOperand->getType());
710 CB->setCalledOperand(GuardRetVal);
711}
712
713bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
714 if (!GenerateThunks)
715 return false;
716
717 M = &Mod;
718
719 // Check if this module has the cfguard flag and read its value.
720 if (auto *MD =
721 mdconst::extract_or_null<ConstantInt>(M->getModuleFlag("cfguard")))
722 cfguard_module_flag = MD->getZExtValue();
723
724 PtrTy = PointerType::getUnqual(M->getContext());
725 I64Ty = Type::getInt64Ty(M->getContext());
726 VoidTy = Type::getVoidTy(M->getContext());
727
728 GuardFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy}, false);
729 GuardFnPtrType = PointerType::get(GuardFnType, 0);
730 GuardFnCFGlobal =
731 M->getOrInsertGlobal("__os_arm64x_check_icall_cfg", GuardFnPtrType);
732 GuardFnGlobal =
733 M->getOrInsertGlobal("__os_arm64x_check_icall", GuardFnPtrType);
734
735 SetVector<Function *> DirectCalledFns;
736 for (Function &F : Mod)
737 if (!F.isDeclaration() &&
738 F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
739 F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64)
740 processFunction(F, DirectCalledFns);
741
742 struct ThunkInfo {
743 Constant *Src;
744 Constant *Dst;
746 };
747 SmallVector<ThunkInfo> ThunkMapping;
748 for (Function &F : Mod) {
749 if (!F.isDeclaration() && (!F.hasLocalLinkage() || F.hasAddressTaken()) &&
750 F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
751 F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64) {
752 if (!F.hasComdat())
753 F.setComdat(Mod.getOrInsertComdat(F.getName()));
754 ThunkMapping.push_back(
755 {&F, buildEntryThunk(&F), Arm64ECThunkType::Entry});
756 }
757 }
758 for (Function *F : DirectCalledFns) {
759 ThunkMapping.push_back(
760 {F, buildExitThunk(F->getFunctionType(), F->getAttributes()),
761 Arm64ECThunkType::Exit});
762 if (!F->hasDLLImportStorageClass())
763 ThunkMapping.push_back(
764 {buildGuestExitThunk(F), F, Arm64ECThunkType::GuestExit});
765 }
766
767 if (!ThunkMapping.empty()) {
768 SmallVector<Constant *> ThunkMappingArrayElems;
769 for (ThunkInfo &Thunk : ThunkMapping) {
770 ThunkMappingArrayElems.push_back(ConstantStruct::getAnon(
771 {ConstantExpr::getBitCast(Thunk.Src, PtrTy),
773 ConstantInt::get(M->getContext(), APInt(32, uint8_t(Thunk.Kind)))}));
774 }
775 Constant *ThunkMappingArray = ConstantArray::get(
776 llvm::ArrayType::get(ThunkMappingArrayElems[0]->getType(),
777 ThunkMappingArrayElems.size()),
778 ThunkMappingArrayElems);
779 new GlobalVariable(Mod, ThunkMappingArray->getType(), /*isConstant*/ false,
780 GlobalValue::ExternalLinkage, ThunkMappingArray,
781 "llvm.arm64ec.symbolmap");
782 }
783
784 return true;
785}
786
787bool AArch64Arm64ECCallLowering::processFunction(
788 Function &F, SetVector<Function *> &DirectCalledFns) {
789 SmallVector<CallBase *, 8> IndirectCalls;
790
791 // For ARM64EC targets, a function definition's name is mangled differently
792 // from the normal symbol. We currently have no representation of this sort
793 // of symbol in IR, so we change the name to the mangled name, then store
794 // the unmangled name as metadata. Later passes that need the unmangled
795 // name (emitting the definition) can grab it from the metadata.
796 //
797 // FIXME: Handle functions with weak linkage?
798 if (!F.hasLocalLinkage() || F.hasAddressTaken()) {
799 if (std::optional<std::string> MangledName =
800 getArm64ECMangledFunctionName(F.getName().str())) {
801 F.setMetadata("arm64ec_unmangled_name",
802 MDNode::get(M->getContext(),
803 MDString::get(M->getContext(), F.getName())));
804 if (F.hasComdat() && F.getComdat()->getName() == F.getName()) {
805 Comdat *MangledComdat = M->getOrInsertComdat(MangledName.value());
806 SmallVector<GlobalObject *> ComdatUsers =
807 to_vector(F.getComdat()->getUsers());
808 for (GlobalObject *User : ComdatUsers)
809 User->setComdat(MangledComdat);
810 }
811 F.setName(MangledName.value());
812 }
813 }
814
815 // Iterate over the instructions to find all indirect call/invoke/callbr
816 // instructions. Make a separate list of pointers to indirect
817 // call/invoke/callbr instructions because the original instructions will be
818 // deleted as the checks are added.
819 for (BasicBlock &BB : F) {
820 for (Instruction &I : BB) {
821 auto *CB = dyn_cast<CallBase>(&I);
823 CB->isInlineAsm())
824 continue;
825
826 // We need to instrument any call that isn't directly calling an
827 // ARM64 function.
828 //
829 // FIXME: getCalledFunction() fails if there's a bitcast (e.g.
830 // unprototyped functions in C)
831 if (Function *F = CB->getCalledFunction()) {
832 if (!LowerDirectToIndirect || F->hasLocalLinkage() ||
833 F->isIntrinsic() || !F->isDeclaration())
834 continue;
835
836 DirectCalledFns.insert(F);
837 continue;
838 }
839
840 IndirectCalls.push_back(CB);
841 ++Arm64ECCallsLowered;
842 }
843 }
844
845 if (IndirectCalls.empty())
846 return false;
847
848 for (CallBase *CB : IndirectCalls)
849 lowerCall(CB);
850
851 return true;
852}
853
854char AArch64Arm64ECCallLowering::ID = 0;
855INITIALIZE_PASS(AArch64Arm64ECCallLowering, "Arm64ECCallLowering",
856 "AArch64Arm64ECCallLowering", false, false)
857
859 return new AArch64Arm64ECCallLowering;
860}
static cl::opt< bool > LowerDirectToIndirect("arm64ec-lower-direct-to-indirect", cl::Hidden, cl::init(true))
static cl::opt< bool > GenerateThunks("arm64ec-generate-thunks", cl::Hidden, cl::init(true))
OperandBundleDefT< Value * > OperandBundleDef
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
return RetTy
lazy value info
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
Module.h This file contains the declarations for the Module class.
Module * Mod
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:38
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file implements a set that has insertion order iteration characteristics.
This file defines the SmallString class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:167
static SymbolRef::Type getType(const Symbol *Sym)
Definition: TapiFile.cpp:40
Class for arbitrary precision integers.
Definition: APInt.h:78
This class represents an incoming formal argument to a Function.
Definition: Argument.h:31
static ArrayType * get(Type *ElementType, uint64_t NumElements)
This static method is the primary way to construct an ArrayType.
Definition: Type.cpp:647
Attribute getParamAttr(unsigned ArgNo, Attribute::AttrKind Kind) const
Return the attribute object that exists at the arg index.
Definition: Attributes.h:854
MaybeAlign getParamAlignment(unsigned ArgNo) const
Return the alignment for the specified function parameter.
bool isValid() const
Return true if the attribute is any kind of attribute.
Definition: Attributes.h:203
Type * getValueAsType() const
Return the attribute's value as a Type.
Definition: Attributes.cpp:398
LLVM Basic Block Representation.
Definition: BasicBlock.h:61
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition: BasicBlock.h:202
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
Definition: InstrTypes.h:1236
bool isInlineAsm() const
Check if this call is an inline asm statement.
Definition: InstrTypes.h:1532
void setCallingConv(CallingConv::ID CC)
Definition: InstrTypes.h:1527
std::optional< OperandBundleUse > getOperandBundle(StringRef Name) const
Return an operand bundle by name, if present.
Definition: InstrTypes.h:2143
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1465
bool hasFnAttr(Attribute::AttrKind Kind) const
Determine whether this call has the given attribute.
Definition: InstrTypes.h:1551
CallingConv::ID getCallingConv() const
Definition: InstrTypes.h:1523
Value * getCalledOperand() const
Definition: InstrTypes.h:1458
FunctionType * getFunctionType() const
Definition: InstrTypes.h:1323
void setCalledOperand(Value *V)
Definition: InstrTypes.h:1501
AttributeList getAttributes() const
Return the parameter attributes for this call.
Definition: InstrTypes.h:1542
This class represents a function call, abstracting a target machine's calling convention.
static Constant * get(ArrayType *T, ArrayRef< Constant * > V)
Definition: Constants.cpp:1292
static Constant * getBitCast(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:2283
static Constant * getAnon(ArrayRef< Constant * > V, bool Packed=false)
Return an anonymous struct that has the specified elements.
Definition: Constants.h:477
This is an important base class in LLVM.
Definition: Constant.h:42
static Function * Create(FunctionType *Ty, LinkageTypes Linkage, unsigned AddrSpace, const Twine &N="", Module *M=nullptr)
Definition: Function.h:165
@ WeakODRLinkage
Same, but only replaced by something equivalent.
Definition: GlobalValue.h:57
@ ExternalLinkage
Externally visible function.
Definition: GlobalValue.h:52
@ LinkOnceODRLinkage
Same, but only replaced by something equivalent.
Definition: GlobalValue.h:55
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2671
const Module * getModule() const
Return the module owning the function this instruction belongs to or nullptr it the function does not...
Definition: Instruction.cpp:66
An instruction for reading from memory.
Definition: Instructions.h:174
static MDTuple * get(LLVMContext &Context, ArrayRef< Metadata * > MDs)
Definition: Metadata.h:1541
static MDString * get(LLVMContext &Context, StringRef Str)
Definition: Metadata.cpp:600
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:251
virtual bool runOnModule(Module &M)=0
runOnModule - Virtual method overriden by subclasses to process the module being operated on.
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
const std::string & getTargetTriple() const
Get the target triple which is a string describing the target host.
Definition: Module.h:297
Comdat * getOrInsertComdat(StringRef Name)
Return the Comdat in the module with the specified name.
Definition: Module.cpp:599
A container for an operand bundle being viewed as a set of values rather than a set of uses.
Definition: InstrTypes.h:1189
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
A vector that has set insertion semantics.
Definition: SetVector.h:57
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:162
SmallString - A SmallString is just a SmallVector with methods and accessors that make it work better...
Definition: SmallString.h:26
bool empty() const
Definition: SmallVector.h:94
size_t size() const
Definition: SmallVector.h:91
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:586
void push_back(const T &Elt)
Definition: SmallVector.h:426
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1209
Triple - Helper class for working with autoconf configuration names.
Definition: Triple.h:44
bool isOSWindows() const
Tests whether the OS is Windows.
Definition: Triple.h:624
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
bool isPointerTy() const
True if this is an instance of PointerType.
Definition: Type.h:255
bool isFloatTy() const
Return true if this is 'float', a 32-bit IEEE fp type.
Definition: Type.h:154
static IntegerType * getIntNTy(LLVMContext &C, unsigned N)
static Type * getVoidTy(LLVMContext &C)
bool isDoubleTy() const
Return true if this is 'double', a 64-bit IEEE fp type.
Definition: Type.h:157
static IntegerType * getInt64Ty(LLVMContext &C)
bool isVoidTy() const
Return true if this is 'void'.
Definition: Type.h:140
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
A raw_ostream that discards all output.
Definition: raw_ostream.h:731
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
A raw_ostream that writes to an SmallVector or SmallString.
Definition: raw_ostream.h:691
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
constexpr char Attrs[]
Key for Kernel::Metadata::mAttrs.
Arm64ECThunkType
Definition: COFF.h:809
@ GuestExit
Definition: COFF.h:810
@ ARM64EC_Thunk_Native
Calling convention used in the ARM64EC ABI to implement calls between ARM64 code and thunks.
Definition: CallingConv.h:265
@ CFGuard_Check
Special calling convention on Windows for calling the Control Guard Check ICall funtion.
Definition: CallingConv.h:82
@ ARM64EC_Thunk_X64
Calling convention used in the ARM64EC ABI to implement calls between x64 code and thunks.
Definition: CallingConv.h:260
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:443
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
std::optional< std::string > getArm64ECMangledFunctionName(StringRef Name)
Definition: Mangler.cpp:293
detail::zippy< detail::zip_first, T, U, Args... > zip_equal(T &&t, U &&u, Args &&...args)
zip iterator that assumes that all iteratees have the same length.
Definition: STLExtras.h:863
iterator_range< T > make_range(T x, T y)
Convenience function for iterating over sub-ranges.
void initializeAArch64Arm64ECCallLoweringPass(PassRegistry &)
ModulePass * createAArch64Arm64ECCallLoweringPass()
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
Definition: Error.cpp:167
SmallVector< ValueTypeFromRangeType< R >, Size > to_vector(R &&Range)
Given a range of type R, iterate the entire range and return a SmallVector with elements of the vecto...
Definition: SmallVector.h:1312
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
uint64_t value() const
This is a hole in the type system and should not be abused.
Definition: Alignment.h:85
Align valueOrOne() const
For convenience, returns a valid alignment or 1 if undefined.
Definition: Alignment.h:141