LLVM  6.0.0svn
WebAssemblyFixFunctionBitcasts.cpp
Go to the documentation of this file.
1 //===-- WebAssemblyFixFunctionBitcasts.cpp - Fix function bitcasts --------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 ///
10 /// \file
11 /// \brief Fix bitcasted functions.
12 ///
13 /// WebAssembly requires caller and callee signatures to match, however in LLVM,
14 /// some amount of slop is vaguely permitted. Detect mismatch by looking for
15 /// bitcasts of functions and rewrite them to use wrapper functions instead.
16 ///
17 /// This doesn't catch all cases, such as when a function's address is taken in
18 /// one place and casted in another, but it works for many common cases.
19 ///
20 /// Note that LLVM already optimizes away function bitcasts in common cases by
21 /// dropping arguments as needed, so this pass only ends up getting used in less
22 /// common cases.
23 ///
24 //===----------------------------------------------------------------------===//
25 
26 #include "WebAssembly.h"
27 #include "llvm/IR/CallSite.h"
28 #include "llvm/IR/Constants.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/Module.h"
31 #include "llvm/IR/Operator.h"
32 #include "llvm/Pass.h"
33 #include "llvm/Support/Debug.h"
35 using namespace llvm;
36 
37 #define DEBUG_TYPE "wasm-fix-function-bitcasts"
38 
39 namespace {
40 class FixFunctionBitcasts final : public ModulePass {
41  StringRef getPassName() const override {
42  return "WebAssembly Fix Function Bitcasts";
43  }
44 
45  void getAnalysisUsage(AnalysisUsage &AU) const override {
46  AU.setPreservesCFG();
48  }
49 
50  bool runOnModule(Module &M) override;
51 
52 public:
53  static char ID;
54  FixFunctionBitcasts() : ModulePass(ID) {}
55 };
56 } // End anonymous namespace
57 
60  return new FixFunctionBitcasts();
61 }
62 
63 // Recursively descend the def-use lists from V to find non-bitcast users of
64 // bitcasts of V.
65 static void FindUses(Value *V, Function &F,
66  SmallVectorImpl<std::pair<Use *, Function *>> &Uses,
67  SmallPtrSetImpl<Constant *> &ConstantBCs) {
68  for (Use &U : V->uses()) {
69  if (BitCastOperator *BC = dyn_cast<BitCastOperator>(U.getUser()))
70  FindUses(BC, F, Uses, ConstantBCs);
71  else if (U.get()->getType() != F.getType()) {
72  CallSite CS(U.getUser());
73  if (!CS)
74  // Skip uses that aren't immediately called
75  continue;
76  Value *Callee = CS.getCalledValue();
77  if (Callee != V)
78  // Skip calls where the function isn't the callee
79  continue;
80  if (isa<Constant>(U.get())) {
81  // Only add constant bitcasts to the list once; they get RAUW'd
82  auto c = ConstantBCs.insert(cast<Constant>(U.get()));
83  if (!c.second)
84  continue;
85  }
86  Uses.push_back(std::make_pair(&U, &F));
87  }
88  }
89 }
90 
91 // Create a wrapper function with type Ty that calls F (which may have a
92 // different type). Attempt to support common bitcasted function idioms:
93 // - Call with more arguments than needed: arguments are dropped
94 // - Call with fewer arguments than needed: arguments are filled in with undef
95 // - Return value is not needed: drop it
96 // - Return value needed but not present: supply an undef
97 //
98 // For now, return nullptr without creating a wrapper if the wrapper cannot
99 // be generated due to incompatible types.
101  Module *M = F->getParent();
102 
103  Function *Wrapper =
104  Function::Create(Ty, Function::PrivateLinkage, "bitcast", M);
105  BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);
106 
107  // Determine what arguments to pass.
109  Function::arg_iterator AI = Wrapper->arg_begin();
112  for (; AI != Wrapper->arg_end() && PI != PE; ++AI, ++PI) {
113  if (AI->getType() != *PI) {
114  Wrapper->eraseFromParent();
115  return nullptr;
116  }
117  Args.push_back(&*AI);
118  }
119  for (; PI != PE; ++PI)
120  Args.push_back(UndefValue::get(*PI));
121 
122  CallInst *Call = CallInst::Create(F, Args, "", BB);
123 
124  // Determine what value to return.
125  if (Ty->getReturnType()->isVoidTy())
126  ReturnInst::Create(M->getContext(), BB);
127  else if (F->getFunctionType()->getReturnType()->isVoidTy())
129  BB);
130  else if (F->getFunctionType()->getReturnType() == Ty->getReturnType())
131  ReturnInst::Create(M->getContext(), Call, BB);
132  else {
133  Wrapper->eraseFromParent();
134  return nullptr;
135  }
136 
137  return Wrapper;
138 }
139 
140 bool FixFunctionBitcasts::runOnModule(Module &M) {
142  SmallPtrSet<Constant *, 2> ConstantBCs;
143 
144  // Collect all the places that need wrappers.
145  for (Function &F : M) FindUses(&F, F, Uses, ConstantBCs);
146 
148 
149  for (auto &UseFunc : Uses) {
150  Use *U = UseFunc.first;
151  Function *F = UseFunc.second;
152  PointerType *PTy = cast<PointerType>(U->get()->getType());
154 
155  // If the function is casted to something like i8* as a "generic pointer"
156  // to be later casted to something else, we can't generate a wrapper for it.
157  // Just ignore such casts for now.
158  if (!Ty)
159  continue;
160 
161  // Wasm varargs are not ABI-compatible with non-varargs. Just ignore
162  // such casts for now.
163  if (Ty->isVarArg() || F->isVarArg())
164  continue;
165 
166  auto Pair = Wrappers.insert(std::make_pair(std::make_pair(F, Ty), nullptr));
167  if (Pair.second)
168  Pair.first->second = CreateWrapper(F, Ty);
169 
170  Function *Wrapper = Pair.first->second;
171  if (!Wrapper)
172  continue;
173 
174  if (isa<Constant>(U->get()))
175  U->get()->replaceAllUsesWith(Wrapper);
176  else
177  U->set(Wrapper);
178  }
179 
180  return true;
181 }
bool isVarArg() const
isVarArg - Return true if this function takes a variable number of arguments.
Definition: Function.h:158
iterator_range< use_iterator > uses()
Definition: Value.h:356
This class represents an incoming formal argument to a Function.
Definition: Argument.h:30
ModulePass * createWebAssemblyFixFunctionBitcasts()
Compute iterated dominance frontiers using a linear time algorithm.
Definition: AllocatorList.h:24
A Module instance is used to store all the information related to an LLVM module. ...
Definition: Module.h:63
This class represents a function call, abstracting a target machine&#39;s calling convention.
Like Internal, but omit from symbol table.
Definition: GlobalValue.h:57
arg_iterator arg_end()
Definition: Function.h:612
F(f)
static CallInst * Create(Value *Func, ArrayRef< Value *> Args, ArrayRef< OperandBundleDef > Bundles=None, const Twine &NameStr="", Instruction *InsertBefore=nullptr)
param_iterator param_end() const
Definition: DerivedTypes.h:129
This file contains the entry points for global functions defined in the LLVM WebAssembly back-end...
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:344
static ReturnInst * Create(LLVMContext &C, Value *retVal=nullptr, Instruction *InsertBefore=nullptr)
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:191
LLVMContext & getContext() const
Get the global data context.
Definition: Module.h:237
A Use represents the edge between a Value definition and its users.
Definition: Use.h:56
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: APFloat.h:42
Class to represent function types.
Definition: DerivedTypes.h:103
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:91
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:245
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:430
amdgpu Simplify well known AMD library false Value * Callee
Class to represent pointers.
Definition: DerivedTypes.h:467
static Function * CreateWrapper(Function *F, FunctionType *Ty)
bool isVoidTy() const
Return true if this is &#39;void&#39;.
Definition: Type.h:141
LLVM Basic Block Representation.
Definition: BasicBlock.h:59
This file contains the declarations for the subclasses of Constant, which represent the different fla...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:371
param_iterator param_begin() const
Definition: DerivedTypes.h:128
Represent the analysis usage information of a pass.
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition: BasicBlock.h:101
arg_iterator arg_begin()
Definition: Function.h:603
static UndefValue * get(Type *T)
Static factory methods - Return an &#39;undef&#39; object of the specified type.
Definition: Constants.cpp:1320
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements...
Definition: SmallPtrSet.h:418
This is a &#39;vector&#39; (really, a variable-sized array), optimized for the case when the array is small...
Definition: SmallVector.h:864
Module.h This file contains the declarations for the Module class.
Type::subtype_iterator param_iterator
Definition: DerivedTypes.h:126
Type * getReturnType() const
Definition: DerivedTypes.h:124
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Definition: Pass.cpp:285
FunctionType * getFunctionType() const
Returns the FunctionType for me.
Definition: Function.h:145
static void FindUses(Value *V, Function &F, SmallVectorImpl< std::pair< Use *, Function *>> &Uses, SmallPtrSetImpl< Constant *> &ConstantBCs)
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:225
LLVM_NODISCARD std::enable_if<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
Definition: Casting.h:323
void eraseFromParent()
eraseFromParent - This method unlinks &#39;this&#39; from the containing module and deletes it...
Definition: Function.cpp:202
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:556
LLVM Value Representation.
Definition: Value.h:73
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:49
static Function * Create(FunctionType *Ty, LinkageTypes Linkage, const Twine &N="", Module *M=nullptr)
Definition: Function.h:136
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
Type * getElementType() const
Definition: DerivedTypes.h:486
PointerType * getType() const
Global values are always pointers.
Definition: GlobalValue.h:265