35#define DEBUG_TYPE "wasm-fix-function-bitcasts"
38class FixFunctionBitcasts final :
public ModulePass {
40 return "WebAssembly Fix Function Bitcasts";
56char FixFunctionBitcasts::ID = 0;
58 "Fix mismatching bitcasts for WebAssembly",
false,
false)
61 return new FixFunctionBitcasts();
68 for (
User *U : V->users()) {
69 if (
auto *BC = dyn_cast<BitCastOperator>(U))
71 else if (
auto *
A = dyn_cast<GlobalAlias>(U))
73 else if (
auto *CB = dyn_cast<CallBase>(U)) {
74 Value *Callee = CB->getCalledOperand();
78 if (CB->getFunctionType() ==
F.getValueType())
81 Uses.push_back(std::make_pair(CB, &
F));
113 F->getName() +
"_bitcast", M);
123 bool TypeMismatch =
false;
124 bool WrapperNeeded =
false;
126 Type *ExpectedRtnType =
F->getFunctionType()->getReturnType();
127 Type *RtnType = Ty->getReturnType();
129 if ((
F->getFunctionType()->getNumParams() != Ty->getNumParams()) ||
130 (
F->getFunctionType()->isVarArg() != Ty->isVarArg()) ||
131 (ExpectedRtnType != RtnType))
132 WrapperNeeded =
true;
134 for (; AI != AE && PI != PE; ++AI, ++PI) {
136 Type *ParamType = *PI;
138 if (ArgType == ParamType) {
139 Args.push_back(&*AI);
145 Args.push_back(PtrCast);
147 LLVM_DEBUG(
dbgs() <<
"createWrapper: struct param type in bitcast: "
148 <<
F->getName() <<
"\n");
149 WrapperNeeded =
false;
152 <<
F->getName() <<
"\n");
154 << *ParamType <<
" Got: " << *ArgType <<
"\n");
161 if (WrapperNeeded && !TypeMismatch) {
162 for (; PI != PE; ++PI)
165 for (; AI != AE; ++AI)
166 Args.push_back(&*AI);
170 Type *ExpectedRtnType =
F->getFunctionType()->getReturnType();
171 Type *RtnType = Ty->getReturnType();
175 }
else if (ExpectedRtnType->
isVoidTy()) {
176 LLVM_DEBUG(
dbgs() <<
"Creating dummy return: " << *RtnType <<
"\n");
178 }
else if (RtnType == ExpectedRtnType) {
187 LLVM_DEBUG(
dbgs() <<
"createWrapper: struct return type in bitcast: "
188 <<
F->getName() <<
"\n");
189 WrapperNeeded =
false;
191 LLVM_DEBUG(
dbgs() <<
"createWrapper: return type mismatch calling: "
192 <<
F->getName() <<
"\n");
194 <<
" Got: " << *RtnType <<
"\n");
203 F->getName() +
"_bitcast_invalid", M);
206 Wrapper->setName(
F->getName() +
"_bitcast_invalid");
207 }
else if (!WrapperNeeded) {
208 LLVM_DEBUG(
dbgs() <<
"createWrapper: no wrapper needed: " <<
F->getName()
223 return FuncTy->getReturnType() == MainTy->getReturnType() &&
224 FuncTy->getNumParams() == 0 &&
228bool FixFunctionBitcasts::runOnModule(
Module &M) {
229 LLVM_DEBUG(
dbgs() <<
"********** Fix Function Bitcasts **********\n");
247 if (
F.getName() ==
"main") {
254 LLVM_DEBUG(
dbgs() <<
"Found `main` function with incorrect type: "
255 << *
F.getFunctionType() <<
"\n");
259 Uses.push_back(std::make_pair(CallMain, &
F));
266 for (
auto &UseFunc :
Uses) {
271 auto Pair = Wrappers.
insert(std::make_pair(std::make_pair(
F, Ty),
nullptr));
285 Main->
setName(
"__original_main");
296 MainWrapper->setName(
"main");
amdgpu aa AMDGPU Address space based Alias Analysis Wrapper
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Rewrite Partial Register Uses
Module.h This file contains the declarations for the Module class.
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
static void findUses(Value *V, Function &F, SmallVectorImpl< std::pair< CallBase *, Function * > > &Uses)
static bool shouldFixMainFunction(FunctionType *FuncTy, FunctionType *MainTy)
static Function * createWrapper(Function *F, FunctionType *Ty)
This file contains the entry points for global functions defined in the LLVM WebAssembly back-end.
Represent the analysis usage information of a pass.
void setPreservesCFG()
This function should be called by the pass, iff they do not:
This class represents an incoming formal argument to a Function.
LLVM Basic Block Representation.
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
const DataLayout & getDataLayout() const
Get the data layout of the module this basic block belongs to.
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
Value * getCalledOperand() const
FunctionType * getFunctionType() const
void setCalledOperand(Value *V)
This class represents a function call, abstracting a target machine's calling convention.
static CallInst * Create(FunctionType *Ty, Value *F, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
static bool isBitOrNoopPointerCastable(Type *SrcTy, Type *DestTy, const DataLayout &DL)
Check whether a bitcast, inttoptr, or ptrtoint cast between these types is valid and a no-op.
static CastInst * CreateBitOrPointerCast(Value *S, Type *Ty, const Twine &Name="", InsertPosition InsertBefore=nullptr)
Create a BitCast, a PtrToInt, or an IntToPTr cast instruction.
A parsed version of the target data layout string in and methods for querying it.
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Type::subtype_iterator param_iterator
static Function * Create(FunctionType *Ty, LinkageTypes Linkage, unsigned AddrSpace, const Twine &N="", Module *M=nullptr)
VisibilityTypes getVisibility() const
bool isDeclaration() const
Return true if the primary definition of this global value is outside of the current translation unit...
LinkageTypes getLinkage() const
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
InstListType::iterator insertInto(BasicBlock *ParentBB, InstListType::iterator It)
Inserts an unlinked instruction into ParentBB at position It and returns the iterator of the inserted...
This is an important class for using LLVM in a threaded context.
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
virtual bool runOnModule(Module &M)=0
runOnModule - Virtual method overriden by subclasses to process the module being operated on.
A Module instance is used to store all the information related to an LLVM module.
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
static ReturnInst * Create(LLVMContext &C, Value *retVal=nullptr, InsertPosition InsertBefore=nullptr)
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
The instances of the Type class are immutable: once they are created, they are never changed.
bool isStructTy() const
True if this is an instance of StructType.
static IntegerType * getInt32Ty(LLVMContext &C)
bool isVoidTy() const
Return true if this is 'void'.
static UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
This function has undefined behavior.
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
void setName(const Twine &Name)
Change the name of the value.
const Value * stripPointerCasts() const
Strip off pointer casts, all-zero GEPs and address space casts.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ Swift
Calling convention for Swift.
@ C
The default llvm calling convention, compatible with C.
This is an optimization pass for GlobalISel generic memory operations.
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
ModulePass * createWebAssemblyFixFunctionBitcasts()