LLVM  9.0.0svn
WebAssemblyFixFunctionBitcasts.cpp
Go to the documentation of this file.
1 //===-- WebAssemblyFixFunctionBitcasts.cpp - Fix function bitcasts --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// Fix bitcasted functions.
11 ///
12 /// WebAssembly requires caller and callee signatures to match, however in LLVM,
13 /// some amount of slop is vaguely permitted. Detect mismatch by looking for
14 /// bitcasts of functions and rewrite them to use wrapper functions instead.
15 ///
16 /// This doesn't catch all cases, such as when a function's address is taken in
17 /// one place and casted in another, but it works for many common cases.
18 ///
19 /// Note that LLVM already optimizes away function bitcasts in common cases by
20 /// dropping arguments as needed, so this pass only ends up getting used in less
21 /// common cases.
22 ///
23 //===----------------------------------------------------------------------===//
24 
25 #include "WebAssembly.h"
26 #include "llvm/IR/CallSite.h"
27 #include "llvm/IR/Constants.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/Module.h"
30 #include "llvm/IR/Operator.h"
31 #include "llvm/Pass.h"
32 #include "llvm/Support/Debug.h"
34 using namespace llvm;
35 
36 #define DEBUG_TYPE "wasm-fix-function-bitcasts"
37 
38 namespace {
39 class FixFunctionBitcasts final : public ModulePass {
40  StringRef getPassName() const override {
41  return "WebAssembly Fix Function Bitcasts";
42  }
43 
44  void getAnalysisUsage(AnalysisUsage &AU) const override {
45  AU.setPreservesCFG();
47  }
48 
49  bool runOnModule(Module &M) override;
50 
51 public:
52  static char ID;
53  FixFunctionBitcasts() : ModulePass(ID) {}
54 };
55 } // End anonymous namespace
56 
58 INITIALIZE_PASS(FixFunctionBitcasts, DEBUG_TYPE,
59  "Fix mismatching bitcasts for WebAssembly", false, false)
60 
62  return new FixFunctionBitcasts();
63 }
64 
65 // Recursively descend the def-use lists from V to find non-bitcast users of
66 // bitcasts of V.
67 static void findUses(Value *V, Function &F,
68  SmallVectorImpl<std::pair<Use *, Function *>> &Uses,
69  SmallPtrSetImpl<Constant *> &ConstantBCs) {
70  for (Use &U : V->uses()) {
71  if (auto *BC = dyn_cast<BitCastOperator>(U.getUser()))
72  findUses(BC, F, Uses, ConstantBCs);
73  else if (U.get()->getType() != F.getType()) {
74  CallSite CS(U.getUser());
75  if (!CS)
76  // Skip uses that aren't immediately called
77  continue;
78  Value *Callee = CS.getCalledValue();
79  if (Callee != V)
80  // Skip calls where the function isn't the callee
81  continue;
82  if (isa<Constant>(U.get())) {
83  // Only add constant bitcasts to the list once; they get RAUW'd
84  auto C = ConstantBCs.insert(cast<Constant>(U.get()));
85  if (!C.second)
86  continue;
87  }
88  Uses.push_back(std::make_pair(&U, &F));
89  }
90  }
91 }
92 
93 // Create a wrapper function with type Ty that calls F (which may have a
94 // different type). Attempt to support common bitcasted function idioms:
95 // - Call with more arguments than needed: arguments are dropped
96 // - Call with fewer arguments than needed: arguments are filled in with undef
97 // - Return value is not needed: drop it
98 // - Return value needed but not present: supply an undef
99 //
100 // If the all the argument types of trivially castable to one another (i.e.
101 // I32 vs pointer type) then we don't create a wrapper at all (return nullptr
102 // instead).
103 //
104 // If there is a type mismatch that we know would result in an invalid wasm
105 // module then generate wrapper that contains unreachable (i.e. abort at
106 // runtime). Such programs are deep into undefined behaviour territory,
107 // but we choose to fail at runtime rather than generate and invalid module
108 // or fail at compiler time. The reason we delay the error is that we want
109 // to support the CMake which expects to be able to compile and link programs
110 // that refer to functions with entirely incorrect signatures (this is how
111 // CMake detects the existence of a function in a toolchain).
112 //
113 // For bitcasts that involve struct types we don't know at this stage if they
114 // would be equivalent at the wasm level and so we can't know if we need to
115 // generate a wrapper.
117  Module *M = F->getParent();
118 
120  F->getName() + "_bitcast", M);
121  BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);
122  const DataLayout &DL = BB->getModule()->getDataLayout();
123 
124  // Determine what arguments to pass.
126  Function::arg_iterator AI = Wrapper->arg_begin();
127  Function::arg_iterator AE = Wrapper->arg_end();
130  bool TypeMismatch = false;
131  bool WrapperNeeded = false;
132 
133  Type *ExpectedRtnType = F->getFunctionType()->getReturnType();
134  Type *RtnType = Ty->getReturnType();
135 
136  if ((F->getFunctionType()->getNumParams() != Ty->getNumParams()) ||
137  (F->getFunctionType()->isVarArg() != Ty->isVarArg()) ||
138  (ExpectedRtnType != RtnType))
139  WrapperNeeded = true;
140 
141  for (; AI != AE && PI != PE; ++AI, ++PI) {
142  Type *ArgType = AI->getType();
143  Type *ParamType = *PI;
144 
145  if (ArgType == ParamType) {
146  Args.push_back(&*AI);
147  } else {
148  if (CastInst::isBitOrNoopPointerCastable(ArgType, ParamType, DL)) {
149  Instruction *PtrCast =
150  CastInst::CreateBitOrPointerCast(AI, ParamType, "cast");
151  BB->getInstList().push_back(PtrCast);
152  Args.push_back(PtrCast);
153  } else if (ArgType->isStructTy() || ParamType->isStructTy()) {
154  LLVM_DEBUG(dbgs() << "createWrapper: struct param type in bitcast: "
155  << F->getName() << "\n");
156  WrapperNeeded = false;
157  } else {
158  LLVM_DEBUG(dbgs() << "createWrapper: arg type mismatch calling: "
159  << F->getName() << "\n");
160  LLVM_DEBUG(dbgs() << "Arg[" << Args.size() << "] Expected: "
161  << *ParamType << " Got: " << *ArgType << "\n");
162  TypeMismatch = true;
163  break;
164  }
165  }
166  }
167 
168  if (WrapperNeeded && !TypeMismatch) {
169  for (; PI != PE; ++PI)
170  Args.push_back(UndefValue::get(*PI));
171  if (F->isVarArg())
172  for (; AI != AE; ++AI)
173  Args.push_back(&*AI);
174 
175  CallInst *Call = CallInst::Create(F, Args, "", BB);
176 
177  Type *ExpectedRtnType = F->getFunctionType()->getReturnType();
178  Type *RtnType = Ty->getReturnType();
179  // Determine what value to return.
180  if (RtnType->isVoidTy()) {
181  ReturnInst::Create(M->getContext(), BB);
182  } else if (ExpectedRtnType->isVoidTy()) {
183  LLVM_DEBUG(dbgs() << "Creating dummy return: " << *RtnType << "\n");
184  ReturnInst::Create(M->getContext(), UndefValue::get(RtnType), BB);
185  } else if (RtnType == ExpectedRtnType) {
186  ReturnInst::Create(M->getContext(), Call, BB);
187  } else if (CastInst::isBitOrNoopPointerCastable(ExpectedRtnType, RtnType,
188  DL)) {
189  Instruction *Cast =
190  CastInst::CreateBitOrPointerCast(Call, RtnType, "cast");
191  BB->getInstList().push_back(Cast);
192  ReturnInst::Create(M->getContext(), Cast, BB);
193  } else if (RtnType->isStructTy() || ExpectedRtnType->isStructTy()) {
194  LLVM_DEBUG(dbgs() << "createWrapper: struct return type in bitcast: "
195  << F->getName() << "\n");
196  WrapperNeeded = false;
197  } else {
198  LLVM_DEBUG(dbgs() << "createWrapper: return type mismatch calling: "
199  << F->getName() << "\n");
200  LLVM_DEBUG(dbgs() << "Expected: " << *ExpectedRtnType
201  << " Got: " << *RtnType << "\n");
202  TypeMismatch = true;
203  }
204  }
205 
206  if (TypeMismatch) {
207  // Create a new wrapper that simply contains `unreachable`.
208  Wrapper->eraseFromParent();
210  F->getName() + "_bitcast_invalid", M);
211  BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);
212  new UnreachableInst(M->getContext(), BB);
213  Wrapper->setName(F->getName() + "_bitcast_invalid");
214  } else if (!WrapperNeeded) {
215  LLVM_DEBUG(dbgs() << "createWrapper: no wrapper needed: " << F->getName()
216  << "\n");
217  Wrapper->eraseFromParent();
218  return nullptr;
219  }
220  LLVM_DEBUG(dbgs() << "createWrapper: " << F->getName() << "\n");
221  return Wrapper;
222 }
223 
224 // Test whether a main function with type FuncTy should be rewritten to have
225 // type MainTy.
226 static bool shouldFixMainFunction(FunctionType *FuncTy, FunctionType *MainTy) {
227  // Only fix the main function if it's the standard zero-arg form. That way,
228  // the standard cases will work as expected, and users will see signature
229  // mismatches from the linker for non-standard cases.
230  return FuncTy->getReturnType() == MainTy->getReturnType() &&
231  FuncTy->getNumParams() == 0 &&
232  !FuncTy->isVarArg();
233 }
234 
235 bool FixFunctionBitcasts::runOnModule(Module &M) {
236  LLVM_DEBUG(dbgs() << "********** Fix Function Bitcasts **********\n");
237 
238  Function *Main = nullptr;
239  CallInst *CallMain = nullptr;
241  SmallPtrSet<Constant *, 2> ConstantBCs;
242 
243  // Collect all the places that need wrappers.
244  for (Function &F : M) {
245  findUses(&F, F, Uses, ConstantBCs);
246 
247  // If we have a "main" function, and its type isn't
248  // "int main(int argc, char *argv[])", create an artificial call with it
249  // bitcasted to that type so that we generate a wrapper for it, so that
250  // the C runtime can call it.
251  if (F.getName() == "main") {
252  Main = &F;
253  LLVMContext &C = M.getContext();
254  Type *MainArgTys[] = {Type::getInt32Ty(C),
256  FunctionType *MainTy = FunctionType::get(Type::getInt32Ty(C), MainArgTys,
257  /*isVarArg=*/false);
258  if (shouldFixMainFunction(F.getFunctionType(), MainTy)) {
259  LLVM_DEBUG(dbgs() << "Found `main` function with incorrect type: "
260  << *F.getFunctionType() << "\n");
261  Value *Args[] = {UndefValue::get(MainArgTys[0]),
262  UndefValue::get(MainArgTys[1])};
263  Value *Casted =
264  ConstantExpr::getBitCast(Main, PointerType::get(MainTy, 0));
265  CallMain = CallInst::Create(MainTy, Casted, Args, "call_main");
266  Use *UseMain = &CallMain->getOperandUse(2);
267  Uses.push_back(std::make_pair(UseMain, &F));
268  }
269  }
270  }
271 
273 
274  for (auto &UseFunc : Uses) {
275  Use *U = UseFunc.first;
276  Function *F = UseFunc.second;
277  auto *PTy = cast<PointerType>(U->get()->getType());
278  auto *Ty = dyn_cast<FunctionType>(PTy->getElementType());
279 
280  // If the function is casted to something like i8* as a "generic pointer"
281  // to be later casted to something else, we can't generate a wrapper for it.
282  // Just ignore such casts for now.
283  if (!Ty)
284  continue;
285 
286  auto Pair = Wrappers.insert(std::make_pair(std::make_pair(F, Ty), nullptr));
287  if (Pair.second)
288  Pair.first->second = createWrapper(F, Ty);
289 
290  Function *Wrapper = Pair.first->second;
291  if (!Wrapper)
292  continue;
293 
294  if (isa<Constant>(U->get()))
295  U->get()->replaceAllUsesWith(Wrapper);
296  else
297  U->set(Wrapper);
298  }
299 
300  // If we created a wrapper for main, rename the wrapper so that it's the
301  // one that gets called from startup.
302  if (CallMain) {
303  Main->setName("__original_main");
304  auto *MainWrapper =
305  cast<Function>(CallMain->getCalledValue()->stripPointerCasts());
306  delete CallMain;
307  if (Main->isDeclaration()) {
308  // The wrapper is not needed in this case as we don't need to export
309  // it to anyone else.
310  MainWrapper->eraseFromParent();
311  } else {
312  // Otherwise give the wrapper the same linkage as the original main
313  // function, so that it can be called from the same places.
314  MainWrapper->setName("main");
315  MainWrapper->setLinkage(Main->getLinkage());
316  MainWrapper->setVisibility(Main->getVisibility());
317  }
318  }
319 
320  return true;
321 }
bool isVarArg() const
isVarArg - Return true if this function takes a variable number of arguments.
Definition: Function.h:176
uint64_t CallInst * C
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks &#39;this&#39; from the containing basic block and deletes it.
Definition: Instruction.cpp:67
A parsed version of the target data layout string in and methods for querying it. ...
Definition: DataLayout.h:110
iterator_range< use_iterator > uses()
Definition: Value.h:354
This class represents an incoming formal argument to a Function.
Definition: Argument.h:29
This class represents lattice values for constants.
Definition: AllocatorList.h:23
A Module instance is used to store all the information related to an LLVM module. ...
Definition: Module.h:64
static CallInst * Create(FunctionType *Ty, Value *F, const Twine &NameStr="", Instruction *InsertBefore=nullptr)
This class represents a function call, abstracting a target machine&#39;s calling convention.
static PointerType * get(Type *ElementType, unsigned AddressSpace)
This constructs a pointer to an object of the specified type in a numbered address space...
Definition: Type.cpp:629
Like Internal, but omit from symbol table.
Definition: GlobalValue.h:56
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:705
static CastInst * CreateBitOrPointerCast(Value *S, Type *Ty, const Twine &Name="", Instruction *InsertBefore=nullptr)
Create a BitCast, a PtrToInt, or an IntToPTr cast instruction.
const Use & getOperandUse(unsigned i) const
Definition: User.h:182
arg_iterator arg_end()
Definition: Function.h:679
F(f)
param_iterator param_end() const
Definition: DerivedTypes.h:128
This file contains the entry points for global functions defined in the LLVM WebAssembly back-end...
static bool isBitOrNoopPointerCastable(Type *SrcTy, Type *DestTy, const DataLayout &DL)
Check whether a bitcast, inttoptr, or ptrtoint cast between these types is valid and a no-op...
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:343
static ReturnInst * Create(LLVMContext &C, Value *retVal=nullptr, Instruction *InsertBefore=nullptr)
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:221
amdgpu aa AMDGPU Address space based Alias Analysis Wrapper
LLVMContext & getContext() const
Get the global data context.
Definition: Module.h:243
A Use represents the edge between a Value definition and its users.
Definition: Use.h:55
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: APFloat.h:41
static void findUses(Value *V, Function &F, SmallVectorImpl< std::pair< Use *, Function *>> &Uses, SmallPtrSetImpl< Constant *> &ConstantBCs)
void setName(const Twine &Name)
Change the name of the value.
Definition: Value.cpp:285
Class to represent function types.
Definition: DerivedTypes.h:102
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:91
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:244
INITIALIZE_PASS(FixFunctionBitcasts, DEBUG_TYPE, "Fix mismatching bitcasts for WebAssembly", false, false) ModulePass *llvm
bool isVarArg() const
Definition: DerivedTypes.h:122
LinkageTypes getLinkage() const
Definition: GlobalValue.h:450
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:429
static Constant * getBitCast(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:1772
bool isVoidTy() const
Return true if this is &#39;void&#39;.
Definition: Type.h:140
static Function * Create(FunctionType *Ty, LinkageTypes Linkage, unsigned AddrSpace, const Twine &N="", Module *M=nullptr)
Definition: Function.h:135
VisibilityTypes getVisibility() const
Definition: GlobalValue.h:232
Value * getCalledValue() const
Definition: InstrTypes.h:1194
LLVM Basic Block Representation.
Definition: BasicBlock.h:57
The instances of the Type class are immutable: once they are created, they are never changed...
Definition: Type.h:45
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:68
This function has undefined behavior.
This file contains the declarations for the subclasses of Constant, which represent the different fla...
unsigned getNumParams() const
Return the number of fixed parameters this function type requires.
Definition: DerivedTypes.h:138
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:370
param_iterator param_begin() const
Definition: DerivedTypes.h:127
Represent the analysis usage information of a pass.
static Function * createWrapper(Function *F, FunctionType *Ty)
static FunctionType * get(Type *Result, ArrayRef< Type *> Params, bool isVarArg)
This static method is the primary way of constructing a FunctionType.
Definition: Type.cpp:296
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition: BasicBlock.h:99
arg_iterator arg_begin()
Definition: Function.h:670
static UndefValue * get(Type *T)
Static factory methods - Return an &#39;undef&#39; object of the specified type.
Definition: Constants.cpp:1414
const Value * stripPointerCasts() const
Strip off pointer casts, all-zero GEPs, and aliases.
Definition: Value.cpp:529
size_t size() const
Definition: SmallVector.h:52
static PointerType * getInt8PtrTy(LLVMContext &C, unsigned AS=0)
Definition: Type.cpp:219
ModulePass * createWebAssemblyFixFunctionBitcasts()
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements...
Definition: SmallPtrSet.h:417
This is a &#39;vector&#39; (really, a variable-sized array), optimized for the case when the array is small...
Definition: SmallVector.h:839
Module.h This file contains the declarations for the Module class.
Type::subtype_iterator param_iterator
Definition: DerivedTypes.h:125
Type * getReturnType() const
Definition: DerivedTypes.h:123
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:285
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:132
FunctionType * getFunctionType() const
Returns the FunctionType for me.
Definition: Function.h:163
amdgpu Simplify well known AMD library false FunctionCallee Callee
static IntegerType * getInt32Ty(LLVMContext &C)
Definition: Type.cpp:175
StringRef getName() const
Return a constant reference to the value&#39;s name.
Definition: Value.cpp:214
static bool shouldFixMainFunction(FunctionType *FuncTy, FunctionType *MainTy)
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:224
LLVM_NODISCARD std::enable_if<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
Definition: Casting.h:322
void eraseFromParent()
eraseFromParent - This method unlinks &#39;this&#39; from the containing module and deletes it...
Definition: Function.cpp:213
bool isDeclaration() const
Return true if the primary definition of this global value is outside of the current translation unit...
Definition: Globals.cpp:205
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:565
LLVM Value Representation.
Definition: Value.h:72
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:48
#define LLVM_DEBUG(X)
Definition: Debug.h:122
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
PointerType * getType() const
Global values are always pointers.
Definition: GlobalValue.h:273
bool isStructTy() const
True if this is an instance of StructType.
Definition: Type.h:217