65#define DEBUG_TYPE "nvptx-lower-args"
71class NVPTXLowerArgsLegacyPass :
public ModulePass {
72 bool runOnModule(
Module &M)
override;
78 return "Lower pointer arguments of CUDA kernels";
86char NVPTXLowerArgsLegacyPass::ID = 1;
89 "Lower arguments (NVPTX)",
false,
false)
101 const auto CloneInstInParamAS = [](
const IP &
I) ->
Value * {
104 LI->setOperand(0,
I.NewParam);
110 GEP->getSourceElementType(),
I.NewParam, Indices,
GEP->getName(),
112 NewGEP->setNoWrapFlags(
GEP->getNoWrapFlags());
119 BC->getName(), BC->getIterator());
128 if (
MI->getRawSource() ==
I.OldUse->get()) {
133 CallInst *
B = Builder.CreateMemTransferInst(
134 ID,
MI->getRawDest(),
MI->getDestAlign(),
I.NewParam,
135 MI->getSourceAlign(),
MI->getLength(),
MI->isVolatile());
136 for (
unsigned I : {0, 1})
137 if (
uint64_t Bytes =
MI->getParamDereferenceableBytes(
I))
138 B->addDereferenceableParamAttr(
I, Bytes);
146 auto ItemsToConvert =
150 while (!ItemsToConvert.empty()) {
151 IP
I = ItemsToConvert.pop_back_val();
152 Value *NewInst = CloneInstInParamAS(
I);
155 if (NewInst && NewInst != OldInst) {
160 ItemsToConvert.push_back({&U, NewInst});
162 InstructionsToDelete.push_back(OldInst);
174 I->eraseFromParent();
179 using Base = PtrUseVisitor<ArgUseChecker>;
181 SmallPtrSet<Instruction *, 4> Conditionals;
183 ArgUseChecker(
const DataLayout &
DL) : PtrUseVisitor(
DL) {}
185 PtrInfo visitArgPtr(Argument &
A) {
186 assert(
A.getType()->isPointerTy());
188 IsOffsetKnown =
false;
199 while (!(Worklist.empty() || PI.isAborted())) {
200 UseToVisit ToVisit = Worklist.pop_back_val();
201 U = ToVisit.UseAndIsOffsetKnown.getPointer();
207 LLVM_DEBUG(
dbgs() <<
"Argument pointer escaped: " << *PI.getEscapingInst()
209 else if (PI.isAborted())
210 LLVM_DEBUG(
dbgs() <<
"Pointer use needs a copy: " << *PI.getAbortingInst()
213 <<
" conditionals\n");
217 void visitStoreInst(StoreInst &SI) {
219 if (
U->get() ==
SI.getValueOperand())
220 return PI.setEscapedAndAborted(&SI);
225 void visitAddrSpaceCastInst(AddrSpaceCastInst &ASC) {
228 return PI.setEscapedAndAborted(&ASC);
234 void visitPHINodeOrSelectInst(Instruction &
I) {
237 Conditionals.insert(&
I);
240 void visitPHINode(PHINode &PN) { visitPHINodeOrSelectInst(PN); }
241 void visitSelectInst(SelectInst &SI) { visitPHINodeOrSelectInst(SI); }
245 void visitMemTransferInst(MemTransferInst &
II) {
246 if (*U ==
II.getRawDest())
250 void visitMemSetInst(MemSetInst &
II) { PI.setAborted(&
II); }
285 return A.hasByValAttr() &&
286 A.getType()->getPointerAddressSpace() != ADDRESS_SPACE_ENTRY_PARAM;
305 ArgUseChecker AUC(
DL);
306 ArgUseChecker::PtrInfo PI = AUC.visitArgPtr(OldArg);
307 const bool ArgUseIsReadOnly = !(PI.isEscaped() || PI.isAborted());
308 if (ArgUseIsReadOnly && AUC.Conditionals.
empty()) {
312 for (
Use *U : UsesToUpdate)
321 DVR->replaceVariableLocationOp(&OldArg, &NewParamArg);
328 LLVM_DEBUG(
dbgs() <<
"Using non-copy pointer to " << OldArg <<
"\n");
343 copyByValParam(
F, OldArg, NewParamArg);
362 F.getLinkage(),
F.getAddressSpace());
365 F.getParent()->getFunctionList().insert(
F.getIterator(), NF);
372 if (NewArg.hasByValAttr())
373 NewArg.addAttr(Attribute::ReadOnly);
377 F.replaceAllUsesWith(NF);
385 if (OldArg.hasByValAttr())
388 OldArg.replaceAllUsesWith(&NewArg);
389 NewArg.takeName(&OldArg);
411 LLVM_DEBUG(
dbgs() <<
"Lowering kernel args of " <<
F.getName() <<
"\n");
424bool NVPTXLowerArgsLegacyPass::runOnModule(
Module &M) {
425 auto &TM = getAnalysis<TargetPassConfig>().getTM<NVPTXTargetMachine>();
430 return new NVPTXLowerArgsLegacyPass();
434 LLVM_DEBUG(
dbgs() <<
"Creating a copy of byval args of " <<
F.getName()
440 copyByValParam(
F, Arg, Arg);
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
This file contains the simple types necessary to represent the attributes associated with functions a...
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
NVPTX address space definition.
static void rewriteKernelByValSignature(Function &F, const bool HasCvtaParam)
static void lowerKernelByValParam(Argument &OldArg, Argument &NewParamArg, Function &F, const bool HasCvtaParam)
nvptx lower Lower static false void convertToParamAS(ArrayRef< Use * > OldUses, Value *Param)
Recursively convert the users of a param to the param address space.
static bool processModule(Module &M, NVPTXTargetMachine &TM)
static bool kernelNeedsByValLowering(const Function &F)
static bool copyFunctionByValArgs(Function &F)
static bool processFunction(Function &F, NVPTXTargetMachine &TM)
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
This file provides a collection of visitors which walk the (instruction) uses of a pointer.
Target-Independent Code Generator Pass Configuration Options pass.
unsigned getDestAddressSpace() const
Returns the address space of the result.
an instruction to allocate memory on the stack
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
LLVM_ABI std::optional< TypeSize > getAllocationSize(const DataLayout &DL) const
Get allocation size in bytes.
void setAlignment(Align Align)
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
This class represents an incoming formal argument to a Function.
LLVM_ABI bool hasByValAttr() const
Return true if this argument has the byval attribute.
unsigned getArgNo() const
Return the index of this formal argument in its containing function.
LLVM_ABI Type * getParamByValType() const
If this is a byval argument, return its type.
LLVM_ABI MaybeAlign getParamAlign() const
If this is a byval or inalloca argument, return its alignment.
Represent a constant reference to an array (0 or more elements consecutively in memory),...
This class represents a function call, abstracting a target machine's calling convention.
static LLVM_ABI CastInst * Create(Instruction::CastOps, Value *S, Type *Ty, const Twine &Name="", InsertPosition InsertBefore=nullptr)
Provides a way to construct any of the CastInst subclasses using an opcode instead of the subclass's ...
A parsed version of the target data layout string in and methods for querying it.
Record of a variable value-assignment, aka a non instruction representation of the dbg....
static LLVM_ABI FunctionType * get(Type *Result, ArrayRef< Type * > Params, bool isVarArg)
This static method is the primary way of constructing a FunctionType.
static Function * Create(FunctionType *Ty, LinkageTypes Linkage, unsigned AddrSpace, const Twine &N="", Module *M=nullptr)
void splice(Function::iterator ToIt, Function *FromF)
Transfer all blocks from FromF to this function at ToIt.
iterator_range< arg_iterator > args()
void copyAttributesFrom(const Function *Src)
copyAttributesFrom - copy all additional attributes (those not needed to create a Function) from the ...
static GetElementPtrInst * Create(Type *PointeeType, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
LLVM_ABI void copyMetadata(const GlobalObject *Src, unsigned Offset)
Copy metadata from Src, adjusting offsets by Offset.
LLVM_ABI void setComdat(Comdat *C)
PointerType * getPtrTy(unsigned AddrSpace=0)
Fetch the type representing a pointer.
Value * CreateAddrSpaceCast(Value *V, Type *DestTy, const Twine &Name="")
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
void visit(Iterator Start, Iterator End)
unsigned getBitWidth() const
Get the number of bits in this IntegerType.
This is an important class for using LLVM in a threaded context.
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
A Module instance is used to store all the information related to an LLVM module.
const NVPTXSubtarget * getSubtargetImpl(const Function &) const override
Virtual method implemented by subclasses that returns a reference to that target's TargetSubtargetInf...
static LLVM_ABI PointerType * get(Type *ElementType, unsigned AddressSpace)
This constructs a pointer to an object of the specified type in a numbered address space.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
A base class for visitors over the uses of a pointer value.
void visitAddrSpaceCastInst(AddrSpaceCastInst &ASC)
void visitPtrToIntInst(PtrToIntInst &I)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Represent a constant reference to a string, i.e.
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
A Use represents the edge between a Value definition and its users.
LLVM Value Representation.
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
bool isUsedByMetadata() const
Return true if there is metadata referencing this value.
iterator_range< use_iterator > uses()
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
LLVM_ABI void takeName(Value *V)
Transfer the name from V to this value.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ ADDRESS_SPACE_ENTRY_PARAM
friend class Instruction
Iterator for Instructions in a `BasicBlock.
This is an optimization pass for GlobalISel generic memory operations.
detail::zippy< detail::zip_first, T, U, Args... > zip_equal(T &&t, U &&u, Args &&...args)
zip iterator that assumes that all iteratees have the same length.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
auto map_to_vector(ContainerTy &&C, FuncTy &&F)
Map a range to a SmallVector with element types deduced from the mapping.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
auto reverse(ContainerTy &&C)
ModulePass * createNVPTXLowerArgsPass()
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
bool isParamGridConstant(const Argument &Arg)
bool isKernelFunction(const Function &F)
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
iterator_range< pointer_iterator< WrappedIteratorT > > make_pointer_range(RangeT &&Range)
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI void findDbgUsers(Value *V, SmallVectorImpl< DbgVariableRecord * > &DbgVariableRecords)
Finds the debug info records describing a value.
AnalysisManager< Module > ModuleAnalysisManager
Convenience typedef for the Module analysis manager.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM)