LLVM  14.0.0git
WebAssemblyFixFunctionBitcasts.cpp
Go to the documentation of this file.
1 //===-- WebAssemblyFixFunctionBitcasts.cpp - Fix function bitcasts --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// Fix bitcasted functions.
11 ///
12 /// WebAssembly requires caller and callee signatures to match, however in LLVM,
13 /// some amount of slop is vaguely permitted. Detect mismatch by looking for
14 /// bitcasts of functions and rewrite them to use wrapper functions instead.
15 ///
16 /// This doesn't catch all cases, such as when a function's address is taken in
17 /// one place and casted in another, but it works for many common cases.
18 ///
19 /// Note that LLVM already optimizes away function bitcasts in common cases by
20 /// dropping arguments as needed, so this pass only ends up getting used in less
21 /// common cases.
22 ///
23 //===----------------------------------------------------------------------===//
24 
25 #include "WebAssembly.h"
26 #include "llvm/IR/Constants.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/Module.h"
29 #include "llvm/IR/Operator.h"
30 #include "llvm/Pass.h"
31 #include "llvm/Support/Debug.h"
33 using namespace llvm;
34 
35 #define DEBUG_TYPE "wasm-fix-function-bitcasts"
36 
37 namespace {
38 class FixFunctionBitcasts final : public ModulePass {
39  StringRef getPassName() const override {
40  return "WebAssembly Fix Function Bitcasts";
41  }
42 
43  void getAnalysisUsage(AnalysisUsage &AU) const override {
44  AU.setPreservesCFG();
46  }
47 
48  bool runOnModule(Module &M) override;
49 
50 public:
51  static char ID;
52  FixFunctionBitcasts() : ModulePass(ID) {}
53 };
54 } // End anonymous namespace
55 
57 INITIALIZE_PASS(FixFunctionBitcasts, DEBUG_TYPE,
58  "Fix mismatching bitcasts for WebAssembly", false, false)
59 
61  return new FixFunctionBitcasts();
62 }
63 
64 // Recursively descend the def-use lists from V to find non-bitcast users of
65 // bitcasts of V.
66 static void findUses(Value *V, Function &F,
67  SmallVectorImpl<std::pair<CallBase *, Function *>> &Uses) {
68  for (User *U : V->users()) {
69  if (auto *BC = dyn_cast<BitCastOperator>(U))
70  findUses(BC, F, Uses);
71  else if (auto *A = dyn_cast<GlobalAlias>(U))
72  findUses(A, F, Uses);
73  else if (auto *CB = dyn_cast<CallBase>(U)) {
74  Value *Callee = CB->getCalledOperand();
75  if (Callee != V)
76  // Skip calls where the function isn't the callee
77  continue;
78  if (CB->getFunctionType() == F.getValueType())
79  // Skip uses that are immediately called
80  continue;
81  Uses.push_back(std::make_pair(CB, &F));
82  }
83  }
84 }
85 
86 // Create a wrapper function with type Ty that calls F (which may have a
87 // different type). Attempt to support common bitcasted function idioms:
88 // - Call with more arguments than needed: arguments are dropped
89 // - Call with fewer arguments than needed: arguments are filled in with undef
90 // - Return value is not needed: drop it
91 // - Return value needed but not present: supply an undef
92 //
93 // If the all the argument types of trivially castable to one another (i.e.
94 // I32 vs pointer type) then we don't create a wrapper at all (return nullptr
95 // instead).
96 //
97 // If there is a type mismatch that we know would result in an invalid wasm
98 // module then generate wrapper that contains unreachable (i.e. abort at
99 // runtime). Such programs are deep into undefined behaviour territory,
100 // but we choose to fail at runtime rather than generate and invalid module
101 // or fail at compiler time. The reason we delay the error is that we want
102 // to support the CMake which expects to be able to compile and link programs
103 // that refer to functions with entirely incorrect signatures (this is how
104 // CMake detects the existence of a function in a toolchain).
105 //
106 // For bitcasts that involve struct types we don't know at this stage if they
107 // would be equivalent at the wasm level and so we can't know if we need to
108 // generate a wrapper.
110  Module *M = F->getParent();
111 
113  F->getName() + "_bitcast", M);
114  BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);
115  const DataLayout &DL = BB->getModule()->getDataLayout();
116 
117  // Determine what arguments to pass.
119  Function::arg_iterator AI = Wrapper->arg_begin();
120  Function::arg_iterator AE = Wrapper->arg_end();
121  FunctionType::param_iterator PI = F->getFunctionType()->param_begin();
122  FunctionType::param_iterator PE = F->getFunctionType()->param_end();
123  bool TypeMismatch = false;
124  bool WrapperNeeded = false;
125 
126  Type *ExpectedRtnType = F->getFunctionType()->getReturnType();
127  Type *RtnType = Ty->getReturnType();
128 
129  if ((F->getFunctionType()->getNumParams() != Ty->getNumParams()) ||
130  (F->getFunctionType()->isVarArg() != Ty->isVarArg()) ||
131  (ExpectedRtnType != RtnType))
132  WrapperNeeded = true;
133 
134  for (; AI != AE && PI != PE; ++AI, ++PI) {
135  Type *ArgType = AI->getType();
136  Type *ParamType = *PI;
137 
138  if (ArgType == ParamType) {
139  Args.push_back(&*AI);
140  } else {
141  if (CastInst::isBitOrNoopPointerCastable(ArgType, ParamType, DL)) {
142  Instruction *PtrCast =
143  CastInst::CreateBitOrPointerCast(AI, ParamType, "cast");
144  BB->getInstList().push_back(PtrCast);
145  Args.push_back(PtrCast);
146  } else if (ArgType->isStructTy() || ParamType->isStructTy()) {
147  LLVM_DEBUG(dbgs() << "createWrapper: struct param type in bitcast: "
148  << F->getName() << "\n");
149  WrapperNeeded = false;
150  } else {
151  LLVM_DEBUG(dbgs() << "createWrapper: arg type mismatch calling: "
152  << F->getName() << "\n");
153  LLVM_DEBUG(dbgs() << "Arg[" << Args.size() << "] Expected: "
154  << *ParamType << " Got: " << *ArgType << "\n");
155  TypeMismatch = true;
156  break;
157  }
158  }
159  }
160 
161  if (WrapperNeeded && !TypeMismatch) {
162  for (; PI != PE; ++PI)
163  Args.push_back(UndefValue::get(*PI));
164  if (F->isVarArg())
165  for (; AI != AE; ++AI)
166  Args.push_back(&*AI);
167 
168  CallInst *Call = CallInst::Create(F, Args, "", BB);
169 
170  Type *ExpectedRtnType = F->getFunctionType()->getReturnType();
171  Type *RtnType = Ty->getReturnType();
172  // Determine what value to return.
173  if (RtnType->isVoidTy()) {
174  ReturnInst::Create(M->getContext(), BB);
175  } else if (ExpectedRtnType->isVoidTy()) {
176  LLVM_DEBUG(dbgs() << "Creating dummy return: " << *RtnType << "\n");
177  ReturnInst::Create(M->getContext(), UndefValue::get(RtnType), BB);
178  } else if (RtnType == ExpectedRtnType) {
179  ReturnInst::Create(M->getContext(), Call, BB);
180  } else if (CastInst::isBitOrNoopPointerCastable(ExpectedRtnType, RtnType,
181  DL)) {
182  Instruction *Cast =
183  CastInst::CreateBitOrPointerCast(Call, RtnType, "cast");
184  BB->getInstList().push_back(Cast);
185  ReturnInst::Create(M->getContext(), Cast, BB);
186  } else if (RtnType->isStructTy() || ExpectedRtnType->isStructTy()) {
187  LLVM_DEBUG(dbgs() << "createWrapper: struct return type in bitcast: "
188  << F->getName() << "\n");
189  WrapperNeeded = false;
190  } else {
191  LLVM_DEBUG(dbgs() << "createWrapper: return type mismatch calling: "
192  << F->getName() << "\n");
193  LLVM_DEBUG(dbgs() << "Expected: " << *ExpectedRtnType
194  << " Got: " << *RtnType << "\n");
195  TypeMismatch = true;
196  }
197  }
198 
199  if (TypeMismatch) {
200  // Create a new wrapper that simply contains `unreachable`.
201  Wrapper->eraseFromParent();
203  F->getName() + "_bitcast_invalid", M);
204  BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);
205  new UnreachableInst(M->getContext(), BB);
206  Wrapper->setName(F->getName() + "_bitcast_invalid");
207  } else if (!WrapperNeeded) {
208  LLVM_DEBUG(dbgs() << "createWrapper: no wrapper needed: " << F->getName()
209  << "\n");
210  Wrapper->eraseFromParent();
211  return nullptr;
212  }
213  LLVM_DEBUG(dbgs() << "createWrapper: " << F->getName() << "\n");
214  return Wrapper;
215 }
216 
217 // Test whether a main function with type FuncTy should be rewritten to have
218 // type MainTy.
219 static bool shouldFixMainFunction(FunctionType *FuncTy, FunctionType *MainTy) {
220  // Only fix the main function if it's the standard zero-arg form. That way,
221  // the standard cases will work as expected, and users will see signature
222  // mismatches from the linker for non-standard cases.
223  return FuncTy->getReturnType() == MainTy->getReturnType() &&
224  FuncTy->getNumParams() == 0 &&
225  !FuncTy->isVarArg();
226 }
227 
228 bool FixFunctionBitcasts::runOnModule(Module &M) {
229  LLVM_DEBUG(dbgs() << "********** Fix Function Bitcasts **********\n");
230 
231  Function *Main = nullptr;
232  CallInst *CallMain = nullptr;
234 
235  // Collect all the places that need wrappers.
236  for (Function &F : M) {
237  // Skip to fix when the function is swiftcc because swiftcc allows
238  // bitcast type difference for swiftself and swifterror.
239  if (F.getCallingConv() == CallingConv::Swift)
240  continue;
241  findUses(&F, F, Uses);
242 
243  // If we have a "main" function, and its type isn't
244  // "int main(int argc, char *argv[])", create an artificial call with it
245  // bitcasted to that type so that we generate a wrapper for it, so that
246  // the C runtime can call it.
247  if (F.getName() == "main") {
248  Main = &F;
249  LLVMContext &C = M.getContext();
250  Type *MainArgTys[] = {Type::getInt32Ty(C),
252  FunctionType *MainTy = FunctionType::get(Type::getInt32Ty(C), MainArgTys,
253  /*isVarArg=*/false);
254  if (shouldFixMainFunction(F.getFunctionType(), MainTy)) {
255  LLVM_DEBUG(dbgs() << "Found `main` function with incorrect type: "
256  << *F.getFunctionType() << "\n");
257  Value *Args[] = {UndefValue::get(MainArgTys[0]),
258  UndefValue::get(MainArgTys[1])};
259  Value *Casted =
260  ConstantExpr::getBitCast(Main, PointerType::get(MainTy, 0));
261  CallMain = CallInst::Create(MainTy, Casted, Args, "call_main");
262  Uses.push_back(std::make_pair(CallMain, &F));
263  }
264  }
265  }
266 
268 
269  for (auto &UseFunc : Uses) {
270  CallBase *CB = UseFunc.first;
271  Function *F = UseFunc.second;
272  FunctionType *Ty = CB->getFunctionType();
273 
274  auto Pair = Wrappers.insert(std::make_pair(std::make_pair(F, Ty), nullptr));
275  if (Pair.second)
276  Pair.first->second = createWrapper(F, Ty);
277 
278  Function *Wrapper = Pair.first->second;
279  if (!Wrapper)
280  continue;
281 
283  }
284 
285  // If we created a wrapper for main, rename the wrapper so that it's the
286  // one that gets called from startup.
287  if (CallMain) {
288  Main->setName("__original_main");
289  auto *MainWrapper =
290  cast<Function>(CallMain->getCalledOperand()->stripPointerCasts());
291  delete CallMain;
292  if (Main->isDeclaration()) {
293  // The wrapper is not needed in this case as we don't need to export
294  // it to anyone else.
295  MainWrapper->eraseFromParent();
296  } else {
297  // Otherwise give the wrapper the same linkage as the original main
298  // function, so that it can be called from the same places.
299  MainWrapper->setName("main");
300  MainWrapper->setLinkage(Main->getLinkage());
301  MainWrapper->setVisibility(Main->getVisibility());
302  }
303  }
304 
305  return true;
306 }
shouldFixMainFunction
static bool shouldFixMainFunction(FunctionType *FuncTy, FunctionType *MainTy)
Definition: WebAssemblyFixFunctionBitcasts.cpp:219
llvm::Argument
This class represents an incoming formal argument to a Function.
Definition: Argument.h:29
llvm
This is an optimization pass for GlobalISel generic memory operations.
Definition: AllocatorList.h:23
WebAssembly.h
M
We currently emits eax Perhaps this is what we really should generate is Is imull three or four cycles eax eax The current instruction priority is based on pattern complexity The former is more complex because it folds a load so the latter will not be emitted Perhaps we should use AddedComplexity to give LEA32r a higher priority We should always try to match LEA first since the LEA matching code does some estimate to determine whether the match is profitable if we care more about code then imull is better It s two bytes shorter than movl leal On a Pentium M
Definition: README.txt:252
llvm::DataLayout
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:113
llvm::Type::getInt8PtrTy
static PointerType * getInt8PtrTy(LLVMContext &C, unsigned AS=0)
Definition: Type.cpp:293
llvm::ModulePass
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:238
llvm::GlobalValue::getLinkage
LinkageTypes getLinkage() const
Definition: GlobalValue.h:467
llvm::Function
Definition: Function.h:62
Pass.h
llvm::PointerType::get
static PointerType * get(Type *ElementType, unsigned AddressSpace)
This constructs a pointer to an object of the specified type in a numbered address space.
Definition: Type.cpp:729
llvm::SmallVector
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1168
Wrapper
amdgpu aa AMDGPU Address space based Alias Analysis Wrapper
Definition: AMDGPUAliasAnalysis.cpp:31
llvm::ConstantExpr::getBitCast
static Constant * getBitCast(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:2233
llvm::FunctionType::get
static FunctionType * get(Type *Result, ArrayRef< Type * > Params, bool isVarArg)
This static method is the primary way of constructing a FunctionType.
Definition: Type.cpp:363
llvm::Type
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
Module.h
llvm::CallBase::getFunctionType
FunctionType * getFunctionType() const
Definition: InstrTypes.h:1256
Operator.h
llvm::FunctionType::getNumParams
unsigned getNumParams() const
Return the number of fixed parameters this function type requires.
Definition: DerivedTypes.h:139
llvm::Type::getInt32Ty
static IntegerType * getInt32Ty(LLVMContext &C)
Definition: Type.cpp:241
LLVM_DEBUG
#define LLVM_DEBUG(X)
Definition: Debug.h:101
F
#define F(x, y, z)
Definition: MD5.cpp:56
llvm::FunctionType::isVarArg
bool isVarArg() const
Definition: DerivedTypes.h:123
Uses
SmallPtrSet< MachineInstr *, 2 > Uses
Definition: ARMLowOverheadLoops.cpp:589
llvm::BasicBlock
LLVM Basic Block Representation.
Definition: BasicBlock.h:58
llvm::dbgs
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
findUses
static void findUses(Value *V, Function &F, SmallVectorImpl< std::pair< CallBase *, Function * >> &Uses)
Definition: WebAssemblyFixFunctionBitcasts.cpp:66
llvm::GlobalValue::isDeclaration
bool isDeclaration() const
Return true if the primary definition of this global value is outside of the current translation unit...
Definition: Globals.cpp:228
Constants.h
createWrapper
static Function * createWrapper(Function *F, FunctionType *Ty)
Definition: WebAssemblyFixFunctionBitcasts.cpp:109
llvm::User
Definition: User.h:44
C
(vector float) vec_cmpeq(*A, *B) C
Definition: README_ALTIVEC.txt:86
llvm::CallInst::Create
static CallInst * Create(FunctionType *Ty, Value *F, const Twine &NameStr="", Instruction *InsertBefore=nullptr)
Definition: Instructions.h:1530
llvm::AnalysisUsage
Represent the analysis usage information of a pass.
Definition: PassAnalysisSupport.h:47
llvm::Instruction
Definition: Instruction.h:45
llvm::Value::setName
void setName(const Twine &Name)
Change the name of the value.
Definition: Value.cpp:376
llvm::UndefValue::get
static UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
Definition: Constants.cpp:1796
DEBUG_TYPE
#define DEBUG_TYPE
Definition: WebAssemblyFixFunctionBitcasts.cpp:35
llvm::FunctionType::param_iterator
Type::subtype_iterator param_iterator
Definition: DerivedTypes.h:126
llvm::Instruction::eraseFromParent
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:78
llvm::GlobalValue::getVisibility
VisibilityTypes getVisibility() const
Definition: GlobalValue.h:229
llvm::LLVMContext
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:68
llvm::DenseMap
Definition: DenseMap.h:714
llvm::Function::Create
static Function * Create(FunctionType *Ty, LinkageTypes Linkage, unsigned AddrSpace, const Twine &N="", Module *M=nullptr)
Definition: Function.h:139
llvm::Type::isVoidTy
bool isVoidTy() const
Return true if this is 'void'.
Definition: Type.h:138
llvm::Module
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:67
INITIALIZE_PASS
INITIALIZE_PASS(FixFunctionBitcasts, DEBUG_TYPE, "Fix mismatching bitcasts for WebAssembly", false, false) ModulePass *llvm
Definition: WebAssemblyFixFunctionBitcasts.cpp:57
llvm::AnalysisUsage::setPreservesCFG
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:253
llvm::StringRef
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:57
llvm::Value::getType
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
llvm::BasicBlock::Create
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition: BasicBlock.h:100
DL
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Definition: AArch64SLSHardening.cpp:76
llvm::CastInst::isBitOrNoopPointerCastable
static bool isBitOrNoopPointerCastable(Type *SrcTy, Type *DestTy, const DataLayout &DL)
Check whether a bitcast, inttoptr, or ptrtoint cast between these types is valid and a no-op.
Definition: Instructions.cpp:3417
llvm::CallingConv::Swift
@ Swift
Definition: CallingConv.h:73
llvm::DenseMapBase< DenseMap< KeyT, ValueT, DenseMapInfo< KeyT >, llvm::detail::DenseMapPair< KeyT, ValueT > >, KeyT, ValueT, DenseMapInfo< KeyT >, llvm::detail::DenseMapPair< KeyT, ValueT > >::insert
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:207
llvm::Value::stripPointerCasts
const Value * stripPointerCasts() const
Strip off pointer casts, all-zero GEPs and address space casts.
Definition: Value.cpp:685
Callee
amdgpu Simplify well known AMD library false FunctionCallee Callee
Definition: AMDGPULibCalls.cpp:206
llvm::CastInst::CreateBitOrPointerCast
static CastInst * CreateBitOrPointerCast(Value *S, Type *Ty, const Twine &Name="", Instruction *InsertBefore=nullptr)
Create a BitCast, a PtrToInt, or an IntToPTr cast instruction.
Definition: Instructions.cpp:3312
llvm::ReturnInst::Create
static ReturnInst * Create(LLVMContext &C, Value *retVal=nullptr, Instruction *InsertBefore=nullptr)
Definition: Instructions.h:3046
llvm::CallBase::getCalledOperand
Value * getCalledOperand() const
Definition: InstrTypes.h:1391
llvm::CallBase::setCalledOperand
void setCalledOperand(Value *V)
Definition: InstrTypes.h:1431
llvm::GlobalValue::PrivateLinkage
@ PrivateLinkage
Like Internal, but omit from symbol table.
Definition: GlobalValue.h:56
Instructions.h
llvm::Type::isStructTy
bool isStructTy() const
True if this is an instance of StructType.
Definition: Type.h:211
llvm::createWebAssemblyFixFunctionBitcasts
ModulePass * createWebAssemblyFixFunctionBitcasts()
llvm::SmallVectorImpl
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: APFloat.h:43
llvm::CallBase
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
Definition: InstrTypes.h:1176
llvm::Pass::getAnalysisUsage
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:93
llvm::CallInst
This class represents a function call, abstracting a target machine's calling convention.
Definition: Instructions.h:1487
BB
Common register allocation spilling lr str ldr sxth r3 ldr mla r4 can lr mov lr str ldr sxth r3 mla r4 and then merge mul and lr str ldr sxth r3 mla r4 It also increase the likelihood the store may become dead bb27 Successors according to LLVM BB
Definition: README.txt:39
llvm::UnreachableInst
This function has undefined behavior.
Definition: Instructions.h:4742
llvm::AMDGPU::HSAMD::Kernel::Key::Args
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
Definition: AMDGPUMetadata.h:389
raw_ostream.h
llvm::FunctionType::getReturnType
Type * getReturnType() const
Definition: DerivedTypes.h:124
llvm::Value
LLVM Value Representation.
Definition: Value.h:74
Debug.h
llvm::Value::users
iterator_range< user_iterator > users()
Definition: Value.h:421
llvm::FunctionType
Class to represent function types.
Definition: DerivedTypes.h:103
llvm::Intrinsic::ID
unsigned ID
Definition: TargetTransformInfo.h:38