LLVM  4.0.0
WebAssemblyFixFunctionBitcasts.cpp
Go to the documentation of this file.
1 //===-- WebAssemblyFixFunctionBitcasts.cpp - Fix function bitcasts --------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 ///
10 /// \file
11 /// \brief Fix bitcasted functions.
12 ///
13 /// WebAssembly requires caller and callee signatures to match, however in LLVM,
14 /// some amount of slop is vaguely permitted. Detect mismatch by looking for
15 /// bitcasts of functions and rewrite them to use wrapper functions instead.
16 ///
17 /// This doesn't catch all cases, such as when a function's address is taken in
18 /// one place and casted in another, but it works for many common cases.
19 ///
20 /// Note that LLVM already optimizes away function bitcasts in common cases by
21 /// dropping arguments as needed, so this pass only ends up getting used in less
22 /// common cases.
23 ///
24 //===----------------------------------------------------------------------===//
25 
26 #include "WebAssembly.h"
27 #include "llvm/IR/Constants.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/Module.h"
30 #include "llvm/IR/Operator.h"
31 #include "llvm/Pass.h"
32 #include "llvm/Support/Debug.h"
34 using namespace llvm;
35 
36 #define DEBUG_TYPE "wasm-fix-function-bitcasts"
37 
38 namespace {
39 class FixFunctionBitcasts final : public ModulePass {
40  StringRef getPassName() const override {
41  return "WebAssembly Fix Function Bitcasts";
42  }
43 
44  void getAnalysisUsage(AnalysisUsage &AU) const override {
45  AU.setPreservesCFG();
47  }
48 
49  bool runOnModule(Module &M) override;
50 
51 public:
52  static char ID;
53  FixFunctionBitcasts() : ModulePass(ID) {}
54 };
55 } // End anonymous namespace
56 
59  return new FixFunctionBitcasts();
60 }
61 
62 // Recursively descend the def-use lists from V to find non-bitcast users of
63 // bitcasts of V.
64 static void FindUses(Value *V, Function &F,
65  SmallVectorImpl<std::pair<Use *, Function *>> &Uses,
66  SmallPtrSetImpl<Constant *> &ConstantBCs) {
67  for (Use &U : V->uses()) {
68  if (BitCastOperator *BC = dyn_cast<BitCastOperator>(U.getUser()))
69  FindUses(BC, F, Uses, ConstantBCs);
70  else if (U.get()->getType() != F.getType()) {
71  if (isa<Constant>(U.get())) {
72  // Only add constant bitcasts to the list once; they get RAUW'd
73  auto c = ConstantBCs.insert(cast<Constant>(U.get()));
74  if (!c.second) continue;
75  }
76  Uses.push_back(std::make_pair(&U, &F));
77  }
78  }
79 }
80 
81 // Create a wrapper function with type Ty that calls F (which may have a
82 // different type). Attempt to support common bitcasted function idioms:
83 // - Call with more arguments than needed: arguments are dropped
84 // - Call with fewer arguments than needed: arguments are filled in with undef
85 // - Return value is not needed: drop it
86 // - Return value needed but not present: supply an undef
87 //
88 // For now, return nullptr without creating a wrapper if the wrapper cannot
89 // be generated due to incompatible types.
91  Module *M = F->getParent();
92 
93  Function *Wrapper =
94  Function::Create(Ty, Function::PrivateLinkage, "bitcast", M);
95  BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);
96 
97  // Determine what arguments to pass.
99  Function::arg_iterator AI = Wrapper->arg_begin();
102  for (; AI != Wrapper->arg_end() && PI != PE; ++AI, ++PI) {
103  if (AI->getType() != *PI) {
104  Wrapper->eraseFromParent();
105  return nullptr;
106  }
107  Args.push_back(&*AI);
108  }
109  for (; PI != PE; ++PI)
110  Args.push_back(UndefValue::get(*PI));
111 
112  CallInst *Call = CallInst::Create(F, Args, "", BB);
113 
114  // Determine what value to return.
115  if (Ty->getReturnType()->isVoidTy())
116  ReturnInst::Create(M->getContext(), BB);
117  else if (F->getFunctionType()->getReturnType()->isVoidTy())
119  BB);
120  else if (F->getFunctionType()->getReturnType() == Ty->getReturnType())
122  else {
123  Wrapper->eraseFromParent();
124  return nullptr;
125  }
126 
127  return Wrapper;
128 }
129 
130 bool FixFunctionBitcasts::runOnModule(Module &M) {
132  SmallPtrSet<Constant *, 2> ConstantBCs;
133 
134  // Collect all the places that need wrappers.
135  for (Function &F : M) FindUses(&F, F, Uses, ConstantBCs);
136 
138 
139  for (auto &UseFunc : Uses) {
140  Use *U = UseFunc.first;
141  Function *F = UseFunc.second;
142  PointerType *PTy = cast<PointerType>(U->get()->getType());
144 
145  // If the function is casted to something like i8* as a "generic pointer"
146  // to be later casted to something else, we can't generate a wrapper for it.
147  // Just ignore such casts for now.
148  if (!Ty)
149  continue;
150 
151  auto Pair = Wrappers.insert(std::make_pair(std::make_pair(F, Ty), nullptr));
152  if (Pair.second)
153  Pair.first->second = CreateWrapper(F, Ty);
154 
155  Function *Wrapper = Pair.first->second;
156  if (!Wrapper)
157  continue;
158 
159  if (isa<Constant>(U->get()))
160  U->get()->replaceAllUsesWith(Wrapper);
161  else
162  U->set(Wrapper);
163  }
164 
165  return true;
166 }
iterator_range< use_iterator > uses()
Definition: Value.h:326
ModulePass * createWebAssemblyFixFunctionBitcasts()
A Module instance is used to store all the information related to an LLVM module. ...
Definition: Module.h:52
Type::subtype_iterator param_iterator
Definition: DerivedTypes.h:125
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:84
This class represents a function call, abstracting a target machine's calling convention.
Like Internal, but omit from symbol table.
Definition: GlobalValue.h:57
arg_iterator arg_end()
Definition: Function.h:559
Type * getElementType() const
Definition: DerivedTypes.h:462
This file contains the entry points for global functions defined in the LLVM WebAssembly back-end...
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:345
static ReturnInst * Create(LLVMContext &C, Value *retVal=nullptr, Instruction *InsertBefore=nullptr)
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:172
A Use represents the edge between a Value definition and its users.
Definition: Use.h:56
param_iterator param_end() const
Definition: DerivedTypes.h:127
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: APFloat.h:32
Class to represent function types.
Definition: DerivedTypes.h:102
#define F(x, y, z)
Definition: MD5.cpp:51
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:401
Class to represent pointers.
Definition: DerivedTypes.h:443
static Function * CreateWrapper(Function *F, FunctionType *Ty)
LLVM Basic Block Representation.
Definition: BasicBlock.h:51
This file contains the declarations for the subclasses of Constant, which represent the different fla...
param_iterator param_begin() const
Definition: DerivedTypes.h:126
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:368
Represent the analysis usage information of a pass.
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition: BasicBlock.h:93
arg_iterator arg_begin()
Definition: Function.h:550
static UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
Definition: Constants.cpp:1337
Iterator for intrusive lists based on ilist_node.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements...
Definition: SmallPtrSet.h:425
static CallInst * Create(Value *Func, ArrayRef< Value * > Args, ArrayRef< OperandBundleDef > Bundles=None, const Twine &NameStr="", Instruction *InsertBefore=nullptr)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small...
Definition: SmallVector.h:843
Module.h This file contains the declarations for the Module class.
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:230
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:276
void eraseFromParent() override
eraseFromParent - This method unlinks 'this' from the containing module and deletes it...
Definition: Function.cpp:246
PointerType * getType() const
Global values are always pointers.
Definition: GlobalValue.h:259
FunctionType * getFunctionType() const
Returns the FunctionType for me.
Definition: Function.cpp:230
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:235
LLVM_NODISCARD std::enable_if<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
Definition: Casting.h:287
Type * getReturnType() const
Definition: DerivedTypes.h:123
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:537
LLVM Value Representation.
Definition: Value.h:71
static void FindUses(Value *V, Function &F, SmallVectorImpl< std::pair< Use *, Function * >> &Uses, SmallPtrSetImpl< Constant * > &ConstantBCs)
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:47
static Function * Create(FunctionType *Ty, LinkageTypes Linkage, const Twine &N="", Module *M=nullptr)
Definition: Function.h:117
LLVMContext & getContext() const
Get the global data context.
Definition: Module.h:222
bool isVoidTy() const
Return true if this is 'void'.
Definition: Type.h:139