LLVM  7.0.0svn
WebAssemblyFixFunctionBitcasts.cpp
Go to the documentation of this file.
1 //===-- WebAssemblyFixFunctionBitcasts.cpp - Fix function bitcasts --------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 ///
10 /// \file
11 /// Fix bitcasted functions.
12 ///
13 /// WebAssembly requires caller and callee signatures to match, however in LLVM,
14 /// some amount of slop is vaguely permitted. Detect mismatch by looking for
15 /// bitcasts of functions and rewrite them to use wrapper functions instead.
16 ///
17 /// This doesn't catch all cases, such as when a function's address is taken in
18 /// one place and casted in another, but it works for many common cases.
19 ///
20 /// Note that LLVM already optimizes away function bitcasts in common cases by
21 /// dropping arguments as needed, so this pass only ends up getting used in less
22 /// common cases.
23 ///
24 //===----------------------------------------------------------------------===//
25 
26 #include "WebAssembly.h"
27 #include "llvm/IR/CallSite.h"
28 #include "llvm/IR/Constants.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/Module.h"
31 #include "llvm/IR/Operator.h"
32 #include "llvm/Pass.h"
33 #include "llvm/Support/Debug.h"
35 using namespace llvm;
36 
37 #define DEBUG_TYPE "wasm-fix-function-bitcasts"
38 
40  "wasm-temporary-workarounds",
41  cl::desc("Apply certain temporary workarounds"),
42  cl::init(true), cl::Hidden);
43 
44 namespace {
45 class FixFunctionBitcasts final : public ModulePass {
46  StringRef getPassName() const override {
47  return "WebAssembly Fix Function Bitcasts";
48  }
49 
50  void getAnalysisUsage(AnalysisUsage &AU) const override {
51  AU.setPreservesCFG();
53  }
54 
55  bool runOnModule(Module &M) override;
56 
57 public:
58  static char ID;
59  FixFunctionBitcasts() : ModulePass(ID) {}
60 };
61 } // End anonymous namespace
62 
64 INITIALIZE_PASS(FixFunctionBitcasts, DEBUG_TYPE,
65  "Fix mismatching bitcasts for WebAssembly", false, false)
66 
68  return new FixFunctionBitcasts();
69 }
70 
71 // Recursively descend the def-use lists from V to find non-bitcast users of
72 // bitcasts of V.
73 static void FindUses(Value *V, Function &F,
74  SmallVectorImpl<std::pair<Use *, Function *>> &Uses,
75  SmallPtrSetImpl<Constant *> &ConstantBCs) {
76  for (Use &U : V->uses()) {
77  if (BitCastOperator *BC = dyn_cast<BitCastOperator>(U.getUser()))
78  FindUses(BC, F, Uses, ConstantBCs);
79  else if (U.get()->getType() != F.getType()) {
80  CallSite CS(U.getUser());
81  if (!CS)
82  // Skip uses that aren't immediately called
83  continue;
84  Value *Callee = CS.getCalledValue();
85  if (Callee != V)
86  // Skip calls where the function isn't the callee
87  continue;
88  if (isa<Constant>(U.get())) {
89  // Only add constant bitcasts to the list once; they get RAUW'd
90  auto c = ConstantBCs.insert(cast<Constant>(U.get()));
91  if (!c.second)
92  continue;
93  }
94  Uses.push_back(std::make_pair(&U, &F));
95  }
96  }
97 }
98 
99 // Create a wrapper function with type Ty that calls F (which may have a
100 // different type). Attempt to support common bitcasted function idioms:
101 // - Call with more arguments than needed: arguments are dropped
102 // - Call with fewer arguments than needed: arguments are filled in with undef
103 // - Return value is not needed: drop it
104 // - Return value needed but not present: supply an undef
105 //
106 // For now, return nullptr without creating a wrapper if the wrapper cannot
107 // be generated due to incompatible types.
109  Module *M = F->getParent();
110 
111  Function *Wrapper =
112  Function::Create(Ty, Function::PrivateLinkage, "bitcast", M);
113  BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);
114 
115  // Determine what arguments to pass.
117  Function::arg_iterator AI = Wrapper->arg_begin();
118  Function::arg_iterator AE = Wrapper->arg_end();
121  for (; AI != AE && PI != PE; ++AI, ++PI) {
122  if (AI->getType() != *PI) {
123  Wrapper->eraseFromParent();
124  return nullptr;
125  }
126  Args.push_back(&*AI);
127  }
128  for (; PI != PE; ++PI)
129  Args.push_back(UndefValue::get(*PI));
130  if (F->isVarArg())
131  for (; AI != AE; ++AI)
132  Args.push_back(&*AI);
133 
134  CallInst *Call = CallInst::Create(F, Args, "", BB);
135 
136  // Determine what value to return.
137  if (Ty->getReturnType()->isVoidTy())
138  ReturnInst::Create(M->getContext(), BB);
139  else if (F->getFunctionType()->getReturnType()->isVoidTy())
141  BB);
142  else if (F->getFunctionType()->getReturnType() == Ty->getReturnType())
143  ReturnInst::Create(M->getContext(), Call, BB);
144  else {
145  Wrapper->eraseFromParent();
146  return nullptr;
147  }
148 
149  return Wrapper;
150 }
151 
152 bool FixFunctionBitcasts::runOnModule(Module &M) {
153  Function *Main = nullptr;
154  CallInst *CallMain = nullptr;
156  SmallPtrSet<Constant *, 2> ConstantBCs;
157 
158  // Collect all the places that need wrappers.
159  for (Function &F : M) {
160  FindUses(&F, F, Uses, ConstantBCs);
161 
162  // If we have a "main" function, and its type isn't
163  // "int main(int argc, char *argv[])", create an artificial call with it
164  // bitcasted to that type so that we generate a wrapper for it, so that
165  // the C runtime can call it.
166  if (!TemporaryWorkarounds && !F.isDeclaration() && F.getName() == "main") {
167  Main = &F;
168  LLVMContext &C = M.getContext();
169  Type *MainArgTys[] = {
172  };
173  FunctionType *MainTy = FunctionType::get(Type::getInt32Ty(C), MainArgTys,
174  /*isVarArg=*/false);
175  if (F.getFunctionType() != MainTy) {
176  Value *Args[] = {
177  UndefValue::get(MainArgTys[0]),
178  UndefValue::get(MainArgTys[1])
179  };
180  Value *Casted = ConstantExpr::getBitCast(Main,
181  PointerType::get(MainTy, 0));
182  CallMain = CallInst::Create(Casted, Args, "call_main");
183  Use *UseMain = &CallMain->getOperandUse(2);
184  Uses.push_back(std::make_pair(UseMain, &F));
185  }
186  }
187  }
188 
190 
191  for (auto &UseFunc : Uses) {
192  Use *U = UseFunc.first;
193  Function *F = UseFunc.second;
194  PointerType *PTy = cast<PointerType>(U->get()->getType());
196 
197  // If the function is casted to something like i8* as a "generic pointer"
198  // to be later casted to something else, we can't generate a wrapper for it.
199  // Just ignore such casts for now.
200  if (!Ty)
201  continue;
202 
203  // Bitcasted vararg functions occur in Emscripten's implementation of
204  // EM_ASM, so suppress wrappers for them for now.
205  if (TemporaryWorkarounds && (Ty->isVarArg() || F->isVarArg()))
206  continue;
207 
208  auto Pair = Wrappers.insert(std::make_pair(std::make_pair(F, Ty), nullptr));
209  if (Pair.second)
210  Pair.first->second = CreateWrapper(F, Ty);
211 
212  Function *Wrapper = Pair.first->second;
213  if (!Wrapper)
214  continue;
215 
216  if (isa<Constant>(U->get()))
217  U->get()->replaceAllUsesWith(Wrapper);
218  else
219  U->set(Wrapper);
220  }
221 
222  // If we created a wrapper for main, rename the wrapper so that it's the
223  // one that gets called from startup.
224  if (CallMain) {
225  Main->setName("__original_main");
226  Function *MainWrapper =
227  cast<Function>(CallMain->getCalledValue()->stripPointerCasts());
228  MainWrapper->setName("main");
229  MainWrapper->setLinkage(Main->getLinkage());
230  MainWrapper->setVisibility(Main->getVisibility());
233  delete CallMain;
234  }
235 
236  return true;
237 }
void setVisibility(VisibilityTypes V)
Definition: GlobalValue.h:238
bool isVarArg() const
isVarArg - Return true if this function takes a variable number of arguments.
Definition: Function.h:163
uint64_t CallInst * C
iterator_range< use_iterator > uses()
Definition: Value.h:354
This class represents an incoming formal argument to a Function.
Definition: Argument.h:30
Compute iterated dominance frontiers using a linear time algorithm.
Definition: AllocatorList.h:24
A Module instance is used to store all the information related to an LLVM module. ...
Definition: Module.h:63
This class represents a function call, abstracting a target machine&#39;s calling convention.
static PointerType * get(Type *ElementType, unsigned AddressSpace)
This constructs a pointer to an object of the specified type in a numbered address space...
Definition: Type.cpp:617
Like Internal, but omit from symbol table.
Definition: GlobalValue.h:57
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:714
const Use & getOperandUse(unsigned i) const
Definition: User.h:183
arg_iterator arg_end()
Definition: Function.h:666
F(f)
static CallInst * Create(Value *Func, ArrayRef< Value *> Args, ArrayRef< OperandBundleDef > Bundles=None, const Twine &NameStr="", Instruction *InsertBefore=nullptr)
param_iterator param_end() const
Definition: DerivedTypes.h:129
This file contains the entry points for global functions defined in the LLVM WebAssembly back-end...
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:344
static ReturnInst * Create(LLVMContext &C, Value *retVal=nullptr, Instruction *InsertBefore=nullptr)
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:191
LLVMContext & getContext() const
Get the global data context.
Definition: Module.h:242
A Use represents the edge between a Value definition and its users.
Definition: Use.h:56
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: APFloat.h:42
void setName(const Twine &Name)
Change the name of the value.
Definition: Value.cpp:295
Class to represent function types.
Definition: DerivedTypes.h:103
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:92
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:245
INITIALIZE_PASS(FixFunctionBitcasts, DEBUG_TYPE, "Fix mismatching bitcasts for WebAssembly", false, false) ModulePass *llvm
static cl::opt< bool > TemporaryWorkarounds("wasm-temporary-workarounds", cl::desc("Apply certain temporary workarounds"), cl::init(true), cl::Hidden)
LinkageTypes getLinkage() const
Definition: GlobalValue.h:450
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:439
amdgpu Simplify well known AMD library false Value * Callee
Class to represent pointers.
Definition: DerivedTypes.h:467
static Function * CreateWrapper(Function *F, FunctionType *Ty)
static Constant * getBitCast(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:1750
bool isVoidTy() const
Return true if this is &#39;void&#39;.
Definition: Type.h:141
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:410
VisibilityTypes getVisibility() const
Definition: GlobalValue.h:232
LLVM Basic Block Representation.
Definition: BasicBlock.h:59
The instances of the Type class are immutable: once they are created, they are never changed...
Definition: Type.h:46
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:69
This file contains the declarations for the subclasses of Constant, which represent the different fla...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:371
param_iterator param_begin() const
Definition: DerivedTypes.h:128
Represent the analysis usage information of a pass.
static FunctionType * get(Type *Result, ArrayRef< Type *> Params, bool isVarArg)
This static method is the primary way of constructing a FunctionType.
Definition: Type.cpp:297
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition: BasicBlock.h:101
arg_iterator arg_begin()
Definition: Function.h:657
static UndefValue * get(Type *T)
Static factory methods - Return an &#39;undef&#39; object of the specified type.
Definition: Constants.cpp:1392
const Value * stripPointerCasts() const
Strip off pointer casts, all-zero GEPs, and aliases.
Definition: Value.cpp:539
static PointerType * getInt8PtrTy(LLVMContext &C, unsigned AS=0)
Definition: Type.cpp:220
ModulePass * createWebAssemblyFixFunctionBitcasts()
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements...
Definition: SmallPtrSet.h:418
This is a &#39;vector&#39; (really, a variable-sized array), optimized for the case when the array is small...
Definition: SmallVector.h:861
Module.h This file contains the declarations for the Module class.
Type::subtype_iterator param_iterator
Definition: DerivedTypes.h:126
Type * getReturnType() const
Definition: DerivedTypes.h:124
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:286
void setLinkage(LinkageTypes LT)
Definition: GlobalValue.h:444
FunctionType * getFunctionType() const
Returns the FunctionType for me.
Definition: Function.h:150
static IntegerType * getInt32Ty(LLVMContext &C)
Definition: Type.cpp:176
static void FindUses(Value *V, Function &F, SmallVectorImpl< std::pair< Use *, Function *>> &Uses, SmallPtrSetImpl< Constant *> &ConstantBCs)
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:225
LLVM_NODISCARD std::enable_if<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
Definition: Casting.h:323
void eraseFromParent()
eraseFromParent - This method unlinks &#39;this&#39; from the containing module and deletes it...
Definition: Function.cpp:210
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:565
LLVM Value Representation.
Definition: Value.h:73
const Value * getCalledValue() const
Get a pointer to the function that is invoked by this instruction.
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:49
static Function * Create(FunctionType *Ty, LinkageTypes Linkage, const Twine &N="", Module *M=nullptr)
Definition: Function.h:136
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
Type * getElementType() const
Definition: DerivedTypes.h:486
PointerType * getType() const
Global values are always pointers.
Definition: GlobalValue.h:273