LLVM  8.0.0svn
WebAssemblyFixFunctionBitcasts.cpp
Go to the documentation of this file.
1 //===-- WebAssemblyFixFunctionBitcasts.cpp - Fix function bitcasts --------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 ///
10 /// \file
11 /// Fix bitcasted functions.
12 ///
13 /// WebAssembly requires caller and callee signatures to match, however in LLVM,
14 /// some amount of slop is vaguely permitted. Detect mismatch by looking for
15 /// bitcasts of functions and rewrite them to use wrapper functions instead.
16 ///
17 /// This doesn't catch all cases, such as when a function's address is taken in
18 /// one place and casted in another, but it works for many common cases.
19 ///
20 /// Note that LLVM already optimizes away function bitcasts in common cases by
21 /// dropping arguments as needed, so this pass only ends up getting used in less
22 /// common cases.
23 ///
24 //===----------------------------------------------------------------------===//
25 
26 #include "WebAssembly.h"
27 #include "llvm/IR/CallSite.h"
28 #include "llvm/IR/Constants.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/Module.h"
31 #include "llvm/IR/Operator.h"
32 #include "llvm/Pass.h"
33 #include "llvm/Support/Debug.h"
35 using namespace llvm;
36 
37 #define DEBUG_TYPE "wasm-fix-function-bitcasts"
38 
39 static cl::opt<bool>
40  TemporaryWorkarounds("wasm-temporary-workarounds",
41  cl::desc("Apply certain temporary workarounds"),
42  cl::init(true), cl::Hidden);
43 
44 namespace {
45 class FixFunctionBitcasts final : public ModulePass {
46  StringRef getPassName() const override {
47  return "WebAssembly Fix Function Bitcasts";
48  }
49 
50  void getAnalysisUsage(AnalysisUsage &AU) const override {
51  AU.setPreservesCFG();
53  }
54 
55  bool runOnModule(Module &M) override;
56 
57 public:
58  static char ID;
59  FixFunctionBitcasts() : ModulePass(ID) {}
60 };
61 } // End anonymous namespace
62 
64 INITIALIZE_PASS(FixFunctionBitcasts, DEBUG_TYPE,
65  "Fix mismatching bitcasts for WebAssembly", false, false)
66 
68  return new FixFunctionBitcasts();
69 }
70 
71 // Recursively descend the def-use lists from V to find non-bitcast users of
72 // bitcasts of V.
73 static void FindUses(Value *V, Function &F,
74  SmallVectorImpl<std::pair<Use *, Function *>> &Uses,
75  SmallPtrSetImpl<Constant *> &ConstantBCs) {
76  for (Use &U : V->uses()) {
77  if (BitCastOperator *BC = dyn_cast<BitCastOperator>(U.getUser()))
78  FindUses(BC, F, Uses, ConstantBCs);
79  else if (U.get()->getType() != F.getType()) {
80  CallSite CS(U.getUser());
81  if (!CS)
82  // Skip uses that aren't immediately called
83  continue;
84  Value *Callee = CS.getCalledValue();
85  if (Callee != V)
86  // Skip calls where the function isn't the callee
87  continue;
88  if (isa<Constant>(U.get())) {
89  // Only add constant bitcasts to the list once; they get RAUW'd
90  auto c = ConstantBCs.insert(cast<Constant>(U.get()));
91  if (!c.second)
92  continue;
93  }
94  Uses.push_back(std::make_pair(&U, &F));
95  }
96  }
97 }
98 
99 // Create a wrapper function with type Ty that calls F (which may have a
100 // different type). Attempt to support common bitcasted function idioms:
101 // - Call with more arguments than needed: arguments are dropped
102 // - Call with fewer arguments than needed: arguments are filled in with undef
103 // - Return value is not needed: drop it
104 // - Return value needed but not present: supply an undef
105 //
106 // If the all the argument types of trivially castable to one another (i.e.
107 // I32 vs pointer type) then we don't create a wrapper at all (return nullptr
108 // instead).
109 //
110 // If there is a type mismatch that we know would result in an invalid wasm
111 // module then generate wrapper that contains unreachable (i.e. abort at
112 // runtime). Such programs are deep into undefined behaviour territory,
113 // but we choose to fail at runtime rather than generate and invalid module
114 // or fail at compiler time. The reason we delay the error is that we want
115 // to support the CMake which expects to be able to compile and link programs
116 // that refer to functions with entirely incorrect signatures (this is how
117 // CMake detects the existence of a function in a toolchain).
118 //
119 // For bitcasts that involve struct types we don't know at this stage if they
120 // would be equivalent at the wasm level and so we can't know if we need to
121 // generate a wrapper.
123  Module *M = F->getParent();
124 
126  F->getName() + "_bitcast", M);
127  BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);
128  const DataLayout &DL = BB->getModule()->getDataLayout();
129 
130  // Determine what arguments to pass.
132  Function::arg_iterator AI = Wrapper->arg_begin();
133  Function::arg_iterator AE = Wrapper->arg_end();
136  bool TypeMismatch = false;
137  bool WrapperNeeded = false;
138 
139  Type *ExpectedRtnType = F->getFunctionType()->getReturnType();
140  Type *RtnType = Ty->getReturnType();
141 
142  if ((F->getFunctionType()->getNumParams() != Ty->getNumParams()) ||
143  (F->getFunctionType()->isVarArg() != Ty->isVarArg()) ||
144  (ExpectedRtnType != RtnType))
145  WrapperNeeded = true;
146 
147  for (; AI != AE && PI != PE; ++AI, ++PI) {
148  Type *ArgType = AI->getType();
149  Type *ParamType = *PI;
150 
151  if (ArgType == ParamType) {
152  Args.push_back(&*AI);
153  } else {
154  if (CastInst::isBitOrNoopPointerCastable(ArgType, ParamType, DL)) {
155  Instruction *PtrCast =
156  CastInst::CreateBitOrPointerCast(AI, ParamType, "cast");
157  BB->getInstList().push_back(PtrCast);
158  Args.push_back(PtrCast);
159  } else if (ArgType->isStructTy() || ParamType->isStructTy()) {
160  LLVM_DEBUG(dbgs() << "CreateWrapper: struct param type in bitcast: "
161  << F->getName() << "\n");
162  WrapperNeeded = false;
163  } else {
164  LLVM_DEBUG(dbgs() << "CreateWrapper: arg type mismatch calling: "
165  << F->getName() << "\n");
166  LLVM_DEBUG(dbgs() << "Arg[" << Args.size() << "] Expected: "
167  << *ParamType << " Got: " << *ArgType << "\n");
168  TypeMismatch = true;
169  break;
170  }
171  }
172  }
173 
174  if (WrapperNeeded && !TypeMismatch) {
175  for (; PI != PE; ++PI)
176  Args.push_back(UndefValue::get(*PI));
177  if (F->isVarArg())
178  for (; AI != AE; ++AI)
179  Args.push_back(&*AI);
180 
181  CallInst *Call = CallInst::Create(F, Args, "", BB);
182 
183  Type *ExpectedRtnType = F->getFunctionType()->getReturnType();
184  Type *RtnType = Ty->getReturnType();
185  // Determine what value to return.
186  if (RtnType->isVoidTy()) {
187  ReturnInst::Create(M->getContext(), BB);
188  } else if (ExpectedRtnType->isVoidTy()) {
189  LLVM_DEBUG(dbgs() << "Creating dummy return: " << *RtnType << "\n");
190  ReturnInst::Create(M->getContext(), UndefValue::get(RtnType), BB);
191  } else if (RtnType == ExpectedRtnType) {
192  ReturnInst::Create(M->getContext(), Call, BB);
193  } else if (CastInst::isBitOrNoopPointerCastable(ExpectedRtnType, RtnType,
194  DL)) {
195  Instruction *Cast =
196  CastInst::CreateBitOrPointerCast(Call, RtnType, "cast");
197  BB->getInstList().push_back(Cast);
198  ReturnInst::Create(M->getContext(), Cast, BB);
199  } else if (RtnType->isStructTy() || ExpectedRtnType->isStructTy()) {
200  LLVM_DEBUG(dbgs() << "CreateWrapper: struct return type in bitcast: "
201  << F->getName() << "\n");
202  WrapperNeeded = false;
203  } else {
204  LLVM_DEBUG(dbgs() << "CreateWrapper: return type mismatch calling: "
205  << F->getName() << "\n");
206  LLVM_DEBUG(dbgs() << "Expected: " << *ExpectedRtnType
207  << " Got: " << *RtnType << "\n");
208  TypeMismatch = true;
209  }
210  }
211 
212  if (TypeMismatch) {
213  // Create a new wrapper that simply contains `unreachable`.
214  Wrapper->eraseFromParent();
216  F->getName() + "_bitcast_invalid", M);
217  BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);
218  new UnreachableInst(M->getContext(), BB);
219  Wrapper->setName(F->getName() + "_bitcast_invalid");
220  } else if (!WrapperNeeded) {
221  LLVM_DEBUG(dbgs() << "CreateWrapper: no wrapper needed: " << F->getName()
222  << "\n");
223  Wrapper->eraseFromParent();
224  return nullptr;
225  }
226  LLVM_DEBUG(dbgs() << "CreateWrapper: " << F->getName() << "\n");
227  return Wrapper;
228 }
229 
230 bool FixFunctionBitcasts::runOnModule(Module &M) {
231  Function *Main = nullptr;
232  CallInst *CallMain = nullptr;
234  SmallPtrSet<Constant *, 2> ConstantBCs;
235 
236  // Collect all the places that need wrappers.
237  for (Function &F : M) {
238  FindUses(&F, F, Uses, ConstantBCs);
239 
240  // If we have a "main" function, and its type isn't
241  // "int main(int argc, char *argv[])", create an artificial call with it
242  // bitcasted to that type so that we generate a wrapper for it, so that
243  // the C runtime can call it.
244  if (!TemporaryWorkarounds && !F.isDeclaration() && F.getName() == "main") {
245  Main = &F;
246  LLVMContext &C = M.getContext();
247  Type *MainArgTys[] = {Type::getInt32Ty(C),
249  FunctionType *MainTy = FunctionType::get(Type::getInt32Ty(C), MainArgTys,
250  /*isVarArg=*/false);
251  if (F.getFunctionType() != MainTy) {
252  LLVM_DEBUG(dbgs() << "Found `main` function with incorrect type: "
253  << *F.getFunctionType() << "\n");
254  Value *Args[] = {UndefValue::get(MainArgTys[0]),
255  UndefValue::get(MainArgTys[1])};
256  Value *Casted =
257  ConstantExpr::getBitCast(Main, PointerType::get(MainTy, 0));
258  CallMain = CallInst::Create(Casted, Args, "call_main");
259  Use *UseMain = &CallMain->getOperandUse(2);
260  Uses.push_back(std::make_pair(UseMain, &F));
261  }
262  }
263  }
264 
266 
267  for (auto &UseFunc : Uses) {
268  Use *U = UseFunc.first;
269  Function *F = UseFunc.second;
270  PointerType *PTy = cast<PointerType>(U->get()->getType());
272 
273  // If the function is casted to something like i8* as a "generic pointer"
274  // to be later casted to something else, we can't generate a wrapper for it.
275  // Just ignore such casts for now.
276  if (!Ty)
277  continue;
278 
279  auto Pair = Wrappers.insert(std::make_pair(std::make_pair(F, Ty), nullptr));
280  if (Pair.second)
281  Pair.first->second = CreateWrapper(F, Ty);
282 
283  Function *Wrapper = Pair.first->second;
284  if (!Wrapper)
285  continue;
286 
287  if (isa<Constant>(U->get()))
288  U->get()->replaceAllUsesWith(Wrapper);
289  else
290  U->set(Wrapper);
291  }
292 
293  // If we created a wrapper for main, rename the wrapper so that it's the
294  // one that gets called from startup.
295  if (CallMain) {
296  Main->setName("__original_main");
297  Function *MainWrapper =
298  cast<Function>(CallMain->getCalledValue()->stripPointerCasts());
299  MainWrapper->setName("main");
300  MainWrapper->setLinkage(Main->getLinkage());
301  MainWrapper->setVisibility(Main->getVisibility());
304  delete CallMain;
305  }
306 
307  return true;
308 }
void setVisibility(VisibilityTypes V)
Definition: GlobalValue.h:239
bool isVarArg() const
isVarArg - Return true if this function takes a variable number of arguments.
Definition: Function.h:177
uint64_t CallInst * C
A parsed version of the target data layout string in and methods for querying it. ...
Definition: DataLayout.h:111
iterator_range< use_iterator > uses()
Definition: Value.h:355
This class represents an incoming formal argument to a Function.
Definition: Argument.h:30
Compute iterated dominance frontiers using a linear time algorithm.
Definition: AllocatorList.h:24
A Module instance is used to store all the information related to an LLVM module. ...
Definition: Module.h:64
This class represents a function call, abstracting a target machine&#39;s calling convention.
static PointerType * get(Type *ElementType, unsigned AddressSpace)
This constructs a pointer to an object of the specified type in a numbered address space...
Definition: Type.cpp:617
Like Internal, but omit from symbol table.
Definition: GlobalValue.h:57
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:714
static CastInst * CreateBitOrPointerCast(Value *S, Type *Ty, const Twine &Name="", Instruction *InsertBefore=nullptr)
Create a BitCast, a PtrToInt, or an IntToPTr cast instruction.
const Use & getOperandUse(unsigned i) const
Definition: User.h:183
arg_iterator arg_end()
Definition: Function.h:680
F(f)
static CallInst * Create(Value *Func, ArrayRef< Value *> Args, ArrayRef< OperandBundleDef > Bundles=None, const Twine &NameStr="", Instruction *InsertBefore=nullptr)
param_iterator param_end() const
Definition: DerivedTypes.h:129
This file contains the entry points for global functions defined in the LLVM WebAssembly back-end...
static bool isBitOrNoopPointerCastable(Type *SrcTy, Type *DestTy, const DataLayout &DL)
Check whether a bitcast, inttoptr, or ptrtoint cast between these types is valid and a no-op...
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:344
static ReturnInst * Create(LLVMContext &C, Value *retVal=nullptr, Instruction *InsertBefore=nullptr)
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:196
LLVMContext & getContext() const
Get the global data context.
Definition: Module.h:243
A Use represents the edge between a Value definition and its users.
Definition: Use.h:56
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: APFloat.h:42
void setName(const Twine &Name)
Change the name of the value.
Definition: Value.cpp:295
Class to represent function types.
Definition: DerivedTypes.h:103
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:92
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:245
INITIALIZE_PASS(FixFunctionBitcasts, DEBUG_TYPE, "Fix mismatching bitcasts for WebAssembly", false, false) ModulePass *llvm
bool isVarArg() const
Definition: DerivedTypes.h:123
static cl::opt< bool > TemporaryWorkarounds("wasm-temporary-workarounds", cl::desc("Apply certain temporary workarounds"), cl::init(true), cl::Hidden)
LinkageTypes getLinkage() const
Definition: GlobalValue.h:451
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:439
amdgpu Simplify well known AMD library false Value * Callee
Class to represent pointers.
Definition: DerivedTypes.h:467
static Function * CreateWrapper(Function *F, FunctionType *Ty)
static Constant * getBitCast(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:1750
bool isVoidTy() const
Return true if this is &#39;void&#39;.
Definition: Type.h:141
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:410
static Function * Create(FunctionType *Ty, LinkageTypes Linkage, unsigned AddrSpace, const Twine &N="", Module *M=nullptr)
Definition: Function.h:136
VisibilityTypes getVisibility() const
Definition: GlobalValue.h:233
LLVM Basic Block Representation.
Definition: BasicBlock.h:58
The instances of the Type class are immutable: once they are created, they are never changed...
Definition: Type.h:46
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:69
This function has undefined behavior.
This file contains the declarations for the subclasses of Constant, which represent the different fla...
unsigned getNumParams() const
Return the number of fixed parameters this function type requires.
Definition: DerivedTypes.h:139
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:371
param_iterator param_begin() const
Definition: DerivedTypes.h:128
Represent the analysis usage information of a pass.
static FunctionType * get(Type *Result, ArrayRef< Type *> Params, bool isVarArg)
This static method is the primary way of constructing a FunctionType.
Definition: Type.cpp:297
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition: BasicBlock.h:100
arg_iterator arg_begin()
Definition: Function.h:671
static UndefValue * get(Type *T)
Static factory methods - Return an &#39;undef&#39; object of the specified type.
Definition: Constants.cpp:1392
const Value * stripPointerCasts() const
Strip off pointer casts, all-zero GEPs, and aliases.
Definition: Value.cpp:539
size_t size() const
Definition: SmallVector.h:53
static PointerType * getInt8PtrTy(LLVMContext &C, unsigned AS=0)
Definition: Type.cpp:220
ModulePass * createWebAssemblyFixFunctionBitcasts()
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements...
Definition: SmallPtrSet.h:418
This is a &#39;vector&#39; (really, a variable-sized array), optimized for the case when the array is small...
Definition: SmallVector.h:847
Module.h This file contains the declarations for the Module class.
Type::subtype_iterator param_iterator
Definition: DerivedTypes.h:126
Type * getReturnType() const
Definition: DerivedTypes.h:124
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:286
void setLinkage(LinkageTypes LT)
Definition: GlobalValue.h:445
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:133
FunctionType * getFunctionType() const
Returns the FunctionType for me.
Definition: Function.h:164
static IntegerType * getInt32Ty(LLVMContext &C)
Definition: Type.cpp:176
static void FindUses(Value *V, Function &F, SmallVectorImpl< std::pair< Use *, Function *>> &Uses, SmallPtrSetImpl< Constant *> &ConstantBCs)
StringRef getName() const
Return a constant reference to the value&#39;s name.
Definition: Value.cpp:224
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:225
LLVM_NODISCARD std::enable_if<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
Definition: Casting.h:323
void eraseFromParent()
eraseFromParent - This method unlinks &#39;this&#39; from the containing module and deletes it...
Definition: Function.cpp:215
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:566
LLVM Value Representation.
Definition: Value.h:73
const Value * getCalledValue() const
Get a pointer to the function that is invoked by this instruction.
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:49
#define LLVM_DEBUG(X)
Definition: Debug.h:123
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
Type * getElementType() const
Definition: DerivedTypes.h:486
PointerType * getType() const
Global values are always pointers.
Definition: GlobalValue.h:274
bool isStructTy() const
True if this is an instance of StructType.
Definition: Type.h:218