150#include "llvm/IR/IntrinsicsNVPTX.h"
159#define DEBUG_TYPE "nvptx-lower-args"
180 void markPointerAsGlobal(
Value *
Ptr);
186 return "Lower pointer arguments of CUDA kernels";
194char NVPTXLowerArgs::ID = 1;
197 "Lower arguments (NVPTX)",
false,
false)
225 bool IsGridConstant) {
226 Instruction *
I = dyn_cast<Instruction>(OldUse->getUser());
227 assert(
I &&
"OldUse must be in an instruction");
236 auto CloneInstInParamAS = [HasCvtaParam,
237 IsGridConstant](
const IP &
I) ->
Value * {
238 if (
auto *LI = dyn_cast<LoadInst>(
I.OldInstruction)) {
239 LI->setOperand(0,
I.NewParam);
242 if (
auto *
GEP = dyn_cast<GetElementPtrInst>(
I.OldInstruction)) {
245 GEP->getSourceElementType(),
I.NewParam, Indices,
GEP->getName(),
247 NewGEP->setIsInBounds(
GEP->isInBounds());
250 if (
auto *BC = dyn_cast<BitCastInst>(
I.OldInstruction)) {
252 return BitCastInst::Create(BC->getOpcode(),
I.NewParam, NewBCType,
253 BC->getName(), BC->getIterator());
255 if (
auto *ASC = dyn_cast<AddrSpaceCastInst>(
I.OldInstruction)) {
261 if (
auto *
MI = dyn_cast<MemTransferInst>(
I.OldInstruction)) {
262 if (
MI->getRawSource() ==
I.OldUse->get()) {
268 ID,
MI->getRawDest(),
MI->getDestAlign(),
I.NewParam,
269 MI->getSourceAlign(),
MI->getLength(),
MI->isVolatile());
270 for (
unsigned I : {0, 1})
271 if (
uint64_t Bytes =
MI->getParamDereferenceableBytes(
I))
272 B->addDereferenceableParamAttr(
I, Bytes);
280 auto GetParamAddrCastToGeneric =
285 OriginalUser->getModule(), Intrinsic::nvvm_ptr_param_to_gen,
286 {ReturnTy, PointerType::get(OriginalUser->getContext(),
287 ADDRESS_SPACE_PARAM)});
290 Value *CvtToGenCall =
292 OriginalUser->getIterator());
295 auto *ParamInGenericAS =
296 GetParamAddrCastToGeneric(
I.NewParam,
I.OldInstruction);
299 if (
auto *
PHI = dyn_cast<PHINode>(
I.OldInstruction)) {
301 if (V.get() ==
I.OldUse->get())
302 PHI->setIncomingValue(
Idx, ParamInGenericAS);
305 if (
auto *SI = dyn_cast<SelectInst>(
I.OldInstruction)) {
306 if (SI->getTrueValue() ==
I.OldUse->get())
307 SI->setTrueValue(ParamInGenericAS);
308 if (SI->getFalseValue() ==
I.OldUse->get())
309 SI->setFalseValue(ParamInGenericAS);
314 if (IsGridConstant) {
315 if (
auto *CI = dyn_cast<CallInst>(
I.OldInstruction)) {
316 I.OldUse->set(ParamInGenericAS);
319 if (
auto *SI = dyn_cast<StoreInst>(
I.OldInstruction)) {
321 if (SI->getValueOperand() ==
I.OldUse->get())
322 SI->setOperand(0, ParamInGenericAS);
325 if (
auto *PI = dyn_cast<PtrToIntInst>(
I.OldInstruction)) {
326 if (PI->getPointerOperand() ==
I.OldUse->get())
327 PI->setOperand(0, ParamInGenericAS);
338 while (!ItemsToConvert.
empty()) {
340 Value *NewInst = CloneInstInParamAS(
I);
342 if (NewInst && NewInst !=
I.OldInstruction) {
346 for (
Use &U :
I.OldInstruction->uses())
347 ItemsToConvert.
push_back({&U, cast<Instruction>(U.getUser()), NewInst});
349 InstructionsToDelete.
push_back(
I.OldInstruction);
361 I->eraseFromParent();
381 if (CurArgAlign >= NewArgAlign)
384 LLVM_DEBUG(
dbgs() <<
"Try to use alignment " << NewArgAlign <<
" instead of "
385 << CurArgAlign <<
" for " << *Arg <<
'\n');
388 Attribute::get(Func->getContext(), Attribute::Alignment, NewArgAlign);
403 std::queue<LoadContext> Worklist;
404 Worklist.push({ArgInParamAS, 0});
407 while (!Worklist.empty()) {
408 LoadContext Ctx = Worklist.front();
411 for (
User *CurUser : Ctx.InitialVal->users()) {
412 if (
auto *
I = dyn_cast<LoadInst>(CurUser)) {
417 if (
auto *
I = dyn_cast<BitCastInst>(CurUser)) {
418 Worklist.push({
I, Ctx.Offset});
422 if (
auto *
I = dyn_cast<GetElementPtrInst>(CurUser)) {
423 APInt OffsetAccumulated =
426 if (!
I->accumulateConstantOffset(
DL, OffsetAccumulated))
431 assert(
Offset != OffsetLimit &&
"Expect Offset less than UINT64_MAX");
433 Worklist.push({
I, Ctx.Offset +
Offset});
437 if (isa<MemTransferInst>(CurUser))
441 if (IsGridConstant &&
442 (isa<CallInst>(CurUser) || isa<StoreInst>(CurUser) ||
443 isa<PtrToIntInst>(CurUser)))
447 "bitcast, getelementptr, call, store, ptrtoint");
451 for (Load &CurLoad : Loads) {
452 Align NewLoadAlign(std::gcd(NewArgAlign, CurLoad.Offset));
453 Align CurLoadAlign(CurLoad.Inst->getAlign());
454 CurLoad.Inst->setAlignment(std::max(NewLoadAlign, CurLoadAlign));
466 ArgUseChecker(
const DataLayout &
DL,
bool IsGridConstant)
470 assert(
A.getType()->isPointerTy());
471 IntegerType *IntIdxTy = cast<IntegerType>(
DL.getIndexType(
A.getType()));
472 IsOffsetKnown =
false;
475 Conditionals.
clear();
484 while (!(Worklist.empty() || PI.isAborted())) {
485 UseToVisit ToVisit = Worklist.pop_back_val();
486 U = ToVisit.UseAndIsOffsetKnown.getPointer();
488 if (isa<PHINode>(
I) || isa<SelectInst>(
I))
494 LLVM_DEBUG(
dbgs() <<
"Argument pointer escaped: " << *PI.getEscapingInst()
496 else if (PI.isAborted())
497 LLVM_DEBUG(
dbgs() <<
"Pointer use needs a copy: " << *PI.getAbortingInst()
500 <<
" conditionals\n");
506 if (
U->get() ==
SI.getValueOperand())
507 return PI.setEscapedAndAborted(&SI);
511 return PI.setAborted(&SI);
517 return PI.setEscapedAndAborted(&ASC);
527 assert(isa<PHINode>(
I) || isa<SelectInst>(
I));
534 if (*U ==
II.getRawDest() && !IsGridConstant)
569 IRB.CreateMemCpy(AllocA, AllocA->
getAlign(), ArgInParam, AllocA->
getAlign(),
585 ArgUseChecker AUC(
DL, IsGridConstant);
586 ArgUseChecker::PtrInfo PI = AUC.visitArgPtr(*Arg);
587 bool ArgUseIsReadOnly = !(PI.isEscaped() || PI.isAborted());
589 if (ArgUseIsReadOnly && AUC.Conditionals.empty()) {
599 for (
Use *U : UsesToUpdate)
604 cast<NVPTXTargetLowering>(
TM.getSubtargetImpl()->getTargetLowering());
616 if (HasCvtaParam && (ArgUseIsReadOnly || IsGridConstant)) {
617 LLVM_DEBUG(
dbgs() <<
"Using non-copy pointer to " << *Arg <<
"\n");
624 auto *CastToParam = cast<AddrSpaceCastInst>(IRB.CreateAddrSpaceCast(
632 Value *CvtToGenCall = IRB.CreateIntrinsic(
634 CastToParam,
nullptr, CastToParam->getName() +
".gen");
639 CastToParam->setOperand(0, Arg);
641 copyByValParam(*Func, *Arg);
644void NVPTXLowerArgs::markPointerAsGlobal(
Value *
Ptr) {
655 InsertPt = ++cast<Instruction>(
Ptr)->getIterator();
656 assert(InsertPt != InsertPt->getParent()->end() &&
657 "We don't call this function with Ptr being a terminator.");
662 Ptr->getName(), InsertPt);
664 Ptr->getName(), InsertPt);
666 Ptr->replaceAllUsesWith(PtrInGeneric);
679 auto HandleIntToPtr = [
this](
Value &
V) {
682 for (
User *U : UsersToUpdate)
683 markPointerAsGlobal(U);
686 if (
TM.getDrvInterface() == NVPTX::CUDA) {
690 if (
LoadInst *LI = dyn_cast<LoadInst>(&
I)) {
691 if (LI->getType()->isPointerTy() || LI->getType()->isIntegerTy()) {
693 if (
Argument *Arg = dyn_cast<Argument>(UO)) {
696 if (LI->getType()->isPointerTy())
697 markPointerAsGlobal(LI);
708 LLVM_DEBUG(
dbgs() <<
"Lowering kernel args of " <<
F.getName() <<
"\n");
712 handleByValParam(TM, &Arg);
713 else if (
TM.getDrvInterface() == NVPTX::CUDA)
714 markPointerAsGlobal(&Arg);
716 TM.getDrvInterface() == NVPTX::CUDA) {
726 LLVM_DEBUG(
dbgs() <<
"Lowering function args of " <<
F.getName() <<
"\n");
729 handleByValParam(TM, &Arg);
733bool NVPTXLowerArgs::runOnFunction(
Function &
F) {
737 : runOnDeviceFunction(
TM,
F);
743 LLVM_DEBUG(
dbgs() <<
"Creating a copy of byval args of " <<
F.getName()
745 bool Changed =
false;
749 copyByValParam(
F, Arg);
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
nvptx lower Lower arguments(NVPTX)"
static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS, const NVPTXTargetLowering *TLI)
static bool copyFunctionByValArgs(Function &F)
nvptx lower Lower static false void convertToParamAS(Use *OldUse, Value *Param, bool HasCvtaParam, bool IsGridConstant)
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
This file provides a collection of visitors which walk the (instruction) uses of a pointer.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
Target-Independent Code Generator Pass Configuration Options pass.
Class for arbitrary precision integers.
uint64_t getLimitedValue(uint64_t Limit=UINT64_MAX) const
If this value is smaller than the specified limit, return it, otherwise return the limit value.
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
This class represents a conversion between pointers from one address space to another.
unsigned getDestAddressSpace() const
Returns the address space of the result.
an instruction to allocate memory on the stack
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
std::optional< TypeSize > getAllocationSize(const DataLayout &DL) const
Get allocation size in bytes.
void setAlignment(Align Align)
A container for analyses that lazily runs them and caches their results.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
This class represents an incoming formal argument to a Function.
Attribute getAttribute(Attribute::AttrKind Kind) const
void addAttr(Attribute::AttrKind Kind)
bool hasByValAttr() const
Return true if this argument has the byval attribute.
void removeAttr(Attribute::AttrKind Kind)
Remove attributes from an argument.
const Function * getParent() const
unsigned getArgNo() const
Return the index of this formal argument in its containing function.
Type * getParamByValType() const
If this is a byval argument, return its type.
uint64_t getValueAsInt() const
Return the attribute's value as an integer.
static Attribute get(LLVMContext &Context, AttrKind Kind, uint64_t Val=0)
Return a uniquified Attribute object.
iterator begin()
Instruction iterator methods.
InstListType::iterator iterator
Instruction iterators...
This class represents a function call, abstracting a target machine's calling convention.
static CallInst * Create(FunctionType *Ty, Value *F, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
A parsed version of the target data layout string in and methods for querying it.
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
const BasicBlock & getEntryBlock() const
static GetElementPtrInst * Create(Type *PointeeType, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
CallInst * CreateMemTransferInst(Intrinsic::ID IntrID, Value *Dst, MaybeAlign DstAlign, Value *Src, MaybeAlign SrcAlign, Value *Size, bool isVolatile=false, MDNode *TBAATag=nullptr, MDNode *TBAAStructTag=nullptr, MDNode *ScopeTag=nullptr, MDNode *NoAliasTag=nullptr)
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
void visitPtrToIntInst(PtrToIntInst &I)
void visit(Iterator Start, Iterator End)
void visitPHINode(PHINode &I)
void visitAddrSpaceCastInst(AddrSpaceCastInst &I)
void visitMemTransferInst(MemTransferInst &I)
void visitMemSetInst(MemSetInst &I)
void visitSelectInst(SelectInst &I)
Class to represent integer types.
unsigned getBitWidth() const
Get the number of bits in this IntegerType.
An instruction for reading from memory.
This class wraps the llvm.memset and llvm.memset.inline intrinsics.
This class wraps the llvm.memcpy/memmove intrinsics.
Align getFunctionParamOptimizedAlign(const Function *F, Type *ArgTy, const DataLayout &DL) const
getFunctionParamOptimizedAlign - since function arguments are passed via .param space,...
PassRegistry - This class manages the registration and intitialization of the pass subsystem as appli...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
This class represents a cast from a pointer to an integer.
A base class for visitors over the uses of a pointer value.
void visitAddrSpaceCastInst(AddrSpaceCastInst &ASC)
void visitStoreInst(StoreInst &SI)
void visitPtrToIntInst(PtrToIntInst &I)
This class represents the LLVM 'select' instruction.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
StringRef - Represent a constant reference to a string, i.e.
Class to represent struct types.
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
bool isPointerTy() const
True if this is an instance of PointerType.
bool isIntegerTy() const
True if this is an instance of IntegerType.
A Use represents the edge between a Value definition and its users.
void setOperand(unsigned i, Value *Val)
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
LLVMContext & getContext() const
All values hold a context through their type.
iterator_range< use_iterator > uses()
StringRef getName() const
Return a constant reference to the value's name.
void enqueueUsers(Value &I)
Enqueue the users of this instruction in the visit worklist.
PtrInfo PI
The info collected about the pointer being visited thus far.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Function * getOrInsertDeclaration(Module *M, ID id, ArrayRef< Type * > Tys={})
Look up the Function declaration of the intrinsic id in the Module M.
NodeAddr< FuncNode * > Func
This is an optimization pass for GlobalISel generic memory operations.
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
bool isParamGridConstant(const Value &V)
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are tuples (A, B,...
const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)
This method strips off any GEP address adjustments, pointer casts or llvm.threadlocal....
void initializeNVPTXLowerArgsPass(PassRegistry &)
FunctionPass * createNVPTXLowerArgsPass()
auto reverse(ContainerTy &&C)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
bool isKernelFunction(const Function &F)
This struct is a compact representation of a valid (non-zero power of two) alignment.
uint64_t value() const
This is a hole in the type system and should not be abused.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)