65#define DEBUG_TYPE "amdgpu-split-module"
70 "amdgpu-module-splitting-large-function-threshold",
cl::init(2.0f),
73 "consider a function as large and needing special treatment when the "
74 "cost of importing it into a partition"
75 "exceeds the average cost of a partition by this factor; e;g. 2.0 "
76 "means if the function and its dependencies is 2 times bigger than "
77 "an average partition; 0 disables large functions handling entirely"));
80 "amdgpu-module-splitting-large-function-merge-overlap",
cl::init(0.8f),
83 "defines how much overlap between two large function's dependencies "
84 "is needed to put them in the same partition"));
87 "amdgpu-module-splitting-no-externalize-globals",
cl::Hidden,
88 cl::desc(
"disables externalization of global variable with local linkage; "
89 "may cause globals to be duplicated which increases binary size"));
92 LogDirOpt(
"amdgpu-module-splitting-log-dir",
cl::Hidden,
93 cl::desc(
"output directory for AMDGPU module splitting logs"));
96 LogPrivate(
"amdgpu-module-splitting-log-private",
cl::Hidden,
97 cl::desc(
"hash value names before printing them in the AMDGPU "
98 "module splitting logs"));
109 static bool HideNames;
114 HideNames = LogPrivate;
116 const auto EV = sys::Process::GetEnv(
"AMD_SPLIT_MODULE_LOG_PRIVATE");
117 HideNames = (EV.value_or(
"0") !=
"0");
122 return V.getName().str();
123 return toHex(
SHA256::hash(arrayRefFromStringRef(V.getName())),
160class SplitModuleLogger {
162 SplitModuleLogger(
const Module &M) {
163 std::string LogDir = LogDirOpt;
181 "': " + Err.message(),
185 FileOS = std::make_unique<raw_fd_ostream>(Fd,
true);
188 bool hasLogFile()
const {
return FileOS !=
nullptr; }
191 assert(FileOS &&
"no logfile!");
198 operator bool()
const {
203 std::unique_ptr<raw_fd_ostream> FileOS;
206template <
typename Ty>
207static SplitModuleLogger &
operator<<(SplitModuleLogger &SML,
const Ty &Val) {
209 !std::is_same_v<Ty, Value>,
210 "do not print values to logs directly, use handleName instead!");
212 if (SML.hasLogFile())
213 SML.logfile() << Val;
224calculateFunctionCosts(SplitModuleLogger &SML, GetTTIFn GetTTI,
Module &M,
226 CostType ModuleCost = 0;
227 CostType KernelCost = 0;
230 if (Fn.isDeclaration())
234 const auto &
TTI = GetTTI(Fn);
235 for (
const auto &BB : Fn) {
236 for (
const auto &
I : BB) {
243 assert((FnCost + CostVal) >= FnCost &&
"Overflow!");
250 CostMap[&Fn] = FnCost;
251 assert((ModuleCost + FnCost) >= ModuleCost &&
"Overflow!");
252 ModuleCost += FnCost;
255 KernelCost += FnCost;
258 CostType FnCost = (ModuleCost - KernelCost);
259 CostType ModuleCostOr1 = ModuleCost ? ModuleCost : 1;
260 SML <<
"=> Total Module Cost: " << ModuleCost <<
'\n'
261 <<
" => KernelCost: " << KernelCost <<
" ("
262 <<
format(
"%0.2f", (
float(KernelCost) / ModuleCostOr1) * 100) <<
"%)\n"
263 <<
" => FnsCost: " << FnCost <<
" ("
264 <<
format(
"%0.2f", (
float(FnCost) / ModuleCostOr1) * 100) <<
"%)\n";
269static bool canBeIndirectlyCalled(
const Function &
F) {
272 return !
F.hasLocalLinkage() ||
273 F.hasAddressTaken(
nullptr,
290static void addAllIndirectCallDependencies(
const Module &M,
292 for (
const auto &Fn : M) {
293 if (canBeIndirectlyCalled(Fn))
310static void addAllDependencies(SplitModuleLogger &SML,
const CallGraph &CG,
313 bool &HadIndirectCall) {
318 while (!WorkList.empty()) {
320 assert(!CurFn.isDeclaration());
326 for (
auto &CGEntry : *CG[&CurFn]) {
327 auto *CGNode = CGEntry.second;
328 auto *Callee = CGNode->getFunction();
341 SML <<
"Indirect call detected in " <<
getName(CurFn)
342 <<
" - treating all non-entrypoint functions as "
343 "potential dependencies\n";
346 addAllIndirectCallDependencies(M, Fns);
347 HadIndirectCall =
true;
351 if (Callee->isDeclaration())
354 auto [It, Inserted] = Fns.
insert(Callee);
356 WorkList.push_back(Callee);
364struct FunctionWithDependencies {
365 FunctionWithDependencies(SplitModuleLogger &SML,
CallGraph &CG,
372 addAllDependencies(SML, CG, *Fn, Dependencies,
374 TotalCost = FnCosts.
at(Fn);
375 for (
const auto *Dep : Dependencies) {
376 TotalCost += FnCosts.
at(Dep);
380 HasNonDuplicatableDependecy |=
381 (Dep->hasExternalLinkage() || !Dep->isDefinitionExact());
388 bool HasIndirectCall =
false;
390 bool HasNonDuplicatableDependecy =
false;
392 CostType TotalCost = 0;
396 bool isLarge(CostType Threshold)
const {
397 return TotalCost > Threshold && !Dependencies.
empty();
407 for (
const auto *
F :
A) {
415 unsigned NumCommon = 0;
416 for (
const auto *
F :
B) {
420 auto [It, Inserted] =
Total.insert(
F);
425 return static_cast<float>(NumCommon) /
Total.size();
436static std::vector<DenseSet<const Function *>>
437doPartitioning(SplitModuleLogger &SML,
Module &M,
unsigned NumParts,
442 SML <<
"\n--Partitioning Starts--\n";
451 const CostType LargeFnThreshold =
452 LargeFnFactor ? CostType(((ModuleCost / NumParts) * LargeFnFactor))
453 : std::numeric_limits<CostType>::max();
455 std::vector<DenseSet<const Function *>> Partitions;
456 Partitions.resize(NumParts);
467 auto ComparePartitions = [](
const std::pair<PartitionID, CostType> &a,
468 const std::pair<PartitionID, CostType> &b) {
472 if (a.second == b.second)
473 return a.first < b.first;
474 return a.second > b.second;
481 std::vector<std::pair<PartitionID, CostType>> BalancingQueue;
482 for (
unsigned I = 0;
I < NumParts; ++
I)
483 BalancingQueue.emplace_back(
I, 0);
487 const auto AssignToPartition = [&](PartitionID PID,
488 const FunctionWithDependencies &
FWD) {
489 auto &FnsInPart = Partitions[PID];
490 FnsInPart.insert(
FWD.Fn);
491 FnsInPart.insert(
FWD.Dependencies.begin(),
FWD.Dependencies.end());
493 SML <<
"assign " <<
getName(*
FWD.Fn) <<
" to P" << PID <<
"\n -> ";
494 if (!
FWD.Dependencies.empty()) {
495 SML <<
FWD.Dependencies.size() <<
" dependencies added\n";
500 for (
auto &[QueuePID,
Cost] :
reverse(BalancingQueue)) {
501 if (QueuePID == PID) {
502 CostType NewCost = 0;
503 for (
auto *Fn : Partitions[PID])
504 NewCost += FnCosts.
at(Fn);
506 SML <<
"[Updating P" << PID <<
" Cost]:" <<
Cost <<
" -> " << NewCost;
508 SML <<
" (" <<
unsigned(((
float(NewCost) /
Cost) - 1) * 100)
517 sort(BalancingQueue, ComparePartitions);
520 for (
auto &CurFn : WorkList) {
524 if (CurFn.HasIndirectCall) {
525 SML <<
"Function with indirect call(s): " <<
getName(*CurFn.Fn)
526 <<
" defaulting to P0\n";
527 AssignToPartition(0, CurFn);
535 if (CurFn.HasNonDuplicatableDependecy) {
536 SML <<
"Function with externally visible dependency "
537 <<
getName(*CurFn.Fn) <<
" defaulting to P0\n";
538 AssignToPartition(0, CurFn);
543 if (CurFn.isLarge(LargeFnThreshold)) {
544 assert(LargeFnOverlapForMerge >= 0.0f && LargeFnOverlapForMerge <= 1.0f);
545 SML <<
"Large Function: " <<
getName(*CurFn.Fn)
546 <<
" - looking for partition with at least "
547 <<
format(
"%0.2f", LargeFnOverlapForMerge * 100) <<
"% overlap\n";
549 bool Assigned =
false;
550 for (
const auto &[PID, Fns] :
enumerate(Partitions)) {
551 float Overlap = calculateOverlap(CurFn.Dependencies, Fns);
552 SML <<
" => " <<
format(
"%0.2f", Overlap * 100) <<
"% overlap with P"
554 if (Overlap > LargeFnOverlapForMerge) {
555 SML <<
" selecting P" << PID <<
'\n';
556 AssignToPartition(PID, CurFn);
566 auto [PID, CurCost] = BalancingQueue.back();
567 AssignToPartition(PID, CurFn);
571 CostType ModuleCostOr1 = ModuleCost ? ModuleCost : 1;
574 for (
auto *Fn : Part)
576 SML <<
"P" <<
Idx <<
" has a total cost of " <<
Cost <<
" ("
577 <<
format(
"%0.2f", (
float(
Cost) / ModuleCostOr1) * 100)
578 <<
"% of source module)\n";
581 SML <<
"--Partitioning Done--\n\n";
587 for (
const auto &Part : Partitions)
588 AllFunctions.
insert(Part.begin(), Part.end());
609 GV.
setName(
"__llvmsplit_unnamed");
612static bool hasDirectCaller(
const Function &Fn) {
613 for (
auto &U : Fn.
uses()) {
614 if (
auto *CB = dyn_cast<CallBase>(U.getUser()); CB && CB->isCallee(&U))
620static void splitAMDGPUModule(
621 GetTTIFn GetTTI,
Module &M,
unsigned N,
622 function_ref<
void(std::unique_ptr<Module> MPart)> ModuleCallback) {
624 SplitModuleLogger SML(M);
642 SML <<
"[externalize] " << Fn.
getName()
643 <<
" because its address is taken\n";
651 if (!NoExternalizeGlobals) {
652 for (
auto &GV : M.globals()) {
654 SML <<
"[externalize] GV " << GV.
getName() <<
'\n';
662 const CostType ModuleCost = calculateFunctionCosts(SML, GetTTI, M, FnCosts);
676 for (
const auto &
FWD : WorkList) {
678 SeenFunctions.
insert(
FWD.Dependencies.begin(),
FWD.Dependencies.end());
685 !SeenFunctions.
count(&Fn) && !hasDirectCaller(Fn)) {
686 WorkList.emplace_back(SML, CG, FnCosts, &Fn);
692 sort(WorkList, [&](
auto &
A,
auto &
B) {
695 if (
A.TotalCost ==
B.TotalCost)
696 return A.Fn->getName() <
B.Fn->getName();
697 return A.TotalCost >
B.TotalCost;
702 for (
const auto &
FWD : WorkList) {
703 SML <<
"[root] " <<
getName(*
FWD.Fn) <<
" (totalCost:" <<
FWD.TotalCost
704 <<
" indirect:" <<
FWD.HasIndirectCall
705 <<
" hasNonDuplicatableDep:" <<
FWD.HasNonDuplicatableDependecy
709 SortedDepNames.
reserve(
FWD.Dependencies.size());
710 for (
const auto *Dep :
FWD.Dependencies)
712 sort(SortedDepNames);
714 for (
const auto &
Name : SortedDepNames)
715 SML <<
" [dependency] " <<
Name <<
'\n';
720 auto Partitions = doPartitioning(SML, M,
N, ModuleCost, FnCosts, WorkList);
721 assert(Partitions.size() ==
N);
726 const auto NeedsConservativeImport = [&](
const GlobalValue *GV) {
729 const auto *Var = dyn_cast<GlobalVariable>(GV);
730 return Var && Var->hasLocalLinkage();
733 SML <<
"Creating " <<
N <<
" modules...\n";
734 unsigned TotalFnImpls = 0;
735 for (
unsigned I = 0;
I <
N; ++
I) {
736 const auto &FnsInPart = Partitions[
I];
739 std::unique_ptr<Module> MPart(
742 if (
const auto *Fn = dyn_cast<Function>(GV))
743 return FnsInPart.contains(Fn);
745 if (NeedsConservativeImport(GV))
754 if (NeedsConservativeImport(&GV) && GV.
use_empty())
758 unsigned NumAllFns = 0, NumKernels = 0;
759 for (
auto &Cur : *MPart) {
760 if (!Cur.isDeclaration()) {
766 TotalFnImpls += NumAllFns;
767 SML <<
" - Module " <<
I <<
" with " << NumAllFns <<
" functions ("
768 << NumKernels <<
" kernels)\n";
769 ModuleCallback(std::move(MPart));
772 SML << TotalFnImpls <<
" function definitions across all modules ("
773 <<
format(
"%0.2f", (
float(TotalFnImpls) / FnCosts.
size()) * 100)
774 <<
"% of original module)\n";
785 splitAMDGPUModule(TTIGetter, M, N, ModuleCallback);
The AMDGPU TargetMachine interface definition for hw codegen targets.
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
This file provides interfaces used to build and manipulate a call graph, which is a very useful tool ...
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
This file defines the DenseMap class.
Module.h This file contains the declarations for the Module class.
FunctionAnalysisManager FAM
ModuleAnalysisManager MAM
Provides a library for accessing information about this process and other processes on the operating ...
static StringRef getName(Value *V)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the SmallVector class.
static void externalize(GlobalValue *GV)
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
A container for analyses that lazily runs them and caches their results.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
The basic data container for the call graph of a Module of IR.
CallGraphNode * getCallsExternalNode() const
const ValueT & at(const_arg_type_t< KeyT > Val) const
at - Return the entry for the specified key, or abort if no such entry exists.
Implements a dense probed hash-table based set.
bool hasAddressTaken(const User **=nullptr, bool IgnoreCallbackUses=false, bool IgnoreAssumeLikeCalls=true, bool IngoreLLVMUsed=false, bool IgnoreARCAttachedCall=false, bool IgnoreCastedDirectCall=false) const
hasAddressTaken - returns true if there are any uses of this function other than direct calls or invo...
bool isDeclaration() const
Return true if the primary definition of this global value is outside of the current translation unit...
bool hasLocalLinkage() const
void setLinkage(LinkageTypes LT)
Module * getParent()
Get the module that this global value is contained inside of...
void eraseFromParent()
This method unlinks 'this' from the containing module and deletes it.
@ HiddenVisibility
The GV is hidden.
void setVisibility(VisibilityTypes V)
@ ExternalLinkage
Externally visible function.
An analysis over an "outer" IR unit that provides access to an analysis manager over an "inner" IR un...
static InstructionCost getMax()
std::optional< CostType > getValue() const
This function is intended to be used as sparingly as possible, since the class provides the full rang...
A Module instance is used to store all the information related to an LLVM module.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
static std::array< uint8_t, 32 > hash(ArrayRef< uint8_t > Data)
Returns a raw 256-bit SHA256 hash for the given data.
SmallString - A SmallString is just a SmallVector with methods and accessors that make it work better...
StringRef str() const
Explicit conversion to StringRef.
reference emplace_back(ArgTypes &&... Args)
void reserve(size_type N)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Analysis pass providing the TargetTransformInfo.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
LLVM Value Representation.
void setName(const Twine &Name)
Change the name of the value.
iterator_range< use_iterator > uses()
StringRef getName() const
Return a constant reference to the value's name.
int getNumOccurrences() const
std::pair< iterator, bool > insert(const ValueT &V)
bool contains(const_arg_type_t< ValueT > V) const
Check if the set contains the given element.
size_type count(const_arg_type_t< ValueT > V) const
Return 1 if the specified key is in the set, 0 otherwise.
An efficient, type-erasing, non-owning reference to a callable.
This class implements an extremely fast bulk output stream that can only output to a stream.
static std::optional< std::string > GetEnv(StringRef name)
bool isEntryFunctionCC(CallingConv::ID CC)
initializer< Ty > init(const Ty &Val)
std::error_code createUniqueFile(const Twine &Model, int &ResultFD, SmallVectorImpl< char > &ResultPath, OpenFlags Flags=OF_None, unsigned Mode=all_read|all_write)
Create a uniquely named file.
void append(SmallVectorImpl< char > &path, const Twine &a, const Twine &b="", const Twine &c="", const Twine &d="")
Append to path.
This is an optimization pass for GlobalISel generic memory operations.
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are tuples (A, B,...
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
bool DebugFlag
This boolean is set to true if the '-debug' command line option is specified.
auto reverse(ContainerTy &&C)
void sort(IteratorTy Start, IteratorTy End)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
format_object< Ts... > format(const char *Fmt, const Ts &... Vals)
These are helper functions used to produce formatted output.
bool isEntryPoint(const Function &F)
bool isCurrentDebugType(const char *Type)
isCurrentDebugType - Return true if the specified string is the debug type specified on the command l...
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
void call_once(once_flag &flag, Function &&F, Args &&... ArgList)
Execute the function specified as a parameter once.
std::unique_ptr< Module > CloneModule(const Module &M)
Return an exact copy of the specified module.
The llvm::once_flag structure.