30#include "llvm/IR/IntrinsicsSPIRV.h"
44class SPIRVPrepareFunctions :
public ModulePass {
46 bool substituteIntrinsicCalls(
Function *
F);
66char SPIRVPrepareFunctions::ID = 0;
69 "SPIRV prepare functions",
false,
false)
72 Function *IntrinsicFunc =
II->getCalledFunction();
73 assert(IntrinsicFunc &&
"Missing function");
74 std::string FuncName = IntrinsicFunc->
getName().
str();
75 std::replace(FuncName.begin(), FuncName.end(),
'.',
'_');
76 FuncName =
"spirv." + FuncName;
85 if (
F &&
F->getFunctionType() == FT)
99 if (
auto *MSI = dyn_cast<MemSetInst>(Intrinsic))
100 if (isa<Constant>(MSI->getValue()) && isa<ConstantInt>(MSI->getLength()))
103 Module *M = Intrinsic->getModule();
104 std::string FuncName = lowerLLVMIntrinsicName(Intrinsic);
105 if (Intrinsic->isVolatile())
106 FuncName +=
".volatile";
110 Intrinsic->setCalledFunction(
F);
115 M->getOrInsertFunction(FuncName, Intrinsic->getFunctionType());
116 auto IntrinsicID = Intrinsic->getIntrinsicID();
117 Intrinsic->setCalledFunction(FC);
119 F = dyn_cast<Function>(FC.getCallee());
120 assert(
F &&
"Callee must be a function");
122 switch (IntrinsicID) {
123 case Intrinsic::memset: {
124 auto *MSI =
static_cast<MemSetInst *
>(Intrinsic);
132 IsVolatile->setName(
"isvolatile");
135 auto *MemSet = IRB.
CreateMemSet(Dest, Val, Len, MSI->getDestAlign(),
139 MemSet->eraseFromParent();
142 case Intrinsic::bswap: {
145 auto *BSwap = IRB.
CreateIntrinsic(Intrinsic::bswap, Intrinsic->getType(),
159 if (
auto *
Ref = dyn_cast_or_null<GetElementPtrInst>(AnnoVal))
160 AnnoVal =
Ref->getOperand(0);
161 if (
auto *
Ref = dyn_cast_or_null<BitCastInst>(OptAnnoVal))
162 OptAnnoVal =
Ref->getOperand(0);
165 if (
auto *
C = dyn_cast_or_null<Constant>(AnnoVal)) {
172 if (
auto *
C = dyn_cast_or_null<Constant>(OptAnnoVal);
173 C &&
C->getNumOperands()) {
174 Value *MaybeStruct =
C->getOperand(0);
175 if (
auto *
Struct = dyn_cast<ConstantStruct>(MaybeStruct)) {
176 for (
unsigned I = 0, E =
Struct->getNumOperands();
I != E; ++
I) {
177 if (
auto *CInt = dyn_cast<ConstantInt>(
Struct->getOperand(
I)))
178 Anno += (
I == 0 ?
": " :
", ") +
179 std::to_string(CInt->getType()->getIntegerBitWidth() == 1
180 ? CInt->getZExtValue()
181 : CInt->getSExtValue());
183 }
else if (
auto *
Struct = dyn_cast<ConstantAggregateZero>(MaybeStruct)) {
185 for (
unsigned I = 0, E =
Struct->getType()->getStructNumElements();
187 Anno +=
I == 0 ?
": 0" :
", 0";
194 const std::string &Anno,
201 static const std::regex R(
202 "\\{(\\d+)(?:[:,](\\d+|\"[^\"]*\")(?:,(\\d+|\"[^\"]*\"))*)?\\}");
205 for (std::sregex_iterator
206 It = std::sregex_iterator(Anno.begin(), Anno.end(), R),
207 ItEnd = std::sregex_iterator();
209 if (It->position() != Pos)
211 Pos = It->position() + It->length();
212 std::smatch
Match = *It;
214 for (std::size_t i = 1; i <
Match.size(); ++i) {
215 std::ssub_match SMatch =
Match[i];
216 std::string Item = SMatch.str();
217 if (Item.length() == 0)
219 if (Item[0] ==
'"') {
220 Item = Item.substr(1, Item.length() - 2);
222 static const std::regex RStr(
"^(\\d+)(?:,(\\d+))*$");
223 if (std::smatch MatchStr; std::regex_match(Item, MatchStr, RStr)) {
224 for (std::size_t SubIdx = 1; SubIdx < MatchStr.size(); ++SubIdx)
225 if (std::string SubStr = MatchStr[SubIdx].str(); SubStr.length())
227 ConstantInt::get(Int32Ty, std::stoi(SubStr))));
231 }
else if (int32_t Num;
232 std::from_chars(Item.data(), Item.data() + Item.size(), Num)
233 .ec == std::errc{}) {
240 if (MDsItem.
size() == 0)
244 return Pos ==
static_cast<int>(Anno.length()) ? MDs
253 Value *PtrArg =
nullptr;
254 if (
auto *BI = dyn_cast<BitCastInst>(
II->getArgOperand(0)))
255 PtrArg = BI->getOperand(0);
257 PtrArg =
II->getOperand(0);
260 4 <
II->arg_size() ?
II->getArgOperand(4) :
nullptr);
269 if (MDs.
size() == 0) {
271 Int32Ty,
static_cast<uint32_t>(SPIRV::Decoration::UserSemantic)));
279 Intrinsic::spv_assign_decoration, {PtrArg->
getType()},
281 II->replaceAllUsesWith(
II->getOperand(0));
291 Type *FSHRetTy = FSHFuncTy->getReturnType();
292 const std::string FuncName = lowerLLVMIntrinsicName(FSHIntrinsic);
296 if (!FSHFunc->
empty()) {
310 Value *BitWidthForInsts =
314 Value *RotateModVal =
316 Value *FirstShift =
nullptr, *SecShift =
nullptr;
329 Value *SubRotateVal = IRB.
CreateSub(BitWidthForInsts, RotateModVal);
347 if (!UMulFunc->
empty())
379 if (
II->getIntrinsicID() == Intrinsic::assume) {
381 II->getModule(), Intrinsic::SPVIntrinsics::spv_assume);
382 II->setCalledFunction(
F);
383 }
else if (
II->getIntrinsicID() == Intrinsic::expect) {
385 II->getModule(), Intrinsic::SPVIntrinsics::spv_expect,
386 {II->getOperand(0)->getType()});
387 II->setCalledFunction(
F);
402 for (
unsigned OpNo : OpNos)
406 II->setCalledFunction(
F);
416 Type *FSHLRetTy = UMulFuncTy->getReturnType();
417 const std::string FuncName = lowerLLVMIntrinsicName(UMulIntrinsic);
426bool SPIRVPrepareFunctions::substituteIntrinsicCalls(
Function *
F) {
427 bool Changed =
false;
430 auto Call = dyn_cast<CallInst>(&
I);
436 auto *
II = cast<IntrinsicInst>(Call);
437 switch (
II->getIntrinsicID()) {
438 case Intrinsic::memset:
439 case Intrinsic::bswap:
442 case Intrinsic::fshl:
443 case Intrinsic::fshr:
447 case Intrinsic::umul_with_overflow:
451 case Intrinsic::assume:
452 case Intrinsic::expect: {
458 case Intrinsic::lifetime_start:
460 II, Intrinsic::SPVIntrinsics::spv_lifetime_start, {1});
462 case Intrinsic::lifetime_end:
464 II, Intrinsic::SPVIntrinsics::spv_lifetime_end, {1});
466 case Intrinsic::ptr_annotation:
480SPIRVPrepareFunctions::removeAggregateTypesFromSignature(
Function *
F) {
483 bool IsRetAggr =
F->getReturnType()->isAggregateType();
485 std::any_of(
F->arg_begin(),
F->arg_end(), [](
Argument &Arg) {
486 return Arg.getType()->isAggregateType();
488 bool DoClone = IsRetAggr || HasAggrArg;
492 Type *RetType = IsRetAggr ?
B.getInt32Ty() :
F->getReturnType();
494 ChangedTypes.
push_back(std::pair<int, Type *>(-1,
F->getReturnType()));
496 for (
const auto &Arg :
F->args()) {
497 if (Arg.getType()->isAggregateType()) {
500 std::pair<int, Type *>(Arg.getArgNo(), Arg.getType()));
505 FunctionType::get(RetType, ArgTypes,
F->getFunctionType()->isVarArg());
511 for (
auto &Arg :
F->args()) {
513 NewFArgIt->setName(ArgName);
514 VMap[&Arg] = &(*NewFArgIt++);
523 F->getParent()->getOrInsertNamedMetadata(
"spv.cloned_funcs");
526 for (
auto &ChangedTyP : ChangedTypes)
529 {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)),
530 ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))}));
535 if (
auto *CI = dyn_cast<CallInst>(U))
537 U->replaceUsesOfWith(
F, NewF);
541 if (RetType !=
F->getReturnType())
543 NewF,
F->getReturnType());
547bool SPIRVPrepareFunctions::runOnModule(
Module &M) {
548 bool Changed =
false;
550 Changed |= substituteIntrinsicCalls(&
F);
552 std::vector<Function *> FuncsWorklist;
554 FuncsWorklist.push_back(&
F);
556 for (
auto *
F : FuncsWorklist) {
557 Function *NewF = removeAggregateTypesFromSignature(
F);
560 F->eraseFromParent();
569 return new SPIRVPrepareFunctions(
TM);
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
uint64_t IntrinsicInst * II
const char LLVMTargetMachineRef TM
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static void lowerFunnelShifts(IntrinsicInst *FSHIntrinsic)
static std::string getAnnotation(Value *AnnoVal, Value *OptAnnoVal)
static bool toSpvOverloadedIntrinsic(IntrinsicInst *II, Intrinsic::ID NewID, ArrayRef< unsigned > OpNos)
static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic)
static void lowerPtrAnnotation(IntrinsicInst *II)
static void lowerUMulWithOverflow(IntrinsicInst *UMulIntrinsic)
static SmallVector< Metadata * > parseAnnotation(Value *I, const std::string &Anno, LLVMContext &Ctx, Type *Int32Ty)
static void lowerExpectAssume(IntrinsicInst *II)
static void buildUMulWithOverflowFunc(Function *UMulFunc)
static Function * getOrCreateFunction(Module *M, Type *RetTy, ArrayRef< Type * > ArgTypes, StringRef Name)
Represent the analysis usage information of a pass.
This class represents an incoming formal argument to a Function.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
bool empty() const
empty - Check if the array is empty.
LLVM Basic Block Representation.
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
FunctionType * getFunctionType() const
void setCalledFunction(Function *Fn)
Sets the function called, including updating the function type.
This is the shared class of boolean and integer constants.
Class to represent fixed width SIMD vectors.
unsigned getNumElements() const
A handy container for a FunctionType+Callee-pointer pair, which can be passed around as a single enti...
static Function * Create(FunctionType *Ty, LinkageTypes Linkage, unsigned AddrSpace, const Twine &N="", Module *M=nullptr)
FunctionType * getFunctionType() const
Returns the FunctionType for me.
bool isIntrinsic() const
isIntrinsic - Returns true if the function's name starts with "llvm.".
Type * getReturnType() const
Returns the type of the ret val.
void setCallingConv(CallingConv::ID CC)
Argument * getArg(unsigned i) const
Module * getParent()
Get the module that this global value is contained inside of...
void setDSOLocal(bool Local)
@ ExternalLinkage
Externally visible function.
Value * CreateNUWMul(Value *LHS, Value *RHS, const Twine &Name="")
Value * CreateInsertValue(Value *Agg, Value *Val, ArrayRef< unsigned > Idxs, const Twine &Name="")
Value * CreateVectorSplat(unsigned NumElts, Value *V, const Twine &Name="")
Return a vector value that contains.
CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, Instruction *FMFSource=nullptr, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
CallInst * CreateMemSet(Value *Ptr, Value *Val, uint64_t Size, MaybeAlign Align, bool isVolatile=false, MDNode *TBAATag=nullptr, MDNode *ScopeTag=nullptr, MDNode *NoAliasTag=nullptr)
Create and insert a memset to the specified pointer and the specified value.
Value * CreateLShr(Value *LHS, Value *RHS, const Twine &Name="", bool isExact=false)
ReturnInst * CreateRet(Value *V)
Create a 'ret <val>' instruction.
Value * CreateUDiv(Value *LHS, Value *RHS, const Twine &Name="", bool isExact=false)
Value * CreateICmpNE(Value *LHS, Value *RHS, const Twine &Name="")
Value * CreateSub(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreateShl(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
ReturnInst * CreateRetVoid()
Create a 'ret void' instruction.
Value * CreateOr(Value *LHS, Value *RHS, const Twine &Name="")
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
ConstantInt * getInt(const APInt &AI)
Get a constant integer value.
Value * CreateURem(Value *LHS, Value *RHS, const Twine &Name="")
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
const Module * getModule() const
Return the module owning the function this instruction belongs to or nullptr it the function does not...
A wrapper class for inspecting calls to intrinsic functions.
Intrinsic::ID getIntrinsicID() const
Return the intrinsic ID of this intrinsic.
void LowerIntrinsicCall(CallInst *CI)
Replace a call to the specified intrinsic function.
This is an important class for using LLVM in a threaded context.
static MDTuple * get(LLVMContext &Context, ArrayRef< Metadata * > MDs)
static MDString * get(LLVMContext &Context, StringRef Str)
This class wraps the llvm.memset and llvm.memset.inline intrinsics.
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
virtual bool runOnModule(Module &M)=0
runOnModule - Virtual method overriden by subclasses to process the module being operated on.
A Module instance is used to store all the information related to an LLVM module.
LLVMContext & getContext() const
Get the global data context.
void addOperand(MDNode *M)
PassRegistry - This class manages the registration and intitialization of the pass subsystem as appli...
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
void addMutated(Value *Val, Type *Ty)
SPIRVGlobalRegistry * getSPIRVGlobalRegistry() const
bool canUseExtension(SPIRV::Extension::Extension E) const
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
std::string str() const
str - Get the contents as an std::string.
The instances of the Type class are immutable: once they are created, they are never changed.
unsigned getIntegerBitWidth() const
static IntegerType * getInt32Ty(LLVMContext &C)
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
void setName(const Twine &Name)
Change the name of the value.
StringRef getName() const
Return a constant reference to the value's name.
void takeName(Value *V)
Transfer the name from V to this value.
Type * getElementType() const
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ SPIR_FUNC
Used for SPIR non-kernel device functions.
@ C
The default llvm calling convention, compatible with C.
Function * getDeclaration(Module *M, ID id, ArrayRef< Type * > Tys=std::nullopt)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
This is an optimization pass for GlobalISel generic memory operations.
void initializeSPIRVPrepareFunctionsPass(PassRegistry &)
bool getConstantStringInfo(const Value *V, StringRef &Str, bool TrimAtNul=true)
This function computes the length of a null-terminated C string pointed to by V.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
@ Ref
The access may reference the value stored in memory.
constexpr unsigned BitWidth
void CloneFunctionInto(Function *NewFunc, const Function *OldFunc, ValueToValueMapTy &VMap, CloneFunctionChangeType Changes, SmallVectorImpl< ReturnInst * > &Returns, const char *NameSuffix="", ClonedCodeInfo *CodeInfo=nullptr, ValueMapTypeRemapper *TypeMapper=nullptr, ValueMaterializer *Materializer=nullptr)
Clone OldFunc into NewFunc, transforming the old arguments into references to VMap values.
ModulePass * createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM)
void expandMemSetAsLoop(MemSetInst *MemSet)
Expand MemSet as a loop. MemSet is not deleted.
Implement std::hash so that hash_code can be used in STL containers.