147#include "llvm/IR/IntrinsicsNVPTX.h"
155#define DEBUG_TYPE "nvptx-lower-args"
176 void markPointerAsGlobal(
Value *
Ptr);
182 return "Lower pointer arguments of CUDA kernels";
190char NVPTXLowerArgs::ID = 1;
193 "Lower arguments (NVPTX)",
false,
false)
221 Instruction *
I = dyn_cast<Instruction>(OldUse->getUser());
222 assert(
I &&
"OldUse must be in an instruction");
231 auto CloneInstInParamAS = [GridConstant](
const IP &
I) ->
Value * {
232 if (
auto *LI = dyn_cast<LoadInst>(
I.OldInstruction)) {
233 LI->setOperand(0,
I.NewParam);
236 if (
auto *
GEP = dyn_cast<GetElementPtrInst>(
I.OldInstruction)) {
239 GEP->getSourceElementType(),
I.NewParam, Indices,
GEP->getName(),
241 NewGEP->setIsInBounds(
GEP->isInBounds());
244 if (
auto *BC = dyn_cast<BitCastInst>(
I.OldInstruction)) {
246 return BitCastInst::Create(BC->getOpcode(),
I.NewParam, NewBCType,
247 BC->getName(), BC->getIterator());
249 if (
auto *ASC = dyn_cast<AddrSpaceCastInst>(
I.OldInstruction)) {
257 auto GetParamAddrCastToGeneric =
262 OriginalUser->getModule(), Intrinsic::nvvm_ptr_param_to_gen,
263 {ReturnTy, PointerType::get(OriginalUser->getContext(),
264 ADDRESS_SPACE_PARAM)});
267 Value *CvtToGenCall =
269 OriginalUser->getIterator());
273 if (
auto *CI = dyn_cast<CallInst>(
I.OldInstruction)) {
274 I.OldUse->set(GetParamAddrCastToGeneric(
I.NewParam, CI));
277 if (
auto *SI = dyn_cast<StoreInst>(
I.OldInstruction)) {
279 if (SI->getValueOperand() ==
I.OldUse->get())
280 SI->setOperand(0, GetParamAddrCastToGeneric(
I.NewParam, SI));
283 if (
auto *PI = dyn_cast<PtrToIntInst>(
I.OldInstruction)) {
284 if (PI->getPointerOperand() ==
I.OldUse->get())
285 PI->setOperand(0, GetParamAddrCastToGeneric(
I.NewParam, PI));
289 "Instruction unsupported even for grid_constant argument");
295 while (!ItemsToConvert.
empty()) {
297 Value *NewInst = CloneInstInParamAS(
I);
299 if (NewInst && NewInst !=
I.OldInstruction) {
303 for (
Use &U :
I.OldInstruction->uses())
304 ItemsToConvert.
push_back({&U, cast<Instruction>(U.getUser()), NewInst});
306 InstructionsToDelete.
push_back(
I.OldInstruction);
318 I->eraseFromParent();
338 if (CurArgAlign >= NewArgAlign)
341 LLVM_DEBUG(
dbgs() <<
"Try to use alignment " << NewArgAlign <<
" instead of "
342 << CurArgAlign <<
" for " << *Arg <<
'\n');
345 Attribute::get(Func->getContext(), Attribute::Alignment, NewArgAlign);
360 std::queue<LoadContext> Worklist;
361 Worklist.push({ArgInParamAS, 0});
364 while (!Worklist.empty()) {
365 LoadContext Ctx = Worklist.front();
368 for (
User *CurUser : Ctx.InitialVal->users()) {
369 if (
auto *
I = dyn_cast<LoadInst>(CurUser)) {
374 if (
auto *
I = dyn_cast<BitCastInst>(CurUser)) {
375 Worklist.push({
I, Ctx.Offset});
379 if (
auto *
I = dyn_cast<GetElementPtrInst>(CurUser)) {
380 APInt OffsetAccumulated =
383 if (!
I->accumulateConstantOffset(
DL, OffsetAccumulated))
388 assert(
Offset != OffsetLimit &&
"Expect Offset less than UINT64_MAX");
390 Worklist.push({
I, Ctx.Offset +
Offset});
395 if (IsGridConstant &&
396 (isa<CallInst>(CurUser) || isa<StoreInst>(CurUser) ||
397 isa<PtrToIntInst>(CurUser)))
401 "bitcast, getelementptr, call, store, ptrtoint");
405 for (Load &CurLoad : Loads) {
406 Align NewLoadAlign(std::gcd(NewArgAlign, CurLoad.Offset));
407 Align CurLoadAlign(CurLoad.Inst->getAlign());
408 CurLoad.Inst->setAlignment(std::max(NewLoadAlign, CurLoadAlign));
420 auto AreSupportedUsers = [&](
Value *Start) {
422 auto IsSupportedUse = [IsGridConstant](
Value *
V) ->
bool {
423 if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V))
426 if (
auto *ASC = dyn_cast<AddrSpaceCastInst>(V)) {
432 if (IsGridConstant &&
433 (isa<CallInst>(V) || isa<StoreInst>(V) || isa<PtrToIntInst>(V)))
438 while (!ValuesToCheck.
empty()) {
440 if (!IsSupportedUse(V)) {
443 <<
"of " << *Arg <<
" because of " << *V <<
"\n");
447 if (!isa<LoadInst>(V) && !isa<CallInst>(V) && !isa<StoreInst>(V) &&
448 !isa<PtrToIntInst>(V))
464 for (
Use *U : UsesToUpdate)
469 cast<NVPTXTargetLowering>(
TM.getSubtargetImpl()->getTargetLowering());
477 unsigned AS =
DL.getAllocaAddrSpace();
486 auto *CastToParam = cast<AddrSpaceCastInst>(IRB.CreateAddrSpaceCast(
494 Value *CvtToGenCall = IRB.CreateIntrinsic(
496 CastToParam,
nullptr, CastToParam->getName() +
".gen");
501 CastToParam->setOperand(0, Arg);
521 false, AllocA->
getAlign(), FirstInst);
526void NVPTXLowerArgs::markPointerAsGlobal(
Value *
Ptr) {
537 InsertPt = ++cast<Instruction>(
Ptr)->getIterator();
538 assert(InsertPt != InsertPt->getParent()->end() &&
539 "We don't call this function with Ptr being a terminator.");
544 Ptr->getName(), InsertPt);
546 Ptr->getName(), InsertPt);
548 Ptr->replaceAllUsesWith(PtrInGeneric);
561 auto HandleIntToPtr = [
this](
Value &
V) {
564 for (
User *U : UsersToUpdate)
565 markPointerAsGlobal(U);
568 if (
TM.getDrvInterface() == NVPTX::CUDA) {
572 if (
LoadInst *LI = dyn_cast<LoadInst>(&
I)) {
573 if (LI->getType()->isPointerTy() || LI->getType()->isIntegerTy()) {
575 if (
Argument *Arg = dyn_cast<Argument>(UO)) {
578 if (LI->getType()->isPointerTy())
579 markPointerAsGlobal(LI);
590 LLVM_DEBUG(
dbgs() <<
"Lowering kernel args of " <<
F.getName() <<
"\n");
594 handleByValParam(
TM, &Arg);
595 else if (
TM.getDrvInterface() == NVPTX::CUDA)
596 markPointerAsGlobal(&Arg);
598 TM.getDrvInterface() == NVPTX::CUDA) {
608 LLVM_DEBUG(
dbgs() <<
"Lowering function args of " <<
F.getName() <<
"\n");
611 handleByValParam(
TM, &Arg);
615bool NVPTXLowerArgs::runOnFunction(
Function &
F) {
619 : runOnDeviceFunction(
TM,
F);
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
Module.h This file contains the declarations for the Module class.
nvptx lower Lower arguments(NVPTX)"
nvptx lower Lower static false void convertToParamAS(Use *OldUse, Value *Param, bool GridConstant)
static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS, const NVPTXTargetLowering *TLI)
const char LLVMTargetMachineRef TM
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
Target-Independent Code Generator Pass Configuration Options pass.
Class for arbitrary precision integers.
uint64_t getLimitedValue(uint64_t Limit=UINT64_MAX) const
If this value is smaller than the specified limit, return it, otherwise return the limit value.
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
This class represents a conversion between pointers from one address space to another.
an instruction to allocate memory on the stack
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
void setAlignment(Align Align)
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
This class represents an incoming formal argument to a Function.
Attribute getAttribute(Attribute::AttrKind Kind) const
void addAttr(Attribute::AttrKind Kind)
bool hasByValAttr() const
Return true if this argument has the byval attribute.
void removeAttr(Attribute::AttrKind Kind)
Remove attributes from an argument.
const Function * getParent() const
unsigned getArgNo() const
Return the index of this formal argument in its containing function.
Type * getParamByValType() const
If this is a byval argument, return its type.
uint64_t getValueAsInt() const
Return the attribute's value as an integer.
static Attribute get(LLVMContext &Context, AttrKind Kind, uint64_t Val=0)
Return a uniquified Attribute object.
iterator begin()
Instruction iterator methods.
InstListType::iterator iterator
Instruction iterators...
static CallInst * Create(FunctionType *Ty, Value *F, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
A parsed version of the target data layout string in and methods for querying it.
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
const BasicBlock & getEntryBlock() const
static GetElementPtrInst * Create(Type *PointeeType, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
An instruction for reading from memory.
Align getFunctionParamOptimizedAlign(const Function *F, Type *ArgTy, const DataLayout &DL) const
getFunctionParamOptimizedAlign - since function arguments are passed via .param space,...
PassRegistry - This class manages the registration and intitialization of the pass subsystem as appli...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
StringRef - Represent a constant reference to a string, i.e.
Class to represent struct types.
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
bool isPointerTy() const
True if this is an instance of PointerType.
bool isIntegerTy() const
True if this is an instance of IntegerType.
A Use represents the edge between a Value definition and its users.
void setOperand(unsigned i, Value *Val)
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
iterator_range< user_iterator > users()
LLVMContext & getContext() const
All values hold a context through their type.
iterator_range< use_iterator > uses()
StringRef getName() const
Return a constant reference to the value's name.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Function * getDeclaration(Module *M, ID id, ArrayRef< Type * > Tys=std::nullopt)
Create or insert an LLVM Function declaration for an intrinsic, and return it.
NodeAddr< FuncNode * > Func
This is an optimization pass for GlobalISel generic memory operations.
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
bool isParamGridConstant(const Value &V)
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)
This method strips off any GEP address adjustments, pointer casts or llvm.threadlocal....
void initializeNVPTXLowerArgsPass(PassRegistry &)
FunctionPass * createNVPTXLowerArgsPass()
auto reverse(ContainerTy &&C)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
bool isKernelFunction(const Function &F)
This struct is a compact representation of a valid (non-zero power of two) alignment.
uint64_t value() const
This is a hole in the type system and should not be abused.