46#include "llvm/IR/IntrinsicsAMDGPU.h"
47#include "llvm/IR/IntrinsicsNVPTX.h"
62#define DEBUG_TYPE "openmp-opt"
65 "openmp-opt-disable",
cl::desc(
"Disable OpenMP specific optimizations."),
69 "openmp-opt-enable-merging",
75 cl::desc(
"Disable function internalization."),
86 "openmp-hide-memory-transfer-latency",
87 cl::desc(
"[WIP] Tries to hide the latency of host to device memory"
92 "openmp-opt-disable-deglobalization",
93 cl::desc(
"Disable OpenMP optimizations involving deglobalization."),
97 "openmp-opt-disable-spmdization",
98 cl::desc(
"Disable OpenMP optimizations involving SPMD-ization."),
102 "openmp-opt-disable-folding",
107 "openmp-opt-disable-state-machine-rewrite",
108 cl::desc(
"Disable OpenMP optimizations that replace the state machine."),
112 "openmp-opt-disable-barrier-elimination",
113 cl::desc(
"Disable OpenMP optimizations that eliminate barriers."),
117 "openmp-opt-print-module-after",
118 cl::desc(
"Print the current module after OpenMP optimizations."),
122 "openmp-opt-print-module-before",
123 cl::desc(
"Print the current module before OpenMP optimizations."),
127 "openmp-opt-inline-device",
138 cl::desc(
"Maximal number of attributor iterations."),
143 cl::desc(
"Maximum amount of shared memory to use."),
144 cl::init(std::numeric_limits<unsigned>::max()));
147 "Number of OpenMP runtime calls deduplicated");
149 "Number of OpenMP parallel regions deleted");
151 "Number of OpenMP runtime functions identified");
153 "Number of OpenMP runtime function uses identified");
155 "Number of OpenMP target region entry points (=kernels) identified");
157 "Number of OpenMP target region entry points (=kernels) executed in "
158 "SPMD-mode instead of generic-mode");
159STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
160 "Number of OpenMP target region entry points (=kernels) executed in "
161 "generic-mode without a state machines");
162STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
163 "Number of OpenMP target region entry points (=kernels) executed in "
164 "generic-mode with customized state machines with fallback");
165STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
166 "Number of OpenMP target region entry points (=kernels) executed in "
167 "generic-mode with customized state machines without fallback");
169 NumOpenMPParallelRegionsReplacedInGPUStateMachine,
170 "Number of OpenMP parallel regions replaced with ID in GPU state machines");
172 "Number of OpenMP parallel regions merged");
174 "Amount of memory pushed to shared memory");
175STATISTIC(NumBarriersEliminated,
"Number of redundant barriers eliminated");
183struct AAHeapToShared;
194 Kernels(Kernels), OpenMPPostLink(OpenMPPostLink) {
196 OMPBuilder.initialize();
197 initializeRuntimeFunctions(M);
198 initializeInternalControlVars();
202 struct InternalControlVarInfo {
229 struct RuntimeFunctionInfo {
253 void clearUsesMap() { UsesMap.
clear(); }
256 operator bool()
const {
return Declaration; }
259 UseVector &getOrCreateUseVector(
Function *
F) {
260 std::shared_ptr<UseVector> &UV = UsesMap[
F];
262 UV = std::make_shared<UseVector>();
268 const UseVector *getUseVector(
Function &
F)
const {
269 auto I = UsesMap.find(&
F);
270 if (
I != UsesMap.end())
271 return I->second.get();
276 size_t getNumFunctionsWithUses()
const {
return UsesMap.size(); }
280 size_t getNumArgs()
const {
return ArgumentTypes.
size(); }
298 UseVector &UV = getOrCreateUseVector(
F);
308 while (!ToBeDeleted.
empty()) {
322 decltype(UsesMap)::iterator begin() {
return UsesMap.
begin(); }
323 decltype(UsesMap)::iterator end() {
return UsesMap.
end(); }
331 RuntimeFunction::OMPRTL___last>
339 InternalControlVar::ICV___last>
344 void initializeInternalControlVars() {
345#define ICV_RT_SET(_Name, RTL) \
347 auto &ICV = ICVs[_Name]; \
350#define ICV_RT_GET(Name, RTL) \
352 auto &ICV = ICVs[Name]; \
355#define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \
357 auto &ICV = ICVs[Enum]; \
360 ICV.InitKind = Init; \
361 ICV.EnvVarName = _EnvVarName; \
362 switch (ICV.InitKind) { \
363 case ICV_IMPLEMENTATION_DEFINED: \
364 ICV.InitValue = nullptr; \
367 ICV.InitValue = ConstantInt::get( \
368 Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \
371 ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \
377#include "llvm/Frontend/OpenMP/OMPKinds.def"
383 static bool declMatchesRTFTypes(
Function *
F,
Type *RTFRetType,
390 if (
F->getReturnType() != RTFRetType)
392 if (
F->arg_size() != RTFArgTypes.
size())
395 auto *RTFTyIt = RTFArgTypes.
begin();
397 if (
Arg.getType() != *RTFTyIt)
407 unsigned collectUses(RuntimeFunctionInfo &RFI,
bool CollectStats =
true) {
408 unsigned NumUses = 0;
409 if (!RFI.Declaration)
414 NumOpenMPRuntimeFunctionsIdentified += 1;
415 NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
419 for (
Use &U : RFI.Declaration->uses()) {
420 if (
Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
421 if (ModuleSlice.empty() || ModuleSlice.count(UserI->getFunction())) {
422 RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
426 RFI.getOrCreateUseVector(
nullptr).push_back(&U);
435 auto &RFI = RFIs[RTF];
437 collectUses(RFI,
false);
441 void recollectUses() {
442 for (
int Idx = 0;
Idx < RFIs.size(); ++
Idx)
462 RuntimeFunctionInfo &RFI = RFIs[Fn];
464 if (RFI.Declaration && RFI.Declaration->isDeclaration())
472 void initializeRuntimeFunctions(
Module &M) {
475#define OMP_TYPE(VarName, ...) \
476 Type *VarName = OMPBuilder.VarName; \
479#define OMP_ARRAY_TYPE(VarName, ...) \
480 ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \
482 PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \
483 (void)VarName##PtrTy;
485#define OMP_FUNCTION_TYPE(VarName, ...) \
486 FunctionType *VarName = OMPBuilder.VarName; \
488 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
491#define OMP_STRUCT_TYPE(VarName, ...) \
492 StructType *VarName = OMPBuilder.VarName; \
494 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
497#define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \
499 SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \
500 Function *F = M.getFunction(_Name); \
501 RTLFunctions.insert(F); \
502 if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \
503 RuntimeFunctionIDMap[F] = _Enum; \
504 auto &RFI = RFIs[_Enum]; \
507 RFI.IsVarArg = _IsVarArg; \
508 RFI.ReturnType = OMPBuilder._ReturnType; \
509 RFI.ArgumentTypes = std::move(ArgsTypes); \
510 RFI.Declaration = F; \
511 unsigned NumUses = collectUses(RFI); \
514 dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \
516 if (RFI.Declaration) \
517 dbgs() << TAG << "-> got " << NumUses << " uses in " \
518 << RFI.getNumFunctionsWithUses() \
519 << " different functions.\n"; \
523#include "llvm/Frontend/OpenMP/OMPKinds.def"
529 for (
StringRef Prefix : {
"__kmpc",
"_ZN4ompx",
"omp_"})
530 if (
F.hasFnAttribute(Attribute::NoInline) &&
531 F.getName().startswith(Prefix) &&
532 !
F.hasFnAttribute(Attribute::OptimizeNone))
533 F.removeFnAttr(Attribute::NoInline);
547 bool OpenMPPostLink =
false;
550template <
typename Ty,
bool InsertInval
idates = true>
552 bool contains(
const Ty &Elem)
const {
return Set.contains(Elem); }
553 bool insert(
const Ty &Elem) {
554 if (InsertInvalidates)
556 return Set.insert(Elem);
559 const Ty &operator[](
int Idx)
const {
return Set[
Idx]; }
560 bool operator==(
const BooleanStateWithSetVector &
RHS)
const {
561 return BooleanState::operator==(
RHS) && Set ==
RHS.Set;
563 bool operator!=(
const BooleanStateWithSetVector &
RHS)
const {
564 return !(*
this ==
RHS);
567 bool empty()
const {
return Set.empty(); }
568 size_t size()
const {
return Set.size(); }
571 BooleanStateWithSetVector &operator^=(
const BooleanStateWithSetVector &
RHS) {
572 BooleanState::operator^=(
RHS);
573 Set.insert(
RHS.Set.begin(),
RHS.Set.end());
582 typename decltype(Set)::iterator begin() {
return Set.
begin(); }
583 typename decltype(Set)::iterator end() {
return Set.
end(); }
588template <
typename Ty,
bool InsertInval
idates = true>
589using BooleanStateWithPtrSetVector =
590 BooleanStateWithSetVector<Ty *, InsertInvalidates>;
594 bool IsAtFixpoint =
false;
598 BooleanStateWithPtrSetVector<
Function,
false>
599 ReachedKnownParallelRegions;
602 BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
607 BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
618 bool IsKernelEntry =
false;
621 BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
626 BooleanStateWithSetVector<uint8_t> ParallelLevels;
629 bool NestedParallelism =
false;
634 KernelInfoState() =
default;
635 KernelInfoState(
bool BestState) {
644 bool isAtFixpoint()
const override {
return IsAtFixpoint; }
649 ParallelLevels.indicatePessimisticFixpoint();
650 ReachingKernelEntries.indicatePessimisticFixpoint();
651 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
652 ReachedKnownParallelRegions.indicatePessimisticFixpoint();
653 ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
660 ParallelLevels.indicateOptimisticFixpoint();
661 ReachingKernelEntries.indicateOptimisticFixpoint();
662 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
663 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
664 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
669 KernelInfoState &getAssumed() {
return *
this; }
670 const KernelInfoState &getAssumed()
const {
return *
this; }
673 if (SPMDCompatibilityTracker !=
RHS.SPMDCompatibilityTracker)
675 if (ReachedKnownParallelRegions !=
RHS.ReachedKnownParallelRegions)
677 if (ReachedUnknownParallelRegions !=
RHS.ReachedUnknownParallelRegions)
679 if (ReachingKernelEntries !=
RHS.ReachingKernelEntries)
681 if (ParallelLevels !=
RHS.ParallelLevels)
687 bool mayContainParallelRegion() {
688 return !ReachedKnownParallelRegions.empty() ||
689 !ReachedUnknownParallelRegions.empty();
693 static KernelInfoState getBestState() {
return KernelInfoState(
true); }
695 static KernelInfoState getBestState(KernelInfoState &KIS) {
696 return getBestState();
700 static KernelInfoState getWorstState() {
return KernelInfoState(
false); }
703 KernelInfoState operator^=(
const KernelInfoState &KIS) {
705 if (KIS.KernelInitCB) {
706 if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
709 KernelInitCB = KIS.KernelInitCB;
711 if (KIS.KernelDeinitCB) {
712 if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
715 KernelDeinitCB = KIS.KernelDeinitCB;
717 SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
718 ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
719 ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
720 NestedParallelism |= KIS.NestedParallelism;
724 KernelInfoState
operator&=(
const KernelInfoState &KIS) {
725 return (*
this ^= KIS);
741 OffloadArray() =
default;
748 if (!Array.getAllocatedType()->isArrayTy())
751 if (!getValues(Array, Before))
754 this->Array = &Array;
758 static const unsigned DeviceIDArgNum = 1;
759 static const unsigned BasePtrsArgNum = 3;
760 static const unsigned PtrsArgNum = 4;
761 static const unsigned SizesArgNum = 5;
769 const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
770 StoredValues.
assign(NumValues,
nullptr);
771 LastAccesses.
assign(NumValues,
nullptr);
779 const DataLayout &
DL = Array.getModule()->getDataLayout();
780 const unsigned int PointerSize =
DL.getPointerSize();
786 if (!isa<StoreInst>(&
I))
789 auto *S = cast<StoreInst>(&
I);
796 LastAccesses[
Idx] = S;
806 const unsigned NumValues = StoredValues.
size();
807 for (
unsigned I = 0;
I < NumValues; ++
I) {
808 if (!StoredValues[
I] || !LastAccesses[
I])
818 using OptimizationRemarkGetter =
822 OptimizationRemarkGetter OREGetter,
823 OMPInformationCache &OMPInfoCache,
Attributor &A)
824 : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
825 OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
828 bool remarksEnabled() {
829 auto &Ctx = M.getContext();
830 return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(
DEBUG_TYPE);
834 bool run(
bool IsModulePass) {
838 bool Changed =
false;
841 <<
" functions in a slice with "
842 << OMPInfoCache.ModuleSlice.size() <<
" functions\n");
845 Changed |= runAttributor(IsModulePass);
848 OMPInfoCache.recollectUses();
851 Changed |= rewriteDeviceCodeStateMachine();
853 if (remarksEnabled())
854 analysisGlobalization();
861 Changed |= runAttributor(IsModulePass);
864 OMPInfoCache.recollectUses();
866 Changed |= deleteParallelRegions();
869 Changed |= hideMemTransfersLatency();
870 Changed |= deduplicateRuntimeCalls();
872 if (mergeParallelRegions()) {
873 deduplicateRuntimeCalls();
884 void printICVs()
const {
889 for (
auto ICV : ICVs) {
890 auto ICVInfo = OMPInfoCache.ICVs[ICV];
892 return ORA <<
"OpenMP ICV " <<
ore::NV(
"OpenMPICV", ICVInfo.Name)
894 << (ICVInfo.InitValue
895 ?
toString(ICVInfo.InitValue->getValue(), 10,
true)
896 :
"IMPLEMENTATION_DEFINED");
899 emitRemark<OptimizationRemarkAnalysis>(
F,
"OpenMPICVTracker",
Remark);
905 void printKernels()
const {
907 if (!OMPInfoCache.Kernels.count(
F))
911 return ORA <<
"OpenMP GPU kernel "
912 <<
ore::NV(
"OpenMPGPUKernel",
F->getName()) <<
"\n";
915 emitRemark<OptimizationRemarkAnalysis>(
F,
"OpenMPGPU",
Remark);
921 static CallInst *getCallIfRegularCall(
922 Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI =
nullptr) {
923 CallInst *CI = dyn_cast<CallInst>(U.getUser());
933 static CallInst *getCallIfRegularCall(
934 Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI =
nullptr) {
935 CallInst *CI = dyn_cast<CallInst>(&V);
945 bool mergeParallelRegions() {
946 const unsigned CallbackCalleeOperand = 2;
947 const unsigned CallbackFirstArgOperand = 3;
951 OMPInformationCache::RuntimeFunctionInfo &RFI =
952 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
954 if (!RFI.Declaration)
958 OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
959 OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
960 OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
963 bool Changed =
false;
969 BasicBlock *StartBB =
nullptr, *EndBB =
nullptr;
970 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
973 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
974 assert(StartBB !=
nullptr &&
"StartBB should not be null");
976 assert(EndBB !=
nullptr &&
"EndBB should not be null");
977 EndBB->getTerminator()->setSuccessor(0, CGEndBB);
980 auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
Value &,
981 Value &Inner,
Value *&ReplacementValue) -> InsertPointTy {
982 ReplacementValue = &Inner;
986 auto FiniCB = [&](InsertPointTy CodeGenIP) {};
990 auto CreateSequentialRegion = [&](
Function *OuterFn,
1002 SplitBlock(ParentBB, SeqStartI, DT, LI,
nullptr,
"seq.par.merged");
1005 "Expected a different CFG");
1009 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1010 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1012 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1013 assert(SeqStartBB !=
nullptr &&
"SeqStartBB should not be null");
1015 assert(SeqEndBB !=
nullptr &&
"SeqEndBB should not be null");
1018 auto FiniCB = [&](InsertPointTy CodeGenIP) {};
1024 for (
User *Usr :
I.users()) {
1032 OutsideUsers.
insert(&UsrI);
1035 if (OutsideUsers.
empty())
1042 I.getType(),
DL.getAllocaAddrSpace(),
nullptr,
1043 I.getName() +
".seq.output.alloc", &OuterFn->
front().
front());
1053 I.getType(), AllocaI,
I.getName() +
".seq.output.load", UsrI);
1054 UsrI->replaceUsesOfWith(&
I, LoadI);
1059 InsertPointTy(ParentBB, ParentBB->
end()),
DL);
1060 InsertPointTy SeqAfterIP =
1061 OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
1063 OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
1082 assert(MergableCIs.
size() > 1 &&
"Assumed multiple mergable CIs");
1085 OR <<
"Parallel region merged with parallel region"
1086 << (MergableCIs.
size() > 2 ?
"s" :
"") <<
" at ";
1088 OR <<
ore::NV(
"OpenMPParallelMerge", CI->getDebugLoc());
1089 if (CI != MergableCIs.
back())
1095 emitRemark<OptimizationRemark>(MergableCIs.
front(),
"OMP150",
Remark);
1099 <<
" parallel regions in " << OriginalFn->
getName()
1103 EndBB =
SplitBlock(BB, MergableCIs.
back()->getNextNode(), DT, LI);
1105 SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1109 assert(BB->getUniqueSuccessor() == StartBB &&
"Expected a different CFG");
1110 const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1115 for (
auto *It = MergableCIs.
begin(), *
End = MergableCIs.
end() - 1;
1124 CreateSequentialRegion(OriginalFn, BB, ForkCI->
getNextNode(),
1135 InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
1136 Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB,
nullptr,
nullptr,
1137 OMP_PROC_BIND_default,
false);
1141 OMPInfoCache.OMPBuilder.finalize(OriginalFn);
1148 for (
auto *CI : MergableCIs) {
1149 Value *
Callee = CI->getArgOperand(CallbackCalleeOperand);
1150 FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask;
1152 Args.push_back(OutlinedFn->
getArg(0));
1153 Args.push_back(OutlinedFn->
getArg(1));
1154 for (
unsigned U = CallbackFirstArgOperand,
E = CI->arg_size(); U <
E;
1156 Args.push_back(CI->getArgOperand(U));
1159 if (CI->getDebugLoc())
1163 for (
unsigned U = CallbackFirstArgOperand,
E = CI->arg_size(); U <
E;
1165 for (
const Attribute &A : CI->getAttributes().getParamAttrs(U))
1167 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1170 if (CI != MergableCIs.back()) {
1173 OMPInfoCache.OMPBuilder.createBarrier(
1179 CI->eraseFromParent();
1182 assert(OutlinedFn != OriginalFn &&
"Outlining failed");
1186 NumOpenMPParallelRegionsMerged += MergableCIs.size();
1194 CallInst *CI = getCallIfRegularCall(U, &RFI);
1201 RFI.foreachUse(SCC, DetectPRsCB);
1207 for (
auto &It : BB2PRMap) {
1208 auto &CIs = It.getSecond();
1223 auto IsMergable = [&](
Instruction &
I,
bool IsBeforeMergableRegion) {
1226 if (
I.isTerminator())
1229 if (!isa<CallInst>(&
I))
1233 if (IsBeforeMergableRegion) {
1235 if (!CalledFunction)
1242 for (
const auto &RFI : UnmergableCallsInfo) {
1243 if (CalledFunction == RFI.Declaration)
1251 if (!isa<IntrinsicInst>(CI))
1262 if (CIs.count(&
I)) {
1268 if (IsMergable(
I, MergableCIs.
empty()))
1273 for (; It !=
End; ++It) {
1275 if (CIs.count(&SkipI)) {
1277 <<
" due to " <<
I <<
"\n");
1284 if (MergableCIs.
size() > 1) {
1285 MergableCIsVector.
push_back(MergableCIs);
1287 <<
" parallel regions in block " << BB->
getName()
1292 MergableCIs.
clear();
1295 if (!MergableCIsVector.
empty()) {
1298 for (
auto &MergableCIs : MergableCIsVector)
1299 Merge(MergableCIs, BB);
1300 MergableCIsVector.clear();
1307 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1308 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1309 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1310 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1317 bool deleteParallelRegions() {
1318 const unsigned CallbackCalleeOperand = 2;
1320 OMPInformationCache::RuntimeFunctionInfo &RFI =
1321 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1323 if (!RFI.Declaration)
1326 bool Changed =
false;
1328 CallInst *CI = getCallIfRegularCall(U);
1331 auto *Fn = dyn_cast<Function>(
1335 if (!Fn->onlyReadsMemory())
1337 if (!Fn->hasFnAttribute(Attribute::WillReturn))
1344 return OR <<
"Removing parallel region with no side-effects.";
1346 emitRemark<OptimizationRemark>(CI,
"OMP160",
Remark);
1351 ++NumOpenMPParallelRegionsDeleted;
1355 RFI.foreachUse(SCC, DeleteCallCB);
1361 bool deduplicateRuntimeCalls() {
1362 bool Changed =
false;
1365 OMPRTL_omp_get_num_threads,
1366 OMPRTL_omp_in_parallel,
1367 OMPRTL_omp_get_cancellation,
1368 OMPRTL_omp_get_thread_limit,
1369 OMPRTL_omp_get_supported_active_levels,
1370 OMPRTL_omp_get_level,
1371 OMPRTL_omp_get_ancestor_thread_num,
1372 OMPRTL_omp_get_team_size,
1373 OMPRTL_omp_get_active_level,
1374 OMPRTL_omp_in_final,
1375 OMPRTL_omp_get_proc_bind,
1376 OMPRTL_omp_get_num_places,
1377 OMPRTL_omp_get_num_procs,
1378 OMPRTL_omp_get_place_num,
1379 OMPRTL_omp_get_partition_num_places,
1380 OMPRTL_omp_get_partition_place_nums};
1384 collectGlobalThreadIdArguments(GTIdArgs);
1386 <<
" global thread ID arguments\n");
1389 for (
auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1390 Changed |= deduplicateRuntimeCalls(
1391 *
F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1395 Value *GTIdArg =
nullptr;
1401 Changed |= deduplicateRuntimeCalls(
1402 *
F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1414 bool hideMemTransfersLatency() {
1415 auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1416 bool Changed =
false;
1417 auto SplitMemTransfers = [&](
Use &U,
Function &Decl) {
1418 auto *RTCall = getCallIfRegularCall(U, &RFI);
1422 OffloadArray OffloadArrays[3];
1423 if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1426 LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1429 bool WasSplit =
false;
1430 Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1431 if (WaitMovementPoint)
1432 WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1434 Changed |= WasSplit;
1437 if (OMPInfoCache.runtimeFnsAvailable(
1438 {OMPRTL___tgt_target_data_begin_mapper_issue,
1439 OMPRTL___tgt_target_data_begin_mapper_wait}))
1440 RFI.foreachUse(SCC, SplitMemTransfers);
1445 void analysisGlobalization() {
1446 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1448 auto CheckGlobalization = [&](
Use &U,
Function &Decl) {
1449 if (
CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1452 <<
"Found thread data sharing on the GPU. "
1453 <<
"Expect degraded performance due to data globalization.";
1455 emitRemark<OptimizationRemarkMissed>(CI,
"OMP112",
Remark);
1461 RFI.foreachUse(SCC, CheckGlobalization);
1466 bool getValuesInOffloadArrays(
CallInst &RuntimeCall,
1468 assert(OAs.
size() == 3 &&
"Need space for three offload arrays!");
1478 Value *BasePtrsArg =
1487 if (!isa<AllocaInst>(V))
1489 auto *BasePtrsArray = cast<AllocaInst>(V);
1490 if (!OAs[0].
initialize(*BasePtrsArray, RuntimeCall))
1495 if (!isa<AllocaInst>(V))
1497 auto *PtrsArray = cast<AllocaInst>(V);
1498 if (!OAs[1].
initialize(*PtrsArray, RuntimeCall))
1504 if (isa<GlobalValue>(V))
1505 return isa<Constant>(V);
1506 if (!isa<AllocaInst>(V))
1509 auto *SizesArray = cast<AllocaInst>(V);
1510 if (!OAs[2].
initialize(*SizesArray, RuntimeCall))
1521 assert(OAs.
size() == 3 &&
"There are three offload arrays to debug!");
1524 std::string ValuesStr;
1526 std::string Separator =
" --- ";
1528 for (
auto *BP : OAs[0].StoredValues) {
1535 for (
auto *
P : OAs[1].StoredValues) {
1542 for (
auto *S : OAs[2].StoredValues) {
1556 bool IsWorthIt =
false;
1579 bool splitTargetDataBeginRTC(
CallInst &RuntimeCall,
1584 auto &
IRBuilder = OMPInfoCache.OMPBuilder;
1588 Entry.getFirstNonPHIOrDbgOrAlloca());
1590 IRBuilder.AsyncInfo,
nullptr,
"handle");
1598 M, OMPRTL___tgt_target_data_begin_mapper_issue);
1602 for (
auto &
Arg : RuntimeCall.
args())
1603 Args.push_back(
Arg.get());
1604 Args.push_back(Handle);
1608 OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);
1614 M, OMPRTL___tgt_target_data_begin_mapper_wait);
1616 Value *WaitParams[2] = {
1618 OffloadArray::DeviceIDArgNum),
1622 WaitDecl, WaitParams,
"", &WaitMovementPoint);
1623 OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);
1628 static Value *combinedIdentStruct(
Value *CurrentIdent,
Value *NextIdent,
1629 bool GlobalOnly,
bool &SingleChoice) {
1630 if (CurrentIdent == NextIdent)
1631 return CurrentIdent;
1635 if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1636 SingleChoice = !CurrentIdent;
1648 getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1650 bool SingleChoice =
true;
1651 Value *Ident =
nullptr;
1652 auto CombineIdentStruct = [&](
Use &U,
Function &Caller) {
1653 CallInst *CI = getCallIfRegularCall(U, &RFI);
1654 if (!CI || &
F != &Caller)
1657 true, SingleChoice);
1660 RFI.foreachUse(SCC, CombineIdentStruct);
1662 if (!Ident || !SingleChoice) {
1665 if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1667 &
F.getEntryBlock(),
F.getEntryBlock().begin()));
1672 OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1673 Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize);
1680 bool deduplicateRuntimeCalls(
Function &
F,
1681 OMPInformationCache::RuntimeFunctionInfo &RFI,
1682 Value *ReplVal =
nullptr) {
1683 auto *UV = RFI.getUseVector(
F);
1684 if (!UV || UV->size() + (ReplVal !=
nullptr) < 2)
1688 dbgs() <<
TAG <<
"Deduplicate " << UV->size() <<
" uses of " << RFI.Name
1689 << (ReplVal ?
" with an existing value\n" :
"\n") <<
"\n");
1691 assert((!ReplVal || (isa<Argument>(ReplVal) &&
1692 cast<Argument>(ReplVal)->
getParent() == &
F)) &&
1693 "Unexpected replacement value!");
1696 auto CanBeMoved = [
this](
CallBase &CB) {
1697 unsigned NumArgs = CB.arg_size();
1700 if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1702 for (
unsigned U = 1; U < NumArgs; ++U)
1703 if (isa<Instruction>(CB.getArgOperand(U)))
1714 for (
Use *U : *UV) {
1715 if (
CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1717 IP = DT->findNearestCommonDominator(IP, CI);
1720 if (!CanBeMoved(*CI))
1728 assert(IP &&
"Expected insertion point!");
1729 cast<Instruction>(ReplVal)->moveBefore(IP);
1735 if (
CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1736 if (!CI->arg_empty() &&
1737 CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1738 Value *Ident = getCombinedIdentFromCallUsesIn(RFI,
F,
1740 CI->setArgOperand(0, Ident);
1744 bool Changed =
false;
1745 auto ReplaceAndDeleteCB = [&](
Use &U,
Function &Caller) {
1746 CallInst *CI = getCallIfRegularCall(U, &RFI);
1747 if (!CI || CI == ReplVal || &
F != &Caller)
1752 return OR <<
"OpenMP runtime call "
1753 <<
ore::NV(
"OpenMPOptRuntime", RFI.Name) <<
" deduplicated.";
1756 emitRemark<OptimizationRemark>(CI,
"OMP170",
Remark);
1758 emitRemark<OptimizationRemark>(&
F,
"OMP170",
Remark);
1763 ++NumOpenMPRuntimeCallsDeduplicated;
1767 RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1781 if (!
F.hasLocalLinkage())
1783 for (
Use &U :
F.uses()) {
1784 if (
CallInst *CI = getCallIfRegularCall(U)) {
1785 Value *ArgOp = CI->getArgOperand(ArgNo);
1786 if (CI == &RefCI || GTIdArgs.
count(ArgOp) ||
1787 getCallIfRegularCall(
1788 *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1797 auto AddUserArgs = [&](
Value >Id) {
1798 for (
Use &U : GTId.uses())
1799 if (
CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1800 if (CI->isArgOperand(&U))
1802 if (CallArgOpIsGTId(*
Callee, U.getOperandNo(), *CI))
1807 OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1808 OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1810 GlobThreadNumRFI.foreachUse(SCC, [&](
Use &U,
Function &
F) {
1811 if (
CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1819 for (
unsigned U = 0; U < GTIdArgs.
size(); ++U)
1820 AddUserArgs(*GTIdArgs[U]);
1828 bool isKernel(
Function &
F) {
return OMPInfoCache.Kernels.count(&
F); }
1838 return getUniqueKernelFor(*
I.getFunction());
1843 bool rewriteDeviceCodeStateMachine();
1859 template <
typename RemarkKind,
typename RemarkCallBack>
1861 RemarkCallBack &&RemarkCB)
const {
1863 auto &ORE = OREGetter(
F);
1867 return RemarkCB(RemarkKind(
DEBUG_TYPE, RemarkName,
I))
1868 <<
" [" << RemarkName <<
"]";
1872 [&]() {
return RemarkCB(RemarkKind(
DEBUG_TYPE, RemarkName,
I)); });
1876 template <
typename RemarkKind,
typename RemarkCallBack>
1878 RemarkCallBack &&RemarkCB)
const {
1879 auto &ORE = OREGetter(
F);
1883 return RemarkCB(RemarkKind(
DEBUG_TYPE, RemarkName,
F))
1884 <<
" [" << RemarkName <<
"]";
1888 [&]() {
return RemarkCB(RemarkKind(
DEBUG_TYPE, RemarkName,
F)); });
1902 OptimizationRemarkGetter OREGetter;
1905 OMPInformationCache &OMPInfoCache;
1911 bool runAttributor(
bool IsModulePass) {
1915 registerAAs(IsModulePass);
1920 <<
" functions, result: " << Changed <<
".\n");
1922 return Changed == ChangeStatus::CHANGED;
1929 void registerAAs(
bool IsModulePass);
1938 if (!OMPInfoCache.ModuleSlice.empty() && !OMPInfoCache.ModuleSlice.count(&
F))
1943 std::optional<Kernel> &CachedKernel = UniqueKernelMap[&
F];
1945 return *CachedKernel;
1952 return *CachedKernel;
1955 CachedKernel =
nullptr;
1956 if (!
F.hasLocalLinkage()) {
1960 return ORA <<
"Potentially unknown OpenMP target region caller.";
1962 emitRemark<OptimizationRemarkAnalysis>(&
F,
"OMP100",
Remark);
1968 auto GetUniqueKernelForUse = [&](
const Use &U) ->
Kernel {
1969 if (
auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
1971 if (Cmp->isEquality())
1972 return getUniqueKernelFor(*Cmp);
1975 if (
auto *CB = dyn_cast<CallBase>(U.getUser())) {
1977 if (CB->isCallee(&U))
1978 return getUniqueKernelFor(*CB);
1980 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1981 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
1983 if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
1984 return getUniqueKernelFor(*CB);
1993 OMPInformationCache::foreachUse(
F, [&](
const Use &U) {
1994 PotentialKernels.
insert(GetUniqueKernelForUse(U));
1998 if (PotentialKernels.
size() == 1)
1999 K = *PotentialKernels.
begin();
2002 UniqueKernelMap[&
F] = K;
2007bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
2008 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2009 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2011 bool Changed =
false;
2012 if (!KernelParallelRFI)
2023 bool UnknownUse =
false;
2024 bool KernelParallelUse =
false;
2025 unsigned NumDirectCalls = 0;
2028 OMPInformationCache::foreachUse(*
F, [&](
Use &U) {
2029 if (
auto *CB = dyn_cast<CallBase>(U.getUser()))
2030 if (CB->isCallee(&U)) {
2035 if (isa<ICmpInst>(U.getUser())) {
2036 ToBeReplacedStateMachineUses.push_back(&U);
2042 OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
2043 const unsigned int WrapperFunctionArgNo = 6;
2044 if (!KernelParallelUse && CI &&
2045 CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
2046 KernelParallelUse = true;
2047 ToBeReplacedStateMachineUses.push_back(&U);
2055 if (!KernelParallelUse)
2061 if (UnknownUse || NumDirectCalls != 1 ||
2062 ToBeReplacedStateMachineUses.
size() > 2) {
2064 return ORA <<
"Parallel region is used in "
2065 << (UnknownUse ?
"unknown" :
"unexpected")
2066 <<
" ways. Will not attempt to rewrite the state machine.";
2068 emitRemark<OptimizationRemarkAnalysis>(
F,
"OMP101",
Remark);
2074 Kernel K = getUniqueKernelFor(*
F);
2077 return ORA <<
"Parallel region is not called from a unique kernel. "
2078 "Will not attempt to rewrite the state machine.";
2080 emitRemark<OptimizationRemarkAnalysis>(
F,
"OMP102",
Remark);
2096 for (
Use *U : ToBeReplacedStateMachineUses)
2098 ID, U->get()->getType()));
2100 ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2109struct AAICVTracker :
public StateWrapper<BooleanState, AbstractAttribute> {
2115 if (!
F || !
A.isFunctionIPOAmendable(*
F))
2116 indicatePessimisticFixpoint();
2120 bool isAssumedTracked()
const {
return getAssumed(); }
2123 bool isKnownTracked()
const {
return getAssumed(); }
2132 return std::nullopt;
2138 virtual std::optional<Value *>
2146 const std::string
getName()
const override {
return "AAICVTracker"; }
2149 const char *getIdAddr()
const override {
return &
ID; }
2156 static const char ID;
2159struct AAICVTrackerFunction :
public AAICVTracker {
2161 : AAICVTracker(IRP,
A) {}
2164 const std::string getAsStr()
const override {
return "ICVTrackerFunction"; }
2167 void trackStatistics()
const override {}
2171 return ChangeStatus::UNCHANGED;
2176 InternalControlVar::ICV___last>
2177 ICVReplacementValuesMap;
2184 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
2187 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2189 auto &ValuesMap = ICVReplacementValuesMap[ICV];
2191 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2197 if (ValuesMap.insert(std::make_pair(CI, CI->
getArgOperand(0))).second)
2198 HasChanged = ChangeStatus::CHANGED;
2204 std::optional<Value *> ReplVal = getValueForCall(
A,
I, ICV);
2205 if (ReplVal && ValuesMap.insert(std::make_pair(&
I, *ReplVal)).second)
2206 HasChanged = ChangeStatus::CHANGED;
2212 SetterRFI.foreachUse(TrackValues,
F);
2214 bool UsedAssumedInformation =
false;
2215 A.checkForAllInstructions(CallCheck, *
this, {Instruction::Call},
2216 UsedAssumedInformation,
2222 if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
2223 ValuesMap.insert(std::make_pair(Entry,
nullptr));
2234 const auto *CB = dyn_cast<CallBase>(&
I);
2235 if (!CB || CB->hasFnAttr(
"no_openmp") ||
2236 CB->hasFnAttr(
"no_openmp_routines"))
2237 return std::nullopt;
2239 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
2240 auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2241 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2242 Function *CalledFunction = CB->getCalledFunction();
2245 if (CalledFunction ==
nullptr)
2247 if (CalledFunction == GetterRFI.Declaration)
2248 return std::nullopt;
2249 if (CalledFunction == SetterRFI.Declaration) {
2250 if (ICVReplacementValuesMap[ICV].
count(&
I))
2251 return ICVReplacementValuesMap[ICV].
lookup(&
I);
2260 const auto &ICVTrackingAA =
A.getAAFor<AAICVTracker>(
2263 if (ICVTrackingAA.isAssumedTracked()) {
2264 std::optional<Value *> URV = ICVTrackingAA.getUniqueReplacementValue(ICV);
2275 std::optional<Value *>
2277 return std::nullopt;
2284 const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2285 if (ValuesMap.count(
I))
2286 return ValuesMap.lookup(
I);
2292 std::optional<Value *> ReplVal;
2294 while (!Worklist.
empty()) {
2296 if (!Visited.
insert(CurrInst).second)
2304 if (ValuesMap.count(CurrInst)) {
2305 std::optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2308 ReplVal = NewReplVal;
2314 if (ReplVal != NewReplVal)
2320 std::optional<Value *> NewReplVal = getValueForCall(
A, *CurrInst, ICV);
2326 ReplVal = NewReplVal;
2332 if (ReplVal != NewReplVal)
2337 if (CurrBB ==
I->getParent() && ReplVal)
2342 if (
const Instruction *Terminator = Pred->getTerminator())
2350struct AAICVTrackerFunctionReturned : AAICVTracker {
2352 : AAICVTracker(IRP,
A) {}
2355 const std::string getAsStr()
const override {
2356 return "ICVTrackerFunctionReturned";
2360 void trackStatistics()
const override {}
2364 return ChangeStatus::UNCHANGED;
2369 InternalControlVar::ICV___last>
2370 ICVReplacementValuesMap;
2373 std::optional<Value *>
2375 return ICVReplacementValuesMap[ICV];
2380 const auto &ICVTrackingAA =
A.getAAFor<AAICVTracker>(
2383 if (!ICVTrackingAA.isAssumedTracked())
2384 return indicatePessimisticFixpoint();
2387 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2388 std::optional<Value *> UniqueICVValue;
2391 std::optional<Value *> NewReplVal =
2392 ICVTrackingAA.getReplacementValue(ICV, &
I,
A);
2395 if (UniqueICVValue && UniqueICVValue != NewReplVal)
2398 UniqueICVValue = NewReplVal;
2403 bool UsedAssumedInformation =
false;
2404 if (!
A.checkForAllInstructions(CheckReturnInst, *
this, {Instruction::Ret},
2405 UsedAssumedInformation,
2407 UniqueICVValue =
nullptr;
2409 if (UniqueICVValue == ReplVal)
2412 ReplVal = UniqueICVValue;
2413 Changed = ChangeStatus::CHANGED;
2420struct AAICVTrackerCallSite : AAICVTracker {
2422 : AAICVTracker(IRP,
A) {}
2426 if (!
F || !
A.isFunctionIPOAmendable(*
F))
2427 indicatePessimisticFixpoint();
2431 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
2433 auto ICVInfo = OMPInfoCache.ICVs[ICV];
2434 auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2435 if (Getter.Declaration == getAssociatedFunction()) {
2436 AssociatedICV = ICVInfo.Kind;
2442 indicatePessimisticFixpoint();
2446 if (!ReplVal || !*ReplVal)
2447 return ChangeStatus::UNCHANGED;
2450 A.deleteAfterManifest(*getCtxI());
2452 return ChangeStatus::CHANGED;
2456 const std::string getAsStr()
const override {
return "ICVTrackerCallSite"; }
2459 void trackStatistics()
const override {}
2462 std::optional<Value *> ReplVal;
2465 const auto &ICVTrackingAA =
A.getAAFor<AAICVTracker>(
2469 if (!ICVTrackingAA.isAssumedTracked())
2470 return indicatePessimisticFixpoint();
2472 std::optional<Value *> NewReplVal =
2473 ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(),
A);
2475 if (ReplVal == NewReplVal)
2476 return ChangeStatus::UNCHANGED;
2478 ReplVal = NewReplVal;
2479 return ChangeStatus::CHANGED;
2484 std::optional<Value *>
2490struct AAICVTrackerCallSiteReturned : AAICVTracker {
2492 : AAICVTracker(IRP,
A) {}
2495 const std::string getAsStr()
const override {
2496 return "ICVTrackerCallSiteReturned";
2500 void trackStatistics()
const override {}
2504 return ChangeStatus::UNCHANGED;
2509 InternalControlVar::ICV___last>
2510 ICVReplacementValuesMap;
2514 std::optional<Value *>
2516 return ICVReplacementValuesMap[ICV];
2521 const auto &ICVTrackingAA =
A.getAAFor<AAICVTracker>(
2523 DepClassTy::REQUIRED);
2526 if (!ICVTrackingAA.isAssumedTracked())
2527 return indicatePessimisticFixpoint();
2530 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2531 std::optional<Value *> NewReplVal =
2532 ICVTrackingAA.getUniqueReplacementValue(ICV);
2534 if (ReplVal == NewReplVal)
2537 ReplVal = NewReplVal;
2538 Changed = ChangeStatus::CHANGED;
2548 ~AAExecutionDomainFunction() {
delete RPOT; }
2558 const std::string
getAsStr()
const override {
2559 unsigned TotalBlocks = 0, InitialThreadBlocks = 0, AlignedBlocks = 0;
2560 for (
auto &It : BEDMap) {
2564 InitialThreadBlocks += It.getSecond().IsExecutedByInitialThreadOnly;
2565 AlignedBlocks += It.getSecond().IsReachedFromAlignedBarrierOnly &&
2566 It.getSecond().IsReachingAlignedBarrierOnly;
2568 return "[AAExecutionDomain] " + std::to_string(InitialThreadBlocks) +
"/" +
2569 std::to_string(AlignedBlocks) +
" of " +
2570 std::to_string(TotalBlocks) +
2571 " executed by initial thread / aligned";
2583 << BB.getName() <<
" is executed by a single thread.\n";
2593 auto HandleAlignedBarrier = [&](
CallBase *CB) {
2594 const ExecutionDomainTy &ED = CEDMap[{CB, PRE}];
2595 if (!ED.IsReachedFromAlignedBarrierOnly ||
2596 ED.EncounteredNonLocalSideEffect)
2604 DeletedBarriers.
insert(CB);
2605 A.deleteAfterManifest(*CB);
2606 ++NumBarriersEliminated;
2608 }
else if (!ED.AlignedBarriers.empty()) {
2609 NumBarriersEliminated += ED.AlignedBarriers.size();
2612 ED.AlignedBarriers.end());
2614 while (!Worklist.empty()) {
2615 CallBase *LastCB = Worklist.pop_back_val();
2616 if (!Visited.
insert(LastCB))
2620 if (!DeletedBarriers.
count(LastCB)) {
2621 A.deleteAfterManifest(*LastCB);
2627 const ExecutionDomainTy &LastED = CEDMap[{LastCB, PRE}];
2628 Worklist.append(LastED.AlignedBarriers.begin(),
2629 LastED.AlignedBarriers.end());
2635 if (!ED.EncounteredAssumes.empty() && (CB || !ED.AlignedBarriers.empty()))
2636 for (
auto *AssumeCB : ED.EncounteredAssumes)
2637 A.deleteAfterManifest(*AssumeCB);
2640 for (
auto *CB : AlignedBarriers)
2641 HandleAlignedBarrier(CB);
2643 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
2645 if (OMPInfoCache.Kernels.count(getAnchorScope()))
2646 HandleAlignedBarrier(
nullptr);
2654 mergeInPredecessorBarriersAndAssumptions(
Attributor &
A, ExecutionDomainTy &ED,
2655 const ExecutionDomainTy &PredED);
2660 bool mergeInPredecessor(
Attributor &
A, ExecutionDomainTy &ED,
2661 const ExecutionDomainTy &PredED,
2662 bool InitialEdgeOnly =
false);
2665 bool handleCallees(
Attributor &
A, ExecutionDomainTy &EntryBBED);
2675 assert(BB.
getParent() == getAnchorScope() &&
"Block is out of scope!");
2676 return BEDMap.lookup(&BB).IsExecutedByInitialThreadOnly;
2681 assert(
I.getFunction() == getAnchorScope() &&
2682 "Instruction is out of scope!");
2691 auto *CB = dyn_cast<CallBase>(CurI);
2694 if (CB != &
I && AlignedBarriers.contains(
const_cast<CallBase *
>(CB))) {
2697 const auto &It = CEDMap.find({CB, PRE});
2698 if (It == CEDMap.end())
2700 if (!It->getSecond().IsReachingAlignedBarrierOnly)
2705 if (!CurI && !BEDMap.lookup(
I.getParent()).IsReachingAlignedBarrierOnly)
2711 auto *CB = dyn_cast<CallBase>(CurI);
2714 if (CB != &
I && AlignedBarriers.contains(
const_cast<CallBase *
>(CB))) {
2717 const auto &It = CEDMap.find({CB, POST});
2718 if (It == CEDMap.end())
2720 if (It->getSecond().IsReachedFromAlignedBarrierOnly)
2728 return BEDMap.lookup(
nullptr).IsReachedFromAlignedBarrierOnly;
2730 return BEDMap.lookup(PredBB).IsReachedFromAlignedBarrierOnly;
2742 "No request should be made against an invalid state!");
2743 return BEDMap.lookup(&BB);
2745 std::pair<ExecutionDomainTy, ExecutionDomainTy>
2748 "No request should be made against an invalid state!");
2749 return {CEDMap.lookup({&CB, PRE}), CEDMap.lookup({&CB, POST})};
2753 "No request should be made against an invalid state!");
2754 return InterProceduralED;
2768 if (!Cmp || !
Cmp->isTrueWhenEqual() || !
Cmp->isEquality())
2776 if (
C->isAllOnesValue()) {
2777 auto *CB = dyn_cast<CallBase>(
Cmp->getOperand(0));
2778 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
2779 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2780 CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2783 const int InitModeArgNo = 1;
2784 auto *ModeCI = dyn_cast<ConstantInt>(CB->
getOperand(InitModeArgNo));
2790 if (
auto *II = dyn_cast<IntrinsicInst>(
Cmp->getOperand(0)))
2791 if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2795 if (
auto *II = dyn_cast<IntrinsicInst>(
Cmp->getOperand(0)))
2796 if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2804 ExecutionDomainTy InterProceduralED;
2816 static bool setAndRecord(
bool &R,
bool V) {
2823void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions(
2824 Attributor &
A, ExecutionDomainTy &ED,
const ExecutionDomainTy &PredED) {
2825 for (
auto *EA : PredED.EncounteredAssumes)
2826 ED.addAssumeInst(
A, *EA);
2828 for (
auto *AB : PredED.AlignedBarriers)
2829 ED.addAlignedBarrier(
A, *AB);
2832bool AAExecutionDomainFunction::mergeInPredecessor(
2833 Attributor &
A, ExecutionDomainTy &ED,
const ExecutionDomainTy &PredED,
2834 bool InitialEdgeOnly) {
2836 bool Changed =
false;
2838 setAndRecord(ED.IsExecutedByInitialThreadOnly,
2839 InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly &&
2840 ED.IsExecutedByInitialThreadOnly));
2842 Changed |= setAndRecord(ED.IsReachedFromAlignedBarrierOnly,
2843 ED.IsReachedFromAlignedBarrierOnly &&
2844 PredED.IsReachedFromAlignedBarrierOnly);
2845 Changed |= setAndRecord(ED.EncounteredNonLocalSideEffect,
2846 ED.EncounteredNonLocalSideEffect |
2847 PredED.EncounteredNonLocalSideEffect);
2849 if (ED.IsReachedFromAlignedBarrierOnly)
2850 mergeInPredecessorBarriersAndAssumptions(
A, ED, PredED);
2852 ED.clearAssumeInstAndAlignedBarriers();
2856bool AAExecutionDomainFunction::handleCallees(
Attributor &
A,
2857 ExecutionDomainTy &EntryBBED) {
2862 DepClassTy::OPTIONAL);
2863 if (!EDAA.getState().isValidState())
2866 EDAA.getExecutionDomain(*cast<CallBase>(ACS.getInstruction())));
2870 ExecutionDomainTy ExitED;
2871 bool AllCallSitesKnown;
2872 if (
A.checkForAllCallSites(PredForCallSite, *
this,
2874 AllCallSitesKnown)) {
2875 for (
const auto &[CSInED, CSOutED] : CallSiteEDs) {
2876 mergeInPredecessor(
A, EntryBBED, CSInED);
2877 ExitED.IsReachingAlignedBarrierOnly &=
2878 CSOutED.IsReachingAlignedBarrierOnly;
2884 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
2885 if (OMPInfoCache.Kernels.count(getAnchorScope())) {
2886 EntryBBED.IsExecutedByInitialThreadOnly =
false;
2887 EntryBBED.IsReachedFromAlignedBarrierOnly =
true;
2888 EntryBBED.EncounteredNonLocalSideEffect =
false;
2889 ExitED.IsReachingAlignedBarrierOnly =
true;
2891 EntryBBED.IsExecutedByInitialThreadOnly =
false;
2892 EntryBBED.IsReachedFromAlignedBarrierOnly =
false;
2893 EntryBBED.EncounteredNonLocalSideEffect =
true;
2894 ExitED.IsReachingAlignedBarrierOnly =
false;
2898 bool Changed =
false;
2899 auto &FnED = BEDMap[
nullptr];
2900 Changed |= setAndRecord(FnED.IsReachedFromAlignedBarrierOnly,
2901 FnED.IsReachedFromAlignedBarrierOnly &
2902 EntryBBED.IsReachedFromAlignedBarrierOnly);
2903 Changed |= setAndRecord(FnED.IsReachingAlignedBarrierOnly,
2904 FnED.IsReachingAlignedBarrierOnly &
2905 ExitED.IsReachingAlignedBarrierOnly);
2906 Changed |= setAndRecord(FnED.IsExecutedByInitialThreadOnly,
2907 EntryBBED.IsExecutedByInitialThreadOnly);
2913 bool Changed =
false;
2918 auto HandleAlignedBarrier = [&](
CallBase *CB, ExecutionDomainTy &ED) {
2920 Changed |= AlignedBarriers.insert(CB);
2922 auto &CallInED = CEDMap[{CB, PRE}];
2923 Changed |= mergeInPredecessor(
A, CallInED, ED);
2924 CallInED.IsReachingAlignedBarrierOnly =
true;
2926 ED.EncounteredNonLocalSideEffect =
false;
2927 ED.IsReachedFromAlignedBarrierOnly =
true;
2929 ED.clearAssumeInstAndAlignedBarriers();
2931 ED.addAlignedBarrier(
A, *CB);
2932 auto &CallOutED = CEDMap[{CB, POST}];
2933 Changed |= mergeInPredecessor(
A, CallOutED, ED);
2937 A.getAAFor<
AAIsDead>(*
this, getIRPosition(), DepClassTy::OPTIONAL);
2939 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
2943 bool IsKernel = OMPInfoCache.Kernels.count(
F);
2946 for (
auto &RIt : *RPOT) {
2949 bool IsEntryBB = &BB == &EntryBB;
2952 bool AlignedBarrierLastInBlock = IsEntryBB && IsKernel;
2953 ExecutionDomainTy ED;
2956 Changed |= handleCallees(
A, ED);
2960 if (LivenessAA.isAssumedDead(&BB))
2964 if (LivenessAA.isEdgeDead(PredBB, &BB))
2966 bool InitialEdgeOnly = isInitialThreadOnlyEdge(
2967 A, dyn_cast<BranchInst>(PredBB->getTerminator()), BB);
2968 mergeInPredecessor(
A, ED, BEDMap[PredBB], InitialEdgeOnly);
2975 bool UsedAssumedInformation;
2976 if (
A.isAssumedDead(
I, *
this, &LivenessAA, UsedAssumedInformation,
2977 false, DepClassTy::OPTIONAL,
2983 if (
auto *II = dyn_cast<IntrinsicInst>(&
I)) {
2984 if (
auto *AI = dyn_cast_or_null<AssumeInst>(II)) {
2985 ED.addAssumeInst(
A, *AI);
2989 if (II->isAssumeLikeIntrinsic())
2993 auto *CB = dyn_cast<CallBase>(&
I);
2995 bool IsAlignedBarrier =
2999 AlignedBarrierLastInBlock &= IsNoSync;
3005 if (IsAlignedBarrier) {
3006 HandleAlignedBarrier(CB, ED);
3007 AlignedBarrierLastInBlock =
true;
3012 if (isa<MemIntrinsic>(&
I)) {
3013 if (!ED.EncounteredNonLocalSideEffect &&
3015 ED.EncounteredNonLocalSideEffect =
true;
3017 ED.IsReachedFromAlignedBarrierOnly =
false;
3025 auto &CallInED = CEDMap[{CB, PRE}];
3026 Changed |= mergeInPredecessor(
A, CallInED, ED);
3035 if (EDAA.getState().isValidState()) {
3038 CalleeED.IsReachedFromAlignedBarrierOnly;
3039 AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly;
3040 if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly)
3041 ED.EncounteredNonLocalSideEffect |=
3042 CalleeED.EncounteredNonLocalSideEffect;
3044 ED.EncounteredNonLocalSideEffect =
3045 CalleeED.EncounteredNonLocalSideEffect;
3046 if (!CalleeED.IsReachingAlignedBarrierOnly) {
3048 setAndRecord(CallInED.IsReachingAlignedBarrierOnly,
false);
3051 if (CalleeED.IsReachedFromAlignedBarrierOnly)
3052 mergeInPredecessorBarriersAndAssumptions(
A, ED, CalleeED);
3053 auto &CallOutED = CEDMap[{CB, POST}];
3054 Changed |= mergeInPredecessor(
A, CallOutED, ED);
3059 ED.IsReachedFromAlignedBarrierOnly =
false;
3060 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly,
false);
3063 AlignedBarrierLastInBlock &= ED.IsReachedFromAlignedBarrierOnly;
3065 auto &CallOutED = CEDMap[{CB, POST}];
3066 Changed |= mergeInPredecessor(
A, CallOutED, ED);
3069 if (!
I.mayHaveSideEffects() && !
I.mayReadFromMemory())
3083 if (MemAA.getState().isValidState() &&
3084 MemAA.checkForAllAccessesToMemoryKind(
3089 if (!
I.mayHaveSideEffects() && OMPInfoCache.isOnlyUsedByAssume(
I))
3092 if (
auto *LI = dyn_cast<LoadInst>(&
I))
3093 if (LI->hasMetadata(LLVMContext::MD_invariant_load))
3096 if (!ED.EncounteredNonLocalSideEffect &&
3098 ED.EncounteredNonLocalSideEffect =
true;
3101 bool IsEndAndNotReachingAlignedBarriersOnly =
false;
3102 if (!isa<UnreachableInst>(BB.getTerminator()) &&
3103 !BB.getTerminator()->getNumSuccessors()) {
3105 Changed |= mergeInPredecessor(
A, InterProceduralED, ED);
3107 auto &FnED = BEDMap[
nullptr];
3108 if (!FnED.IsReachingAlignedBarrierOnly) {
3109 IsEndAndNotReachingAlignedBarriersOnly =
true;
3110 SyncInstWorklist.
push_back(BB.getTerminator());
3111 auto &BBED = BEDMap[&BB];
3112 Changed |= setAndRecord(BBED.IsReachingAlignedBarrierOnly,
false);
3115 HandleAlignedBarrier(
nullptr, ED);
3118 ExecutionDomainTy &StoredED = BEDMap[&BB];
3119 ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly &
3120 !IsEndAndNotReachingAlignedBarriersOnly;
3126 if (ED.IsExecutedByInitialThreadOnly !=
3127 StoredED.IsExecutedByInitialThreadOnly ||
3128 ED.IsReachedFromAlignedBarrierOnly !=
3129 StoredED.IsReachedFromAlignedBarrierOnly ||
3130 ED.EncounteredNonLocalSideEffect !=
3131 StoredED.EncounteredNonLocalSideEffect)
3135 StoredED = std::move(ED);
3141 while (!SyncInstWorklist.
empty()) {
3144 bool HitAlignedBarrierOrKnownEnd =
false;
3146 auto *CB = dyn_cast<CallBase>(CurInst);
3149 auto &CallOutED = CEDMap[{CB, POST}];
3150 if (setAndRecord(CallOutED.IsReachingAlignedBarrierOnly,
false))
3152 auto &CallInED = CEDMap[{CB, PRE}];
3153 HitAlignedBarrierOrKnownEnd =
3154 AlignedBarriers.count(CB) || !CallInED.IsReachingAlignedBarrierOnly;
3155 if (HitAlignedBarrierOrKnownEnd)
3158 if (HitAlignedBarrierOrKnownEnd)
3162 if (LivenessAA.isEdgeDead(PredBB, SyncBB))
3164 if (!Visited.
insert(PredBB))
3166 auto &PredED = BEDMap[PredBB];
3167 if (setAndRecord(PredED.IsReachingAlignedBarrierOnly,
false)) {
3169 SyncInstWorklist.
push_back(PredBB->getTerminator());
3172 if (SyncBB != &EntryBB)
3174 if (setAndRecord(InterProceduralED.IsReachingAlignedBarrierOnly,
false))
3178 return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
3183struct AAHeapToShared :
public StateWrapper<BooleanState, AbstractAttribute> {
3188 static AAHeapToShared &createForPosition(
const IRPosition &IRP,
3192 virtual bool isAssumedHeapToShared(
CallBase &CB)
const = 0;
3196 virtual bool isAssumedHeapToSharedRemovedFree(
CallBase &CB)
const = 0;
3199 const std::string
getName()
const override {
return "AAHeapToShared"; }
3202 const char *getIdAddr()
const override {
return &
ID; }
3211 static const char ID;
3214struct AAHeapToSharedFunction :
public AAHeapToShared {
3216 : AAHeapToShared(IRP,
A) {}
3218 const std::string getAsStr()
const override {
3219 return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
3220 " malloc calls eligible.";
3224 void trackStatistics()
const override {}
3228 void findPotentialRemovedFreeCalls(
Attributor &
A) {
3229 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3230 auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3232 PotentialRemovedFreeCalls.clear();
3236 for (
auto *U : CB->
users()) {
3238 if (
C &&
C->getCalledFunction() == FreeRFI.Declaration)
3242 if (FreeCalls.
size() != 1)
3245 PotentialRemovedFreeCalls.insert(FreeCalls.
front());
3251 indicatePessimisticFixpoint();
3255 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3256 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3257 if (!RFI.Declaration)
3262 bool &) -> std::optional<Value *> {
return nullptr; };
3265 for (
User *U : RFI.Declaration->
users())
3266 if (
CallBase *CB = dyn_cast<CallBase>(U)) {
3269 MallocCalls.insert(CB);
3274 findPotentialRemovedFreeCalls(
A);
3277 bool isAssumedHeapToShared(
CallBase &CB)
const override {
3278 return isValidState() && MallocCalls.count(&CB);
3281 bool isAssumedHeapToSharedRemovedFree(
CallBase &CB)
const override {
3282 return isValidState() && PotentialRemovedFreeCalls.count(&CB);
3286 if (MallocCalls.empty())
3287 return ChangeStatus::UNCHANGED;
3289 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3290 auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3294 DepClassTy::OPTIONAL);
3299 if (HS &&
HS->isAssumedHeapToStack(*CB))
3304 for (
auto *U : CB->
users()) {
3306 if (
C &&
C->getCalledFunction() == FreeCall.Declaration)
3309 if (FreeCalls.
size() != 1)
3316 <<
" with shared memory."
3317 <<
" Shared memory usage is limited to "
3323 <<
" with " << AllocSize->getZExtValue()
3324 <<
" bytes of shared memory\n");
3330 Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
3335 static_cast<unsigned>(AddressSpace::Shared));
3340 return OR <<
"Replaced globalized variable with "
3341 <<
ore::NV(
"SharedMemory", AllocSize->getZExtValue())
3342 << (AllocSize->isOne() ?
" byte " :
" bytes ")
3343 <<
"of shared memory.";
3349 "HeapToShared on allocation without alignment attribute");
3350 SharedMem->setAlignment(*Alignment);
3353 A.deleteAfterManifest(*CB);
3354 A.deleteAfterManifest(*FreeCalls.
front());
3356 SharedMemoryUsed += AllocSize->getZExtValue();
3357 NumBytesMovedToSharedMemory = SharedMemoryUsed;
3358 Changed = ChangeStatus::CHANGED;
3365 if (MallocCalls.empty())
3366 return indicatePessimisticFixpoint();
3367 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3368 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3369 if (!RFI.Declaration)
3370 return ChangeStatus::UNCHANGED;
3374 auto NumMallocCalls = MallocCalls.size();
3377 for (
User *U : RFI.Declaration->
users()) {
3378 if (
CallBase *CB = dyn_cast<CallBase>(U)) {
3379 if (CB->getCaller() !=
F)
3381 if (!MallocCalls.count(CB))
3383 if (!isa<ConstantInt>(CB->getArgOperand(0))) {
3384 MallocCalls.remove(CB);
3389 if (!ED.isExecutedByInitialThreadOnly(*CB))
3390 MallocCalls.remove(CB);
3394 findPotentialRemovedFreeCalls(
A);
3396 if (NumMallocCalls != MallocCalls.size())
3397 return ChangeStatus::CHANGED;
3399 return ChangeStatus::UNCHANGED;
3407 unsigned SharedMemoryUsed = 0;
3410struct AAKernelInfo :
public StateWrapper<KernelInfoState, AbstractAttribute> {
3415 void trackStatistics()
const override {}
3418 const std::string getAsStr()
const override {
3419 if (!isValidState())
3421 return std::string(SPMDCompatibilityTracker.isAssumed() ?
"SPMD"
3423 std::string(SPMDCompatibilityTracker.isAtFixpoint() ?
" [FIX]"
3425 std::string(
" #PRs: ") +
3426 (ReachedKnownParallelRegions.isValidState()
3427 ? std::to_string(ReachedKnownParallelRegions.size())
3429 ", #Unknown PRs: " +
3430 (ReachedUnknownParallelRegions.isValidState()
3433 ", #Reaching Kernels: " +
3434 (ReachingKernelEntries.isValidState()
3438 (ParallelLevels.isValidState()
3447 const std::string
getName()
const override {
return "AAKernelInfo"; }
3450 const char *getIdAddr()
const override {
return &
ID; }
3457 static const char ID;
3462struct AAKernelInfoFunction : AAKernelInfo {
3464 : AAKernelInfo(IRP,
A) {}
3469 return GuardedInstructions;
3477 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3481 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
3482 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
3483 OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
3484 OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
3488 auto StoreCallBase = [](
Use &
U,
3489 OMPInformationCache::RuntimeFunctionInfo &RFI,
3491 CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
3493 "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
3495 "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
3501 StoreCallBase(U, InitRFI, KernelInitCB);
3505 DeinitRFI.foreachUse(
3507 StoreCallBase(U, DeinitRFI, KernelDeinitCB);
3513 if (!KernelInitCB || !KernelDeinitCB)
3517 ReachingKernelEntries.insert(Fn);
3518 IsKernelEntry =
true;
3527 bool &UsedAssumedInformation) -> std::optional<Value *> {
3533 bool &UsedAssumedInformation) -> std::optional<Value *> {
3538 if (!SPMDCompatibilityTracker.isValidState())
3540 if (!SPMDCompatibilityTracker.isAtFixpoint()) {
3542 A.recordDependence(*
this, *AA, DepClassTy::OPTIONAL);
3543 UsedAssumedInformation =
true;
3545 UsedAssumedInformation =
false;
3554 constexpr const int InitModeArgNo = 1;
3555 constexpr const int DeinitModeArgNo = 1;
3556 constexpr const int InitUseStateMachineArgNo = 2;
3557 A.registerSimplificationCallback(
3559 StateMachineSimplifyCB);
3560 A.registerSimplificationCallback(
3563 A.registerSimplificationCallback(
3569 dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo));
3571 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3574 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3579 if (!OMPInfoCache.RFIs[RFKind].Declaration)
3581 A.registerVirtualUseCallback(*OMPInfoCache.RFIs[RFKind].Declaration, CB);
3585 auto AddDependence = [](
Attributor &
A,
const AAKernelInfo *KI,
3588 A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);
3602 if (SPMDCompatibilityTracker.isValidState())
3603 return AddDependence(
A,
this, QueryingAA);
3605 if (!ReachedKnownParallelRegions.isValidState())
3606 return AddDependence(
A,
this, QueryingAA);
3611 if (!KernelInitCB->getCalledFunction()->isDeclaration()) {
3612 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_num_threads_in_block,
3613 CustomStateMachineUseCB);
3614 RegisterVirtualUse(OMPRTL___kmpc_get_warp_size, CustomStateMachineUseCB);
3615 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_generic,
3616 CustomStateMachineUseCB);
3617 RegisterVirtualUse(OMPRTL___kmpc_kernel_parallel,
3618 CustomStateMachineUseCB);
3619 RegisterVirtualUse(OMPRTL___kmpc_kernel_end_parallel,
3620 CustomStateMachineUseCB);
3624 if (SPMDCompatibilityTracker.isAtFixpoint())
3631 if (!SPMDCompatibilityTracker.isValidState())
3632 return AddDependence(
A,
this, QueryingAA);
3635 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_thread_id_in_block,
3644 if (!SPMDCompatibilityTracker.isValidState())
3645 return AddDependence(
A,
this, QueryingAA);
3646 if (SPMDCompatibilityTracker.empty())
3647 return AddDependence(
A,
this, QueryingAA);
3648 if (!mayContainParallelRegion())
3649 return AddDependence(
A,
this, QueryingAA);
3652 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_spmd, SPMDBarrierUseCB);
3656 static std::string sanitizeForGlobalName(std::string S) {
3660 return !((C >=
'a' && C <=
'z') || (C >=
'A' && C <=
'Z') ||
3661 (C >=
'0' && C <=
'9') || C ==
'_');
3672 if (!KernelInitCB || !KernelDeinitCB)
3673 return ChangeStatus::UNCHANGED;
3687 if (!changeToSPMDMode(
A, Changed)) {
3688 if (!KernelInitCB->getCalledFunction()->isDeclaration())
3689 return buildCustomStateMachine(
A);
3695 void insertInstructionGuardsHelper(
Attributor &
A) {
3696 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3698 auto CreateGuardedRegion = [&](
Instruction *RegionStartI,
3732 DT, LI, MSU,
"region.guarded.end");
3735 MSU,
"region.barrier");
3738 DT, LI, MSU,
"region.exit");
3740 SplitBlock(ParentBB, RegionStartI, DT, LI, MSU,
"region.guarded");
3743 "Expected a different CFG");
3746 ParentBB, ParentBB->
getTerminator(), DT, LI, MSU,
"region.check.tid");
3749 A.registerManifestAddedBasicBlock(*RegionEndBB);
3750 A.registerManifestAddedBasicBlock(*RegionBarrierBB);
3751 A.registerManifestAddedBasicBlock(*RegionExitBB);
3752 A.registerManifestAddedBasicBlock(*RegionStartBB);
3753 A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
3755 bool HasBroadcastValues =
false;
3760 for (
User *Usr :
I.users()) {
3763 OutsideUsers.
insert(&UsrI);
3766 if (OutsideUsers.
empty())
3769 HasBroadcastValues =
true;
3774 M,
I.getType(),
false,
3776 sanitizeForGlobalName(
3777 (
I.getName() +
".guarded.output.alloc").str()),
3779 static_cast<unsigned>(AddressSpace::Shared));
3785 I.getName() +
".guarded.output.load",
3793 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3799 InsertPointTy(ParentBB, ParentBB->
end()),
DL);
3800 OMPInfoCache.OMPBuilder.updateToLocation(Loc);
3803 OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3805 OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3811 InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->
end()),
DL);
3812 OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
3814 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3815 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
3817 OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
3819 OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);
3820 Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
3821 OMPInfoCache.OMPBuilder.Builder
3822 .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
3828 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3829 M, OMPRTL___kmpc_barrier_simple_spmd);
3830 OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
3833 OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});
3835 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
3838 if (HasBroadcastValues) {
3842 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
3846 auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3848 for (
Instruction *GuardedI : SPMDCompatibilityTracker) {
3850 if (!Visited.
insert(BB).second)
3856 while (++IP != IPEnd) {
3857 if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
3860 if (OpenMPOpt::getCallIfRegularCall(*
I, &AllocSharedRFI))
3862 if (!
I->user_empty() || !SPMDCompatibilityTracker.contains(
I)) {
3863 LastEffect =
nullptr;
3870 for (
auto &Reorder : Reorders)
3876 for (
Instruction *GuardedI : SPMDCompatibilityTracker) {
3878 auto *CalleeAA =
A.lookupAAFor<AAKernelInfo>(
3881 assert(CalleeAA !=
nullptr &&
"Expected Callee AAKernelInfo");
3882 auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
3884 if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
3887 Instruction *GuardedRegionStart =
nullptr, *GuardedRegionEnd =
nullptr;
3891 if (SPMDCompatibilityTracker.contains(&
I)) {
3892 CalleeAAFunction.getGuardedInstructions().insert(&
I);
3893 if (GuardedRegionStart)
3894 GuardedRegionEnd = &
I;
3896 GuardedRegionStart = GuardedRegionEnd = &
I;
3903 if (GuardedRegionStart) {
3905 std::make_pair(GuardedRegionStart, GuardedRegionEnd));
3906 GuardedRegionStart =
nullptr;
3907 GuardedRegionEnd =
nullptr;
3912 for (
auto &GR : GuardedRegions)
3913 CreateGuardedRegion(GR.first, GR.second);
3916 void forceSingleThreadPerWorkgroupHelper(
Attributor &
A) {
3925 auto &Ctx = getAnchorValue().getContext();
3932 KernelInitCB->getNextNode(),
"main.thread.user_code");
3937 A.registerManifestAddedBasicBlock(*InitBB);
3938 A.registerManifestAddedBasicBlock(*UserCodeBB);
3939 A.registerManifestAddedBasicBlock(*ReturnBB);
3942 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
3948 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3950 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3951 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
3956 OMPInfoCache.setCallingConvention(ThreadIdInBlockFn, ThreadIdInBlock);
3963 "thread.is_main", InitBB);
3969 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3972 if (!OMPInfoCache.runtimeFnsAvailable(
3973 {OMPRTL___kmpc_get_hardware_thread_id_in_block,
3974 OMPRTL___kmpc_barrier_simple_spmd}))
3977 if (!SPMDCompatibilityTracker.isAssumed()) {
3978 for (
Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
3979 if (!NonCompatibleI)
3983 if (
auto *CB = dyn_cast<CallBase>(NonCompatibleI))
3984 if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
3988 ORA <<
"Value has potential side effects preventing SPMD-mode "
3990 if (isa<CallBase>(NonCompatibleI)) {
3991 ORA <<
". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to "
3992 "the called function to override";
4000 << *NonCompatibleI <<
"\n");
4012 Kernel = CB->getCaller();
4014 assert(OMPInfoCache.Kernels.count(
Kernel) &&
"Expected kernel function!");
4019 assert(ExecMode &&
"Kernel without exec mode?");
4024 "ExecMode is not an integer!");
4025 const int8_t ExecModeVal =
4031 Changed = ChangeStatus::CHANGED;
4035 if (mayContainParallelRegion())
4036 insertInstructionGuardsHelper(
A);
4038 forceSingleThreadPerWorkgroupHelper(
A);
4043 "Initially non-SPMD kernel has SPMD exec mode!");
4049 const int InitModeArgNo = 1;
4050 const int DeinitModeArgNo = 1;
4051 const int InitUseStateMachineArgNo = 2;
4053 auto &Ctx = getAnchorValue().getContext();
4054 A.changeUseAfterManifest(
4055 KernelInitCB->getArgOperandUse(InitModeArgNo),
4058 A.changeUseAfterManifest(
4059 KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo),
4061 A.changeUseAfterManifest(
4062 KernelDeinitCB->getArgOperandUse(DeinitModeArgNo),
4066 ++NumOpenMPTargetRegionKernelsSPMD;
4069 return OR <<
"Transformed generic-mode kernel to SPMD-mode.";
4078 return ChangeStatus::UNCHANGED;
4081 if (!ReachedKnownParallelRegions.isValidState())
4082 return ChangeStatus::UNCHANGED;
4084 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
4085 if (!OMPInfoCache.runtimeFnsAvailable(
4086 {OMPRTL___kmpc_get_hardware_num_threads_in_block,
4087 OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
4088 OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
4089 return ChangeStatus::UNCHANGED;
4091 const int InitModeArgNo = 1;
4092 const int InitUseStateMachineArgNo = 2;
4098 ConstantInt *UseStateMachine = dyn_cast<ConstantInt>(
4099 KernelInitCB->getArgOperand(InitUseStateMachineArgNo));
4101 dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo));
4106 if (!UseStateMachine || UseStateMachine->
isZero() || !Mode ||
4108 return ChangeStatus::UNCHANGED;
4111 auto &Ctx = getAnchorValue().getContext();
4113 A.changeUseAfterManifest(
4114 KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *FalseVal);
4120 if (!mayContainParallelRegion()) {
4121 ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
4124 return OR <<
"Removing unused state machine from generic-mode kernel.";
4128 return ChangeStatus::CHANGED;
4132 if (ReachedUnknownParallelRegions.empty()) {
4133 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
4136 return OR <<
"Rewriting generic-mode kernel with a customized state "
4141 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
4144 return OR <<
"Generic-mode kernel is executed with a customized state "
4145 "machine that requires a fallback.";
4150 for (
CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
4151 if (!UnknownParallelRegionCB)
4154 return ORA <<
"Call may contain unknown parallel regions. Use "
4155 <<
"`__attribute__((assume(\"omp_no_parallelism\")))` to "
4196 BasicBlock *InitBB = KernelInitCB->getParent();
4198 KernelInitCB->getNextNode(),
"thread.user_code.check");
4202 Ctx,
"worker_state_machine.begin",
Kernel, UserCodeEntryBB);
4204 Ctx,
"worker_state_machine.finished",
Kernel, UserCodeEntryBB);
4206 Ctx,
"worker_state_machine.is_active.check",
Kernel, UserCodeEntryBB);
4209 Kernel, UserCodeEntryBB);
4212 Kernel, UserCodeEntryBB);
4214 Ctx,
"worker_state_machine.done.barrier",
Kernel, UserCodeEntryBB);
4215 A.registerManifestAddedBasicBlock(*InitBB);
4216 A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
4217 A.registerManifestAddedBasicBlock(*IsWorkerCheckBB);
4218 A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
4219 A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
4220 A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
4221 A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
4222 A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
4223 A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
4225 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4232 "thread.is_worker", InitBB);
4238 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4239 M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
4241 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4242 M, OMPRTL___kmpc_get_warp_size);
4245 OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);
4249 OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);
4252 BlockHwSize, WarpSize,
"block.size", IsWorkerCheckBB);
4256 "thread.is_main_or_worker", IsWorkerCheckBB);
4259 IsMainOrWorker, IsWorkerCheckBB);
4265 new AllocaInst(VoidPtrTy,
DL.getAllocaAddrSpace(),
nullptr,
4269 OMPInfoCache.OMPBuilder.updateToLocation(
4272 StateMachineBeginBB->
end()),
4275 Value *Ident = KernelInitCB->getArgOperand(0);
4276 Value *GTid = KernelInitCB;
4279 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4280 M, OMPRTL___kmpc_barrier_simple_generic);
4283 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4287 (
unsigned int)AddressSpace::Generic) {
4290 PointerType::getWithSamePointeeType(
4291 cast<PointerType>(WorkFnAI->
getType()),
4292 (
unsigned int)AddressSpace::Generic),
4293 WorkFnAI->
getName() +
".generic", StateMachineBeginBB);
4298 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4299 M, OMPRTL___kmpc_kernel_parallel);
4301 KernelParallelFn, {WorkFnAI},
"worker.is_active", StateMachineBeginBB);
4302 OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);
4305 StateMachineBeginBB);
4311 Value *WorkFnCast = BitCastInst::CreatePointerBitCastOrAddrSpaceCast(
4312 WorkFn, ParallelRegionFnTy->getPointerTo(),
"worker.work_fn.addr_cast",
4313 StateMachineBeginBB);
4318 StateMachineBeginBB);
4321 IsDone, StateMachineBeginBB)
4325 StateMachineDoneBarrierBB, IsActiveWorker,
4326 StateMachineIsActiveCheckBB)
4335 for (
int I = 0,
E = ReachedKnownParallelRegions.size();
I <
E; ++
I) {
4336 auto *ParallelRegion = ReachedKnownParallelRegions[
I];
4338 Ctx,
"worker_state_machine.parallel_region.execute",
Kernel,
4339 StateMachineEndParallelBB);
4341 ->setDebugLoc(DLoc);
4347 Kernel, StateMachineEndParallelBB);
4352 if (
I + 1 <
E || !ReachedUnknownParallelRegions.empty()) {
4355 "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
4363 StateMachineIfCascadeCurrentBB)
4365 StateMachineIfCascadeCurrentBB = PRNextBB;
4371 if (!ReachedUnknownParallelRegions.empty()) {
4372 StateMachineIfCascadeCurrentBB->
setName(
4373 "worker_state_machine.parallel_region.fallback.execute");
4375 StateMachineIfCascadeCurrentBB)
4376 ->setDebugLoc(DLoc);
4379 StateMachineIfCascadeCurrentBB)
4383 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4384 M, OMPRTL___kmpc_kernel_end_parallel);
4387 OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);
4393 ->setDebugLoc(DLoc);
4397 return ChangeStatus::CHANGED;
4403 KernelInfoState StateBefore = getState();
4408 if (isa<CallBase>(
I))
4411 if (!
I.mayWriteToMemory())
4413 if (
auto *SI = dyn_cast<StoreInst>(&
I)) {
4416 DepClassTy::OPTIONAL);
4419 DepClassTy::OPTIONAL);
4420 if (UnderlyingObjsAA.forallUnderlyingObjects([&](
Value &Obj) {
4421 if (AA::isAssumedThreadLocalObject(A, Obj, *this))
4425 auto *CB = dyn_cast<CallBase>(&Obj);
4426 return CB && HS.isAssumedHeapToStack(*CB);
4432 SPMDCompatibilityTracker.insert(&
I);
4436 bool UsedAssumedInformationInCheckRWInst =
false;
4437 if (!SPMDCompatibilityTracker.isAtFixpoint())
4438 if (!
A.checkForAllReadWriteInstructions(
4439 CheckRWInst, *
this, UsedAssumedInformationInCheckRWInst))
4442 bool UsedAssumedInformationFromReachingKernels =
false;
4443 if (!IsKernelEntry) {
4444 updateParallelLevels(
A);
4446 bool AllReachingKernelsKnown =
true;
4447 updateReachingKernelEntries(
A, AllReachingKernelsKnown);
4448 UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
4450 if (!SPMDCompatibilityTracker.empty()) {
4451 if (!ParallelLevels.isValidState())
4452 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4453 else if (!ReachingKernelEntries.isValidState())
4454 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4460 for (
auto *
Kernel : ReachingKernelEntries) {
4461 auto &CBAA =
A.getAAFor<AAKernelInfo>(
4463 if (CBAA.SPMDCompatibilityTracker.isValidState() &&
4464 CBAA.SPMDCompatibilityTracker.isAssumed())
4468 if (!CBAA.SPMDCompatibilityTracker.isAtFixpoint())
4469 UsedAssumedInformationFromReachingKernels =
true;
4471 if (SPMD != 0 &&
Generic != 0)
4472 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4478 bool AllParallelRegionStatesWereFixed =
true;
4479 bool AllSPMDStatesWereFixed =
true;
4481 auto &CB = cast<CallBase>(
I);
4482 auto &CBAA =
A.getAAFor<AAKernelInfo>(
4484 getState() ^= CBAA.getState();
4485 AllSPMDStatesWereFixed &= CBAA.SPMDCompatibilityTracker.isAtFixpoint();
4486 AllParallelRegionStatesWereFixed &=
4487 CBAA.ReachedKnownParallelRegions.isAtFixpoint();
4488 AllParallelRegionStatesWereFixed &=
4489 CBAA.ReachedUnknownParallelRegions.isAtFixpoint();
4493 bool UsedAssumedInformationInCheckCallInst =
false;
4494 if (!
A.checkForAllCallLikeInstructions(
4495 CheckCallInst, *
this, UsedAssumedInformationInCheckCallInst)) {
4497 <<
"Failed to visit all call-like instructions!\n";);
4498 return indicatePessimisticFixpoint();
4503 if (!UsedAssumedInformationInCheckCallInst &&
4504 AllParallelRegionStatesWereFixed) {
4505 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
4506 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
4511 if (!UsedAssumedInformationInCheckRWInst &&
4512 !UsedAssumedInformationInCheckCallInst &&
4513 !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
4514 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4516 return StateBefore == getState() ? ChangeStatus::UNCHANGED
4517 : ChangeStatus::CHANGED;
4523 bool &AllReachingKernelsKnown) {
4527 assert(Caller &&
"Caller is nullptr");
4529 auto &CAA =
A.getOrCreateAAFor<AAKernelInfo>(
4531 if (CAA.ReachingKernelEntries.isValidState()) {
4532 ReachingKernelEntries ^= CAA.ReachingKernelEntries;
4538 ReachingKernelEntries.indicatePessimisticFixpoint();
4543 if (!
A.checkForAllCallSites(PredCallSite, *
this,
4545 AllReachingKernelsKnown))
4546 ReachingKernelEntries.indicatePessimisticFixpoint();
4551 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
4552 OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
4553 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
4558 assert(Caller &&
"Caller is nullptr");
4562 if (CAA.ParallelLevels.isValidState()) {
4568 if (Caller == Parallel51RFI.Declaration) {
4569 ParallelLevels.indicatePessimisticFixpoint();
4573 ParallelLevels ^= CAA.ParallelLevels;
4580 ParallelLevels.indicatePessimisticFixpoint();
4585 bool AllCallSitesKnown =
true;
4586 if (!
A.checkForAllCallSites(PredCallSite, *
this,
4589 ParallelLevels.indicatePessimisticFixpoint();
4596struct AAKernelInfoCallSite : AAKernelInfo {
4598 : AAKernelInfo(IRP,
A) {}
4602 AAKernelInfo::initialize(
A);
4604 CallBase &CB = cast<CallBase>(getAssociatedValue());
4611 if (AssumptionAA.hasAssumption(
"ompx_spmd_amenable")) {
4613 indicateOptimisticFixpoint();
4620 indicateOptimisticFixpoint();
4628 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
4629 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(
Callee);
4630 if (It == OMPInfoCache.RuntimeFunctionIDMap.
end()) {
4636 if (!(AssumptionAA.hasAssumption(
"omp_no_openmp") ||
4637 AssumptionAA.hasAssumption(
"omp_no_parallelism")))
4638 ReachedUnknownParallelRegions.insert(&CB);
4642 if (!SPMDCompatibilityTracker.isAtFixpoint()) {
4644 SPMDCompatibilityTracker.insert(&CB);
4649 indicateOptimisticFixpoint();
4656 const unsigned int WrapperFunctionArgNo = 6;
4660 case OMPRTL___kmpc_is_spmd_exec_mode:
4661 case OMPRTL___kmpc_distribute_static_fini:
4662 case OMPRTL___kmpc_for_static_fini:
4663 case OMPRTL___kmpc_global_thread_num:
4664 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4665 case OMPRTL___kmpc_get_hardware_num_blocks:
4666 case OMPRTL___kmpc_single:
4667 case OMPRTL___kmpc_end_single:
4668 case OMPRTL___kmpc_master:
4669 case OMPRTL___kmpc_end_master:
4670 case OMPRTL___kmpc_barrier:
4671 case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
4672 case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
4673 case OMPRTL___kmpc_nvptx_end_reduce_nowait:
4675 case OMPRTL___kmpc_distribute_static_init_4:
4676 case OMPRTL___kmpc_distribute_static_init_4u:
4677 case OMPRTL___kmpc_distribute_static_init_8:
4678 case OMPRTL___kmpc_distribute_static_init_8u:
4679 case OMPRTL___kmpc_for_static_init_4:
4680 case OMPRTL___kmpc_for_static_init_4u:
4681 case OMPRTL___kmpc_for_static_init_8:
4682 case OMPRTL___kmpc_for_static_init_8u: {
4684 unsigned ScheduleArgOpNo = 2;
4685 auto *ScheduleTypeCI =
4687 unsigned ScheduleTypeVal =
4688 ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
4690 case OMPScheduleType::UnorderedStatic:
4691 case OMPScheduleType::UnorderedStaticChunked:
4692 case OMPScheduleType::OrderedDistribute:
4693 case OMPScheduleType::OrderedDistributeChunked:
4696 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4697 SPMDCompatibilityTracker.insert(&CB);
4701 case OMPRTL___kmpc_target_init:
4704 case OMPRTL___kmpc_target_deinit:
4705 KernelDeinitCB = &CB;
4707 case OMPRTL___kmpc_parallel_51:
4708 if (
auto *ParallelRegion = dyn_cast<Function>(
4710 ReachedKnownParallelRegions.insert(ParallelRegion);
4712 auto &FnAA =
A.getAAFor<AAKernelInfo>(
4714 NestedParallelism |= !FnAA.getState().isValidState() ||
4715 !FnAA.ReachedKnownParallelRegions.empty() ||
4716 !FnAA.ReachedUnknownParallelRegions.empty();
4722 ReachedUnknownParallelRegions.insert(&CB);
4724 case OMPRTL___kmpc_omp_task:
4726 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4727 SPMDCompatibilityTracker.insert(&CB);
4728 ReachedUnknownParallelRegions.insert(&CB);
4730 case OMPRTL___kmpc_alloc_shared:
4731 case OMPRTL___kmpc_free_shared:
4737 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4738 SPMDCompatibilityTracker.insert(&CB);
4744 indicateOptimisticFixpoint();
4754 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
4755 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(
F);
4758 if (It == OMPInfoCache.RuntimeFunctionIDMap.
end()) {
4760 auto &FnAA =
A.getAAFor<AAKernelInfo>(*
this, FnPos, DepClassTy::REQUIRED);
4761 if (getState() == FnAA.getState())
4762 return ChangeStatus::UNCHANGED;
4763 getState() = FnAA.getState();
4764 return ChangeStatus::CHANGED;
4769 KernelInfoState StateBefore = getState();
4770 assert((It->getSecond() == OMPRTL___kmpc_alloc_shared ||
4771 It->getSecond() == OMPRTL___kmpc_free_shared) &&
4772 "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
4774 CallBase &CB = cast<CallBase>(getAssociatedValue());
4778 auto &HeapToSharedAA =
A.getAAFor<AAHeapToShared>(
4786 case OMPRTL___kmpc_alloc_shared:
4787 if (!HeapToStackAA.isAssumedHeapToStack(CB) &&
4788 !HeapToSharedAA.isAssumedHeapToShared(CB))
4789 SPMDCompatibilityTracker.insert(&CB);
4791 case OMPRTL___kmpc_free_shared:
4792 if (!HeapToStackAA.isAssumedHeapToStackRemovedFree(CB) &&
4793 !HeapToSharedAA.isAssumedHeapToSharedRemovedFree(CB))
4794 SPMDCompatibilityTracker.insert(&CB);
4797 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4798 SPMDCompatibilityTracker.insert(&CB);
4801 return StateBefore == getState() ? ChangeStatus::UNCHANGED
4802 : ChangeStatus::CHANGED;
4806struct AAFoldRuntimeCall
4807 :
public StateWrapper<BooleanState, AbstractAttribute> {
4813 void trackStatistics()
const override {}
4816 static AAFoldRuntimeCall &createForPosition(
const IRPosition &IRP,
4820 const std::string
getName()
const override {
return "AAFoldRuntimeCall"; }
4823 const char *getIdAddr()
const override {
return &
ID; }
4831 static const char ID;
4834struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
4836 : AAFoldRuntimeCall(IRP,
A) {}
4839 const std::string getAsStr()
const override {
4840 if (!isValidState())
4843 std::string Str(
"simplified value: ");
4845 if (!SimplifiedValue)
4846 return Str + std::string(
"none");
4848 if (!*SimplifiedValue)
4849 return Str + std::string(
"nullptr");
4851 if (
ConstantInt *CI = dyn_cast<ConstantInt>(*SimplifiedValue))
4852 return Str + std::to_string(CI->getSExtValue());
4854 return Str + std::string(
"unknown");
4859 indicatePessimisticFixpoint();
4863 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
4864 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(
Callee);
4865 assert(It != OMPInfoCache.RuntimeFunctionIDMap.
end() &&
4866 "Expected a known OpenMP runtime function");
4868 RFKind = It->getSecond();
4870 CallBase &CB = cast<CallBase>(getAssociatedValue());
4871 A.registerSimplificationCallback(
4874 bool &UsedAssumedInformation) -> std::optional<Value *> {
4875 assert((isValidState() ||
4876 (SimplifiedValue && *SimplifiedValue ==
nullptr)) &&
4877 "Unexpected invalid state!");
4879 if (!isAtFixpoint()) {
4880 UsedAssumedInformation =
true;
4882 A.recordDependence(*
this, *AA, DepClassTy::OPTIONAL);
4884 return SimplifiedValue;
4891 case OMPRTL___kmpc_is_spmd_exec_mode:
4892 Changed |= foldIsSPMDExecMode(
A);
4894 case OMPRTL___kmpc_parallel_level:
4895 Changed |= foldParallelLevel(
A);
4897 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4898 Changed = Changed | foldKernelFnAttribute(
A,
"omp_target_thread_limit");
4900 case OMPRTL___kmpc_get_hardware_num_blocks:
4901 Changed = Changed | foldKernelFnAttribute(
A,
"omp_target_num_teams");
4913 if (SimplifiedValue && *SimplifiedValue) {
4916 A.deleteAfterManifest(
I);
4920 if (
auto *
C = dyn_cast<ConstantInt>(*SimplifiedValue))
4921 return OR <<
"Replacing OpenMP runtime call "
4923 <<
ore::NV(
"FoldedValue",
C->getZExtValue()) <<
".";
4924 return OR <<
"Replacing OpenMP runtime call "
4932 << **SimplifiedValue <<
"\n");
4934 Changed = ChangeStatus::CHANGED;
4941 SimplifiedValue =
nullptr;
4942 return AAFoldRuntimeCall::indicatePessimisticFixpoint();
4948 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4950 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
4951 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
4952 auto &CallerKernelInfoAA =
A.getAAFor<AAKernelInfo>(
4955 if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4956 return indicatePessimisticFixpoint();
4958 for (
Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4960 DepClassTy::REQUIRED);
4962 if (!AA.isValidState()) {
4963 SimplifiedValue =
nullptr;
4964 return indicatePessimisticFixpoint();
4967 if (AA.SPMDCompatibilityTracker.isAssumed()) {
4968 if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4973 if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4974 ++KnownNonSPMDCount;
4976 ++AssumedNonSPMDCount;
4980 if ((AssumedSPMDCount + KnownSPMDCount) &&
4981 (AssumedNonSPMDCount + KnownNonSPMDCount))
4982 return indicatePessimisticFixpoint();
4984 auto &Ctx = getAnchorValue().getContext();
4985 if (KnownSPMDCount || AssumedSPMDCount) {
4986 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
4987 "Expected only SPMD kernels!");
4991 }
else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
4992 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
4993 "Expected only non-SPMD kernels!");
5001 assert(!SimplifiedValue &&
"SimplifiedValue should be none");
5004 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5005 : ChangeStatus::CHANGED;
5010 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5012 auto &CallerKernelInfoAA =
A.getAAFor<AAKernelInfo>(
5015 if (!CallerKernelInfoAA.ParallelLevels.isValidState())
5016 return indicatePessimisticFixpoint();
5018 if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
5019 return indicatePessimisticFixpoint();
5021 if (CallerKernelInfoAA.ReachingKernelEntries.empty()) {
5022 assert(!SimplifiedValue &&
5023 "SimplifiedValue should keep none at this point");
5024 return ChangeStatus::UNCHANGED;
5027 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5028 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5029 for (
Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
5031 DepClassTy::REQUIRED);
5032 if (!AA.SPMDCompatibilityTracker.isValidState())
5033 return indicatePessimisticFixpoint();
5035 if (AA.SPMDCompatibilityTracker.isAssumed()) {
5036 if (AA.SPMDCompatibilityTracker.isAtFixpoint())
5041 if (AA.SPMDCompatibilityTracker.isAtFixpoint())
5042 ++KnownNonSPMDCount;
5044 ++AssumedNonSPMDCount;
5048 if ((AssumedSPMDCount + KnownSPMDCount) &&
5049 (AssumedNonSPMDCount + KnownNonSPMDCount))
5050 return indicatePessimisticFixpoint();
5052 auto &Ctx = getAnchorValue().getContext();
5056 if (AssumedSPMDCount || KnownSPMDCount) {
5057 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5058 "Expected only SPMD kernels!");
5061 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5062 "Expected only non-SPMD kernels!");
5065 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5066 : ChangeStatus::CHANGED;
5071 int32_t CurrentAttrValue = -1;
5072 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5074 auto &CallerKernelInfoAA =
A.getAAFor<AAKernelInfo>(
5077 if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
5078 return indicatePessimisticFixpoint();
5081 for (
Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
5082 int32_t NextAttrVal =
K->getFnAttributeAsParsedInteger(Attr, -1);
5084 if (NextAttrVal == -1 ||
5085 (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
5086 return indicatePessimisticFixpoint();
5087 CurrentAttrValue = NextAttrVal;
5090 if (CurrentAttrValue != -1) {
5091 auto &Ctx = getAnchorValue().getContext();
5095 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5096 : ChangeStatus::CHANGED;
5102 std::optional<Value *> SimplifiedValue;
5112 auto &RFI = OMPInfoCache.RFIs[RF];
5114 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
5117 A.getOrCreateAAFor<AAFoldRuntimeCall>(
5119 DepClassTy::NONE,
false,
5125void OpenMPOpt::registerAAs(
bool IsModulePass) {
5135 A.getOrCreateAAFor<AAKernelInfo>(
5137 DepClassTy::NONE,
false,
5141 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
5142 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
5143 InitRFI.foreachUse(SCC, CreateKernelInfoCB);
5145 registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
5146 registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
5147 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
5148 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
5153 for (
int Idx = 0;
Idx < OMPInfoCache.ICVs.size() - 1; ++
Idx) {
5156 auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
5159 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
5163 auto &CB = cast<CallBase>(*CI);
5166 A.getOrCreateAAFor<AAICVTracker>(CBPos);
5170 GetterRFI.foreachUse(SCC, CreateAA);
5179 for (
auto *
F : SCC) {
5180 if (
F->isDeclaration())
5186 if (
F->hasLocalLinkage()) {
5188 const auto *CB = dyn_cast<CallBase>(U.getUser());
5189 return CB && CB->isCallee(&U) &&
5190 A.isRunOn(const_cast<Function *>(CB->getCaller()));
5194 registerAAsForFunction(
A, *
F);
5204 if (
F.hasFnAttribute(Attribute::Convergent))
5208 if (
auto *LI = dyn_cast<LoadInst>(&
I)) {
5209 bool UsedAssumedInformation =
false;
5214 if (
auto *SI = dyn_cast<StoreInst>(&
I)) {
5218 if (
auto *II = dyn_cast<IntrinsicInst>(&
I)) {
5219 if (II->getIntrinsicID() == Intrinsic::assume) {
5228const char AAICVTracker::ID = 0;
5229const char AAKernelInfo::ID = 0;
5231const char AAHeapToShared::ID = 0;
5232const char AAFoldRuntimeCall::ID = 0;
5234AAICVTracker &AAICVTracker::createForPosition(
const IRPosition &IRP,
5236 AAICVTracker *AA =
nullptr;
5244 AA =
new (
A.Allocator) AAICVTrackerFunctionReturned(IRP,
A);
5247 AA =
new (
A.Allocator) AAICVTrackerCallSiteReturned(IRP,
A);
5250 AA =
new (
A.Allocator) AAICVTrackerCallSite(IRP,
A);
5253 AA =
new (
A.Allocator) AAICVTrackerFunction(IRP,
A);
5262 AAExecutionDomainFunction *AA =
nullptr;
5272 "AAExecutionDomain can only be created for function position!");
5274 AA =
new (
A.Allocator) AAExecutionDomainFunction(IRP,
A);
5281AAHeapToShared &AAHeapToShared::createForPosition(
const IRPosition &IRP,
5283 AAHeapToSharedFunction *AA =
nullptr;
5293 "AAHeapToShared can only be created for function position!");
5295 AA =
new (
A.Allocator) AAHeapToSharedFunction(IRP,
A);
5302AAKernelInfo &AAKernelInfo::createForPosition(
const IRPosition &IRP,
5304 AAKernelInfo *AA =
nullptr;
5314 AA =
new (
A.Allocator) AAKernelInfoCallSite(IRP,
A);
5317 AA =
new (
A.Allocator) AAKernelInfoFunction(IRP,
A);
5324AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(
const IRPosition &IRP,
5326 AAFoldRuntimeCall *AA =
nullptr;
5335 llvm_unreachable(
"KernelInfo can only be created for call site position!");
5337 AA =
new (
A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP,
A);
5358 if (Kernels.contains(&
F))
5360 for (
const User *U :
F.users())
5361 if (!isa<BlockAddress>(U))
5370 return ORA <<
"Could not internalize function. "
5371 <<
"Some optimizations may not be possible. [OMP140]";
5381 if (!
F.isDeclaration() && !Kernels.contains(&
F) && IsCalled(
F) &&
5385 }
else if (!
F.hasLocalLinkage() && !
F.hasFnAttribute(Attribute::Cold)) {