LLVM 22.0.0git
NVVMReflect.cpp
Go to the documentation of this file.
1//===- NVVMReflect.cpp - NVVM Emulate conditional compilation -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass replaces occurrences of __nvvm_reflect("foo") and llvm.nvvm.reflect
10// with an integer.
11//
12// We choose the value we use by looking at metadata in the module itself. Note
13// that we intentionally only have one way to choose these values, because other
14// parts of LLVM (particularly, InstCombineCall) rely on being able to predict
15// the values chosen by this pass.
16//
17// If we see an unknown string, we replace its call with 0.
18//
19//===----------------------------------------------------------------------===//
20
21#include "NVPTX.h"
26#include "llvm/IR/Constants.h"
28#include "llvm/IR/Function.h"
30#include "llvm/IR/Intrinsics.h"
31#include "llvm/IR/IntrinsicsNVPTX.h"
32#include "llvm/IR/Module.h"
33#include "llvm/IR/PassManager.h"
34#include "llvm/IR/Type.h"
35#include "llvm/Pass.h"
37#include "llvm/Support/Debug.h"
42#define NVVM_REFLECT_FUNCTION "__nvvm_reflect"
43#define NVVM_REFLECT_OCL_FUNCTION "__nvvm_reflect_ocl"
44// Argument of reflect call to retrive arch number
45#define CUDA_ARCH_NAME "__CUDA_ARCH"
46// Argument of reflect call to retrive ftz mode
47#define CUDA_FTZ_NAME "__CUDA_FTZ"
48// Name of module metadata where ftz mode is stored
49#define CUDA_FTZ_MODULE_NAME "nvvm-reflect-ftz"
50
51using namespace llvm;
52
53#define DEBUG_TYPE "nvvm-reflect"
54
55namespace {
56class NVVMReflect {
57 // Map from reflect function call arguments to the value to replace the call
58 // with. Should include __CUDA_FTZ and __CUDA_ARCH values.
59 StringMap<unsigned> ReflectMap;
60 bool handleReflectFunction(Module &M, StringRef ReflectName);
61 void populateReflectMap(Module &M);
62 void foldReflectCall(CallInst *Call, Constant *NewValue);
63
64public:
65 // __CUDA_FTZ is assigned in `runOnModule` by checking nvvm-reflect-ftz module
66 // metadata.
67 explicit NVVMReflect(unsigned SmVersion)
68 : ReflectMap({{CUDA_ARCH_NAME, SmVersion * 10}}) {}
69 bool runOnModule(Module &M);
70};
71
72class NVVMReflectLegacyPass : public ModulePass {
73 NVVMReflect Impl;
74
75public:
76 static char ID;
77 NVVMReflectLegacyPass(unsigned SmVersion) : ModulePass(ID), Impl(SmVersion) {}
78 bool runOnModule(Module &M) override;
79};
80} // namespace
81
82ModulePass *llvm::createNVVMReflectPass(unsigned SmVersion) {
83 return new NVVMReflectLegacyPass(SmVersion);
84}
85
86static cl::opt<bool>
87 NVVMReflectEnabled("nvvm-reflect-enable", cl::init(true), cl::Hidden,
88 cl::desc("NVVM reflection, enabled by default"));
89
90char NVVMReflectLegacyPass::ID = 0;
91INITIALIZE_PASS(NVVMReflectLegacyPass, "nvvm-reflect",
92 "Replace occurrences of __nvvm_reflect() calls with 0/1", false,
93 false)
94
95// Allow users to specify additional key/value pairs to reflect. These key/value
96// pairs are the last to be added to the ReflectMap, and therefore will take
97// precedence over initial values (i.e. __CUDA_FTZ from module medadata and
98// __CUDA_ARCH from SmVersion).
99static cl::list<std::string> ReflectList(
100 "nvvm-reflect-add", cl::value_desc("name=<int>"), cl::Hidden,
101 cl::desc("A key=value pair. Replace __nvvm_reflect(name) with value."),
102 cl::ValueRequired);
103
104// Set the ReflectMap with, first, the value of __CUDA_FTZ from module metadata,
105// and then the key/value pairs from the command line.
106void NVVMReflect::populateReflectMap(Module &M) {
108 M.getModuleFlag(CUDA_FTZ_MODULE_NAME)))
109 ReflectMap[CUDA_FTZ_NAME] = Flag->getSExtValue();
110
111 for (auto &Option : ReflectList) {
112 LLVM_DEBUG(dbgs() << "ReflectOption : " << Option << "\n");
113 StringRef OptionRef(Option);
114 auto [Name, Val] = OptionRef.split('=');
115 if (Name.empty())
116 report_fatal_error(Twine("Empty name in nvvm-reflect-add option '") +
117 Option + "'");
118 if (Val.empty())
119 report_fatal_error(Twine("Missing value in nvvm-reflect-add option '") +
120 Option + "'");
121 unsigned ValInt;
122 if (!to_integer(Val.trim(), ValInt, 10))
124 Twine("integer value expected in nvvm-reflect-add option '") +
125 Option + "'");
126 ReflectMap[Name] = ValInt;
127 }
128}
129
130/// Process a reflect function by finding all its calls and replacing them with
131/// appropriate constant values. For __CUDA_FTZ, uses the module flag value.
132/// For __CUDA_ARCH, uses SmVersion * 10. For all other strings, uses 0.
133bool NVVMReflect::handleReflectFunction(Module &M, StringRef ReflectName) {
134 Function *F = M.getFunction(ReflectName);
135 if (!F)
136 return false;
137 assert(F->isDeclaration() && "_reflect function should not have a body");
138 assert(F->getReturnType()->isIntegerTy() &&
139 "_reflect's return type should be integer");
140
141 const bool Changed = !F->use_empty();
142 for (User *U : make_early_inc_range(F->users())) {
143 // Reflect function calls look like:
144 // @arch = private unnamed_addr addrspace(1) constant [12 x i8]
145 // c"__CUDA_ARCH\00" call i32 @__nvvm_reflect(ptr addrspacecast (ptr
146 // addrspace(1) @arch to ptr)) We need to extract the string argument from
147 // the call (i.e. "__CUDA_ARCH")
148 auto *Call = dyn_cast<CallInst>(U);
149 if (!Call)
151 "__nvvm_reflect can only be used in a call instruction");
152 if (Call->getNumOperands() != 2)
153 report_fatal_error("__nvvm_reflect requires exactly one argument");
154
155 auto *GlobalStr =
157 if (!GlobalStr)
158 report_fatal_error("__nvvm_reflect argument must be a constant string");
159
160 auto *ConstantStr =
162 if (!ConstantStr)
163 report_fatal_error("__nvvm_reflect argument must be a string constant");
164 if (!ConstantStr->isCString())
166 "__nvvm_reflect argument must be a null-terminated string");
167
168 StringRef ReflectArg = ConstantStr->getAsString().drop_back();
169 if (ReflectArg.empty())
170 report_fatal_error("__nvvm_reflect argument cannot be empty");
171 // Now that we have extracted the string argument, we can look it up in the
172 // ReflectMap
173 unsigned ReflectVal = 0; // The default value is 0
174 if (ReflectMap.contains(ReflectArg))
175 ReflectVal = ReflectMap[ReflectArg];
176
177 LLVM_DEBUG(dbgs() << "Replacing call of reflect function " << F->getName()
178 << "(" << ReflectArg << ") with value " << ReflectVal
179 << "\n");
180 auto *NewValue = ConstantInt::get(Call->getType(), ReflectVal);
181 foldReflectCall(Call, NewValue);
183 }
184
185 // Remove the __nvvm_reflect function from the module
186 F->eraseFromParent();
187 return Changed;
188}
189
190void NVVMReflect::foldReflectCall(CallInst *Call, Constant *NewValue) {
191 SmallVector<Instruction *, 8> Worklist;
192 // Replace an instruction with a constant and add all users of the instruction
193 // to the worklist
194 auto ReplaceInstructionWithConst = [&](Instruction *I, Constant *C) {
195 for (auto *U : I->users())
196 if (auto *UI = dyn_cast<Instruction>(U))
197 Worklist.push_back(UI);
198 I->replaceAllUsesWith(C);
199 };
200
201 ReplaceInstructionWithConst(Call, NewValue);
202
203 auto &DL = Call->getModule()->getDataLayout();
204 while (!Worklist.empty()) {
205 auto *I = Worklist.pop_back_val();
206 if (auto *C = ConstantFoldInstruction(I, DL)) {
207 ReplaceInstructionWithConst(I, C);
209 I->eraseFromParent();
210 } else if (I->isTerminator()) {
211 ConstantFoldTerminator(I->getParent());
212 }
213 }
214}
215
216bool NVVMReflect::runOnModule(Module &M) {
218 return false;
219 populateReflectMap(M);
220 bool Changed = true;
221 Changed |= handleReflectFunction(M, NVVM_REFLECT_FUNCTION);
222 Changed |= handleReflectFunction(M, NVVM_REFLECT_OCL_FUNCTION);
223 Changed |=
224 handleReflectFunction(M, Intrinsic::getName(Intrinsic::nvvm_reflect));
225 return Changed;
226}
227
228bool NVVMReflectLegacyPass::runOnModule(Module &M) {
229 return Impl.runOnModule(M);
230}
231
233 return NVVMReflect(SmVersion).runOnModule(M) ? PreservedAnalyses::none()
235}
for(const MachineOperand &MO :llvm::drop_begin(OldMI.operands(), Desc.getNumOperands()))
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Module.h This file contains the declarations for the Module class.
This header defines various interfaces for pass management in LLVM.
#define F(x, y, z)
Definition MD5.cpp:55
#define I(x, y, z)
Definition MD5.cpp:58
Machine Check Debug Module
static cl::opt< std::string > GlobalStr("nvptx-lower-global-ctor-dtor-id", cl::desc("Override unique ID of ctor/dtor globals."), cl::init(""), cl::Hidden)
#define CUDA_FTZ_MODULE_NAME
#define NVVM_REFLECT_OCL_FUNCTION
#define NVVM_REFLECT_FUNCTION
static cl::opt< bool > NVVMReflectEnabled("nvvm-reflect-enable", cl::init(true), cl::Hidden, cl::desc("NVVM reflection, enabled by default"))
#define CUDA_ARCH_NAME
#define CUDA_FTZ_NAME
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition PassSupport.h:56
This file defines the SmallVector class.
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition Debug.h:114
Value * getArgOperand(unsigned i) const
This class represents a function call, abstracting a target machine's calling convention.
This is an important base class in LLVM.
Definition Constant.h:43
LLVM_ABI const Module * getModule() const
Return the module owning the function this instruction belongs to or nullptr it the function does not...
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition Pass.h:255
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
const DataLayout & getDataLayout() const
Get the data layout for the module's target platform.
Definition Module.h:278
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Definition Analysis.h:115
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
void push_back(const T &Elt)
StringMap - This is an unconventional map that is specialized for handling keys that are "strings",...
Definition StringMap.h:133
bool contains(StringRef Key) const
contains - Return true if the element is in the map, false otherwise.
Definition StringMap.h:275
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
constexpr bool empty() const
empty - Check if the string is empty.
Definition StringRef.h:151
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition Twine.h:82
unsigned getNumOperands() const
Definition User.h:254
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
LLVM_ABI const Value * stripPointerCasts() const
Strip off pointer casts, all-zero GEPs and address space casts.
Definition Value.cpp:701
CallInst * Call
Changed
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
LLVM_ABI StringRef getName(ID id)
Return the LLVM name for an intrinsic, such as "llvm.ppc.altivec.lvx".
This namespace contains all of the command line option processing machinery.
Definition CommandLine.h:53
initializer< Ty > init(const Ty &Val)
std::enable_if_t< detail::IsValidPointer< X, Y >::value, X * > extract_or_null(Y &&MD)
Extract a Value from Metadata, allowing null.
Definition Metadata.h:681
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
LLVM_ABI bool ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions=false, const TargetLibraryInfo *TLI=nullptr, DomTreeUpdater *DTU=nullptr)
If a terminator instruction is predicated on a constant value, convert it into an unconditional branc...
Definition Local.cpp:134
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:649
LLVM_ABI Constant * ConstantFoldInstruction(const Instruction *I, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr)
ConstantFoldInstruction - Try to constant fold the specified instruction.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition STLExtras.h:634
LLVM_ABI bool isInstructionTriviallyDead(Instruction *I, const TargetLibraryInfo *TLI=nullptr)
Return true if the result produced by the instruction is not used, and the instruction will return.
Definition Local.cpp:402
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
LLVM_ABI void report_fatal_error(Error Err, bool gen_crash_diag=true)
Definition Error.cpp:167
ModulePass * createNVVMReflectPass(unsigned int SmVersion)
bool to_integer(StringRef S, N &Num, unsigned Base=0)
Convert the string S to an integer of the specified type using the radix Base. If Base is 0,...
AnalysisManager< Module > ModuleAnalysisManager
Convenience typedef for the Module analysis manager.
Definition MIRParser.h:39
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:851
PreservedAnalyses run(Module &F, ModuleAnalysisManager &AM)