LLVM 20.0.0git
OpenMPOpt.cpp
Go to the documentation of this file.
1//===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// OpenMP specific optimizations:
10//
11// - Deduplication of runtime calls, e.g., omp_get_thread_num.
12// - Replacing globalized device memory with stack memory.
13// - Replacing globalized device memory with shared memory.
14// - Parallel region merging.
15// - Transforming generic-mode device kernels to SPMD mode.
16// - Specializing the state machine for generic-mode device kernels.
17//
18//===----------------------------------------------------------------------===//
19
21
24#include "llvm/ADT/SetVector.h"
27#include "llvm/ADT/Statistic.h"
29#include "llvm/ADT/StringRef.h"
37#include "llvm/IR/Assumptions.h"
38#include "llvm/IR/BasicBlock.h"
39#include "llvm/IR/Constants.h"
41#include "llvm/IR/Dominators.h"
42#include "llvm/IR/Function.h"
43#include "llvm/IR/GlobalValue.h"
45#include "llvm/IR/InstrTypes.h"
46#include "llvm/IR/Instruction.h"
49#include "llvm/IR/IntrinsicsAMDGPU.h"
50#include "llvm/IR/IntrinsicsNVPTX.h"
51#include "llvm/IR/LLVMContext.h"
54#include "llvm/Support/Debug.h"
58
59#include <algorithm>
60#include <optional>
61#include <string>
62
63using namespace llvm;
64using namespace omp;
65
66#define DEBUG_TYPE "openmp-opt"
67
69 "openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."),
70 cl::Hidden, cl::init(false));
71
73 "openmp-opt-enable-merging",
74 cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
75 cl::init(false));
76
77static cl::opt<bool>
78 DisableInternalization("openmp-opt-disable-internalization",
79 cl::desc("Disable function internalization."),
80 cl::Hidden, cl::init(false));
81
82static cl::opt<bool> DeduceICVValues("openmp-deduce-icv-values",
83 cl::init(false), cl::Hidden);
84static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
86static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
87 cl::init(false), cl::Hidden);
88
90 "openmp-hide-memory-transfer-latency",
91 cl::desc("[WIP] Tries to hide the latency of host to device memory"
92 " transfers"),
93 cl::Hidden, cl::init(false));
94
96 "openmp-opt-disable-deglobalization",
97 cl::desc("Disable OpenMP optimizations involving deglobalization."),
98 cl::Hidden, cl::init(false));
99
101 "openmp-opt-disable-spmdization",
102 cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
103 cl::Hidden, cl::init(false));
104
106 "openmp-opt-disable-folding",
107 cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
108 cl::init(false));
109
111 "openmp-opt-disable-state-machine-rewrite",
112 cl::desc("Disable OpenMP optimizations that replace the state machine."),
113 cl::Hidden, cl::init(false));
114
116 "openmp-opt-disable-barrier-elimination",
117 cl::desc("Disable OpenMP optimizations that eliminate barriers."),
118 cl::Hidden, cl::init(false));
119
121 "openmp-opt-print-module-after",
122 cl::desc("Print the current module after OpenMP optimizations."),
123 cl::Hidden, cl::init(false));
124
126 "openmp-opt-print-module-before",
127 cl::desc("Print the current module before OpenMP optimizations."),
128 cl::Hidden, cl::init(false));
129
131 "openmp-opt-inline-device",
132 cl::desc("Inline all applicable functions on the device."), cl::Hidden,
133 cl::init(false));
134
135static cl::opt<bool>
136 EnableVerboseRemarks("openmp-opt-verbose-remarks",
137 cl::desc("Enables more verbose remarks."), cl::Hidden,
138 cl::init(false));
139
141 SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden,
142 cl::desc("Maximal number of attributor iterations."),
143 cl::init(256));
144
146 SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden,
147 cl::desc("Maximum amount of shared memory to use."),
148 cl::init(std::numeric_limits<unsigned>::max()));
149
150STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
151 "Number of OpenMP runtime calls deduplicated");
152STATISTIC(NumOpenMPParallelRegionsDeleted,
153 "Number of OpenMP parallel regions deleted");
154STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
155 "Number of OpenMP runtime functions identified");
156STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
157 "Number of OpenMP runtime function uses identified");
158STATISTIC(NumOpenMPTargetRegionKernels,
159 "Number of OpenMP target region entry points (=kernels) identified");
160STATISTIC(NumNonOpenMPTargetRegionKernels,
161 "Number of non-OpenMP target region kernels identified");
162STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
163 "Number of OpenMP target region entry points (=kernels) executed in "
164 "SPMD-mode instead of generic-mode");
165STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
166 "Number of OpenMP target region entry points (=kernels) executed in "
167 "generic-mode without a state machines");
168STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
169 "Number of OpenMP target region entry points (=kernels) executed in "
170 "generic-mode with customized state machines with fallback");
171STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
172 "Number of OpenMP target region entry points (=kernels) executed in "
173 "generic-mode with customized state machines without fallback");
175 NumOpenMPParallelRegionsReplacedInGPUStateMachine,
176 "Number of OpenMP parallel regions replaced with ID in GPU state machines");
177STATISTIC(NumOpenMPParallelRegionsMerged,
178 "Number of OpenMP parallel regions merged");
179STATISTIC(NumBytesMovedToSharedMemory,
180 "Amount of memory pushed to shared memory");
181STATISTIC(NumBarriersEliminated, "Number of redundant barriers eliminated");
182
183#if !defined(NDEBUG)
184static constexpr auto TAG = "[" DEBUG_TYPE "]";
185#endif
186
187namespace KernelInfo {
188
189// struct ConfigurationEnvironmentTy {
190// uint8_t UseGenericStateMachine;
191// uint8_t MayUseNestedParallelism;
192// llvm::omp::OMPTgtExecModeFlags ExecMode;
193// int32_t MinThreads;
194// int32_t MaxThreads;
195// int32_t MinTeams;
196// int32_t MaxTeams;
197// };
198
199// struct DynamicEnvironmentTy {
200// uint16_t DebugIndentionLevel;
201// };
202
203// struct KernelEnvironmentTy {
204// ConfigurationEnvironmentTy Configuration;
205// IdentTy *Ident;
206// DynamicEnvironmentTy *DynamicEnv;
207// };
208
209#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX) \
210 constexpr const unsigned MEMBER##Idx = IDX;
211
212KERNEL_ENVIRONMENT_IDX(Configuration, 0)
214
215#undef KERNEL_ENVIRONMENT_IDX
216
217#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX) \
218 constexpr const unsigned MEMBER##Idx = IDX;
219
220KERNEL_ENVIRONMENT_CONFIGURATION_IDX(UseGenericStateMachine, 0)
221KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MayUseNestedParallelism, 1)
227
228#undef KERNEL_ENVIRONMENT_CONFIGURATION_IDX
229
230#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE) \
231 RETURNTYPE *get##MEMBER##FromKernelEnvironment(ConstantStruct *KernelEnvC) { \
232 return cast<RETURNTYPE>(KernelEnvC->getAggregateElement(MEMBER##Idx)); \
233 }
234
237
238#undef KERNEL_ENVIRONMENT_GETTER
239
240#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER) \
241 ConstantInt *get##MEMBER##FromKernelEnvironment( \
242 ConstantStruct *KernelEnvC) { \
243 ConstantStruct *ConfigC = \
244 getConfigurationFromKernelEnvironment(KernelEnvC); \
245 return dyn_cast<ConstantInt>(ConfigC->getAggregateElement(MEMBER##Idx)); \
246 }
247
248KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(UseGenericStateMachine)
249KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MayUseNestedParallelism)
255
256#undef KERNEL_ENVIRONMENT_CONFIGURATION_GETTER
257
260 constexpr const int InitKernelEnvironmentArgNo = 0;
261 return cast<GlobalVariable>(
262 KernelInitCB->getArgOperand(InitKernelEnvironmentArgNo)
264}
265
267 GlobalVariable *KernelEnvGV =
269 return cast<ConstantStruct>(KernelEnvGV->getInitializer());
270}
271} // namespace KernelInfo
272
273namespace {
274
275struct AAHeapToShared;
276
277struct AAICVTracker;
278
279/// OpenMP specific information. For now, stores RFIs and ICVs also needed for
280/// Attributor runs.
281struct OMPInformationCache : public InformationCache {
282 OMPInformationCache(Module &M, AnalysisGetter &AG,
284 bool OpenMPPostLink)
285 : InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),
286 OpenMPPostLink(OpenMPPostLink) {
287
288 OMPBuilder.Config.IsTargetDevice = isOpenMPDevice(OMPBuilder.M);
289 const Triple T(OMPBuilder.M.getTargetTriple());
290 switch (T.getArch()) {
294 assert(OMPBuilder.Config.IsTargetDevice &&
295 "OpenMP AMDGPU/NVPTX is only prepared to deal with device code.");
296 OMPBuilder.Config.IsGPU = true;
297 break;
298 default:
299 OMPBuilder.Config.IsGPU = false;
300 break;
301 }
302 OMPBuilder.initialize();
303 initializeRuntimeFunctions(M);
304 initializeInternalControlVars();
305 }
306
307 /// Generic information that describes an internal control variable.
308 struct InternalControlVarInfo {
309 /// The kind, as described by InternalControlVar enum.
311
312 /// The name of the ICV.
314
315 /// Environment variable associated with this ICV.
316 StringRef EnvVarName;
317
318 /// Initial value kind.
319 ICVInitValue InitKind;
320
321 /// Initial value.
322 ConstantInt *InitValue;
323
324 /// Setter RTL function associated with this ICV.
325 RuntimeFunction Setter;
326
327 /// Getter RTL function associated with this ICV.
328 RuntimeFunction Getter;
329
330 /// RTL Function corresponding to the override clause of this ICV
332 };
333
334 /// Generic information that describes a runtime function
335 struct RuntimeFunctionInfo {
336
337 /// The kind, as described by the RuntimeFunction enum.
339
340 /// The name of the function.
342
343 /// Flag to indicate a variadic function.
344 bool IsVarArg;
345
346 /// The return type of the function.
348
349 /// The argument types of the function.
350 SmallVector<Type *, 8> ArgumentTypes;
351
352 /// The declaration if available.
353 Function *Declaration = nullptr;
354
355 /// Uses of this runtime function per function containing the use.
356 using UseVector = SmallVector<Use *, 16>;
357
358 /// Clear UsesMap for runtime function.
359 void clearUsesMap() { UsesMap.clear(); }
360
361 /// Boolean conversion that is true if the runtime function was found.
362 operator bool() const { return Declaration; }
363
364 /// Return the vector of uses in function \p F.
365 UseVector &getOrCreateUseVector(Function *F) {
366 std::shared_ptr<UseVector> &UV = UsesMap[F];
367 if (!UV)
368 UV = std::make_shared<UseVector>();
369 return *UV;
370 }
371
372 /// Return the vector of uses in function \p F or `nullptr` if there are
373 /// none.
374 const UseVector *getUseVector(Function &F) const {
375 auto I = UsesMap.find(&F);
376 if (I != UsesMap.end())
377 return I->second.get();
378 return nullptr;
379 }
380
381 /// Return how many functions contain uses of this runtime function.
382 size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
383
384 /// Return the number of arguments (or the minimal number for variadic
385 /// functions).
386 size_t getNumArgs() const { return ArgumentTypes.size(); }
387
388 /// Run the callback \p CB on each use and forget the use if the result is
389 /// true. The callback will be fed the function in which the use was
390 /// encountered as second argument.
391 void foreachUse(SmallVectorImpl<Function *> &SCC,
392 function_ref<bool(Use &, Function &)> CB) {
393 for (Function *F : SCC)
394 foreachUse(CB, F);
395 }
396
397 /// Run the callback \p CB on each use within the function \p F and forget
398 /// the use if the result is true.
399 void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
400 SmallVector<unsigned, 8> ToBeDeleted;
401 ToBeDeleted.clear();
402
403 unsigned Idx = 0;
404 UseVector &UV = getOrCreateUseVector(F);
405
406 for (Use *U : UV) {
407 if (CB(*U, *F))
408 ToBeDeleted.push_back(Idx);
409 ++Idx;
410 }
411
412 // Remove the to-be-deleted indices in reverse order as prior
413 // modifications will not modify the smaller indices.
414 while (!ToBeDeleted.empty()) {
415 unsigned Idx = ToBeDeleted.pop_back_val();
416 UV[Idx] = UV.back();
417 UV.pop_back();
418 }
419 }
420
421 private:
422 /// Map from functions to all uses of this runtime function contained in
423 /// them.
425
426 public:
427 /// Iterators for the uses of this runtime function.
428 decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
429 decltype(UsesMap)::iterator end() { return UsesMap.end(); }
430 };
431
432 /// An OpenMP-IR-Builder instance
433 OpenMPIRBuilder OMPBuilder;
434
435 /// Map from runtime function kind to the runtime function description.
436 EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
437 RuntimeFunction::OMPRTL___last>
438 RFIs;
439
440 /// Map from function declarations/definitions to their runtime enum type.
441 DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
442
443 /// Map from ICV kind to the ICV description.
444 EnumeratedArray<InternalControlVarInfo, InternalControlVar,
445 InternalControlVar::ICV___last>
446 ICVs;
447
448 /// Helper to initialize all internal control variable information for those
449 /// defined in OMPKinds.def.
450 void initializeInternalControlVars() {
451#define ICV_RT_SET(_Name, RTL) \
452 { \
453 auto &ICV = ICVs[_Name]; \
454 ICV.Setter = RTL; \
455 }
456#define ICV_RT_GET(Name, RTL) \
457 { \
458 auto &ICV = ICVs[Name]; \
459 ICV.Getter = RTL; \
460 }
461#define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \
462 { \
463 auto &ICV = ICVs[Enum]; \
464 ICV.Name = _Name; \
465 ICV.Kind = Enum; \
466 ICV.InitKind = Init; \
467 ICV.EnvVarName = _EnvVarName; \
468 switch (ICV.InitKind) { \
469 case ICV_IMPLEMENTATION_DEFINED: \
470 ICV.InitValue = nullptr; \
471 break; \
472 case ICV_ZERO: \
473 ICV.InitValue = ConstantInt::get( \
474 Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \
475 break; \
476 case ICV_FALSE: \
477 ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \
478 break; \
479 case ICV_LAST: \
480 break; \
481 } \
482 }
483#include "llvm/Frontend/OpenMP/OMPKinds.def"
484 }
485
486 /// Returns true if the function declaration \p F matches the runtime
487 /// function types, that is, return type \p RTFRetType, and argument types
488 /// \p RTFArgTypes.
489 static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
490 SmallVector<Type *, 8> &RTFArgTypes) {
491 // TODO: We should output information to the user (under debug output
492 // and via remarks).
493
494 if (!F)
495 return false;
496 if (F->getReturnType() != RTFRetType)
497 return false;
498 if (F->arg_size() != RTFArgTypes.size())
499 return false;
500
501 auto *RTFTyIt = RTFArgTypes.begin();
502 for (Argument &Arg : F->args()) {
503 if (Arg.getType() != *RTFTyIt)
504 return false;
505
506 ++RTFTyIt;
507 }
508
509 return true;
510 }
511
512 // Helper to collect all uses of the declaration in the UsesMap.
513 unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
514 unsigned NumUses = 0;
515 if (!RFI.Declaration)
516 return NumUses;
517 OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
518
519 if (CollectStats) {
520 NumOpenMPRuntimeFunctionsIdentified += 1;
521 NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
522 }
523
524 // TODO: We directly convert uses into proper calls and unknown uses.
525 for (Use &U : RFI.Declaration->uses()) {
526 if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
527 if (!CGSCC || CGSCC->empty() || CGSCC->contains(UserI->getFunction())) {
528 RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
529 ++NumUses;
530 }
531 } else {
532 RFI.getOrCreateUseVector(nullptr).push_back(&U);
533 ++NumUses;
534 }
535 }
536 return NumUses;
537 }
538
539 // Helper function to recollect uses of a runtime function.
540 void recollectUsesForFunction(RuntimeFunction RTF) {
541 auto &RFI = RFIs[RTF];
542 RFI.clearUsesMap();
543 collectUses(RFI, /*CollectStats*/ false);
544 }
545
546 // Helper function to recollect uses of all runtime functions.
547 void recollectUses() {
548 for (int Idx = 0; Idx < RFIs.size(); ++Idx)
549 recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
550 }
551
552 // Helper function to inherit the calling convention of the function callee.
553 void setCallingConvention(FunctionCallee Callee, CallInst *CI) {
554 if (Function *Fn = dyn_cast<Function>(Callee.getCallee()))
555 CI->setCallingConv(Fn->getCallingConv());
556 }
557
558 // Helper function to determine if it's legal to create a call to the runtime
559 // functions.
560 bool runtimeFnsAvailable(ArrayRef<RuntimeFunction> Fns) {
561 // We can always emit calls if we haven't yet linked in the runtime.
562 if (!OpenMPPostLink)
563 return true;
564
565 // Once the runtime has been already been linked in we cannot emit calls to
566 // any undefined functions.
567 for (RuntimeFunction Fn : Fns) {
568 RuntimeFunctionInfo &RFI = RFIs[Fn];
569
570 if (RFI.Declaration && RFI.Declaration->isDeclaration())
571 return false;
572 }
573 return true;
574 }
575
576 /// Helper to initialize all runtime function information for those defined
577 /// in OpenMPKinds.def.
578 void initializeRuntimeFunctions(Module &M) {
579
580 // Helper macros for handling __VA_ARGS__ in OMP_RTL
581#define OMP_TYPE(VarName, ...) \
582 Type *VarName = OMPBuilder.VarName; \
583 (void)VarName;
584
585#define OMP_ARRAY_TYPE(VarName, ...) \
586 ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \
587 (void)VarName##Ty; \
588 PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \
589 (void)VarName##PtrTy;
590
591#define OMP_FUNCTION_TYPE(VarName, ...) \
592 FunctionType *VarName = OMPBuilder.VarName; \
593 (void)VarName; \
594 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
595 (void)VarName##Ptr;
596
597#define OMP_STRUCT_TYPE(VarName, ...) \
598 StructType *VarName = OMPBuilder.VarName; \
599 (void)VarName; \
600 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
601 (void)VarName##Ptr;
602
603#define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \
604 { \
605 SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \
606 Function *F = M.getFunction(_Name); \
607 RTLFunctions.insert(F); \
608 if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \
609 RuntimeFunctionIDMap[F] = _Enum; \
610 auto &RFI = RFIs[_Enum]; \
611 RFI.Kind = _Enum; \
612 RFI.Name = _Name; \
613 RFI.IsVarArg = _IsVarArg; \
614 RFI.ReturnType = OMPBuilder._ReturnType; \
615 RFI.ArgumentTypes = std::move(ArgsTypes); \
616 RFI.Declaration = F; \
617 unsigned NumUses = collectUses(RFI); \
618 (void)NumUses; \
619 LLVM_DEBUG({ \
620 dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \
621 << " found\n"; \
622 if (RFI.Declaration) \
623 dbgs() << TAG << "-> got " << NumUses << " uses in " \
624 << RFI.getNumFunctionsWithUses() \
625 << " different functions.\n"; \
626 }); \
627 } \
628 }
629#include "llvm/Frontend/OpenMP/OMPKinds.def"
630
631 // Remove the `noinline` attribute from `__kmpc`, `ompx::` and `omp_`
632 // functions, except if `optnone` is present.
633 if (isOpenMPDevice(M)) {
634 for (Function &F : M) {
635 for (StringRef Prefix : {"__kmpc", "_ZN4ompx", "omp_"})
636 if (F.hasFnAttribute(Attribute::NoInline) &&
637 F.getName().starts_with(Prefix) &&
638 !F.hasFnAttribute(Attribute::OptimizeNone))
639 F.removeFnAttr(Attribute::NoInline);
640 }
641 }
642
643 // TODO: We should attach the attributes defined in OMPKinds.def.
644 }
645
646 /// Collection of known OpenMP runtime functions..
647 DenseSet<const Function *> RTLFunctions;
648
649 /// Indicates if we have already linked in the OpenMP device library.
650 bool OpenMPPostLink = false;
651};
652
653template <typename Ty, bool InsertInvalidates = true>
654struct BooleanStateWithSetVector : public BooleanState {
655 bool contains(const Ty &Elem) const { return Set.contains(Elem); }
656 bool insert(const Ty &Elem) {
657 if (InsertInvalidates)
659 return Set.insert(Elem);
660 }
661
662 const Ty &operator[](int Idx) const { return Set[Idx]; }
663 bool operator==(const BooleanStateWithSetVector &RHS) const {
664 return BooleanState::operator==(RHS) && Set == RHS.Set;
665 }
666 bool operator!=(const BooleanStateWithSetVector &RHS) const {
667 return !(*this == RHS);
668 }
669
670 bool empty() const { return Set.empty(); }
671 size_t size() const { return Set.size(); }
672
673 /// "Clamp" this state with \p RHS.
674 BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
675 BooleanState::operator^=(RHS);
676 Set.insert(RHS.Set.begin(), RHS.Set.end());
677 return *this;
678 }
679
680private:
681 /// A set to keep track of elements.
683
684public:
685 typename decltype(Set)::iterator begin() { return Set.begin(); }
686 typename decltype(Set)::iterator end() { return Set.end(); }
687 typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
688 typename decltype(Set)::const_iterator end() const { return Set.end(); }
689};
690
691template <typename Ty, bool InsertInvalidates = true>
692using BooleanStateWithPtrSetVector =
693 BooleanStateWithSetVector<Ty *, InsertInvalidates>;
694
695struct KernelInfoState : AbstractState {
696 /// Flag to track if we reached a fixpoint.
697 bool IsAtFixpoint = false;
698
699 /// The parallel regions (identified by the outlined parallel functions) that
700 /// can be reached from the associated function.
701 BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
702 ReachedKnownParallelRegions;
703
704 /// State to track what parallel region we might reach.
705 BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
706
707 /// State to track if we are in SPMD-mode, assumed or know, and why we decided
708 /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
709 /// false.
710 BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
711
712 /// The __kmpc_target_init call in this kernel, if any. If we find more than
713 /// one we abort as the kernel is malformed.
714 CallBase *KernelInitCB = nullptr;
715
716 /// The constant kernel environement as taken from and passed to
717 /// __kmpc_target_init.
718 ConstantStruct *KernelEnvC = nullptr;
719
720 /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
721 /// one we abort as the kernel is malformed.
722 CallBase *KernelDeinitCB = nullptr;
723
724 /// Flag to indicate if the associated function is a kernel entry.
725 bool IsKernelEntry = false;
726
727 /// State to track what kernel entries can reach the associated function.
728 BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
729
730 /// State to indicate if we can track parallel level of the associated
731 /// function. We will give up tracking if we encounter unknown caller or the
732 /// caller is __kmpc_parallel_51.
733 BooleanStateWithSetVector<uint8_t> ParallelLevels;
734
735 /// Flag that indicates if the kernel has nested Parallelism
736 bool NestedParallelism = false;
737
738 /// Abstract State interface
739 ///{
740
741 KernelInfoState() = default;
742 KernelInfoState(bool BestState) {
743 if (!BestState)
745 }
746
747 /// See AbstractState::isValidState(...)
748 bool isValidState() const override { return true; }
749
750 /// See AbstractState::isAtFixpoint(...)
751 bool isAtFixpoint() const override { return IsAtFixpoint; }
752
753 /// See AbstractState::indicatePessimisticFixpoint(...)
755 IsAtFixpoint = true;
756 ParallelLevels.indicatePessimisticFixpoint();
757 ReachingKernelEntries.indicatePessimisticFixpoint();
758 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
759 ReachedKnownParallelRegions.indicatePessimisticFixpoint();
760 ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
761 NestedParallelism = true;
763 }
764
765 /// See AbstractState::indicateOptimisticFixpoint(...)
767 IsAtFixpoint = true;
768 ParallelLevels.indicateOptimisticFixpoint();
769 ReachingKernelEntries.indicateOptimisticFixpoint();
770 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
771 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
772 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
774 }
775
776 /// Return the assumed state
777 KernelInfoState &getAssumed() { return *this; }
778 const KernelInfoState &getAssumed() const { return *this; }
779
780 bool operator==(const KernelInfoState &RHS) const {
781 if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
782 return false;
783 if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
784 return false;
785 if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
786 return false;
787 if (ReachingKernelEntries != RHS.ReachingKernelEntries)
788 return false;
789 if (ParallelLevels != RHS.ParallelLevels)
790 return false;
791 if (NestedParallelism != RHS.NestedParallelism)
792 return false;
793 return true;
794 }
795
796 /// Returns true if this kernel contains any OpenMP parallel regions.
797 bool mayContainParallelRegion() {
798 return !ReachedKnownParallelRegions.empty() ||
799 !ReachedUnknownParallelRegions.empty();
800 }
801
802 /// Return empty set as the best state of potential values.
803 static KernelInfoState getBestState() { return KernelInfoState(true); }
804
805 static KernelInfoState getBestState(KernelInfoState &KIS) {
806 return getBestState();
807 }
808
809 /// Return full set as the worst state of potential values.
810 static KernelInfoState getWorstState() { return KernelInfoState(false); }
811
812 /// "Clamp" this state with \p KIS.
813 KernelInfoState operator^=(const KernelInfoState &KIS) {
814 // Do not merge two different _init and _deinit call sites.
815 if (KIS.KernelInitCB) {
816 if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
817 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
818 "assumptions.");
819 KernelInitCB = KIS.KernelInitCB;
820 }
821 if (KIS.KernelDeinitCB) {
822 if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
823 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
824 "assumptions.");
825 KernelDeinitCB = KIS.KernelDeinitCB;
826 }
827 if (KIS.KernelEnvC) {
828 if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
829 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
830 "assumptions.");
831 KernelEnvC = KIS.KernelEnvC;
832 }
833 SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
834 ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
835 ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
836 NestedParallelism |= KIS.NestedParallelism;
837 return *this;
838 }
839
840 KernelInfoState operator&=(const KernelInfoState &KIS) {
841 return (*this ^= KIS);
842 }
843
844 ///}
845};
846
847/// Used to map the values physically (in the IR) stored in an offload
848/// array, to a vector in memory.
849struct OffloadArray {
850 /// Physical array (in the IR).
851 AllocaInst *Array = nullptr;
852 /// Mapped values.
853 SmallVector<Value *, 8> StoredValues;
854 /// Last stores made in the offload array.
855 SmallVector<StoreInst *, 8> LastAccesses;
856
857 OffloadArray() = default;
858
859 /// Initializes the OffloadArray with the values stored in \p Array before
860 /// instruction \p Before is reached. Returns false if the initialization
861 /// fails.
862 /// This MUST be used immediately after the construction of the object.
863 bool initialize(AllocaInst &Array, Instruction &Before) {
864 if (!Array.getAllocatedType()->isArrayTy())
865 return false;
866
867 if (!getValues(Array, Before))
868 return false;
869
870 this->Array = &Array;
871 return true;
872 }
873
874 static const unsigned DeviceIDArgNum = 1;
875 static const unsigned BasePtrsArgNum = 3;
876 static const unsigned PtrsArgNum = 4;
877 static const unsigned SizesArgNum = 5;
878
879private:
880 /// Traverses the BasicBlock where \p Array is, collecting the stores made to
881 /// \p Array, leaving StoredValues with the values stored before the
882 /// instruction \p Before is reached.
883 bool getValues(AllocaInst &Array, Instruction &Before) {
884 // Initialize container.
885 const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
886 StoredValues.assign(NumValues, nullptr);
887 LastAccesses.assign(NumValues, nullptr);
888
889 // TODO: This assumes the instruction \p Before is in the same
890 // BasicBlock as Array. Make it general, for any control flow graph.
891 BasicBlock *BB = Array.getParent();
892 if (BB != Before.getParent())
893 return false;
894
895 const DataLayout &DL = Array.getDataLayout();
896 const unsigned int PointerSize = DL.getPointerSize();
897
898 for (Instruction &I : *BB) {
899 if (&I == &Before)
900 break;
901
902 if (!isa<StoreInst>(&I))
903 continue;
904
905 auto *S = cast<StoreInst>(&I);
906 int64_t Offset = -1;
907 auto *Dst =
908 GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
909 if (Dst == &Array) {
910 int64_t Idx = Offset / PointerSize;
911 StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
912 LastAccesses[Idx] = S;
913 }
914 }
915
916 return isFilled();
917 }
918
919 /// Returns true if all values in StoredValues and
920 /// LastAccesses are not nullptrs.
921 bool isFilled() {
922 const unsigned NumValues = StoredValues.size();
923 for (unsigned I = 0; I < NumValues; ++I) {
924 if (!StoredValues[I] || !LastAccesses[I])
925 return false;
926 }
927
928 return true;
929 }
930};
931
932struct OpenMPOpt {
933
934 using OptimizationRemarkGetter =
936
937 OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
938 OptimizationRemarkGetter OREGetter,
939 OMPInformationCache &OMPInfoCache, Attributor &A)
940 : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
941 OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
942
943 /// Check if any remarks are enabled for openmp-opt
944 bool remarksEnabled() {
945 auto &Ctx = M.getContext();
946 return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
947 }
948
949 /// Run all OpenMP optimizations on the underlying SCC.
950 bool run(bool IsModulePass) {
951 if (SCC.empty())
952 return false;
953
954 bool Changed = false;
955
956 LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
957 << " functions\n");
958
959 if (IsModulePass) {
960 Changed |= runAttributor(IsModulePass);
961
962 // Recollect uses, in case Attributor deleted any.
963 OMPInfoCache.recollectUses();
964
965 // TODO: This should be folded into buildCustomStateMachine.
966 Changed |= rewriteDeviceCodeStateMachine();
967
968 if (remarksEnabled())
969 analysisGlobalization();
970 } else {
971 if (PrintICVValues)
972 printICVs();
974 printKernels();
975
976 Changed |= runAttributor(IsModulePass);
977
978 // Recollect uses, in case Attributor deleted any.
979 OMPInfoCache.recollectUses();
980
981 Changed |= deleteParallelRegions();
982
984 Changed |= hideMemTransfersLatency();
985 Changed |= deduplicateRuntimeCalls();
987 if (mergeParallelRegions()) {
988 deduplicateRuntimeCalls();
989 Changed = true;
990 }
991 }
992 }
993
994 if (OMPInfoCache.OpenMPPostLink)
995 Changed |= removeRuntimeSymbols();
996
997 return Changed;
998 }
999
1000 /// Print initial ICV values for testing.
1001 /// FIXME: This should be done from the Attributor once it is added.
1002 void printICVs() const {
1003 InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
1004 ICV_proc_bind};
1005
1006 for (Function *F : SCC) {
1007 for (auto ICV : ICVs) {
1008 auto ICVInfo = OMPInfoCache.ICVs[ICV];
1009 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1010 return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
1011 << " Value: "
1012 << (ICVInfo.InitValue
1013 ? toString(ICVInfo.InitValue->getValue(), 10, true)
1014 : "IMPLEMENTATION_DEFINED");
1015 };
1016
1017 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
1018 }
1019 }
1020 }
1021
1022 /// Print OpenMP GPU kernels for testing.
1023 void printKernels() const {
1024 for (Function *F : SCC) {
1025 if (!omp::isOpenMPKernel(*F))
1026 continue;
1027
1028 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1029 return ORA << "OpenMP GPU kernel "
1030 << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
1031 };
1032
1033 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
1034 }
1035 }
1036
1037 /// Return the call if \p U is a callee use in a regular call. If \p RFI is
1038 /// given it has to be the callee or a nullptr is returned.
1039 static CallInst *getCallIfRegularCall(
1040 Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1041 CallInst *CI = dyn_cast<CallInst>(U.getUser());
1042 if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
1043 (!RFI ||
1044 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1045 return CI;
1046 return nullptr;
1047 }
1048
1049 /// Return the call if \p V is a regular call. If \p RFI is given it has to be
1050 /// the callee or a nullptr is returned.
1051 static CallInst *getCallIfRegularCall(
1052 Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1053 CallInst *CI = dyn_cast<CallInst>(&V);
1054 if (CI && !CI->hasOperandBundles() &&
1055 (!RFI ||
1056 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1057 return CI;
1058 return nullptr;
1059 }
1060
1061private:
1062 /// Merge parallel regions when it is safe.
1063 bool mergeParallelRegions() {
1064 const unsigned CallbackCalleeOperand = 2;
1065 const unsigned CallbackFirstArgOperand = 3;
1066 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
1067
1068 // Check if there are any __kmpc_fork_call calls to merge.
1069 OMPInformationCache::RuntimeFunctionInfo &RFI =
1070 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1071
1072 if (!RFI.Declaration)
1073 return false;
1074
1075 // Unmergable calls that prevent merging a parallel region.
1076 OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
1077 OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
1078 OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
1079 };
1080
1081 bool Changed = false;
1082 LoopInfo *LI = nullptr;
1083 DominatorTree *DT = nullptr;
1084
1086
1087 BasicBlock *StartBB = nullptr, *EndBB = nullptr;
1088 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1089 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1090 BasicBlock *CGEndBB =
1091 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1092 assert(StartBB != nullptr && "StartBB should not be null");
1093 CGStartBB->getTerminator()->setSuccessor(0, StartBB);
1094 assert(EndBB != nullptr && "EndBB should not be null");
1095 EndBB->getTerminator()->setSuccessor(0, CGEndBB);
1096 return Error::success();
1097 };
1098
1099 auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
1100 Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
1101 ReplacementValue = &Inner;
1102 return CodeGenIP;
1103 };
1104
1105 auto FiniCB = [&](InsertPointTy CodeGenIP) { return Error::success(); };
1106
1107 /// Create a sequential execution region within a merged parallel region,
1108 /// encapsulated in a master construct with a barrier for synchronization.
1109 auto CreateSequentialRegion = [&](Function *OuterFn,
1110 BasicBlock *OuterPredBB,
1111 Instruction *SeqStartI,
1112 Instruction *SeqEndI) {
1113 // Isolate the instructions of the sequential region to a separate
1114 // block.
1115 BasicBlock *ParentBB = SeqStartI->getParent();
1116 BasicBlock *SeqEndBB =
1117 SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
1118 BasicBlock *SeqAfterBB =
1119 SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
1120 BasicBlock *SeqStartBB =
1121 SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
1122
1123 assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
1124 "Expected a different CFG");
1125 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
1126 ParentBB->getTerminator()->eraseFromParent();
1127
1128 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1129 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1130 BasicBlock *CGEndBB =
1131 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1132 assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
1133 CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
1134 assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
1135 SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
1136 return Error::success();
1137 };
1138 auto FiniCB = [&](InsertPointTy CodeGenIP) { return Error::success(); };
1139
1140 // Find outputs from the sequential region to outside users and
1141 // broadcast their values to them.
1142 for (Instruction &I : *SeqStartBB) {
1143 SmallPtrSet<Instruction *, 4> OutsideUsers;
1144 for (User *Usr : I.users()) {
1145 Instruction &UsrI = *cast<Instruction>(Usr);
1146 // Ignore outputs to LT intrinsics, code extraction for the merged
1147 // parallel region will fix them.
1148 if (UsrI.isLifetimeStartOrEnd())
1149 continue;
1150
1151 if (UsrI.getParent() != SeqStartBB)
1152 OutsideUsers.insert(&UsrI);
1153 }
1154
1155 if (OutsideUsers.empty())
1156 continue;
1157
1158 // Emit an alloca in the outer region to store the broadcasted
1159 // value.
1160 const DataLayout &DL = M.getDataLayout();
1161 AllocaInst *AllocaI = new AllocaInst(
1162 I.getType(), DL.getAllocaAddrSpace(), nullptr,
1163 I.getName() + ".seq.output.alloc", OuterFn->front().begin());
1164
1165 // Emit a store instruction in the sequential BB to update the
1166 // value.
1167 new StoreInst(&I, AllocaI, SeqStartBB->getTerminator()->getIterator());
1168
1169 // Emit a load instruction and replace the use of the output value
1170 // with it.
1171 for (Instruction *UsrI : OutsideUsers) {
1172 LoadInst *LoadI = new LoadInst(I.getType(), AllocaI,
1173 I.getName() + ".seq.output.load",
1174 UsrI->getIterator());
1175 UsrI->replaceUsesOfWith(&I, LoadI);
1176 }
1177 }
1178
1180 InsertPointTy(ParentBB, ParentBB->end()), DL);
1182 OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB));
1183 cantFail(
1184 OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel));
1185
1186 BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
1187
1188 LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
1189 << "\n");
1190 };
1191
1192 // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
1193 // contained in BB and only separated by instructions that can be
1194 // redundantly executed in parallel. The block BB is split before the first
1195 // call (in MergableCIs) and after the last so the entire region we merge
1196 // into a single parallel region is contained in a single basic block
1197 // without any other instructions. We use the OpenMPIRBuilder to outline
1198 // that block and call the resulting function via __kmpc_fork_call.
1199 auto Merge = [&](const SmallVectorImpl<CallInst *> &MergableCIs,
1200 BasicBlock *BB) {
1201 // TODO: Change the interface to allow single CIs expanded, e.g, to
1202 // include an outer loop.
1203 assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
1204
1205 auto Remark = [&](OptimizationRemark OR) {
1206 OR << "Parallel region merged with parallel region"
1207 << (MergableCIs.size() > 2 ? "s" : "") << " at ";
1208 for (auto *CI : llvm::drop_begin(MergableCIs)) {
1209 OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
1210 if (CI != MergableCIs.back())
1211 OR << ", ";
1212 }
1213 return OR << ".";
1214 };
1215
1216 emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
1217
1218 Function *OriginalFn = BB->getParent();
1219 LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
1220 << " parallel regions in " << OriginalFn->getName()
1221 << "\n");
1222
1223 // Isolate the calls to merge in a separate block.
1224 EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
1225 BasicBlock *AfterBB =
1226 SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1227 StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
1228 "omp.par.merged");
1229
1230 assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
1231 const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1232 BB->getTerminator()->eraseFromParent();
1233
1234 // Create sequential regions for sequential instructions that are
1235 // in-between mergable parallel regions.
1236 for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1237 It != End; ++It) {
1238 Instruction *ForkCI = *It;
1239 Instruction *NextForkCI = *(It + 1);
1240
1241 // Continue if there are not in-between instructions.
1242 if (ForkCI->getNextNode() == NextForkCI)
1243 continue;
1244
1245 CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1246 NextForkCI->getPrevNode());
1247 }
1248
1249 OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1250 DL);
1251 IRBuilder<>::InsertPoint AllocaIP(
1252 &OriginalFn->getEntryBlock(),
1253 OriginalFn->getEntryBlock().getFirstInsertionPt());
1254 // Create the merged parallel region with default proc binding, to
1255 // avoid overriding binding settings, and without explicit cancellation.
1257 cantFail(OMPInfoCache.OMPBuilder.createParallel(
1258 Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
1259 OMP_PROC_BIND_default, /* IsCancellable */ false));
1260 BranchInst::Create(AfterBB, AfterIP.getBlock());
1261
1262 // Perform the actual outlining.
1263 OMPInfoCache.OMPBuilder.finalize(OriginalFn);
1264
1265 Function *OutlinedFn = MergableCIs.front()->getCaller();
1266
1267 // Replace the __kmpc_fork_call calls with direct calls to the outlined
1268 // callbacks.
1270 for (auto *CI : MergableCIs) {
1271 Value *Callee = CI->getArgOperand(CallbackCalleeOperand);
1272 FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask;
1273 Args.clear();
1274 Args.push_back(OutlinedFn->getArg(0));
1275 Args.push_back(OutlinedFn->getArg(1));
1276 for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1277 ++U)
1278 Args.push_back(CI->getArgOperand(U));
1279
1280 CallInst *NewCI =
1281 CallInst::Create(FT, Callee, Args, "", CI->getIterator());
1282 if (CI->getDebugLoc())
1283 NewCI->setDebugLoc(CI->getDebugLoc());
1284
1285 // Forward parameter attributes from the callback to the callee.
1286 for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1287 ++U)
1288 for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
1289 NewCI->addParamAttr(
1290 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1291
1292 // Emit an explicit barrier to replace the implicit fork-join barrier.
1293 if (CI != MergableCIs.back()) {
1294 // TODO: Remove barrier if the merged parallel region includes the
1295 // 'nowait' clause.
1296 cantFail(OMPInfoCache.OMPBuilder.createBarrier(
1297 InsertPointTy(NewCI->getParent(),
1298 NewCI->getNextNode()->getIterator()),
1299 OMPD_parallel));
1300 }
1301
1302 CI->eraseFromParent();
1303 }
1304
1305 assert(OutlinedFn != OriginalFn && "Outlining failed");
1306 CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1307 CGUpdater.reanalyzeFunction(*OriginalFn);
1308
1309 NumOpenMPParallelRegionsMerged += MergableCIs.size();
1310
1311 return true;
1312 };
1313
1314 // Helper function that identifes sequences of
1315 // __kmpc_fork_call uses in a basic block.
1316 auto DetectPRsCB = [&](Use &U, Function &F) {
1317 CallInst *CI = getCallIfRegularCall(U, &RFI);
1318 BB2PRMap[CI->getParent()].insert(CI);
1319
1320 return false;
1321 };
1322
1323 BB2PRMap.clear();
1324 RFI.foreachUse(SCC, DetectPRsCB);
1325 SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1326 // Find mergable parallel regions within a basic block that are
1327 // safe to merge, that is any in-between instructions can safely
1328 // execute in parallel after merging.
1329 // TODO: support merging across basic-blocks.
1330 for (auto &It : BB2PRMap) {
1331 auto &CIs = It.getSecond();
1332 if (CIs.size() < 2)
1333 continue;
1334
1335 BasicBlock *BB = It.getFirst();
1336 SmallVector<CallInst *, 4> MergableCIs;
1337
1338 /// Returns true if the instruction is mergable, false otherwise.
1339 /// A terminator instruction is unmergable by definition since merging
1340 /// works within a BB. Instructions before the mergable region are
1341 /// mergable if they are not calls to OpenMP runtime functions that may
1342 /// set different execution parameters for subsequent parallel regions.
1343 /// Instructions in-between parallel regions are mergable if they are not
1344 /// calls to any non-intrinsic function since that may call a non-mergable
1345 /// OpenMP runtime function.
1346 auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1347 // We do not merge across BBs, hence return false (unmergable) if the
1348 // instruction is a terminator.
1349 if (I.isTerminator())
1350 return false;
1351
1352 if (!isa<CallInst>(&I))
1353 return true;
1354
1355 CallInst *CI = cast<CallInst>(&I);
1356 if (IsBeforeMergableRegion) {
1357 Function *CalledFunction = CI->getCalledFunction();
1358 if (!CalledFunction)
1359 return false;
1360 // Return false (unmergable) if the call before the parallel
1361 // region calls an explicit affinity (proc_bind) or number of
1362 // threads (num_threads) compiler-generated function. Those settings
1363 // may be incompatible with following parallel regions.
1364 // TODO: ICV tracking to detect compatibility.
1365 for (const auto &RFI : UnmergableCallsInfo) {
1366 if (CalledFunction == RFI.Declaration)
1367 return false;
1368 }
1369 } else {
1370 // Return false (unmergable) if there is a call instruction
1371 // in-between parallel regions when it is not an intrinsic. It
1372 // may call an unmergable OpenMP runtime function in its callpath.
1373 // TODO: Keep track of possible OpenMP calls in the callpath.
1374 if (!isa<IntrinsicInst>(CI))
1375 return false;
1376 }
1377
1378 return true;
1379 };
1380 // Find maximal number of parallel region CIs that are safe to merge.
1381 for (auto It = BB->begin(), End = BB->end(); It != End;) {
1382 Instruction &I = *It;
1383 ++It;
1384
1385 if (CIs.count(&I)) {
1386 MergableCIs.push_back(cast<CallInst>(&I));
1387 continue;
1388 }
1389
1390 // Continue expanding if the instruction is mergable.
1391 if (IsMergable(I, MergableCIs.empty()))
1392 continue;
1393
1394 // Forward the instruction iterator to skip the next parallel region
1395 // since there is an unmergable instruction which can affect it.
1396 for (; It != End; ++It) {
1397 Instruction &SkipI = *It;
1398 if (CIs.count(&SkipI)) {
1399 LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1400 << " due to " << I << "\n");
1401 ++It;
1402 break;
1403 }
1404 }
1405
1406 // Store mergable regions found.
1407 if (MergableCIs.size() > 1) {
1408 MergableCIsVector.push_back(MergableCIs);
1409 LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1410 << " parallel regions in block " << BB->getName()
1411 << " of function " << BB->getParent()->getName()
1412 << "\n";);
1413 }
1414
1415 MergableCIs.clear();
1416 }
1417
1418 if (!MergableCIsVector.empty()) {
1419 Changed = true;
1420
1421 for (auto &MergableCIs : MergableCIsVector)
1422 Merge(MergableCIs, BB);
1423 MergableCIsVector.clear();
1424 }
1425 }
1426
1427 if (Changed) {
1428 /// Re-collect use for fork calls, emitted barrier calls, and
1429 /// any emitted master/end_master calls.
1430 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1431 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1432 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1433 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1434 }
1435
1436 return Changed;
1437 }
1438
1439 /// Try to delete parallel regions if possible.
1440 bool deleteParallelRegions() {
1441 const unsigned CallbackCalleeOperand = 2;
1442
1443 OMPInformationCache::RuntimeFunctionInfo &RFI =
1444 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1445
1446 if (!RFI.Declaration)
1447 return false;
1448
1449 bool Changed = false;
1450 auto DeleteCallCB = [&](Use &U, Function &) {
1451 CallInst *CI = getCallIfRegularCall(U);
1452 if (!CI)
1453 return false;
1454 auto *Fn = dyn_cast<Function>(
1455 CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1456 if (!Fn)
1457 return false;
1458 if (!Fn->onlyReadsMemory())
1459 return false;
1460 if (!Fn->hasFnAttribute(Attribute::WillReturn))
1461 return false;
1462
1463 LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1464 << CI->getCaller()->getName() << "\n");
1465
1466 auto Remark = [&](OptimizationRemark OR) {
1467 return OR << "Removing parallel region with no side-effects.";
1468 };
1469 emitRemark<OptimizationRemark>(CI, "OMP160", Remark);
1470
1471 CI->eraseFromParent();
1472 Changed = true;
1473 ++NumOpenMPParallelRegionsDeleted;
1474 return true;
1475 };
1476
1477 RFI.foreachUse(SCC, DeleteCallCB);
1478
1479 return Changed;
1480 }
1481
1482 /// Try to eliminate runtime calls by reusing existing ones.
1483 bool deduplicateRuntimeCalls() {
1484 bool Changed = false;
1485
1486 RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1487 OMPRTL_omp_get_num_threads,
1488 OMPRTL_omp_in_parallel,
1489 OMPRTL_omp_get_cancellation,
1490 OMPRTL_omp_get_supported_active_levels,
1491 OMPRTL_omp_get_level,
1492 OMPRTL_omp_get_ancestor_thread_num,
1493 OMPRTL_omp_get_team_size,
1494 OMPRTL_omp_get_active_level,
1495 OMPRTL_omp_in_final,
1496 OMPRTL_omp_get_proc_bind,
1497 OMPRTL_omp_get_num_places,
1498 OMPRTL_omp_get_num_procs,
1499 OMPRTL_omp_get_place_num,
1500 OMPRTL_omp_get_partition_num_places,
1501 OMPRTL_omp_get_partition_place_nums};
1502
1503 // Global-tid is handled separately.
1505 collectGlobalThreadIdArguments(GTIdArgs);
1506 LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1507 << " global thread ID arguments\n");
1508
1509 for (Function *F : SCC) {
1510 for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1511 Changed |= deduplicateRuntimeCalls(
1512 *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1513
1514 // __kmpc_global_thread_num is special as we can replace it with an
1515 // argument in enough cases to make it worth trying.
1516 Value *GTIdArg = nullptr;
1517 for (Argument &Arg : F->args())
1518 if (GTIdArgs.count(&Arg)) {
1519 GTIdArg = &Arg;
1520 break;
1521 }
1522 Changed |= deduplicateRuntimeCalls(
1523 *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1524 }
1525
1526 return Changed;
1527 }
1528
1529 /// Tries to remove known runtime symbols that are optional from the module.
1530 bool removeRuntimeSymbols() {
1531 // The RPC client symbol is defined in `libc` and indicates that something
1532 // required an RPC server. If its users were all optimized out then we can
1533 // safely remove it.
1534 // TODO: This should be somewhere more common in the future.
1535 if (GlobalVariable *GV = M.getNamedGlobal("__llvm_rpc_client")) {
1536 if (GV->getNumUses() >= 1)
1537 return false;
1538
1539 GV->replaceAllUsesWith(PoisonValue::get(GV->getType()));
1540 GV->eraseFromParent();
1541 return true;
1542 }
1543 return false;
1544 }
1545
1546 /// Tries to hide the latency of runtime calls that involve host to
1547 /// device memory transfers by splitting them into their "issue" and "wait"
1548 /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1549 /// moved downards as much as possible. The "issue" issues the memory transfer
1550 /// asynchronously, returning a handle. The "wait" waits in the returned
1551 /// handle for the memory transfer to finish.
1552 bool hideMemTransfersLatency() {
1553 auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1554 bool Changed = false;
1555 auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1556 auto *RTCall = getCallIfRegularCall(U, &RFI);
1557 if (!RTCall)
1558 return false;
1559
1560 OffloadArray OffloadArrays[3];
1561 if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1562 return false;
1563
1564 LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1565
1566 // TODO: Check if can be moved upwards.
1567 bool WasSplit = false;
1568 Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1569 if (WaitMovementPoint)
1570 WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1571
1572 Changed |= WasSplit;
1573 return WasSplit;
1574 };
1575 if (OMPInfoCache.runtimeFnsAvailable(
1576 {OMPRTL___tgt_target_data_begin_mapper_issue,
1577 OMPRTL___tgt_target_data_begin_mapper_wait}))
1578 RFI.foreachUse(SCC, SplitMemTransfers);
1579
1580 return Changed;
1581 }
1582
1583 void analysisGlobalization() {
1584 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1585
1586 auto CheckGlobalization = [&](Use &U, Function &Decl) {
1587 if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1588 auto Remark = [&](OptimizationRemarkMissed ORM) {
1589 return ORM
1590 << "Found thread data sharing on the GPU. "
1591 << "Expect degraded performance due to data globalization.";
1592 };
1593 emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark);
1594 }
1595
1596 return false;
1597 };
1598
1599 RFI.foreachUse(SCC, CheckGlobalization);
1600 }
1601
1602 /// Maps the values stored in the offload arrays passed as arguments to
1603 /// \p RuntimeCall into the offload arrays in \p OAs.
1604 bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1606 assert(OAs.size() == 3 && "Need space for three offload arrays!");
1607
1608 // A runtime call that involves memory offloading looks something like:
1609 // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1610 // i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1611 // ...)
1612 // So, the idea is to access the allocas that allocate space for these
1613 // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1614 // Therefore:
1615 // i8** %offload_baseptrs.
1616 Value *BasePtrsArg =
1617 RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1618 // i8** %offload_ptrs.
1619 Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1620 // i8** %offload_sizes.
1621 Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1622
1623 // Get values stored in **offload_baseptrs.
1624 auto *V = getUnderlyingObject(BasePtrsArg);
1625 if (!isa<AllocaInst>(V))
1626 return false;
1627 auto *BasePtrsArray = cast<AllocaInst>(V);
1628 if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1629 return false;
1630
1631 // Get values stored in **offload_baseptrs.
1632 V = getUnderlyingObject(PtrsArg);
1633 if (!isa<AllocaInst>(V))
1634 return false;
1635 auto *PtrsArray = cast<AllocaInst>(V);
1636 if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1637 return false;
1638
1639 // Get values stored in **offload_sizes.
1640 V = getUnderlyingObject(SizesArg);
1641 // If it's a [constant] global array don't analyze it.
1642 if (isa<GlobalValue>(V))
1643 return isa<Constant>(V);
1644 if (!isa<AllocaInst>(V))
1645 return false;
1646
1647 auto *SizesArray = cast<AllocaInst>(V);
1648 if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1649 return false;
1650
1651 return true;
1652 }
1653
1654 /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1655 /// For now this is a way to test that the function getValuesInOffloadArrays
1656 /// is working properly.
1657 /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1658 void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1659 assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1660
1661 LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1662 std::string ValuesStr;
1663 raw_string_ostream Printer(ValuesStr);
1664 std::string Separator = " --- ";
1665
1666 for (auto *BP : OAs[0].StoredValues) {
1667 BP->print(Printer);
1668 Printer << Separator;
1669 }
1670 LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << ValuesStr << "\n");
1671 ValuesStr.clear();
1672
1673 for (auto *P : OAs[1].StoredValues) {
1674 P->print(Printer);
1675 Printer << Separator;
1676 }
1677 LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << ValuesStr << "\n");
1678 ValuesStr.clear();
1679
1680 for (auto *S : OAs[2].StoredValues) {
1681 S->print(Printer);
1682 Printer << Separator;
1683 }
1684 LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << ValuesStr << "\n");
1685 }
1686
1687 /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1688 /// moved. Returns nullptr if the movement is not possible, or not worth it.
1689 Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1690 // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1691 // Make it traverse the CFG.
1692
1693 Instruction *CurrentI = &RuntimeCall;
1694 bool IsWorthIt = false;
1695 while ((CurrentI = CurrentI->getNextNode())) {
1696
1697 // TODO: Once we detect the regions to be offloaded we should use the
1698 // alias analysis manager to check if CurrentI may modify one of
1699 // the offloaded regions.
1700 if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1701 if (IsWorthIt)
1702 return CurrentI;
1703
1704 return nullptr;
1705 }
1706
1707 // FIXME: For now if we move it over anything without side effect
1708 // is worth it.
1709 IsWorthIt = true;
1710 }
1711
1712 // Return end of BasicBlock.
1713 return RuntimeCall.getParent()->getTerminator();
1714 }
1715
1716 /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1717 bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1718 Instruction &WaitMovementPoint) {
1719 // Create stack allocated handle (__tgt_async_info) at the beginning of the
1720 // function. Used for storing information of the async transfer, allowing to
1721 // wait on it later.
1722 auto &IRBuilder = OMPInfoCache.OMPBuilder;
1723 Function *F = RuntimeCall.getCaller();
1724 BasicBlock &Entry = F->getEntryBlock();
1725 IRBuilder.Builder.SetInsertPoint(&Entry,
1726 Entry.getFirstNonPHIOrDbgOrAlloca());
1727 Value *Handle = IRBuilder.Builder.CreateAlloca(
1728 IRBuilder.AsyncInfo, /*ArraySize=*/nullptr, "handle");
1729 Handle =
1730 IRBuilder.Builder.CreateAddrSpaceCast(Handle, IRBuilder.AsyncInfoPtr);
1731
1732 // Add "issue" runtime call declaration:
1733 // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1734 // i8**, i8**, i64*, i64*)
1735 FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1736 M, OMPRTL___tgt_target_data_begin_mapper_issue);
1737
1738 // Change RuntimeCall call site for its asynchronous version.
1740 for (auto &Arg : RuntimeCall.args())
1741 Args.push_back(Arg.get());
1742 Args.push_back(Handle);
1743
1744 CallInst *IssueCallsite = CallInst::Create(IssueDecl, Args, /*NameStr=*/"",
1745 RuntimeCall.getIterator());
1746 OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);
1747 RuntimeCall.eraseFromParent();
1748
1749 // Add "wait" runtime call declaration:
1750 // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1751 FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1752 M, OMPRTL___tgt_target_data_begin_mapper_wait);
1753
1754 Value *WaitParams[2] = {
1755 IssueCallsite->getArgOperand(
1756 OffloadArray::DeviceIDArgNum), // device_id.
1757 Handle // handle to wait on.
1758 };
1759 CallInst *WaitCallsite = CallInst::Create(
1760 WaitDecl, WaitParams, /*NameStr=*/"", WaitMovementPoint.getIterator());
1761 OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);
1762
1763 return true;
1764 }
1765
1766 static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1767 bool GlobalOnly, bool &SingleChoice) {
1768 if (CurrentIdent == NextIdent)
1769 return CurrentIdent;
1770
1771 // TODO: Figure out how to actually combine multiple debug locations. For
1772 // now we just keep an existing one if there is a single choice.
1773 if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1774 SingleChoice = !CurrentIdent;
1775 return NextIdent;
1776 }
1777 return nullptr;
1778 }
1779
1780 /// Return an `struct ident_t*` value that represents the ones used in the
1781 /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1782 /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1783 /// return value we create one from scratch. We also do not yet combine
1784 /// information, e.g., the source locations, see combinedIdentStruct.
1785 Value *
1786 getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1787 Function &F, bool GlobalOnly) {
1788 bool SingleChoice = true;
1789 Value *Ident = nullptr;
1790 auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1791 CallInst *CI = getCallIfRegularCall(U, &RFI);
1792 if (!CI || &F != &Caller)
1793 return false;
1794 Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1795 /* GlobalOnly */ true, SingleChoice);
1796 return false;
1797 };
1798 RFI.foreachUse(SCC, CombineIdentStruct);
1799
1800 if (!Ident || !SingleChoice) {
1801 // The IRBuilder uses the insertion block to get to the module, this is
1802 // unfortunate but we work around it for now.
1803 if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1804 OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1805 &F.getEntryBlock(), F.getEntryBlock().begin()));
1806 // Create a fallback location if non was found.
1807 // TODO: Use the debug locations of the calls instead.
1808 uint32_t SrcLocStrSize;
1809 Constant *Loc =
1810 OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1811 Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize);
1812 }
1813 return Ident;
1814 }
1815
1816 /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1817 /// \p ReplVal if given.
1818 bool deduplicateRuntimeCalls(Function &F,
1819 OMPInformationCache::RuntimeFunctionInfo &RFI,
1820 Value *ReplVal = nullptr) {
1821 auto *UV = RFI.getUseVector(F);
1822 if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1823 return false;
1824
1825 LLVM_DEBUG(
1826 dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1827 << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1828
1829 assert((!ReplVal || (isa<Argument>(ReplVal) &&
1830 cast<Argument>(ReplVal)->getParent() == &F)) &&
1831 "Unexpected replacement value!");
1832
1833 // TODO: Use dominance to find a good position instead.
1834 auto CanBeMoved = [this](CallBase &CB) {
1835 unsigned NumArgs = CB.arg_size();
1836 if (NumArgs == 0)
1837 return true;
1838 if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1839 return false;
1840 for (unsigned U = 1; U < NumArgs; ++U)
1841 if (isa<Instruction>(CB.getArgOperand(U)))
1842 return false;
1843 return true;
1844 };
1845
1846 if (!ReplVal) {
1847 auto *DT =
1848 OMPInfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(F);
1849 if (!DT)
1850 return false;
1851 Instruction *IP = nullptr;
1852 for (Use *U : *UV) {
1853 if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1854 if (IP)
1855 IP = DT->findNearestCommonDominator(IP, CI);
1856 else
1857 IP = CI;
1858 if (!CanBeMoved(*CI))
1859 continue;
1860 if (!ReplVal)
1861 ReplVal = CI;
1862 }
1863 }
1864 if (!ReplVal)
1865 return false;
1866 assert(IP && "Expected insertion point!");
1867 cast<Instruction>(ReplVal)->moveBefore(IP);
1868 }
1869
1870 // If we use a call as a replacement value we need to make sure the ident is
1871 // valid at the new location. For now we just pick a global one, either
1872 // existing and used by one of the calls, or created from scratch.
1873 if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1874 if (!CI->arg_empty() &&
1875 CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1876 Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1877 /* GlobalOnly */ true);
1878 CI->setArgOperand(0, Ident);
1879 }
1880 }
1881
1882 bool Changed = false;
1883 auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1884 CallInst *CI = getCallIfRegularCall(U, &RFI);
1885 if (!CI || CI == ReplVal || &F != &Caller)
1886 return false;
1887 assert(CI->getCaller() == &F && "Unexpected call!");
1888
1889 auto Remark = [&](OptimizationRemark OR) {
1890 return OR << "OpenMP runtime call "
1891 << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1892 };
1893 if (CI->getDebugLoc())
1894 emitRemark<OptimizationRemark>(CI, "OMP170", Remark);
1895 else
1896 emitRemark<OptimizationRemark>(&F, "OMP170", Remark);
1897
1898 CI->replaceAllUsesWith(ReplVal);
1899 CI->eraseFromParent();
1900 ++NumOpenMPRuntimeCallsDeduplicated;
1901 Changed = true;
1902 return true;
1903 };
1904 RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1905
1906 return Changed;
1907 }
1908
1909 /// Collect arguments that represent the global thread id in \p GTIdArgs.
1910 void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1911 // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1912 // initialization. We could define an AbstractAttribute instead and
1913 // run the Attributor here once it can be run as an SCC pass.
1914
1915 // Helper to check the argument \p ArgNo at all call sites of \p F for
1916 // a GTId.
1917 auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1918 if (!F.hasLocalLinkage())
1919 return false;
1920 for (Use &U : F.uses()) {
1921 if (CallInst *CI = getCallIfRegularCall(U)) {
1922 Value *ArgOp = CI->getArgOperand(ArgNo);
1923 if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1924 getCallIfRegularCall(
1925 *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1926 continue;
1927 }
1928 return false;
1929 }
1930 return true;
1931 };
1932
1933 // Helper to identify uses of a GTId as GTId arguments.
1934 auto AddUserArgs = [&](Value &GTId) {
1935 for (Use &U : GTId.uses())
1936 if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1937 if (CI->isArgOperand(&U))
1938 if (Function *Callee = CI->getCalledFunction())
1939 if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1940 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1941 };
1942
1943 // The argument users of __kmpc_global_thread_num calls are GTIds.
1944 OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1945 OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1946
1947 GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1948 if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1949 AddUserArgs(*CI);
1950 return false;
1951 });
1952
1953 // Transitively search for more arguments by looking at the users of the
1954 // ones we know already. During the search the GTIdArgs vector is extended
1955 // so we cannot cache the size nor can we use a range based for.
1956 for (unsigned U = 0; U < GTIdArgs.size(); ++U)
1957 AddUserArgs(*GTIdArgs[U]);
1958 }
1959
1960 /// Kernel (=GPU) optimizations and utility functions
1961 ///
1962 ///{{
1963
1964 /// Cache to remember the unique kernel for a function.
1966
1967 /// Find the unique kernel that will execute \p F, if any.
1968 Kernel getUniqueKernelFor(Function &F);
1969
1970 /// Find the unique kernel that will execute \p I, if any.
1971 Kernel getUniqueKernelFor(Instruction &I) {
1972 return getUniqueKernelFor(*I.getFunction());
1973 }
1974
1975 /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1976 /// the cases we can avoid taking the address of a function.
1977 bool rewriteDeviceCodeStateMachine();
1978
1979 ///
1980 ///}}
1981
1982 /// Emit a remark generically
1983 ///
1984 /// This template function can be used to generically emit a remark. The
1985 /// RemarkKind should be one of the following:
1986 /// - OptimizationRemark to indicate a successful optimization attempt
1987 /// - OptimizationRemarkMissed to report a failed optimization attempt
1988 /// - OptimizationRemarkAnalysis to provide additional information about an
1989 /// optimization attempt
1990 ///
1991 /// The remark is built using a callback function provided by the caller that
1992 /// takes a RemarkKind as input and returns a RemarkKind.
1993 template <typename RemarkKind, typename RemarkCallBack>
1994 void emitRemark(Instruction *I, StringRef RemarkName,
1995 RemarkCallBack &&RemarkCB) const {
1996 Function *F = I->getParent()->getParent();
1997 auto &ORE = OREGetter(F);
1998
1999 if (RemarkName.starts_with("OMP"))
2000 ORE.emit([&]() {
2001 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
2002 << " [" << RemarkName << "]";
2003 });
2004 else
2005 ORE.emit(
2006 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
2007 }
2008
2009 /// Emit a remark on a function.
2010 template <typename RemarkKind, typename RemarkCallBack>
2011 void emitRemark(Function *F, StringRef RemarkName,
2012 RemarkCallBack &&RemarkCB) const {
2013 auto &ORE = OREGetter(F);
2014
2015 if (RemarkName.starts_with("OMP"))
2016 ORE.emit([&]() {
2017 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
2018 << " [" << RemarkName << "]";
2019 });
2020 else
2021 ORE.emit(
2022 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
2023 }
2024
2025 /// The underlying module.
2026 Module &M;
2027
2028 /// The SCC we are operating on.
2030
2031 /// Callback to update the call graph, the first argument is a removed call,
2032 /// the second an optional replacement call.
2033 CallGraphUpdater &CGUpdater;
2034
2035 /// Callback to get an OptimizationRemarkEmitter from a Function *
2036 OptimizationRemarkGetter OREGetter;
2037
2038 /// OpenMP-specific information cache. Also Used for Attributor runs.
2039 OMPInformationCache &OMPInfoCache;
2040
2041 /// Attributor instance.
2042 Attributor &A;
2043
2044 /// Helper function to run Attributor on SCC.
2045 bool runAttributor(bool IsModulePass) {
2046 if (SCC.empty())
2047 return false;
2048
2049 registerAAs(IsModulePass);
2050
2051 ChangeStatus Changed = A.run();
2052
2053 LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
2054 << " functions, result: " << Changed << ".\n");
2055
2056 if (Changed == ChangeStatus::CHANGED)
2057 OMPInfoCache.invalidateAnalyses();
2058
2059 return Changed == ChangeStatus::CHANGED;
2060 }
2061
2062 void registerFoldRuntimeCall(RuntimeFunction RF);
2063
2064 /// Populate the Attributor with abstract attribute opportunities in the
2065 /// functions.
2066 void registerAAs(bool IsModulePass);
2067
2068public:
2069 /// Callback to register AAs for live functions, including internal functions
2070 /// marked live during the traversal.
2071 static void registerAAsForFunction(Attributor &A, const Function &F);
2072};
2073
2074Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
2075 if (OMPInfoCache.CGSCC && !OMPInfoCache.CGSCC->empty() &&
2076 !OMPInfoCache.CGSCC->contains(&F))
2077 return nullptr;
2078
2079 // Use a scope to keep the lifetime of the CachedKernel short.
2080 {
2081 std::optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
2082 if (CachedKernel)
2083 return *CachedKernel;
2084
2085 // TODO: We should use an AA to create an (optimistic and callback
2086 // call-aware) call graph. For now we stick to simple patterns that
2087 // are less powerful, basically the worst fixpoint.
2088 if (isOpenMPKernel(F)) {
2089 CachedKernel = Kernel(&F);
2090 return *CachedKernel;
2091 }
2092
2093 CachedKernel = nullptr;
2094 if (!F.hasLocalLinkage()) {
2095
2096 // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
2097 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2098 return ORA << "Potentially unknown OpenMP target region caller.";
2099 };
2100 emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
2101
2102 return nullptr;
2103 }
2104 }
2105
2106 auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
2107 if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
2108 // Allow use in equality comparisons.
2109 if (Cmp->isEquality())
2110 return getUniqueKernelFor(*Cmp);
2111 return nullptr;
2112 }
2113 if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
2114 // Allow direct calls.
2115 if (CB->isCallee(&U))
2116 return getUniqueKernelFor(*CB);
2117
2118 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2119 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2120 // Allow the use in __kmpc_parallel_51 calls.
2121 if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
2122 return getUniqueKernelFor(*CB);
2123 return nullptr;
2124 }
2125 // Disallow every other use.
2126 return nullptr;
2127 };
2128
2129 // TODO: In the future we want to track more than just a unique kernel.
2130 SmallPtrSet<Kernel, 2> PotentialKernels;
2131 OMPInformationCache::foreachUse(F, [&](const Use &U) {
2132 PotentialKernels.insert(GetUniqueKernelForUse(U));
2133 });
2134
2135 Kernel K = nullptr;
2136 if (PotentialKernels.size() == 1)
2137 K = *PotentialKernels.begin();
2138
2139 // Cache the result.
2140 UniqueKernelMap[&F] = K;
2141
2142 return K;
2143}
2144
2145bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
2146 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2147 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2148
2149 bool Changed = false;
2150 if (!KernelParallelRFI)
2151 return Changed;
2152
2153 // If we have disabled state machine changes, exit
2155 return Changed;
2156
2157 for (Function *F : SCC) {
2158
2159 // Check if the function is a use in a __kmpc_parallel_51 call at
2160 // all.
2161 bool UnknownUse = false;
2162 bool KernelParallelUse = false;
2163 unsigned NumDirectCalls = 0;
2164
2165 SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
2166 OMPInformationCache::foreachUse(*F, [&](Use &U) {
2167 if (auto *CB = dyn_cast<CallBase>(U.getUser()))
2168 if (CB->isCallee(&U)) {
2169 ++NumDirectCalls;
2170 return;
2171 }
2172
2173 if (isa<ICmpInst>(U.getUser())) {
2174 ToBeReplacedStateMachineUses.push_back(&U);
2175 return;
2176 }
2177
2178 // Find wrapper functions that represent parallel kernels.
2179 CallInst *CI =
2180 OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
2181 const unsigned int WrapperFunctionArgNo = 6;
2182 if (!KernelParallelUse && CI &&
2183 CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
2184 KernelParallelUse = true;
2185 ToBeReplacedStateMachineUses.push_back(&U);
2186 return;
2187 }
2188 UnknownUse = true;
2189 });
2190
2191 // Do not emit a remark if we haven't seen a __kmpc_parallel_51
2192 // use.
2193 if (!KernelParallelUse)
2194 continue;
2195
2196 // If this ever hits, we should investigate.
2197 // TODO: Checking the number of uses is not a necessary restriction and
2198 // should be lifted.
2199 if (UnknownUse || NumDirectCalls != 1 ||
2200 ToBeReplacedStateMachineUses.size() > 2) {
2201 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2202 return ORA << "Parallel region is used in "
2203 << (UnknownUse ? "unknown" : "unexpected")
2204 << " ways. Will not attempt to rewrite the state machine.";
2205 };
2206 emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark);
2207 continue;
2208 }
2209
2210 // Even if we have __kmpc_parallel_51 calls, we (for now) give
2211 // up if the function is not called from a unique kernel.
2212 Kernel K = getUniqueKernelFor(*F);
2213 if (!K) {
2214 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2215 return ORA << "Parallel region is not called from a unique kernel. "
2216 "Will not attempt to rewrite the state machine.";
2217 };
2218 emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark);
2219 continue;
2220 }
2221
2222 // We now know F is a parallel body function called only from the kernel K.
2223 // We also identified the state machine uses in which we replace the
2224 // function pointer by a new global symbol for identification purposes. This
2225 // ensures only direct calls to the function are left.
2226
2227 Module &M = *F->getParent();
2228 Type *Int8Ty = Type::getInt8Ty(M.getContext());
2229
2230 auto *ID = new GlobalVariable(
2231 M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2232 UndefValue::get(Int8Ty), F->getName() + ".ID");
2233
2234 for (Use *U : ToBeReplacedStateMachineUses)
2236 ID, U->get()->getType()));
2237
2238 ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2239
2240 Changed = true;
2241 }
2242
2243 return Changed;
2244}
2245
2246/// Abstract Attribute for tracking ICV values.
2247struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2249 AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2250
2251 /// Returns true if value is assumed to be tracked.
2252 bool isAssumedTracked() const { return getAssumed(); }
2253
2254 /// Returns true if value is known to be tracked.
2255 bool isKnownTracked() const { return getAssumed(); }
2256
2257 /// Create an abstract attribute biew for the position \p IRP.
2258 static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2259
2260 /// Return the value with which \p I can be replaced for specific \p ICV.
2261 virtual std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2262 const Instruction *I,
2263 Attributor &A) const {
2264 return std::nullopt;
2265 }
2266
2267 /// Return an assumed unique ICV value if a single candidate is found. If
2268 /// there cannot be one, return a nullptr. If it is not clear yet, return
2269 /// std::nullopt.
2270 virtual std::optional<Value *>
2271 getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2272
2273 // Currently only nthreads is being tracked.
2274 // this array will only grow with time.
2275 InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2276
2277 /// See AbstractAttribute::getName()
2278 const std::string getName() const override { return "AAICVTracker"; }
2279
2280 /// See AbstractAttribute::getIdAddr()
2281 const char *getIdAddr() const override { return &ID; }
2282
2283 /// This function should return true if the type of the \p AA is AAICVTracker
2284 static bool classof(const AbstractAttribute *AA) {
2285 return (AA->getIdAddr() == &ID);
2286 }
2287
2288 static const char ID;
2289};
2290
2291struct AAICVTrackerFunction : public AAICVTracker {
2292 AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2293 : AAICVTracker(IRP, A) {}
2294
2295 // FIXME: come up with better string.
2296 const std::string getAsStr(Attributor *) const override {
2297 return "ICVTrackerFunction";
2298 }
2299
2300 // FIXME: come up with some stats.
2301 void trackStatistics() const override {}
2302
2303 /// We don't manifest anything for this AA.
2304 ChangeStatus manifest(Attributor &A) override {
2305 return ChangeStatus::UNCHANGED;
2306 }
2307
2308 // Map of ICV to their values at specific program point.
2310 InternalControlVar::ICV___last>
2311 ICVReplacementValuesMap;
2312
2313 ChangeStatus updateImpl(Attributor &A) override {
2314 ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
2315
2316 Function *F = getAnchorScope();
2317
2318 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2319
2320 for (InternalControlVar ICV : TrackableICVs) {
2321 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2322
2323 auto &ValuesMap = ICVReplacementValuesMap[ICV];
2324 auto TrackValues = [&](Use &U, Function &) {
2325 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2326 if (!CI)
2327 return false;
2328
2329 // FIXME: handle setters with more that 1 arguments.
2330 /// Track new value.
2331 if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2332 HasChanged = ChangeStatus::CHANGED;
2333
2334 return false;
2335 };
2336
2337 auto CallCheck = [&](Instruction &I) {
2338 std::optional<Value *> ReplVal = getValueForCall(A, I, ICV);
2339 if (ReplVal && ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2340 HasChanged = ChangeStatus::CHANGED;
2341
2342 return true;
2343 };
2344
2345 // Track all changes of an ICV.
2346 SetterRFI.foreachUse(TrackValues, F);
2347
2348 bool UsedAssumedInformation = false;
2349 A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2350 UsedAssumedInformation,
2351 /* CheckBBLivenessOnly */ true);
2352
2353 /// TODO: Figure out a way to avoid adding entry in
2354 /// ICVReplacementValuesMap
2355 Instruction *Entry = &F->getEntryBlock().front();
2356 if (HasChanged == ChangeStatus::CHANGED)
2357 ValuesMap.try_emplace(Entry);
2358 }
2359
2360 return HasChanged;
2361 }
2362
2363 /// Helper to check if \p I is a call and get the value for it if it is
2364 /// unique.
2365 std::optional<Value *> getValueForCall(Attributor &A, const Instruction &I,
2366 InternalControlVar &ICV) const {
2367
2368 const auto *CB = dyn_cast<CallBase>(&I);
2369 if (!CB || CB->hasFnAttr("no_openmp") ||
2370 CB->hasFnAttr("no_openmp_routines"))
2371 return std::nullopt;
2372
2373 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2374 auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2375 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2376 Function *CalledFunction = CB->getCalledFunction();
2377
2378 // Indirect call, assume ICV changes.
2379 if (CalledFunction == nullptr)
2380 return nullptr;
2381 if (CalledFunction == GetterRFI.Declaration)
2382 return std::nullopt;
2383 if (CalledFunction == SetterRFI.Declaration) {
2384 if (ICVReplacementValuesMap[ICV].count(&I))
2385 return ICVReplacementValuesMap[ICV].lookup(&I);
2386
2387 return nullptr;
2388 }
2389
2390 // Since we don't know, assume it changes the ICV.
2391 if (CalledFunction->isDeclaration())
2392 return nullptr;
2393
2394 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2395 *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
2396
2397 if (ICVTrackingAA->isAssumedTracked()) {
2398 std::optional<Value *> URV =
2399 ICVTrackingAA->getUniqueReplacementValue(ICV);
2400 if (!URV || (*URV && AA::isValidAtPosition(AA::ValueAndContext(**URV, I),
2401 OMPInfoCache)))
2402 return URV;
2403 }
2404
2405 // If we don't know, assume it changes.
2406 return nullptr;
2407 }
2408
2409 // We don't check unique value for a function, so return std::nullopt.
2410 std::optional<Value *>
2411 getUniqueReplacementValue(InternalControlVar ICV) const override {
2412 return std::nullopt;
2413 }
2414
2415 /// Return the value with which \p I can be replaced for specific \p ICV.
2416 std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2417 const Instruction *I,
2418 Attributor &A) const override {
2419 const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2420 if (ValuesMap.count(I))
2421 return ValuesMap.lookup(I);
2422
2425 Worklist.push_back(I);
2426
2427 std::optional<Value *> ReplVal;
2428
2429 while (!Worklist.empty()) {
2430 const Instruction *CurrInst = Worklist.pop_back_val();
2431 if (!Visited.insert(CurrInst).second)
2432 continue;
2433
2434 const BasicBlock *CurrBB = CurrInst->getParent();
2435
2436 // Go up and look for all potential setters/calls that might change the
2437 // ICV.
2438 while ((CurrInst = CurrInst->getPrevNode())) {
2439 if (ValuesMap.count(CurrInst)) {
2440 std::optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2441 // Unknown value, track new.
2442 if (!ReplVal) {
2443 ReplVal = NewReplVal;
2444 break;
2445 }
2446
2447 // If we found a new value, we can't know the icv value anymore.
2448 if (NewReplVal)
2449 if (ReplVal != NewReplVal)
2450 return nullptr;
2451
2452 break;
2453 }
2454
2455 std::optional<Value *> NewReplVal = getValueForCall(A, *CurrInst, ICV);
2456 if (!NewReplVal)
2457 continue;
2458
2459 // Unknown value, track new.
2460 if (!ReplVal) {
2461 ReplVal = NewReplVal;
2462 break;
2463 }
2464
2465 // if (NewReplVal.hasValue())
2466 // We found a new value, we can't know the icv value anymore.
2467 if (ReplVal != NewReplVal)
2468 return nullptr;
2469 }
2470
2471 // If we are in the same BB and we have a value, we are done.
2472 if (CurrBB == I->getParent() && ReplVal)
2473 return ReplVal;
2474
2475 // Go through all predecessors and add terminators for analysis.
2476 for (const BasicBlock *Pred : predecessors(CurrBB))
2477 if (const Instruction *Terminator = Pred->getTerminator())
2478 Worklist.push_back(Terminator);
2479 }
2480
2481 return ReplVal;
2482 }
2483};
2484
2485struct AAICVTrackerFunctionReturned : AAICVTracker {
2486 AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2487 : AAICVTracker(IRP, A) {}
2488
2489 // FIXME: come up with better string.
2490 const std::string getAsStr(Attributor *) const override {
2491 return "ICVTrackerFunctionReturned";
2492 }
2493
2494 // FIXME: come up with some stats.
2495 void trackStatistics() const override {}
2496
2497 /// We don't manifest anything for this AA.
2498 ChangeStatus manifest(Attributor &A) override {
2499 return ChangeStatus::UNCHANGED;
2500 }
2501
2502 // Map of ICV to their values at specific program point.
2504 InternalControlVar::ICV___last>
2505 ICVReplacementValuesMap;
2506
2507 /// Return the value with which \p I can be replaced for specific \p ICV.
2508 std::optional<Value *>
2509 getUniqueReplacementValue(InternalControlVar ICV) const override {
2510 return ICVReplacementValuesMap[ICV];
2511 }
2512
2513 ChangeStatus updateImpl(Attributor &A) override {
2514 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2515 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2516 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2517
2518 if (!ICVTrackingAA->isAssumedTracked())
2519 return indicatePessimisticFixpoint();
2520
2521 for (InternalControlVar ICV : TrackableICVs) {
2522 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2523 std::optional<Value *> UniqueICVValue;
2524
2525 auto CheckReturnInst = [&](Instruction &I) {
2526 std::optional<Value *> NewReplVal =
2527 ICVTrackingAA->getReplacementValue(ICV, &I, A);
2528
2529 // If we found a second ICV value there is no unique returned value.
2530 if (UniqueICVValue && UniqueICVValue != NewReplVal)
2531 return false;
2532
2533 UniqueICVValue = NewReplVal;
2534
2535 return true;
2536 };
2537
2538 bool UsedAssumedInformation = false;
2539 if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2540 UsedAssumedInformation,
2541 /* CheckBBLivenessOnly */ true))
2542 UniqueICVValue = nullptr;
2543
2544 if (UniqueICVValue == ReplVal)
2545 continue;
2546
2547 ReplVal = UniqueICVValue;
2548 Changed = ChangeStatus::CHANGED;
2549 }
2550
2551 return Changed;
2552 }
2553};
2554
2555struct AAICVTrackerCallSite : AAICVTracker {
2556 AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2557 : AAICVTracker(IRP, A) {}
2558
2559 void initialize(Attributor &A) override {
2560 assert(getAnchorScope() && "Expected anchor function");
2561
2562 // We only initialize this AA for getters, so we need to know which ICV it
2563 // gets.
2564 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2565 for (InternalControlVar ICV : TrackableICVs) {
2566 auto ICVInfo = OMPInfoCache.ICVs[ICV];
2567 auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2568 if (Getter.Declaration == getAssociatedFunction()) {
2569 AssociatedICV = ICVInfo.Kind;
2570 return;
2571 }
2572 }
2573
2574 /// Unknown ICV.
2575 indicatePessimisticFixpoint();
2576 }
2577
2578 ChangeStatus manifest(Attributor &A) override {
2579 if (!ReplVal || !*ReplVal)
2580 return ChangeStatus::UNCHANGED;
2581
2582 A.changeAfterManifest(IRPosition::inst(*getCtxI()), **ReplVal);
2583 A.deleteAfterManifest(*getCtxI());
2584
2585 return ChangeStatus::CHANGED;
2586 }
2587
2588 // FIXME: come up with better string.
2589 const std::string getAsStr(Attributor *) const override {
2590 return "ICVTrackerCallSite";
2591 }
2592
2593 // FIXME: come up with some stats.
2594 void trackStatistics() const override {}
2595
2596 InternalControlVar AssociatedICV;
2597 std::optional<Value *> ReplVal;
2598
2599 ChangeStatus updateImpl(Attributor &A) override {
2600 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2601 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2602
2603 // We don't have any information, so we assume it changes the ICV.
2604 if (!ICVTrackingAA->isAssumedTracked())
2605 return indicatePessimisticFixpoint();
2606
2607 std::optional<Value *> NewReplVal =
2608 ICVTrackingAA->getReplacementValue(AssociatedICV, getCtxI(), A);
2609
2610 if (ReplVal == NewReplVal)
2611 return ChangeStatus::UNCHANGED;
2612
2613 ReplVal = NewReplVal;
2614 return ChangeStatus::CHANGED;
2615 }
2616
2617 // Return the value with which associated value can be replaced for specific
2618 // \p ICV.
2619 std::optional<Value *>
2620 getUniqueReplacementValue(InternalControlVar ICV) const override {
2621 return ReplVal;
2622 }
2623};
2624
2625struct AAICVTrackerCallSiteReturned : AAICVTracker {
2626 AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2627 : AAICVTracker(IRP, A) {}
2628
2629 // FIXME: come up with better string.
2630 const std::string getAsStr(Attributor *) const override {
2631 return "ICVTrackerCallSiteReturned";
2632 }
2633
2634 // FIXME: come up with some stats.
2635 void trackStatistics() const override {}
2636
2637 /// We don't manifest anything for this AA.
2638 ChangeStatus manifest(Attributor &A) override {
2639 return ChangeStatus::UNCHANGED;
2640 }
2641
2642 // Map of ICV to their values at specific program point.
2644 InternalControlVar::ICV___last>
2645 ICVReplacementValuesMap;
2646
2647 /// Return the value with which associated value can be replaced for specific
2648 /// \p ICV.
2649 std::optional<Value *>
2650 getUniqueReplacementValue(InternalControlVar ICV) const override {
2651 return ICVReplacementValuesMap[ICV];
2652 }
2653
2654 ChangeStatus updateImpl(Attributor &A) override {
2655 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2656 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2657 *this, IRPosition::returned(*getAssociatedFunction()),
2658 DepClassTy::REQUIRED);
2659
2660 // We don't have any information, so we assume it changes the ICV.
2661 if (!ICVTrackingAA->isAssumedTracked())
2662 return indicatePessimisticFixpoint();
2663
2664 for (InternalControlVar ICV : TrackableICVs) {
2665 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2666 std::optional<Value *> NewReplVal =
2667 ICVTrackingAA->getUniqueReplacementValue(ICV);
2668
2669 if (ReplVal == NewReplVal)
2670 continue;
2671
2672 ReplVal = NewReplVal;
2673 Changed = ChangeStatus::CHANGED;
2674 }
2675 return Changed;
2676 }
2677};
2678
2679/// Determines if \p BB exits the function unconditionally itself or reaches a
2680/// block that does through only unique successors.
2681static bool hasFunctionEndAsUniqueSuccessor(const BasicBlock *BB) {
2682 if (succ_empty(BB))
2683 return true;
2684 const BasicBlock *const Successor = BB->getUniqueSuccessor();
2685 if (!Successor)
2686 return false;
2687 return hasFunctionEndAsUniqueSuccessor(Successor);
2688}
2689
2690struct AAExecutionDomainFunction : public AAExecutionDomain {
2691 AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2692 : AAExecutionDomain(IRP, A) {}
2693
2694 ~AAExecutionDomainFunction() { delete RPOT; }
2695
2696 void initialize(Attributor &A) override {
2698 assert(F && "Expected anchor function");
2700 }
2701
2702 const std::string getAsStr(Attributor *) const override {
2703 unsigned TotalBlocks = 0, InitialThreadBlocks = 0, AlignedBlocks = 0;
2704 for (auto &It : BEDMap) {
2705 if (!It.getFirst())
2706 continue;
2707 TotalBlocks++;
2708 InitialThreadBlocks += It.getSecond().IsExecutedByInitialThreadOnly;
2709 AlignedBlocks += It.getSecond().IsReachedFromAlignedBarrierOnly &&
2710 It.getSecond().IsReachingAlignedBarrierOnly;
2711 }
2712 return "[AAExecutionDomain] " + std::to_string(InitialThreadBlocks) + "/" +
2713 std::to_string(AlignedBlocks) + " of " +
2714 std::to_string(TotalBlocks) +
2715 " executed by initial thread / aligned";
2716 }
2717
2718 /// See AbstractAttribute::trackStatistics().
2719 void trackStatistics() const override {}
2720
2721 ChangeStatus manifest(Attributor &A) override {
2722 LLVM_DEBUG({
2723 for (const BasicBlock &BB : *getAnchorScope()) {
2725 continue;
2726 dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2727 << BB.getName() << " is executed by a single thread.\n";
2728 }
2729 });
2730
2732
2734 return Changed;
2735
2736 SmallPtrSet<CallBase *, 16> DeletedBarriers;
2737 auto HandleAlignedBarrier = [&](CallBase *CB) {
2738 const ExecutionDomainTy &ED = CB ? CEDMap[{CB, PRE}] : BEDMap[nullptr];
2739 if (!ED.IsReachedFromAlignedBarrierOnly ||
2740 ED.EncounteredNonLocalSideEffect)
2741 return;
2742 if (!ED.EncounteredAssumes.empty() && !A.isModulePass())
2743 return;
2744
2745 // We can remove this barrier, if it is one, or aligned barriers reaching
2746 // the kernel end (if CB is nullptr). Aligned barriers reaching the kernel
2747 // end should only be removed if the kernel end is their unique successor;
2748 // otherwise, they may have side-effects that aren't accounted for in the
2749 // kernel end in their other successors. If those barriers have other
2750 // barriers reaching them, those can be transitively removed as well as
2751 // long as the kernel end is also their unique successor.
2752 if (CB) {
2753 DeletedBarriers.insert(CB);
2754 A.deleteAfterManifest(*CB);
2755 ++NumBarriersEliminated;
2756 Changed = ChangeStatus::CHANGED;
2757 } else if (!ED.AlignedBarriers.empty()) {
2758 Changed = ChangeStatus::CHANGED;
2759 SmallVector<CallBase *> Worklist(ED.AlignedBarriers.begin(),
2760 ED.AlignedBarriers.end());
2762 while (!Worklist.empty()) {
2763 CallBase *LastCB = Worklist.pop_back_val();
2764 if (!Visited.insert(LastCB))
2765 continue;
2766 if (LastCB->getFunction() != getAnchorScope())
2767 continue;
2768 if (!hasFunctionEndAsUniqueSuccessor(LastCB->getParent()))
2769 continue;
2770 if (!DeletedBarriers.count(LastCB)) {
2771 ++NumBarriersEliminated;
2772 A.deleteAfterManifest(*LastCB);
2773 continue;
2774 }
2775 // The final aligned barrier (LastCB) reaching the kernel end was
2776 // removed already. This means we can go one step further and remove
2777 // the barriers encoutered last before (LastCB).
2778 const ExecutionDomainTy &LastED = CEDMap[{LastCB, PRE}];
2779 Worklist.append(LastED.AlignedBarriers.begin(),
2780 LastED.AlignedBarriers.end());
2781 }
2782 }
2783
2784 // If we actually eliminated a barrier we need to eliminate the associated
2785 // llvm.assumes as well to avoid creating UB.
2786 if (!ED.EncounteredAssumes.empty() && (CB || !ED.AlignedBarriers.empty()))
2787 for (auto *AssumeCB : ED.EncounteredAssumes)
2788 A.deleteAfterManifest(*AssumeCB);
2789 };
2790
2791 for (auto *CB : AlignedBarriers)
2792 HandleAlignedBarrier(CB);
2793
2794 // Handle the "kernel end barrier" for kernels too.
2796 HandleAlignedBarrier(nullptr);
2797
2798 return Changed;
2799 }
2800
2801 bool isNoOpFence(const FenceInst &FI) const override {
2802 return getState().isValidState() && !NonNoOpFences.count(&FI);
2803 }
2804
2805 /// Merge barrier and assumption information from \p PredED into the successor
2806 /// \p ED.
2807 void
2808 mergeInPredecessorBarriersAndAssumptions(Attributor &A, ExecutionDomainTy &ED,
2809 const ExecutionDomainTy &PredED);
2810
2811 /// Merge all information from \p PredED into the successor \p ED. If
2812 /// \p InitialEdgeOnly is set, only the initial edge will enter the block
2813 /// represented by \p ED from this predecessor.
2814 bool mergeInPredecessor(Attributor &A, ExecutionDomainTy &ED,
2815 const ExecutionDomainTy &PredED,
2816 bool InitialEdgeOnly = false);
2817
2818 /// Accumulate information for the entry block in \p EntryBBED.
2819 bool handleCallees(Attributor &A, ExecutionDomainTy &EntryBBED);
2820
2821 /// See AbstractAttribute::updateImpl.
2823
2824 /// Query interface, see AAExecutionDomain
2825 ///{
2826 bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2827 if (!isValidState())
2828 return false;
2829 assert(BB.getParent() == getAnchorScope() && "Block is out of scope!");
2830 return BEDMap.lookup(&BB).IsExecutedByInitialThreadOnly;
2831 }
2832
2834 const Instruction &I) const override {
2835 assert(I.getFunction() == getAnchorScope() &&
2836 "Instruction is out of scope!");
2837 if (!isValidState())
2838 return false;
2839
2840 bool ForwardIsOk = true;
2841 const Instruction *CurI;
2842
2843 // Check forward until a call or the block end is reached.
2844 CurI = &I;
2845 do {
2846 auto *CB = dyn_cast<CallBase>(CurI);
2847 if (!CB)
2848 continue;
2849 if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2850 return true;
2851 const auto &It = CEDMap.find({CB, PRE});
2852 if (It == CEDMap.end())
2853 continue;
2854 if (!It->getSecond().IsReachingAlignedBarrierOnly)
2855 ForwardIsOk = false;
2856 break;
2857 } while ((CurI = CurI->getNextNonDebugInstruction()));
2858
2859 if (!CurI && !BEDMap.lookup(I.getParent()).IsReachingAlignedBarrierOnly)
2860 ForwardIsOk = false;
2861
2862 // Check backward until a call or the block beginning is reached.
2863 CurI = &I;
2864 do {
2865 auto *CB = dyn_cast<CallBase>(CurI);
2866 if (!CB)
2867 continue;
2868 if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2869 return true;
2870 const auto &It = CEDMap.find({CB, POST});
2871 if (It == CEDMap.end())
2872 continue;
2873 if (It->getSecond().IsReachedFromAlignedBarrierOnly)
2874 break;
2875 return false;
2876 } while ((CurI = CurI->getPrevNonDebugInstruction()));
2877
2878 // Delayed decision on the forward pass to allow aligned barrier detection
2879 // in the backwards traversal.
2880 if (!ForwardIsOk)
2881 return false;
2882
2883 if (!CurI) {
2884 const BasicBlock *BB = I.getParent();
2885 if (BB == &BB->getParent()->getEntryBlock())
2886 return BEDMap.lookup(nullptr).IsReachedFromAlignedBarrierOnly;
2887 if (!llvm::all_of(predecessors(BB), [&](const BasicBlock *PredBB) {
2888 return BEDMap.lookup(PredBB).IsReachedFromAlignedBarrierOnly;
2889 })) {
2890 return false;
2891 }
2892 }
2893
2894 // On neither traversal we found a anything but aligned barriers.
2895 return true;
2896 }
2897
2898 ExecutionDomainTy getExecutionDomain(const BasicBlock &BB) const override {
2899 assert(isValidState() &&
2900 "No request should be made against an invalid state!");
2901 return BEDMap.lookup(&BB);
2902 }
2903 std::pair<ExecutionDomainTy, ExecutionDomainTy>
2904 getExecutionDomain(const CallBase &CB) const override {
2905 assert(isValidState() &&
2906 "No request should be made against an invalid state!");
2907 return {CEDMap.lookup({&CB, PRE}), CEDMap.lookup({&CB, POST})};
2908 }
2909 ExecutionDomainTy getFunctionExecutionDomain() const override {
2910 assert(isValidState() &&
2911 "No request should be made against an invalid state!");
2912 return InterProceduralED;
2913 }
2914 ///}
2915
2916 // Check if the edge into the successor block contains a condition that only
2917 // lets the main thread execute it.
2918 static bool isInitialThreadOnlyEdge(Attributor &A, BranchInst *Edge,
2919 BasicBlock &SuccessorBB) {
2920 if (!Edge || !Edge->isConditional())
2921 return false;
2922 if (Edge->getSuccessor(0) != &SuccessorBB)
2923 return false;
2924
2925 auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2926 if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2927 return false;
2928
2929 ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2930 if (!C)
2931 return false;
2932
2933 // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
2934 if (C->isAllOnesValue()) {
2935 auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2936 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2937 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2938 CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2939 if (!CB)
2940 return false;
2941 ConstantStruct *KernelEnvC =
2943 ConstantInt *ExecModeC =
2944 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
2945 return ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC;
2946 }
2947
2948 if (C->isZero()) {
2949 // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x()
2950 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2951 if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2952 return true;
2953
2954 // Match: 0 == llvm.amdgcn.workitem.id.x()
2955 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2956 if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2957 return true;
2958 }
2959
2960 return false;
2961 };
2962
2963 /// Mapping containing information about the function for other AAs.
2964 ExecutionDomainTy InterProceduralED;
2965
2966 enum Direction { PRE = 0, POST = 1 };
2967 /// Mapping containing information per block.
2970 CEDMap;
2971 SmallSetVector<CallBase *, 16> AlignedBarriers;
2972
2974
2975 /// Set \p R to \V and report true if that changed \p R.
2976 static bool setAndRecord(bool &R, bool V) {
2977 bool Eq = (R == V);
2978 R = V;
2979 return !Eq;
2980 }
2981
2982 /// Collection of fences known to be non-no-opt. All fences not in this set
2983 /// can be assumed no-opt.
2985};
2986
2987void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions(
2988 Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED) {
2989 for (auto *EA : PredED.EncounteredAssumes)
2990 ED.addAssumeInst(A, *EA);
2991
2992 for (auto *AB : PredED.AlignedBarriers)
2993 ED.addAlignedBarrier(A, *AB);
2994}
2995
2996bool AAExecutionDomainFunction::mergeInPredecessor(
2997 Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED,
2998 bool InitialEdgeOnly) {
2999
3000 bool Changed = false;
3001 Changed |=
3002 setAndRecord(ED.IsExecutedByInitialThreadOnly,
3003 InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly &&
3004 ED.IsExecutedByInitialThreadOnly));
3005
3006 Changed |= setAndRecord(ED.IsReachedFromAlignedBarrierOnly,
3007 ED.IsReachedFromAlignedBarrierOnly &&
3008 PredED.IsReachedFromAlignedBarrierOnly);
3009 Changed |= setAndRecord(ED.EncounteredNonLocalSideEffect,
3010 ED.EncounteredNonLocalSideEffect |
3011 PredED.EncounteredNonLocalSideEffect);
3012 // Do not track assumptions and barriers as part of Changed.
3013 if (ED.IsReachedFromAlignedBarrierOnly)
3014 mergeInPredecessorBarriersAndAssumptions(A, ED, PredED);
3015 else
3016 ED.clearAssumeInstAndAlignedBarriers();
3017 return Changed;
3018}
3019
3020bool AAExecutionDomainFunction::handleCallees(Attributor &A,
3021 ExecutionDomainTy &EntryBBED) {
3023 auto PredForCallSite = [&](AbstractCallSite ACS) {
3024 const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3025 *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
3026 DepClassTy::OPTIONAL);
3027 if (!EDAA || !EDAA->getState().isValidState())
3028 return false;
3029 CallSiteEDs.emplace_back(
3030 EDAA->getExecutionDomain(*cast<CallBase>(ACS.getInstruction())));
3031 return true;
3032 };
3033
3034 ExecutionDomainTy ExitED;
3035 bool AllCallSitesKnown;
3036 if (A.checkForAllCallSites(PredForCallSite, *this,
3037 /* RequiresAllCallSites */ true,
3038 AllCallSitesKnown)) {
3039 for (const auto &[CSInED, CSOutED] : CallSiteEDs) {
3040 mergeInPredecessor(A, EntryBBED, CSInED);
3041 ExitED.IsReachingAlignedBarrierOnly &=
3042 CSOutED.IsReachingAlignedBarrierOnly;
3043 }
3044
3045 } else {
3046 // We could not find all predecessors, so this is either a kernel or a
3047 // function with external linkage (or with some other weird uses).
3048 if (omp::isOpenMPKernel(*getAnchorScope())) {
3049 EntryBBED.IsExecutedByInitialThreadOnly = false;
3050 EntryBBED.IsReachedFromAlignedBarrierOnly = true;
3051 EntryBBED.EncounteredNonLocalSideEffect = false;
3052 ExitED.IsReachingAlignedBarrierOnly = false;
3053 } else {
3054 EntryBBED.IsExecutedByInitialThreadOnly = false;
3055 EntryBBED.IsReachedFromAlignedBarrierOnly = false;
3056 EntryBBED.EncounteredNonLocalSideEffect = true;
3057 ExitED.IsReachingAlignedBarrierOnly = false;
3058 }
3059 }
3060
3061 bool Changed = false;
3062 auto &FnED = BEDMap[nullptr];
3063 Changed |= setAndRecord(FnED.IsReachedFromAlignedBarrierOnly,
3064 FnED.IsReachedFromAlignedBarrierOnly &
3065 EntryBBED.IsReachedFromAlignedBarrierOnly);
3066 Changed |= setAndRecord(FnED.IsReachingAlignedBarrierOnly,
3067 FnED.IsReachingAlignedBarrierOnly &
3068 ExitED.IsReachingAlignedBarrierOnly);
3069 Changed |= setAndRecord(FnED.IsExecutedByInitialThreadOnly,
3070 EntryBBED.IsExecutedByInitialThreadOnly);
3071 return Changed;
3072}
3073
3074ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
3075
3076 bool Changed = false;
3077
3078 // Helper to deal with an aligned barrier encountered during the forward
3079 // traversal. \p CB is the aligned barrier, \p ED is the execution domain when
3080 // it was encountered.
3081 auto HandleAlignedBarrier = [&](CallBase &CB, ExecutionDomainTy &ED) {
3082 Changed |= AlignedBarriers.insert(&CB);
3083 // First, update the barrier ED kept in the separate CEDMap.
3084 auto &CallInED = CEDMap[{&CB, PRE}];
3085 Changed |= mergeInPredecessor(A, CallInED, ED);
3086 CallInED.IsReachingAlignedBarrierOnly = true;
3087 // Next adjust the ED we use for the traversal.
3088 ED.EncounteredNonLocalSideEffect = false;
3089 ED.IsReachedFromAlignedBarrierOnly = true;
3090 // Aligned barrier collection has to come last.
3091 ED.clearAssumeInstAndAlignedBarriers();
3092 ED.addAlignedBarrier(A, CB);
3093 auto &CallOutED = CEDMap[{&CB, POST}];
3094 Changed |= mergeInPredecessor(A, CallOutED, ED);
3095 };
3096
3097 auto *LivenessAA =
3098 A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL);
3099
3100 Function *F = getAnchorScope();
3101 BasicBlock &EntryBB = F->getEntryBlock();
3102 bool IsKernel = omp::isOpenMPKernel(*F);
3103
3104 SmallVector<Instruction *> SyncInstWorklist;
3105 for (auto &RIt : *RPOT) {
3106 BasicBlock &BB = *RIt;
3107
3108 bool IsEntryBB = &BB == &EntryBB;
3109 // TODO: We use local reasoning since we don't have a divergence analysis
3110 // running as well. We could basically allow uniform branches here.
3111 bool AlignedBarrierLastInBlock = IsEntryBB && IsKernel;
3112 bool IsExplicitlyAligned = IsEntryBB && IsKernel;
3113 ExecutionDomainTy ED;
3114 // Propagate "incoming edges" into information about this block.
3115 if (IsEntryBB) {
3116 Changed |= handleCallees(A, ED);
3117 } else {
3118 // For live non-entry blocks we only propagate
3119 // information via live edges.
3120 if (LivenessAA && LivenessAA->isAssumedDead(&BB))
3121 continue;
3122
3123 for (auto *PredBB : predecessors(&BB)) {
3124 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, &BB))
3125 continue;
3126 bool InitialEdgeOnly = isInitialThreadOnlyEdge(
3127 A, dyn_cast<BranchInst>(PredBB->getTerminator()), BB);
3128 mergeInPredecessor(A, ED, BEDMap[PredBB], InitialEdgeOnly);
3129 }
3130 }
3131
3132 // Now we traverse the block, accumulate effects in ED and attach
3133 // information to calls.
3134 for (Instruction &I : BB) {
3135 bool UsedAssumedInformation;
3136 if (A.isAssumedDead(I, *this, LivenessAA, UsedAssumedInformation,
3137 /* CheckBBLivenessOnly */ false, DepClassTy::OPTIONAL,
3138 /* CheckForDeadStore */ true))
3139 continue;
3140
3141 // Asummes and "assume-like" (dbg, lifetime, ...) are handled first, the
3142 // former is collected the latter is ignored.
3143 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
3144 if (auto *AI = dyn_cast_or_null<AssumeInst>(II)) {
3145 ED.addAssumeInst(A, *AI);
3146 continue;
3147 }
3148 // TODO: Should we also collect and delete lifetime markers?
3149 if (II->isAssumeLikeIntrinsic())
3150 continue;
3151 }
3152
3153 if (auto *FI = dyn_cast<FenceInst>(&I)) {
3154 if (!ED.EncounteredNonLocalSideEffect) {
3155 // An aligned fence without non-local side-effects is a no-op.
3156 if (ED.IsReachedFromAlignedBarrierOnly)
3157 continue;
3158 // A non-aligned fence without non-local side-effects is a no-op
3159 // if the ordering only publishes non-local side-effects (or less).
3160 switch (FI->getOrdering()) {
3161 case AtomicOrdering::NotAtomic:
3162 continue;
3163 case AtomicOrdering::Unordered:
3164 continue;
3165 case AtomicOrdering::Monotonic:
3166 continue;
3167 case AtomicOrdering::Acquire:
3168 break;
3169 case AtomicOrdering::Release:
3170 continue;
3171 case AtomicOrdering::AcquireRelease:
3172 break;
3173 case AtomicOrdering::SequentiallyConsistent:
3174 break;
3175 };
3176 }
3177 NonNoOpFences.insert(FI);
3178 }
3179
3180 auto *CB = dyn_cast<CallBase>(&I);
3181 bool IsNoSync = AA::isNoSyncInst(A, I, *this);
3182 bool IsAlignedBarrier =
3183 !IsNoSync && CB &&
3184 AANoSync::isAlignedBarrier(*CB, AlignedBarrierLastInBlock);
3185
3186 AlignedBarrierLastInBlock &= IsNoSync;
3187 IsExplicitlyAligned &= IsNoSync;
3188
3189 // Next we check for calls. Aligned barriers are handled
3190 // explicitly, everything else is kept for the backward traversal and will
3191 // also affect our state.
3192 if (CB) {
3193 if (IsAlignedBarrier) {
3194 HandleAlignedBarrier(*CB, ED);
3195 AlignedBarrierLastInBlock = true;
3196 IsExplicitlyAligned = true;
3197 continue;
3198 }
3199
3200 // Check the pointer(s) of a memory intrinsic explicitly.
3201 if (isa<MemIntrinsic>(&I)) {
3202 if (!ED.EncounteredNonLocalSideEffect &&
3204 ED.EncounteredNonLocalSideEffect = true;
3205 if (!IsNoSync) {
3206 ED.IsReachedFromAlignedBarrierOnly = false;
3207 SyncInstWorklist.push_back(&I);
3208 }
3209 continue;
3210 }
3211
3212 // Record how we entered the call, then accumulate the effect of the
3213 // call in ED for potential use by the callee.
3214 auto &CallInED = CEDMap[{CB, PRE}];
3215 Changed |= mergeInPredecessor(A, CallInED, ED);
3216
3217 // If we have a sync-definition we can check if it starts/ends in an
3218 // aligned barrier. If we are unsure we assume any sync breaks
3219 // alignment.
3221 if (!IsNoSync && Callee && !Callee->isDeclaration()) {
3222 const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3223 *this, IRPosition::function(*Callee), DepClassTy::OPTIONAL);
3224 if (EDAA && EDAA->getState().isValidState()) {
3225 const auto &CalleeED = EDAA->getFunctionExecutionDomain();
3227 CalleeED.IsReachedFromAlignedBarrierOnly;
3228 AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly;
3229 if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly)
3230 ED.EncounteredNonLocalSideEffect |=
3231 CalleeED.EncounteredNonLocalSideEffect;
3232 else
3233 ED.EncounteredNonLocalSideEffect =
3234 CalleeED.EncounteredNonLocalSideEffect;
3235 if (!CalleeED.IsReachingAlignedBarrierOnly) {
3236 Changed |=
3237 setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3238 SyncInstWorklist.push_back(&I);
3239 }
3240 if (CalleeED.IsReachedFromAlignedBarrierOnly)
3241 mergeInPredecessorBarriersAndAssumptions(A, ED, CalleeED);
3242 auto &CallOutED = CEDMap[{CB, POST}];
3243 Changed |= mergeInPredecessor(A, CallOutED, ED);
3244 continue;
3245 }
3246 }
3247 if (!IsNoSync) {
3248 ED.IsReachedFromAlignedBarrierOnly = false;
3249 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3250 SyncInstWorklist.push_back(&I);
3251 }
3252 AlignedBarrierLastInBlock &= ED.IsReachedFromAlignedBarrierOnly;
3253 ED.EncounteredNonLocalSideEffect |= !CB->doesNotAccessMemory();
3254 auto &CallOutED = CEDMap[{CB, POST}];
3255 Changed |= mergeInPredecessor(A, CallOutED, ED);
3256 }
3257
3258 if (!I.mayHaveSideEffects() && !I.mayReadFromMemory())
3259 continue;
3260
3261 // If we have a callee we try to use fine-grained information to
3262 // determine local side-effects.
3263 if (CB) {
3264 const auto *MemAA = A.getAAFor<AAMemoryLocation>(
3265 *this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL);
3266
3267 auto AccessPred = [&](const Instruction *I, const Value *Ptr,
3270 return !AA::isPotentiallyAffectedByBarrier(A, {Ptr}, *this, I);
3271 };
3272 if (MemAA && MemAA->getState().isValidState() &&
3273 MemAA->checkForAllAccessesToMemoryKind(
3275 continue;
3276 }
3277
3278 auto &InfoCache = A.getInfoCache();
3279 if (!I.mayHaveSideEffects() && InfoCache.isOnlyUsedByAssume(I))
3280 continue;
3281
3282 if (auto *LI = dyn_cast<LoadInst>(&I))
3283 if (LI->hasMetadata(LLVMContext::MD_invariant_load))
3284 continue;
3285
3286 if (!ED.EncounteredNonLocalSideEffect &&
3288 ED.EncounteredNonLocalSideEffect = true;
3289 }
3290
3291 bool IsEndAndNotReachingAlignedBarriersOnly = false;
3292 if (!isa<UnreachableInst>(BB.getTerminator()) &&
3293 !BB.getTerminator()->getNumSuccessors()) {
3294
3295 Changed |= mergeInPredecessor(A, InterProceduralED, ED);
3296
3297 auto &FnED = BEDMap[nullptr];
3298 if (IsKernel && !IsExplicitlyAligned)
3299 FnED.IsReachingAlignedBarrierOnly = false;
3300 Changed |= mergeInPredecessor(A, FnED, ED);
3301
3302 if (!FnED.IsReachingAlignedBarrierOnly) {
3303 IsEndAndNotReachingAlignedBarriersOnly = true;
3304 SyncInstWorklist.push_back(BB.getTerminator());
3305 auto &BBED = BEDMap[&BB];
3306 Changed |= setAndRecord(BBED.IsReachingAlignedBarrierOnly, false);
3307 }
3308 }
3309
3310 ExecutionDomainTy &StoredED = BEDMap[&BB];
3311 ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly &
3312 !IsEndAndNotReachingAlignedBarriersOnly;
3313
3314 // Check if we computed anything different as part of the forward
3315 // traversal. We do not take assumptions and aligned barriers into account
3316 // as they do not influence the state we iterate. Backward traversal values
3317 // are handled later on.
3318 if (ED.IsExecutedByInitialThreadOnly !=
3319 StoredED.IsExecutedByInitialThreadOnly ||
3320 ED.IsReachedFromAlignedBarrierOnly !=
3321 StoredED.IsReachedFromAlignedBarrierOnly ||
3322 ED.EncounteredNonLocalSideEffect !=
3323 StoredED.EncounteredNonLocalSideEffect)
3324 Changed = true;
3325
3326 // Update the state with the new value.
3327 StoredED = std::move(ED);
3328 }
3329
3330 // Propagate (non-aligned) sync instruction effects backwards until the
3331 // entry is hit or an aligned barrier.
3333 while (!SyncInstWorklist.empty()) {
3334 Instruction *SyncInst = SyncInstWorklist.pop_back_val();
3335 Instruction *CurInst = SyncInst;
3336 bool HitAlignedBarrierOrKnownEnd = false;
3337 while ((CurInst = CurInst->getPrevNode())) {
3338 auto *CB = dyn_cast<CallBase>(CurInst);
3339 if (!CB)
3340 continue;
3341 auto &CallOutED = CEDMap[{CB, POST}];
3342 Changed |= setAndRecord(CallOutED.IsReachingAlignedBarrierOnly, false);
3343 auto &CallInED = CEDMap[{CB, PRE}];
3344 HitAlignedBarrierOrKnownEnd =
3345 AlignedBarriers.count(CB) || !CallInED.IsReachingAlignedBarrierOnly;
3346 if (HitAlignedBarrierOrKnownEnd)
3347 break;
3348 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3349 }
3350 if (HitAlignedBarrierOrKnownEnd)
3351 continue;
3352 BasicBlock *SyncBB = SyncInst->getParent();
3353 for (auto *PredBB : predecessors(SyncBB)) {
3354 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, SyncBB))
3355 continue;
3356 if (!Visited.insert(PredBB))
3357 continue;
3358 auto &PredED = BEDMap[PredBB];
3359 if (setAndRecord(PredED.IsReachingAlignedBarrierOnly, false)) {
3360 Changed = true;
3361 SyncInstWorklist.push_back(PredBB->getTerminator());
3362 }
3363 }
3364 if (SyncBB != &EntryBB)
3365 continue;
3366 Changed |=
3367 setAndRecord(InterProceduralED.IsReachingAlignedBarrierOnly, false);
3368 }
3369
3370 return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
3371}
3372
3373/// Try to replace memory allocation calls called by a single thread with a
3374/// static buffer of shared memory.
3375struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
3377 AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3378
3379 /// Create an abstract attribute view for the position \p IRP.
3380 static AAHeapToShared &createForPosition(const IRPosition &IRP,
3381 Attributor &A);
3382
3383 /// Returns true if HeapToShared conversion is assumed to be possible.
3384 virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
3385
3386 /// Returns true if HeapToShared conversion is assumed and the CB is a
3387 /// callsite to a free operation to be removed.
3388 virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
3389
3390 /// See AbstractAttribute::getName().
3391 const std::string getName() const override { return "AAHeapToShared"; }
3392
3393 /// See AbstractAttribute::getIdAddr().
3394 const char *getIdAddr() const override { return &ID; }
3395
3396 /// This function should return true if the type of the \p AA is
3397 /// AAHeapToShared.
3398 static bool classof(const AbstractAttribute *AA) {
3399 return (AA->getIdAddr() == &ID);
3400 }
3401
3402 /// Unique ID (due to the unique address)
3403 static const char ID;
3404};
3405
3406struct AAHeapToSharedFunction : public AAHeapToShared {
3407 AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
3408 : AAHeapToShared(IRP, A) {}
3409
3410 const std::string getAsStr(Attributor *) const override {
3411 return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
3412 " malloc calls eligible.";
3413 }
3414
3415 /// See AbstractAttribute::trackStatistics().
3416 void trackStatistics() const override {}
3417
3418 /// This functions finds free calls that will be removed by the
3419 /// HeapToShared transformation.
3420 void findPotentialRemovedFreeCalls(Attributor &A) {
3421 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3422 auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3423
3424 PotentialRemovedFreeCalls.clear();
3425 // Update free call users of found malloc calls.
3426 for (CallBase *CB : MallocCalls) {
3428 for (auto *U : CB->users()) {
3429 CallBase *C = dyn_cast<CallBase>(U);
3430 if (C && C->getCalledFunction() == FreeRFI.Declaration)
3431 FreeCalls.push_back(C);
3432 }
3433
3434 if (FreeCalls.size() != 1)
3435 continue;
3436
3437 PotentialRemovedFreeCalls.insert(FreeCalls.front());
3438 }
3439 }
3440
3441 void initialize(Attributor &A) override {
3443 indicatePessimisticFixpoint();
3444 return;
3445 }
3446
3447 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3448 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3449 if (!RFI.Declaration)
3450 return;
3451
3453 [](const IRPosition &, const AbstractAttribute *,
3454 bool &) -> std::optional<Value *> { return nullptr; };
3455
3456 Function *F = getAnchorScope();
3457 for (User *U : RFI.Declaration->users())
3458 if (CallBase *CB = dyn_cast<CallBase>(U)) {
3459 if (CB->getFunction() != F)
3460 continue;
3461 MallocCalls.insert(CB);
3462 A.registerSimplificationCallback(IRPosition::callsite_returned(*CB),
3463 SCB);
3464 }
3465
3466 findPotentialRemovedFreeCalls(A);
3467 }
3468
3469 bool isAssumedHeapToShared(CallBase &CB) const override {
3470 return isValidState() && MallocCalls.count(&CB);
3471 }
3472
3473 bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
3474 return isValidState() && PotentialRemovedFreeCalls.count(&CB);
3475 }
3476
3477 ChangeStatus manifest(Attributor &A) override {
3478 if (MallocCalls.empty())
3479 return ChangeStatus::UNCHANGED;
3480
3481 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3482 auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3483
3484 Function *F = getAnchorScope();
3485 auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
3486 DepClassTy::OPTIONAL);
3487
3488 ChangeStatus Changed = ChangeStatus::UNCHANGED;
3489 for (CallBase *CB : MallocCalls) {
3490 // Skip replacing this if HeapToStack has already claimed it.
3491 if (HS && HS->isAssumedHeapToStack(*CB))
3492 continue;
3493
3494 // Find the unique free call to remove it.
3496 for (auto *U : CB->users()) {
3497 CallBase *C = dyn_cast<CallBase>(U);
3498 if (C && C->getCalledFunction() == FreeCall.Declaration)
3499 FreeCalls.push_back(C);
3500 }
3501 if (FreeCalls.size() != 1)
3502 continue;
3503
3504 auto *AllocSize = cast<ConstantInt>(CB->getArgOperand(0));
3505
3506 if (AllocSize->getZExtValue() + SharedMemoryUsed > SharedMemoryLimit) {
3507 LLVM_DEBUG(dbgs() << TAG << "Cannot replace call " << *CB
3508 << " with shared memory."
3509 << " Shared memory usage is limited to "
3510 << SharedMemoryLimit << " bytes\n");
3511 continue;
3512 }
3513
3514 LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
3515 << " with " << AllocSize->getZExtValue()
3516 << " bytes of shared memory\n");
3517
3518 // Create a new shared memory buffer of the same size as the allocation
3519 // and replace all the uses of the original allocation with it.
3520 Module *M = CB->getModule();
3521 Type *Int8Ty = Type::getInt8Ty(M->getContext());
3522 Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
3523 auto *SharedMem = new GlobalVariable(
3524 *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
3525 PoisonValue::get(Int8ArrTy), CB->getName() + "_shared", nullptr,
3527 static_cast<unsigned>(AddressSpace::Shared));
3528 auto *NewBuffer = ConstantExpr::getPointerCast(
3529 SharedMem, PointerType::getUnqual(M->getContext()));
3530
3531 auto Remark = [&](OptimizationRemark OR) {
3532 return OR << "Replaced globalized variable with "
3533 << ore::NV("SharedMemory", AllocSize->getZExtValue())
3534 << (AllocSize->isOne() ? " byte " : " bytes ")
3535 << "of shared memory.";
3536 };
3537 A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
3538
3539 MaybeAlign Alignment = CB->getRetAlign();
3540 assert(Alignment &&
3541 "HeapToShared on allocation without alignment attribute");
3542 SharedMem->setAlignment(*Alignment);
3543
3544 A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewBuffer);
3545 A.deleteAfterManifest(*CB);
3546 A.deleteAfterManifest(*FreeCalls.front());
3547
3548 SharedMemoryUsed += AllocSize->getZExtValue();
3549 NumBytesMovedToSharedMemory = SharedMemoryUsed;
3550 Changed = ChangeStatus::CHANGED;
3551 }
3552
3553 return Changed;
3554 }
3555
3556 ChangeStatus updateImpl(Attributor &A) override {
3557 if (MallocCalls.empty())
3558 return indicatePessimisticFixpoint();
3559 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3560 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3561 if (!RFI.Declaration)
3562 return ChangeStatus::UNCHANGED;
3563
3564 Function *F = getAnchorScope();
3565
3566 auto NumMallocCalls = MallocCalls.size();
3567
3568 // Only consider malloc calls executed by a single thread with a constant.
3569 for (User *U : RFI.Declaration->users()) {
3570 if (CallBase *CB = dyn_cast<CallBase>(U)) {
3571 if (CB->getCaller() != F)
3572 continue;
3573 if (!MallocCalls.count(CB))
3574 continue;
3575 if (!isa<ConstantInt>(CB->getArgOperand(0))) {
3576 MallocCalls.remove(CB);
3577 continue;
3578 }
3579 const auto *ED = A.getAAFor<AAExecutionDomain>(
3580 *this, IRPosition::function(*F), DepClassTy::REQUIRED);
3581 if (!ED || !ED->isExecutedByInitialThreadOnly(*CB))
3582 MallocCalls.remove(CB);
3583 }
3584 }
3585
3586 findPotentialRemovedFreeCalls(A);
3587
3588 if (NumMallocCalls != MallocCalls.size())
3589 return ChangeStatus::CHANGED;
3590
3591 return ChangeStatus::UNCHANGED;
3592 }
3593
3594 /// Collection of all malloc calls in a function.
3596 /// Collection of potentially removed free calls in a function.
3597 SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
3598 /// The total amount of shared memory that has been used for HeapToShared.
3599 unsigned SharedMemoryUsed = 0;
3600};
3601
3602struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
3604 AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3605
3606 /// The callee value is tracked beyond a simple stripPointerCasts, so we allow
3607 /// unknown callees.
3608 static bool requiresCalleeForCallBase() { return false; }
3609
3610 /// Statistics are tracked as part of manifest for now.
3611 void trackStatistics() const override {}
3612
3613 /// See AbstractAttribute::getAsStr()
3614 const std::string getAsStr(Attributor *) const override {
3615 if (!isValidState())
3616 return "<invalid>";
3617 return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
3618 : "generic") +
3619 std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
3620 : "") +
3621 std::string(" #PRs: ") +
3622 (ReachedKnownParallelRegions.isValidState()
3623 ? std::to_string(ReachedKnownParallelRegions.size())
3624 : "<invalid>") +
3625 ", #Unknown PRs: " +
3626 (ReachedUnknownParallelRegions.isValidState()
3627 ? std::to_string(ReachedUnknownParallelRegions.size())
3628 : "<invalid>") +
3629 ", #Reaching Kernels: " +
3630 (ReachingKernelEntries.isValidState()
3631 ? std::to_string(ReachingKernelEntries.size())
3632 : "<invalid>") +
3633 ", #ParLevels: " +
3634 (ParallelLevels.isValidState()
3635 ? std::to_string(ParallelLevels.size())
3636 : "<invalid>") +
3637 ", NestedPar: " + (NestedParallelism ? "yes" : "no");
3638 }
3639
3640 /// Create an abstract attribute biew for the position \p IRP.
3641 static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
3642
3643 /// See AbstractAttribute::getName()
3644 const std::string getName() const override { return "AAKernelInfo"; }
3645
3646 /// See AbstractAttribute::getIdAddr()
3647 const char *getIdAddr() const override { return &ID; }
3648
3649 /// This function should return true if the type of the \p AA is AAKernelInfo
3650 static bool classof(const AbstractAttribute *AA) {
3651 return (AA->getIdAddr() == &ID);
3652 }
3653
3654 static const char ID;
3655};
3656
3657/// The function kernel info abstract attribute, basically, what can we say
3658/// about a function with regards to the KernelInfoState.
3659struct AAKernelInfoFunction : AAKernelInfo {
3660 AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
3661 : AAKernelInfo(IRP, A) {}
3662
3663 SmallPtrSet<Instruction *, 4> GuardedInstructions;
3664
3665 SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
3666 return GuardedInstructions;
3667 }
3668
3669 void setConfigurationOfKernelEnvironment(ConstantStruct *ConfigC) {
3671 KernelEnvC, ConfigC, {KernelInfo::ConfigurationIdx});
3672 assert(NewKernelEnvC && "Failed to create new kernel environment");
3673 KernelEnvC = cast<ConstantStruct>(NewKernelEnvC);
3674 }
3675
3676#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER) \
3677 void set##MEMBER##OfKernelEnvironment(ConstantInt *NewVal) { \
3678 ConstantStruct *ConfigC = \
3679 KernelInfo::getConfigurationFromKernelEnvironment(KernelEnvC); \
3680 Constant *NewConfigC = ConstantFoldInsertValueInstruction( \
3681 ConfigC, NewVal, {KernelInfo::MEMBER##Idx}); \
3682 assert(NewConfigC && "Failed to create new configuration environment"); \
3683 setConfigurationOfKernelEnvironment(cast<ConstantStruct>(NewConfigC)); \
3684 }
3685
3686 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(UseGenericStateMachine)
3687 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MayUseNestedParallelism)
3693
3694#undef KERNEL_ENVIRONMENT_CONFIGURATION_SETTER
3695
3696 /// See AbstractAttribute::initialize(...).
3697 void initialize(Attributor &A) override {
3698 // This is a high-level transform that might change the constant arguments
3699 // of the init and dinit calls. We need to tell the Attributor about this
3700 // to avoid other parts using the current constant value for simpliication.
3701 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3702
3703 Function *Fn = getAnchorScope();
3704
3705 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
3706 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
3707 OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
3708 OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
3709
3710 // For kernels we perform more initialization work, first we find the init
3711 // and deinit calls.
3712 auto StoreCallBase = [](Use &U,
3713 OMPInformationCache::RuntimeFunctionInfo &RFI,
3714 CallBase *&Storage) {
3715 CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
3716 assert(CB &&
3717 "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
3718 assert(!Storage &&
3719 "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
3720 Storage = CB;
3721 return false;
3722 };
3723 InitRFI.foreachUse(
3724 [&](Use &U, Function &) {
3725 StoreCallBase(U, InitRFI, KernelInitCB);
3726 return false;
3727 },
3728 Fn);
3729 DeinitRFI.foreachUse(
3730 [&](Use &U, Function &) {
3731 StoreCallBase(U, DeinitRFI, KernelDeinitCB);
3732 return false;
3733 },
3734 Fn);
3735
3736 // Ignore kernels without initializers such as global constructors.
3737 if (!KernelInitCB || !KernelDeinitCB)
3738 return;
3739
3740 // Add itself to the reaching kernel and set IsKernelEntry.
3741 ReachingKernelEntries.insert(Fn);
3742 IsKernelEntry = true;
3743
3744 KernelEnvC =
3746 GlobalVariable *KernelEnvGV =
3748
3750 KernelConfigurationSimplifyCB =
3751 [&](const GlobalVariable &GV, const AbstractAttribute *AA,
3752 bool &UsedAssumedInformation) -> std::optional<Constant *> {
3753 if (!isAtFixpoint()) {
3754 if (!AA)
3755 return nullptr;
3756 UsedAssumedInformation = true;
3757 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3758 }
3759 return KernelEnvC;
3760 };
3761
3762 A.registerGlobalVariableSimplificationCallback(
3763 *KernelEnvGV, KernelConfigurationSimplifyCB);
3764
3765 // We cannot change to SPMD mode if the runtime functions aren't availible.
3766 bool CanChangeToSPMD = OMPInfoCache.runtimeFnsAvailable(
3767 {OMPRTL___kmpc_get_hardware_thread_id_in_block,
3768 OMPRTL___kmpc_barrier_simple_spmd});
3769
3770 // Check if we know we are in SPMD-mode already.
3771 ConstantInt *ExecModeC =
3772 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
3773 ConstantInt *AssumedExecModeC = ConstantInt::get(
3774 ExecModeC->getIntegerType(),
3776 if (ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)
3777 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3778 else if (DisableOpenMPOptSPMDization || !CanChangeToSPMD)
3779 // This is a generic region but SPMDization is disabled so stop
3780 // tracking.
3781 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3782 else
3783 setExecModeOfKernelEnvironment(AssumedExecModeC);
3784
3785 const Triple T(Fn->getParent()->getTargetTriple());
3786 auto *Int32Ty = Type::getInt32Ty(Fn->getContext());
3787 auto [MinThreads, MaxThreads] =
3789 if (MinThreads)
3790 setMinThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinThreads));
3791 if (MaxThreads)
3792 setMaxThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxThreads));
3793 auto [MinTeams, MaxTeams] =
3795 if (MinTeams)
3796 setMinTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinTeams));
3797 if (MaxTeams)
3798 setMaxTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxTeams));
3799
3800 ConstantInt *MayUseNestedParallelismC =
3801 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(KernelEnvC);
3802 ConstantInt *AssumedMayUseNestedParallelismC = ConstantInt::get(
3803 MayUseNestedParallelismC->getIntegerType(), NestedParallelism);
3804 setMayUseNestedParallelismOfKernelEnvironment(
3805 AssumedMayUseNestedParallelismC);
3806
3808 ConstantInt *UseGenericStateMachineC =
3809 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3810 KernelEnvC);
3811 ConstantInt *AssumedUseGenericStateMachineC =
3812 ConstantInt::get(UseGenericStateMachineC->getIntegerType(), false);
3813 setUseGenericStateMachineOfKernelEnvironment(
3814 AssumedUseGenericStateMachineC);
3815 }
3816
3817 // Register virtual uses of functions we might need to preserve.
3818 auto RegisterVirtualUse = [&](RuntimeFunction RFKind,
3820 if (!OMPInfoCache.RFIs[RFKind].Declaration)
3821 return;
3822 A.registerVirtualUseCallback(*OMPInfoCache.RFIs[RFKind].Declaration, CB);
3823 };
3824
3825 // Add a dependence to ensure updates if the state changes.
3826 auto AddDependence = [](Attributor &A, const AAKernelInfo *KI,
3827 const AbstractAttribute *QueryingAA) {
3828 if (QueryingAA) {
3829 A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);
3830 }
3831 return true;
3832 };
3833
3834 Attributor::VirtualUseCallbackTy CustomStateMachineUseCB =
3835 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3836 // Whenever we create a custom state machine we will insert calls to
3837 // __kmpc_get_hardware_num_threads_in_block,
3838 // __kmpc_get_warp_size,
3839 // __kmpc_barrier_simple_generic,
3840 // __kmpc_kernel_parallel, and
3841 // __kmpc_kernel_end_parallel.
3842 // Not needed if we are on track for SPMDzation.
3843 if (SPMDCompatibilityTracker.isValidState())
3844 return AddDependence(A, this, QueryingAA);
3845 // Not needed if we can't rewrite due to an invalid state.
3846 if (!ReachedKnownParallelRegions.isValidState())
3847 return AddDependence(A, this, QueryingAA);
3848 return false;
3849 };
3850
3851 // Not needed if we are pre-runtime merge.
3852 if (!KernelInitCB->getCalledFunction()->isDeclaration()) {
3853 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_num_threads_in_block,
3854 CustomStateMachineUseCB);
3855 RegisterVirtualUse(OMPRTL___kmpc_get_warp_size, CustomStateMachineUseCB);
3856 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_generic,
3857 CustomStateMachineUseCB);
3858 RegisterVirtualUse(OMPRTL___kmpc_kernel_parallel,
3859 CustomStateMachineUseCB);
3860 RegisterVirtualUse(OMPRTL___kmpc_kernel_end_parallel,
3861 CustomStateMachineUseCB);
3862 }
3863
3864 // If we do not perform SPMDzation we do not need the virtual uses below.
3865 if (SPMDCompatibilityTracker.isAtFixpoint())
3866 return;
3867
3868 Attributor::VirtualUseCallbackTy HWThreadIdUseCB =
3869 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3870 // Whenever we perform SPMDzation we will insert
3871 // __kmpc_get_hardware_thread_id_in_block calls.
3872 if (!SPMDCompatibilityTracker.isValidState())
3873 return AddDependence(A, this, QueryingAA);
3874 return false;
3875 };
3876 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_thread_id_in_block,
3877 HWThreadIdUseCB);
3878
3879 Attributor::VirtualUseCallbackTy SPMDBarrierUseCB =
3880 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3881 // Whenever we perform SPMDzation with guarding we will insert
3882 // __kmpc_simple_barrier_spmd calls. If SPMDzation failed, there is
3883 // nothing to guard, or there are no parallel regions, we don't need
3884 // the calls.
3885 if (!SPMDCompatibilityTracker.isValidState())
3886 return AddDependence(A, this, QueryingAA);
3887 if (SPMDCompatibilityTracker.empty())
3888 return AddDependence(A, this, QueryingAA);
3889 if (!mayContainParallelRegion())
3890 return AddDependence(A, this, QueryingAA);
3891 return false;
3892 };
3893 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_spmd, SPMDBarrierUseCB);
3894 }
3895
3896 /// Sanitize the string \p S such that it is a suitable global symbol name.
3897 static std::string sanitizeForGlobalName(std::string S) {
3898 std::replace_if(
3899 S.begin(), S.end(),
3900 [](const char C) {
3901 return !((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') ||
3902 (C >= '0' && C <= '9') || C == '_');
3903 },
3904 '.');
3905 return S;
3906 }
3907
3908 /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
3909 /// finished now.
3910 ChangeStatus manifest(Attributor &A) override {
3911 // If we are not looking at a kernel with __kmpc_target_init and
3912 // __kmpc_target_deinit call we cannot actually manifest the information.
3913 if (!KernelInitCB || !KernelDeinitCB)
3914 return ChangeStatus::UNCHANGED;
3915
3916 ChangeStatus Changed = ChangeStatus::UNCHANGED;
3917
3918 bool HasBuiltStateMachine = true;
3919 if (!changeToSPMDMode(A, Changed)) {
3920 if (!KernelInitCB->getCalledFunction()->isDeclaration())
3921 HasBuiltStateMachine = buildCustomStateMachine(A, Changed);
3922 else
3923 HasBuiltStateMachine = false;
3924 }
3925
3926 // We need to reset KernelEnvC if specific rewriting is not done.
3927 ConstantStruct *ExistingKernelEnvC =
3929 ConstantInt *OldUseGenericStateMachineVal =
3930 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3931 ExistingKernelEnvC);
3932 if (!HasBuiltStateMachine)
3933 setUseGenericStateMachineOfKernelEnvironment(
3934 OldUseGenericStateMachineVal);
3935
3936 // At last, update the KernelEnvc
3937 GlobalVariable *KernelEnvGV =
3939 if (KernelEnvGV->getInitializer() != KernelEnvC) {
3940 KernelEnvGV->setInitializer(KernelEnvC);
3941 Changed = ChangeStatus::CHANGED;
3942 }
3943
3944 return Changed;
3945 }
3946
3947 void insertInstructionGuardsHelper(Attributor &A) {
3948 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3949
3950 auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3951 Instruction *RegionEndI) {
3952 LoopInfo *LI = nullptr;
3953 DominatorTree *DT = nullptr;
3954 MemorySSAUpdater *MSU = nullptr;
3955 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3956
3957 BasicBlock *ParentBB = RegionStartI->getParent();
3958 Function *Fn = ParentBB->getParent();
3959 Module &M = *Fn->getParent();
3960
3961 // Create all the blocks and logic.
3962 // ParentBB:
3963 // goto RegionCheckTidBB
3964 // RegionCheckTidBB:
3965 // Tid = __kmpc_hardware_thread_id()
3966 // if (Tid != 0)
3967 // goto RegionBarrierBB
3968 // RegionStartBB:
3969 // <execute instructions guarded>
3970 // goto RegionEndBB
3971 // RegionEndBB:
3972 // <store escaping values to shared mem>
3973 // goto RegionBarrierBB
3974 // RegionBarrierBB:
3975 // __kmpc_simple_barrier_spmd()
3976 // // second barrier is omitted if lacking escaping values.
3977 // <load escaping values from shared mem>
3978 // __kmpc_simple_barrier_spmd()
3979 // goto RegionExitBB
3980 // RegionExitBB:
3981 // <execute rest of instructions>
3982
3983 BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3984 DT, LI, MSU, "region.guarded.end");
3985 BasicBlock *RegionBarrierBB =
3986 SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3987 MSU, "region.barrier");
3988 BasicBlock *RegionExitBB =
3989 SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3990 DT, LI, MSU, "region.exit");
3991 BasicBlock *RegionStartBB =
3992 SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
3993
3994 assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
3995 "Expected a different CFG");
3996
3997 BasicBlock *RegionCheckTidBB = SplitBlock(
3998 ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
3999
4000 // Register basic blocks with the Attributor.
4001 A.registerManifestAddedBasicBlock(*RegionEndBB);
4002 A.registerManifestAddedBasicBlock(*RegionBarrierBB);
4003 A.registerManifestAddedBasicBlock(*RegionExitBB);
4004 A.registerManifestAddedBasicBlock(*RegionStartBB);
4005 A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
4006
4007 bool HasBroadcastValues = false;
4008 // Find escaping outputs from the guarded region to outside users and
4009 // broadcast their values to them.
4010 for (Instruction &I : *RegionStartBB) {
4011 SmallVector<Use *, 4> OutsideUses;
4012 for (Use &U : I.uses()) {
4013 Instruction &UsrI = *cast<Instruction>(U.getUser());
4014 if (UsrI.getParent() != RegionStartBB)
4015 OutsideUses.push_back(&U);
4016 }
4017
4018 if (OutsideUses.empty())
4019 continue;
4020
4021 HasBroadcastValues = true;
4022
4023 // Emit a global variable in shared memory to store the broadcasted
4024 // value.
4025 auto *SharedMem = new GlobalVariable(
4026 M, I.getType(), /* IsConstant */ false,
4028 sanitizeForGlobalName(
4029 (I.getName() + ".guarded.output.alloc").str()),
4031 static_cast<unsigned>(AddressSpace::Shared));
4032
4033 // Emit a store instruction to update the value.
4034 new StoreInst(&I, SharedMem,
4035 RegionEndBB->getTerminator()->getIterator());
4036
4037 LoadInst *LoadI = new LoadInst(
4038 I.getType(), SharedMem, I.getName() + ".guarded.output.load",
4039 RegionBarrierBB->getTerminator()->getIterator());
4040
4041 // Emit a load instruction and replace uses of the output value.
4042 for (Use *U : OutsideUses)
4043 A.changeUseAfterManifest(*U, *LoadI);
4044 }
4045
4046 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4047
4048 // Go to tid check BB in ParentBB.
4049 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
4050 ParentBB->getTerminator()->eraseFromParent();
4052 InsertPointTy(ParentBB, ParentBB->end()), DL);
4053 OMPInfoCache.OMPBuilder.updateToLocation(Loc);
4054 uint32_t SrcLocStrSize;
4055 auto *SrcLocStr =
4056 OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4057 Value *Ident =
4058 OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4059 BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
4060
4061 // Add check for Tid in RegionCheckTidBB
4062 RegionCheckTidBB->getTerminator()->eraseFromParent();
4063 OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
4064 InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
4065 OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
4066 FunctionCallee HardwareTidFn =
4067 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4068 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4069 CallInst *Tid =
4070 OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
4071 Tid->setDebugLoc(DL);
4072 OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);
4073 Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
4074 OMPInfoCache.OMPBuilder.Builder
4075 .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
4076 ->setDebugLoc(DL);
4077
4078 // First barrier for synchronization, ensures main thread has updated
4079 // values.
4080 FunctionCallee BarrierFn =
4081 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4082 M, OMPRTL___kmpc_barrier_simple_spmd);
4083 OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
4084 RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
4085 CallInst *Barrier =
4086 OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});
4087 Barrier->setDebugLoc(DL);
4088 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4089
4090 // Second barrier ensures workers have read broadcast values.
4091 if (HasBroadcastValues) {
4092 CallInst *Barrier =
4093 CallInst::Create(BarrierFn, {Ident, Tid}, "",
4094 RegionBarrierBB->getTerminator()->getIterator());
4095 Barrier->setDebugLoc(DL);
4096 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4097 }
4098 };
4099
4100 auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4102 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4103 BasicBlock *BB = GuardedI->getParent();
4104 if (!Visited.insert(BB).second)
4105 continue;
4106
4108 Instruction *LastEffect = nullptr;
4109 BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();
4110 while (++IP != IPEnd) {
4111 if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
4112 continue;
4113 Instruction *I = &*IP;
4114 if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI))
4115 continue;
4116 if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) {
4117 LastEffect = nullptr;
4118 continue;
4119 }
4120 if (LastEffect)
4121 Reorders.push_back({I, LastEffect});
4122 LastEffect = &*IP;
4123 }
4124 for (auto &Reorder : Reorders)
4125 Reorder.first->moveBefore(Reorder.second);
4126 }
4127
4129
4130 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4131 BasicBlock *BB = GuardedI->getParent();
4132 auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
4133 IRPosition::function(*GuardedI->getFunction()), nullptr,
4134 DepClassTy::NONE);
4135 assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
4136 auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
4137 // Continue if instruction is already guarded.
4138 if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
4139 continue;
4140
4141 Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
4142 for (Instruction &I : *BB) {
4143 // If instruction I needs to be guarded update the guarded region
4144 // bounds.
4145 if (SPMDCompatibilityTracker.contains(&I)) {
4146 CalleeAAFunction.getGuardedInstructions().insert(&I);
4147 if (GuardedRegionStart)
4148 GuardedRegionEnd = &I;
4149 else
4150 GuardedRegionStart = GuardedRegionEnd = &I;
4151
4152 continue;
4153 }
4154
4155 // Instruction I does not need guarding, store
4156 // any region found and reset bounds.
4157 if (GuardedRegionStart) {
4158 GuardedRegions.push_back(
4159 std::make_pair(GuardedRegionStart, GuardedRegionEnd));
4160 GuardedRegionStart = nullptr;
4161 GuardedRegionEnd = nullptr;
4162 }
4163 }
4164 }
4165
4166 for (auto &GR : GuardedRegions)
4167 CreateGuardedRegion(GR.first, GR.second);
4168 }
4169
4170 void forceSingleThreadPerWorkgroupHelper(Attributor &A) {
4171 // Only allow 1 thread per workgroup to continue executing the user code.
4172 //
4173 // InitCB = __kmpc_target_init(...)
4174 // ThreadIdInBlock = __kmpc_get_hardware_thread_id_in_block();
4175 // if (ThreadIdInBlock != 0) return;
4176 // UserCode:
4177 // // user code
4178 //
4179 auto &Ctx = getAnchorValue().getContext();
4180 Function *Kernel = getAssociatedFunction();
4181 assert(Kernel && "Expected an associated function!");
4182
4183 // Create block for user code to branch to from initial block.
4184 BasicBlock *InitBB = KernelInitCB->getParent();
4185 BasicBlock *UserCodeBB = InitBB->splitBasicBlock(
4186 KernelInitCB->getNextNode(), "main.thread.user_code");
4187 BasicBlock *ReturnBB =
4188 BasicBlock::Create(Ctx, "exit.threads", Kernel, UserCodeBB);
4189
4190 // Register blocks with attributor:
4191 A.registerManifestAddedBasicBlock(*InitBB);
4192 A.registerManifestAddedBasicBlock(*UserCodeBB);
4193 A.registerManifestAddedBasicBlock(*ReturnBB);
4194
4195 // Debug location:
4196 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4197 ReturnInst::Create(Ctx, ReturnBB)->setDebugLoc(DLoc);
4198 InitBB->getTerminator()->eraseFromParent();
4199
4200 // Prepare call to OMPRTL___kmpc_get_hardware_thread_id_in_block.
4201 Module &M = *Kernel->getParent();
4202 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4203 FunctionCallee ThreadIdInBlockFn =
4204 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4205 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4206
4207 // Get thread ID in block.
4208 CallInst *ThreadIdInBlock =
4209 CallInst::Create(ThreadIdInBlockFn, "thread_id.in.block", InitBB);
4210 OMPInfoCache.setCallingConvention(ThreadIdInBlockFn, ThreadIdInBlock);
4211 ThreadIdInBlock->setDebugLoc(DLoc);
4212
4213 // Eliminate all threads in the block with ID not equal to 0:
4214 Instruction *IsMainThread =
4215 ICmpInst::Create(ICmpInst::ICmp, CmpInst::ICMP_NE, ThreadIdInBlock,
4216 ConstantInt::get(ThreadIdInBlock->getType(), 0),
4217 "thread.is_main", InitBB);
4218 IsMainThread->setDebugLoc(DLoc);
4219 BranchInst::Create(ReturnBB, UserCodeBB, IsMainThread, InitBB);
4220 }
4221
4222 bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {
4223 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4224
4225 if (!SPMDCompatibilityTracker.isAssumed()) {
4226 for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
4227 if (!NonCompatibleI)
4228 continue;
4229
4230 // Skip diagnostics on calls to known OpenMP runtime functions for now.
4231 if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
4232 if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
4233 continue;
4234
4235 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4236 ORA << "Value has potential side effects preventing SPMD-mode "
4237 "execution";
4238 if (isa<CallBase>(NonCompatibleI)) {
4239 ORA << ". Add `[[omp::assume(\"ompx_spmd_amenable\")]]` to "
4240 "the called function to override";
4241 }
4242 return ORA << ".";
4243 };
4244 A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
4245 Remark);
4246
4247 LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
4248 << *NonCompatibleI << "\n");
4249 }
4250
4251 return false;
4252 }
4253
4254 // Get the actual kernel, could be the caller of the anchor scope if we have
4255 // a debug wrapper.
4256 Function *Kernel = getAnchorScope();
4257 if (Kernel->hasLocalLinkage()) {
4258 assert(Kernel->hasOneUse() && "Unexpected use of debug kernel wrapper.");
4259 auto *CB = cast<CallBase>(Kernel->user_back());
4260 Kernel = CB->getCaller();
4261 }
4262 assert(omp::isOpenMPKernel(*Kernel) && "Expected kernel function!");
4263
4264 // Check if the kernel is already in SPMD mode, if so, return success.
4265 ConstantStruct *ExistingKernelEnvC =
4267 auto *ExecModeC =
4268 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4269 const int8_t ExecModeVal = ExecModeC->getSExtValue();
4270 if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC)
4271 return true;
4272
4273 // We will now unconditionally modify the IR, indicate a change.
4274 Changed = ChangeStatus::CHANGED;
4275
4276 // Do not use instruction guards when no parallel is present inside
4277 // the target region.
4278 if (mayContainParallelRegion())
4279 insertInstructionGuardsHelper(A);
4280 else
4281 forceSingleThreadPerWorkgroupHelper(A);
4282
4283 // Adjust the global exec mode flag that tells the runtime what mode this
4284 // kernel is executed in.
4285 assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
4286 "Initially non-SPMD kernel has SPMD exec mode!");
4287 setExecModeOfKernelEnvironment(
4288 ConstantInt::get(ExecModeC->getIntegerType(),
4289 ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
4290
4291 ++NumOpenMPTargetRegionKernelsSPMD;
4292
4293 auto Remark = [&](OptimizationRemark OR) {
4294 return OR << "Transformed generic-mode kernel to SPMD-mode.";
4295 };
4296 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
4297 return true;
4298 };
4299
4300 bool buildCustomStateMachine(Attributor &A, ChangeStatus &Changed) {
4301 // If we have disabled state machine rewrites, don't make a custom one
4303 return false;
4304
4305 // Don't rewrite the state machine if we are not in a valid state.
4306 if (!ReachedKnownParallelRegions.isValidState())
4307 return false;
4308
4309 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4310 if (!OMPInfoCache.runtimeFnsAvailable(
4311 {OMPRTL___kmpc_get_hardware_num_threads_in_block,
4312 OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
4313 OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
4314 return false;
4315
4316 ConstantStruct *ExistingKernelEnvC =
4318
4319 // Check if the current configuration is non-SPMD and generic state machine.
4320 // If we already have SPMD mode or a custom state machine we do not need to
4321 // go any further. If it is anything but a constant something is weird and
4322 // we give up.
4323 ConstantInt *UseStateMachineC =
4324 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4325 ExistingKernelEnvC);
4326 ConstantInt *ModeC =
4327 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4328
4329 // If we are stuck with generic mode, try to create a custom device (=GPU)
4330 // state machine which is specialized for the parallel regions that are
4331 // reachable by the kernel.
4332 if (UseStateMachineC->isZero() ||
4334 return false;
4335
4336 Changed = ChangeStatus::CHANGED;
4337
4338 // If not SPMD mode, indicate we use a custom state machine now.
4339 setUseGenericStateMachineOfKernelEnvironment(
4340 ConstantInt::get(UseStateMachineC->getIntegerType(), false));
4341
4342 // If we don't actually need a state machine we are done here. This can
4343 // happen if there simply are no parallel regions. In the resulting kernel
4344 // all worker threads will simply exit right away, leaving the main thread
4345 // to do the work alone.
4346 if (!mayContainParallelRegion()) {
4347 ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
4348
4349 auto Remark = [&](OptimizationRemark OR) {
4350 return OR << "Removing unused state machine from generic-mode kernel.";
4351 };
4352 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
4353
4354 return true;
4355 }
4356
4357 // Keep track in the statistics of our new shiny custom state machine.
4358 if (ReachedUnknownParallelRegions.empty()) {
4359 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
4360
4361 auto Remark = [&](OptimizationRemark OR) {
4362 return OR << "Rewriting generic-mode kernel with a customized state "
4363 "machine.";
4364 };
4365 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
4366 } else {
4367 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
4368
4370 return OR << "Generic-mode kernel is executed with a customized state "
4371 "machine that requires a fallback.";
4372 };
4373 A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
4374
4375 // Tell the user why we ended up with a fallback.
4376 for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
4377 if (!UnknownParallelRegionCB)
4378 continue;
4379 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4380 return ORA << "Call may contain unknown parallel regions. Use "
4381 << "`[[omp::assume(\"omp_no_parallelism\")]]` to "
4382 "override.";
4383 };
4384 A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
4385 "OMP133", Remark);
4386 }
4387 }
4388
4389 // Create all the blocks:
4390 //
4391 // InitCB = __kmpc_target_init(...)
4392 // BlockHwSize =
4393 // __kmpc_get_hardware_num_threads_in_block();
4394 // WarpSize = __kmpc_get_warp_size();
4395 // BlockSize = BlockHwSize - WarpSize;
4396 // IsWorkerCheckBB: bool IsWorker = InitCB != -1;
4397 // if (IsWorker) {
4398 // if (InitCB >= BlockSize) return;
4399 // SMBeginBB: __kmpc_barrier_simple_generic(...);
4400 // void *WorkFn;
4401 // bool Active = __kmpc_kernel_parallel(&WorkFn);
4402 // if (!WorkFn) return;
4403 // SMIsActiveCheckBB: if (Active) {
4404 // SMIfCascadeCurrentBB: if (WorkFn == <ParFn0>)
4405 // ParFn0(...);
4406 // SMIfCascadeCurrentBB: else if (WorkFn == <ParFn1>)
4407 // ParFn1(...);
4408 // ...
4409 // SMIfCascadeCurrentBB: else
4410 // ((WorkFnTy*)WorkFn)(...);
4411 // SMEndParallelBB: __kmpc_kernel_end_parallel(...);
4412 // }
4413 // SMDoneBB: __kmpc_barrier_simple_generic(...);
4414 // goto SMBeginBB;
4415 // }
4416 // UserCodeEntryBB: // user code
4417 // __kmpc_target_deinit(...)
4418 //
4419 auto &Ctx = getAnchorValue().getContext();
4420 Function *Kernel = getAssociatedFunction();
4421 assert(Kernel && "Expected an associated function!");
4422
4423 BasicBlock *InitBB = KernelInitCB->getParent();
4424 BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
4425 KernelInitCB->getNextNode(), "thread.user_code.check");
4426 BasicBlock *IsWorkerCheckBB =
4427 BasicBlock::Create(Ctx, "is_worker_check", Kernel, UserCodeEntryBB);
4428 BasicBlock *StateMachineBeginBB = BasicBlock::Create(
4429 Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
4430 BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
4431 Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
4432 BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
4433 Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
4434 BasicBlock *StateMachineIfCascadeCurrentBB =
4435 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4436 Kernel, UserCodeEntryBB);
4437 BasicBlock *StateMachineEndParallelBB =
4438 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
4439 Kernel, UserCodeEntryBB);
4440 BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
4441 Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
4442 A.registerManifestAddedBasicBlock(*InitBB);
4443 A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
4444 A.registerManifestAddedBasicBlock(*IsWorkerCheckBB);
4445 A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
4446 A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
4447 A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
4448 A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
4449 A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
4450 A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
4451
4452 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4453 ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
4454 InitBB->getTerminator()->eraseFromParent();
4455
4456 Instruction *IsWorker =
4457 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
4458 ConstantInt::get(KernelInitCB->getType(), -1),
4459 "thread.is_worker", InitBB);
4460 IsWorker->setDebugLoc(DLoc);
4461 BranchInst::Create(IsWorkerCheckBB, UserCodeEntryBB, IsWorker, InitBB);
4462
4463 Module &M = *Kernel->getParent();
4464 FunctionCallee BlockHwSizeFn =
4465 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4466 M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
4467 FunctionCallee WarpSizeFn =
4468 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4469 M, OMPRTL___kmpc_get_warp_size);
4470 CallInst *BlockHwSize =
4471 CallInst::Create(BlockHwSizeFn, "block.hw_size", IsWorkerCheckBB);
4472 OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);
4473 BlockHwSize->setDebugLoc(DLoc);
4474 CallInst *WarpSize =
4475 CallInst::Create(WarpSizeFn, "warp.size", IsWorkerCheckBB);
4476 OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);
4477 WarpSize->setDebugLoc(DLoc);
4478 Instruction *BlockSize = BinaryOperator::CreateSub(
4479 BlockHwSize, WarpSize, "block.size", IsWorkerCheckBB);
4480 BlockSize->setDebugLoc(DLoc);
4481 Instruction *IsMainOrWorker = ICmpInst::Create(
4482 ICmpInst::ICmp, llvm::CmpInst::ICMP_SLT, KernelInitCB, BlockSize,
4483 "thread.is_main_or_worker", IsWorkerCheckBB);
4484 IsMainOrWorker->setDebugLoc(DLoc);
4485 BranchInst::Create(StateMachineBeginBB, StateMachineFinishedBB,
4486 IsMainOrWorker, IsWorkerCheckBB);
4487
4488 // Create local storage for the work function pointer.
4489 const DataLayout &DL = M.getDataLayout();
4490 Type *VoidPtrTy = PointerType::getUnqual(Ctx);
4491 Instruction *WorkFnAI =
4492 new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
4493 "worker.work_fn.addr", Kernel->getEntryBlock().begin());
4494 WorkFnAI->setDebugLoc(DLoc);
4495
4496 OMPInfoCache.OMPBuilder.updateToLocation(
4498 IRBuilder<>::InsertPoint(StateMachineBeginBB,
4499 StateMachineBeginBB->end()),
4500 DLoc));
4501
4502 Value *Ident = KernelInfo::getIdentFromKernelEnvironment(KernelEnvC);
4503 Value *GTid = KernelInitCB;
4504
4505 FunctionCallee BarrierFn =
4506 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4507 M, OMPRTL___kmpc_barrier_simple_generic);
4508 CallInst *Barrier =
4509 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB);
4510 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4511 Barrier->setDebugLoc(DLoc);
4512
4513 if (WorkFnAI->getType()->getPointerAddressSpace() !=
4514 (unsigned int)AddressSpace::Generic) {
4515 WorkFnAI = new AddrSpaceCastInst(
4516 WorkFnAI, PointerType::get(Ctx, (unsigned int)AddressSpace::Generic),
4517 WorkFnAI->getName() + ".generic", StateMachineBeginBB);
4518 WorkFnAI->setDebugLoc(DLoc);
4519 }
4520
4521 FunctionCallee KernelParallelFn =
4522 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4523 M, OMPRTL___kmpc_kernel_parallel);
4524 CallInst *IsActiveWorker = CallInst::Create(
4525 KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
4526 OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);
4527 IsActiveWorker->setDebugLoc(DLoc);
4528 Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
4529 StateMachineBeginBB);
4530 WorkFn->setDebugLoc(DLoc);
4531
4532 FunctionType *ParallelRegionFnTy = FunctionType::get(
4534 false);
4535
4536 Instruction *IsDone =
4537 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
4538 Constant::getNullValue(VoidPtrTy), "worker.is_done",
4539 StateMachineBeginBB);
4540 IsDone->setDebugLoc(DLoc);
4541 BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
4542 IsDone, StateMachineBeginBB)
4543 ->setDebugLoc(DLoc);
4544
4545 BranchInst::Create(StateMachineIfCascadeCurrentBB,
4546 StateMachineDoneBarrierBB, IsActiveWorker,
4547 StateMachineIsActiveCheckBB)
4548 ->setDebugLoc(DLoc);
4549
4550 Value *ZeroArg =
4551 Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
4552
4553 const unsigned int WrapperFunctionArgNo = 6;
4554
4555 // Now that we have most of the CFG skeleton it is time for the if-cascade
4556 // that checks the function pointer we got from the runtime against the
4557 // parallel regions we expect, if there are any.
4558 for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) {
4559 auto *CB = ReachedKnownParallelRegions[I];
4560 auto *ParallelRegion = dyn_cast<Function>(
4561 CB->getArgOperand(WrapperFunctionArgNo)->stripPointerCasts());
4562 BasicBlock *PRExecuteBB = BasicBlock::Create(
4563 Ctx, "worker_state_machine.parallel_region.execute", Kernel,
4564 StateMachineEndParallelBB);
4565 CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
4566 ->setDebugLoc(DLoc);
4567 BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
4568 ->setDebugLoc(DLoc);
4569
4570 BasicBlock *PRNextBB =
4571 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4572 Kernel, StateMachineEndParallelBB);
4573 A.registerManifestAddedBasicBlock(*PRExecuteBB);
4574 A.registerManifestAddedBasicBlock(*PRNextBB);
4575
4576 // Check if we need to compare the pointer at all or if we can just
4577 // call the parallel region function.
4578 Value *IsPR;
4579 if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) {
4580 Instruction *CmpI = ICmpInst::Create(
4581 ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn, ParallelRegion,
4582 "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
4583 CmpI->setDebugLoc(DLoc);
4584 IsPR = CmpI;
4585 } else {
4586 IsPR = ConstantInt::getTrue(Ctx);
4587 }
4588
4589 BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
4590 StateMachineIfCascadeCurrentBB)
4591 ->setDebugLoc(DLoc);
4592 StateMachineIfCascadeCurrentBB = PRNextBB;
4593 }
4594
4595 // At the end of the if-cascade we place the indirect function pointer call
4596 // in case we might need it, that is if there can be parallel regions we
4597 // have not handled in the if-cascade above.
4598 if (!ReachedUnknownParallelRegions.empty()) {
4599 StateMachineIfCascadeCurrentBB->setName(
4600 "worker_state_machine.parallel_region.fallback.execute");
4601 CallInst::Create(ParallelRegionFnTy, WorkFn, {ZeroArg, GTid}, "",
4602 StateMachineIfCascadeCurrentBB)
4603 ->setDebugLoc(DLoc);
4604 }
4605 BranchInst::Create(StateMachineEndParallelBB,
4606 StateMachineIfCascadeCurrentBB)
4607 ->setDebugLoc(DLoc);
4608
4609 FunctionCallee EndParallelFn =
4610 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4611 M, OMPRTL___kmpc_kernel_end_parallel);
4612 CallInst *EndParallel =
4613 CallInst::Create(EndParallelFn, {}, "", StateMachineEndParallelBB);
4614 OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);
4615 EndParallel->setDebugLoc(DLoc);
4616 BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
4617 ->setDebugLoc(DLoc);
4618
4619 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
4620 ->setDebugLoc(DLoc);
4621 BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
4622 ->setDebugLoc(DLoc);
4623
4624 return true;
4625 }
4626
4627 /// Fixpoint iteration update function. Will be called every time a dependence
4628 /// changed its state (and in the beginning).
4629 ChangeStatus updateImpl(Attributor &A) override {
4630 KernelInfoState StateBefore = getState();
4631
4632 // When we leave this function this RAII will make sure the member
4633 // KernelEnvC is updated properly depending on the state. That member is
4634 // used for simplification of values and needs to be up to date at all
4635 // times.
4636 struct UpdateKernelEnvCRAII {
4637 AAKernelInfoFunction &AA;
4638
4639 UpdateKernelEnvCRAII(AAKernelInfoFunction &AA) : AA(AA) {}
4640
4641 ~UpdateKernelEnvCRAII() {
4642 if (!AA.KernelEnvC)
4643 return;
4644
4645 ConstantStruct *ExistingKernelEnvC =
4647
4648 if (!AA.isValidState()) {
4649 AA.KernelEnvC = ExistingKernelEnvC;
4650 return;
4651 }
4652
4653 if (!AA.ReachedKnownParallelRegions.isValidState())
4654 AA.setUseGenericStateMachineOfKernelEnvironment(
4655 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4656 ExistingKernelEnvC));
4657
4658 if (!AA.SPMDCompatibilityTracker.isValidState())
4659 AA.setExecModeOfKernelEnvironment(
4660 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC));
4661
4662 ConstantInt *MayUseNestedParallelismC =
4663 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(
4664 AA.KernelEnvC);
4665 ConstantInt *NewMayUseNestedParallelismC = ConstantInt::get(
4666 MayUseNestedParallelismC->getIntegerType(), AA.NestedParallelism);
4667 AA.setMayUseNestedParallelismOfKernelEnvironment(
4668 NewMayUseNestedParallelismC);
4669 }
4670 } RAII(*this);
4671
4672 // Callback to check a read/write instruction.
4673 auto CheckRWInst = [&](Instruction &I) {
4674 // We handle calls later.
4675 if (isa<CallBase>(I))
4676 return true;
4677 // We only care about write effects.
4678 if (!I.mayWriteToMemory())
4679 return true;
4680 if (auto *SI = dyn_cast<StoreInst>(&I)) {
4681 const auto *UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>(
4682 *this, IRPosition::value(*SI->getPointerOperand()),
4683 DepClassTy::OPTIONAL);
4684 auto *HS = A.getAAFor<AAHeapToStack>(
4685 *this, IRPosition::function(*I.getFunction()),
4686 DepClassTy::OPTIONAL);
4687 if (UnderlyingObjsAA &&
4688 UnderlyingObjsAA->forallUnderlyingObjects([&](Value &Obj) {
4689 if (AA::isAssumedThreadLocalObject(A, Obj, *this))
4690 return true;
4691 // Check for AAHeapToStack moved objects which must not be
4692 // guarded.
4693 auto *CB = dyn_cast<CallBase>(&Obj);
4694 return CB && HS && HS->isAssumedHeapToStack(*CB);
4695 }))
4696 return true;
4697 }
4698
4699 // Insert instruction that needs guarding.
4700 SPMDCompatibilityTracker.insert(&I);
4701 return true;
4702 };
4703
4704 bool UsedAssumedInformationInCheckRWInst = false;
4705 if (!SPMDCompatibilityTracker.isAtFixpoint())
4706 if (!A.checkForAllReadWriteInstructions(
4707 CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
4708 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4709
4710 bool UsedAssumedInformationFromReachingKernels = false;
4711 if (!IsKernelEntry) {
4712 updateParallelLevels(A);
4713
4714 bool AllReachingKernelsKnown = true;
4715 updateReachingKernelEntries(A, AllReachingKernelsKnown);
4716 UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
4717
4718 if (!SPMDCompatibilityTracker.empty()) {
4719 if (!ParallelLevels.isValidState())
4720 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4721 else if (!ReachingKernelEntries.isValidState())
4722 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4723 else {
4724 // Check if all reaching kernels agree on the mode as we can otherwise
4725 // not guard instructions. We might not be sure about the mode so we
4726 // we cannot fix the internal spmd-zation state either.
4727 int SPMD = 0, Generic = 0;
4728 for (auto *Kernel : ReachingKernelEntries) {
4729 auto *CBAA = A.getAAFor<AAKernelInfo>(
4730 *this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL);
4731 if (CBAA && CBAA->SPMDCompatibilityTracker.isValidState() &&
4732 CBAA->SPMDCompatibilityTracker.isAssumed())
4733 ++SPMD;
4734 else
4735 ++Generic;
4736 if (!CBAA || !CBAA->SPMDCompatibilityTracker.isAtFixpoint())
4737 UsedAssumedInformationFromReachingKernels = true;
4738 }
4739 if (SPMD != 0 && Generic != 0)
4740 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4741 }
4742 }
4743 }
4744
4745 // Callback to check a call instruction.
4746 bool AllParallelRegionStatesWereFixed = true;
4747 bool AllSPMDStatesWereFixed = true;
4748 auto CheckCallInst = [&](Instruction &I) {
4749 auto &CB = cast<CallBase>(I);
4750 auto *CBAA = A.getAAFor<AAKernelInfo>(
4751 *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4752 if (!CBAA)
4753 return false;
4754 getState() ^= CBAA->getState();
4755 AllSPMDStatesWereFixed &= CBAA->SPMDCompatibilityTracker.isAtFixpoint();
4756 AllParallelRegionStatesWereFixed &=
4757 CBAA->ReachedKnownParallelRegions.isAtFixpoint();
4758 AllParallelRegionStatesWereFixed &=
4759 CBAA->ReachedUnknownParallelRegions.isAtFixpoint();
4760 return true;
4761 };
4762
4763 bool UsedAssumedInformationInCheckCallInst = false;
4764 if (!A.checkForAllCallLikeInstructions(
4765 CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) {
4766 LLVM_DEBUG(dbgs() << TAG
4767 << "Failed to visit all call-like instructions!\n";);
4768 return indicatePessimisticFixpoint();
4769 }
4770
4771 // If we haven't used any assumed information for the reached parallel
4772 // region states we can fix it.
4773 if (!UsedAssumedInformationInCheckCallInst &&
4774 AllParallelRegionStatesWereFixed) {
4775 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
4776 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
4777 }
4778
4779 // If we haven't used any assumed information for the SPMD state we can fix
4780 // it.
4781 if (!UsedAssumedInformationInCheckRWInst &&
4782 !UsedAssumedInformationInCheckCallInst &&
4783 !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
4784 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4785
4786 return StateBefore == getState() ? ChangeStatus::UNCHANGED
4787 : ChangeStatus::CHANGED;
4788 }
4789
4790private:
4791 /// Update info regarding reaching kernels.
4792 void updateReachingKernelEntries(Attributor &A,
4793 bool &AllReachingKernelsKnown) {
4794 auto PredCallSite = [&](AbstractCallSite ACS) {
4795 Function *Caller = ACS.getInstruction()->getFunction();
4796
4797 assert(Caller && "Caller is nullptr");
4798
4799 auto *CAA = A.getOrCreateAAFor<AAKernelInfo>(
4800 IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
4801 if (CAA && CAA->ReachingKernelEntries.isValidState()) {
4802 ReachingKernelEntries ^= CAA->ReachingKernelEntries;
4803 return true;
4804 }
4805
4806 // We lost track of the caller of the associated function, any kernel
4807 // could reach now.
4808 ReachingKernelEntries.indicatePessimisticFixpoint();
4809
4810 return true;
4811 };
4812
4813 if (!A.checkForAllCallSites(PredCallSite, *this,
4814 true /* RequireAllCallSites */,
4815 AllReachingKernelsKnown))
4816 ReachingKernelEntries.indicatePessimisticFixpoint();
4817 }
4818
4819 /// Update info regarding parallel levels.
4820 void updateParallelLevels(Attributor &A) {
4821 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4822 OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
4823 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
4824
4825 auto PredCallSite = [&](AbstractCallSite ACS) {
4826 Function *Caller = ACS.getInstruction()->getFunction();
4827
4828 assert(Caller && "Caller is nullptr");
4829
4830 auto *CAA =
4831 A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
4832 if (CAA && CAA->ParallelLevels.isValidState()) {
4833 // Any function that is called by `__kmpc_parallel_51` will not be
4834 // folded as the parallel level in the function is updated. In order to
4835 // get it right, all the analysis would depend on the implentation. That
4836 // said, if in the future any change to the implementation, the analysis
4837 // could be wrong. As a consequence, we are just conservative here.
4838 if (Caller == Parallel51RFI.Declaration) {
4839 ParallelLevels.indicatePessimisticFixpoint();
4840 return true;
4841 }
4842
4843 ParallelLevels ^= CAA->ParallelLevels;
4844
4845 return true;
4846 }
4847
4848 // We lost track of the caller of the associated function, any kernel
4849 // could reach now.
4850 ParallelLevels.indicatePessimisticFixpoint();
4851
4852 return true;
4853 };
4854
4855 bool AllCallSitesKnown = true;
4856 if (!A.checkForAllCallSites(PredCallSite, *this,
4857 true /* RequireAllCallSites */,
4858 AllCallSitesKnown))
4859 ParallelLevels.indicatePessimisticFixpoint();
4860 }
4861};
4862
4863/// The call site kernel info abstract attribute, basically, what can we say
4864/// about a call site with regards to the KernelInfoState. For now this simply
4865/// forwards the information from the callee.
4866struct AAKernelInfoCallSite : AAKernelInfo {
4867 AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
4868 : AAKernelInfo(IRP, A) {}
4869
4870 /// See AbstractAttribute::initialize(...).
4871 void initialize(Attributor &A) override {
4872 AAKernelInfo::initialize(A);
4873
4874 CallBase &CB = cast<CallBase>(getAssociatedValue());
4875 auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>(
4876 *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4877
4878 // Check for SPMD-mode assumptions.
4879 if (AssumptionAA && AssumptionAA->hasAssumption("ompx_spmd_amenable")) {
4880 indicateOptimisticFixpoint();
4881 return;
4882 }
4883
4884 // First weed out calls we do not care about, that is readonly/readnone
4885 // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
4886 // parallel region or anything else we are looking for.
4887 if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
4888 indicateOptimisticFixpoint();
4889 return;
4890 }
4891
4892 // Next we check if we know the callee. If it is a known OpenMP function
4893 // we will handle them explicitly in the switch below. If it is not, we
4894 // will use an AAKernelInfo object on the callee to gather information and
4895 // merge that into the current state. The latter happens in the updateImpl.
4896 auto CheckCallee = [&](Function *Callee, unsigned NumCallees) {
4897 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4898 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4899 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4900 // Unknown caller or declarations are not analyzable, we give up.
4901 if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
4902
4903 // Unknown callees might contain parallel regions, except if they have
4904 // an appropriate assumption attached.
4905 if (!AssumptionAA ||
4906 !(AssumptionAA->hasAssumption("omp_no_openmp") ||
4907 AssumptionAA->hasAssumption("omp_no_parallelism")))
4908 ReachedUnknownParallelRegions.insert(&CB);
4909
4910 // If SPMDCompatibilityTracker is not fixed, we need to give up on the
4911 // idea we can run something unknown in SPMD-mode.
4912 if (!SPMDCompatibilityTracker.isAtFixpoint()) {
4913 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4914 SPMDCompatibilityTracker.insert(&CB);
4915 }
4916
4917 // We have updated the state for this unknown call properly, there
4918 // won't be any change so we indicate a fixpoint.
4919 indicateOptimisticFixpoint();
4920 }
4921 // If the callee is known and can be used in IPO, we will update the
4922 // state based on the callee state in updateImpl.
4923 return;
4924 }
4925 if (NumCallees > 1) {
4926 indicatePessimisticFixpoint();
4927 return;
4928 }
4929
4930 RuntimeFunction RF = It->getSecond();
4931 switch (RF) {
4932 // All the functions we know are compatible with SPMD mode.
4933 case OMPRTL___kmpc_is_spmd_exec_mode:
4934 case OMPRTL___kmpc_distribute_static_fini:
4935 case OMPRTL___kmpc_for_static_fini:
4936 case OMPRTL___kmpc_global_thread_num:
4937 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4938 case OMPRTL___kmpc_get_hardware_num_blocks:
4939 case OMPRTL___kmpc_single:
4940 case OMPRTL___kmpc_end_single:
4941 case OMPRTL___kmpc_master:
4942 case OMPRTL___kmpc_end_master:
4943 case OMPRTL___kmpc_barrier:
4944 case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
4945 case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
4946 case OMPRTL___kmpc_error:
4947 case OMPRTL___kmpc_flush:
4948 case OMPRTL___kmpc_get_hardware_thread_id_in_block:
4949 case OMPRTL___kmpc_get_warp_size:
4950 case OMPRTL_omp_get_thread_num:
4951 case OMPRTL_omp_get_num_threads:
4952 case OMPRTL_omp_get_max_threads:
4953 case OMPRTL_omp_in_parallel:
4954 case OMPRTL_omp_get_dynamic:
4955 case OMPRTL_omp_get_cancellation:
4956 case OMPRTL_omp_get_nested:
4957 case OMPRTL_omp_get_schedule:
4958 case OMPRTL_omp_get_thread_limit:
4959 case OMPRTL_omp_get_supported_active_levels:
4960 case OMPRTL_omp_get_max_active_levels:
4961 case OMPRTL_omp_get_level:
4962 case OMPRTL_omp_get_ancestor_thread_num:
4963 case OMPRTL_omp_get_team_size:
4964 case OMPRTL_omp_get_active_level:
4965 case OMPRTL_omp_in_final:
4966 case OMPRTL_omp_get_proc_bind:
4967 case OMPRTL_omp_get_num_places:
4968 case OMPRTL_omp_get_num_procs:
4969 case OMPRTL_omp_get_place_proc_ids:
4970 case OMPRTL_omp_get_place_num:
4971 case OMPRTL_omp_get_partition_num_places:
4972 case OMPRTL_omp_get_partition_place_nums:
4973 case OMPRTL_omp_get_wtime:
4974 break;
4975 case OMPRTL___kmpc_distribute_static_init_4:
4976 case OMPRTL___kmpc_distribute_static_init_4u:
4977 case OMPRTL___kmpc_distribute_static_init_8:
4978 case OMPRTL___kmpc_distribute_static_init_8u:
4979 case OMPRTL___kmpc_for_static_init_4:
4980 case OMPRTL___kmpc_for_static_init_4u:
4981 case OMPRTL___kmpc_for_static_init_8:
4982 case OMPRTL___kmpc_for_static_init_8u: {
4983 // Check the schedule and allow static schedule in SPMD mode.
4984 unsigned ScheduleArgOpNo = 2;
4985 auto *ScheduleTypeCI =
4986 dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
4987 unsigned ScheduleTypeVal =
4988 ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
4989 switch (OMPScheduleType(ScheduleTypeVal)) {
4990 case OMPScheduleType::UnorderedStatic:
4991 case OMPScheduleType::UnorderedStaticChunked:
4992 case OMPScheduleType::OrderedDistribute:
4993 case OMPScheduleType::OrderedDistributeChunked:
4994 break;
4995 default:
4996 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4997 SPMDCompatibilityTracker.insert(&CB);
4998 break;
4999 };
5000 } break;
5001 case OMPRTL___kmpc_target_init:
5002 KernelInitCB = &CB;
5003 break;
5004 case OMPRTL___kmpc_target_deinit:
5005 KernelDeinitCB = &CB;
5006 break;
5007 case OMPRTL___kmpc_parallel_51:
5008 if (!handleParallel51(A, CB))
5009 indicatePessimisticFixpoint();
5010 return;
5011 case OMPRTL___kmpc_omp_task:
5012 // We do not look into tasks right now, just give up.
5013 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5014 SPMDCompatibilityTracker.insert(&CB);
5015 ReachedUnknownParallelRegions.insert(&CB);
5016 break;
5017 case OMPRTL___kmpc_alloc_shared:
5018 case OMPRTL___kmpc_free_shared:
5019 // Return without setting a fixpoint, to be resolved in updateImpl.
5020 return;
5021 default:
5022 // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
5023 // generally. However, they do not hide parallel regions.
5024 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5025 SPMDCompatibilityTracker.insert(&CB);
5026 break;
5027 }
5028 // All other OpenMP runtime calls will not reach parallel regions so they
5029 // can be safely ignored for now. Since it is a known OpenMP runtime call
5030 // we have now modeled all effects and there is no need for any update.
5031 indicateOptimisticFixpoint();
5032 };
5033
5034 const auto *AACE =
5035 A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5036 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5037 CheckCallee(getAssociatedFunction(), 1);
5038 return;
5039 }
5040 const auto &OptimisticEdges = AACE->getOptimisticEdges();
5041 for (auto *Callee : OptimisticEdges) {
5042 CheckCallee(Callee, OptimisticEdges.size());
5043 if (isAtFixpoint())
5044 break;
5045 }
5046 }
5047
5048 ChangeStatus updateImpl(Attributor &A) override {
5049 // TODO: Once we have call site specific value information we can provide
5050 // call site specific liveness information and then it makes
5051 // sense to specialize attributes for call sites arguments instead of
5052 // redirecting requests to the callee argument.
5053 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5054 KernelInfoState StateBefore = getState();
5055
5056 auto CheckCallee = [&](Function *F, int NumCallees) {
5057 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
5058
5059 // If F is not a runtime function, propagate the AAKernelInfo of the
5060 // callee.
5061 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
5062 const IRPosition &FnPos = IRPosition::function(*F);
5063 auto *FnAA =
5064 A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
5065 if (!FnAA)
5066 return indicatePessimisticFixpoint();
5067 if (getState() == FnAA->getState())
5068 return ChangeStatus::UNCHANGED;
5069 getState() = FnAA->getState();
5070 return ChangeStatus::CHANGED;
5071 }
5072 if (NumCallees > 1)
5073 return indicatePessimisticFixpoint();
5074
5075 CallBase &CB = cast<CallBase>(getAssociatedValue());
5076 if (It->getSecond() == OMPRTL___kmpc_parallel_51) {
5077 if (!handleParallel51(A, CB))
5078 return indicatePessimisticFixpoint();
5079 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5080 : ChangeStatus::CHANGED;
5081 }
5082
5083 // F is a runtime function that allocates or frees memory, check
5084 // AAHeapToStack and AAHeapToShared.
5085 assert(
5086 (It->getSecond() == OMPRTL___kmpc_alloc_shared ||
5087 It->getSecond() == OMPRTL___kmpc_free_shared) &&
5088 "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
5089
5090 auto *HeapToStackAA = A.getAAFor<AAHeapToStack>(
5091 *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5092 auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>(
5093 *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5094
5095 RuntimeFunction RF = It->getSecond();
5096
5097 switch (RF) {
5098 // If neither HeapToStack nor HeapToShared assume the call is removed,
5099 // assume SPMD incompatibility.
5100 case OMPRTL___kmpc_alloc_shared:
5101 if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&
5102 (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))
5103 SPMDCompatibilityTracker.insert(&CB);
5104 break;
5105 case OMPRTL___kmpc_free_shared:
5106 if ((!HeapToStackAA ||
5107 !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&
5108 (!HeapToSharedAA ||
5109 !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))
5110 SPMDCompatibilityTracker.insert(&CB);
5111 break;
5112 default:
5113 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5114 SPMDCompatibilityTracker.insert(&CB);
5115 }
5116 return ChangeStatus::CHANGED;
5117 };
5118
5119 const auto *AACE =
5120 A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5121 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5122 if (Function *F = getAssociatedFunction())
5123 CheckCallee(F, /*NumCallees=*/1);
5124 } else {
5125 const auto &OptimisticEdges = AACE->getOptimisticEdges();
5126 for (auto *Callee : OptimisticEdges) {
5127 CheckCallee(Callee, OptimisticEdges.size());
5128 if (isAtFixpoint())
5129 break;
5130 }
5131 }
5132
5133 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5134 : ChangeStatus::CHANGED;
5135 }
5136
5137 /// Deal with a __kmpc_parallel_51 call (\p CB). Returns true if the call was
5138 /// handled, if a problem occurred, false is returned.
5139 bool handleParallel51(Attributor &A, CallBase &CB) {
5140 const unsigned int NonWrapperFunctionArgNo = 5;
5141 const unsigned int WrapperFunctionArgNo = 6;
5142 auto ParallelRegionOpArgNo = SPMDCompatibilityTracker.isAssumed()
5143 ? NonWrapperFunctionArgNo
5144 : WrapperFunctionArgNo;
5145
5146 auto *ParallelRegion = dyn_cast<Function>(
5147 CB.getArgOperand(ParallelRegionOpArgNo)->stripPointerCasts());
5148 if (!ParallelRegion)
5149 return false;
5150
5151 ReachedKnownParallelRegions.insert(&CB);
5152 /// Check nested parallelism
5153 auto *FnAA = A.getAAFor<AAKernelInfo>(
5154 *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL);
5155 NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
5156 !FnAA->ReachedKnownParallelRegions.empty() ||
5157 !FnAA->ReachedKnownParallelRegions.isValidState() ||
5158 !FnAA->ReachedUnknownParallelRegions.isValidState() ||
5159 !FnAA->ReachedUnknownParallelRegions.empty();
5160 return true;
5161 }
5162};
5163
5164struct AAFoldRuntimeCall
5165 : public StateWrapper<BooleanState, AbstractAttribute> {
5167
5168 AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
5169
5170 /// Statistics are tracked as part of manifest for now.
5171 void trackStatistics() const override {}
5172
5173 /// Create an abstract attribute biew for the position \p IRP.
5174 static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
5175 Attributor &A);
5176
5177 /// See AbstractAttribute::getName()
5178 const std::string getName() const override { return "AAFoldRuntimeCall"; }
5179
5180 /// See AbstractAttribute::getIdAddr()
5181 const char *getIdAddr() const override { return &ID; }
5182
5183 /// This function should return true if the type of the \p AA is
5184 /// AAFoldRuntimeCall
5185 static bool classof(const AbstractAttribute *AA) {
5186 return (AA->getIdAddr() == &ID);
5187 }
5188
5189 static const char ID;
5190};
5191
5192struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
5193 AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
5194 : AAFoldRuntimeCall(IRP, A) {}
5195
5196 /// See AbstractAttribute::getAsStr()
5197 const std::string getAsStr(Attributor *) const override {
5198 if (!isValidState())
5199 return "<invalid>";
5200
5201 std::string Str("simplified value: ");
5202
5203 if (!SimplifiedValue)
5204 return Str + std::string("none");
5205
5206 if (!*SimplifiedValue)
5207 return Str + std::string("nullptr");
5208
5209 if (ConstantInt *CI = dyn_cast<ConstantInt>(*SimplifiedValue))
5210 return Str + std::to_string(CI->getSExtValue());
5211
5212 return Str + std::string("unknown");
5213 }
5214
5215 void initialize(Attributor &A) override {
5217 indicatePessimisticFixpoint();
5218
5219 Function *Callee = getAssociatedFunction();
5220
5221 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5222 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
5223 assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
5224 "Expected a known OpenMP runtime function");
5225
5226 RFKind = It->getSecond();
5227
5228 CallBase &CB = cast<CallBase>(getAssociatedValue());
5229 A.registerSimplificationCallback(
5231 [&](const IRPosition &IRP, const AbstractAttribute *AA,
5232 bool &UsedAssumedInformation) -> std::optional<Value *> {
5233 assert((isValidState() ||
5234 (SimplifiedValue && *SimplifiedValue == nullptr)) &&
5235 "Unexpected invalid state!");
5236
5237 if (!isAtFixpoint()) {
5238 UsedAssumedInformation = true;
5239 if (AA)
5240 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
5241 }
5242 return SimplifiedValue;
5243 });
5244 }
5245
5246 ChangeStatus updateImpl(Attributor &A) override {
5247 ChangeStatus Changed = ChangeStatus::UNCHANGED;
5248 switch (RFKind) {
5249 case OMPRTL___kmpc_is_spmd_exec_mode:
5250 Changed |= foldIsSPMDExecMode(A);
5251 break;
5252 case OMPRTL___kmpc_parallel_level:
5253 Changed |= foldParallelLevel(A);
5254 break;
5255 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
5256 Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
5257 break;
5258 case OMPRTL___kmpc_get_hardware_num_blocks:
5259 Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
5260 break;
5261 default:
5262 llvm_unreachable("Unhandled OpenMP runtime function!");
5263 }
5264
5265 return Changed;
5266 }
5267
5268 ChangeStatus manifest(Attributor &A) override {
5269 ChangeStatus Changed = ChangeStatus::UNCHANGED;
5270
5271 if (SimplifiedValue && *SimplifiedValue) {
5272 Instruction &I = *getCtxI();
5273 A.changeAfterManifest(IRPosition::inst(I), **SimplifiedValue);
5274 A.deleteAfterManifest(I);
5275
5276 CallBase *CB = dyn_cast<CallBase>(&I);
5277 auto Remark = [&](OptimizationRemark OR) {
5278 if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue))
5279 return OR << "Replacing OpenMP runtime call "
5280 << CB->getCalledFunction()->getName() << " with "
5281 << ore::NV("FoldedValue", C->getZExtValue()) << ".";
5282 return OR << "Replacing OpenMP runtime call "
5283 << CB->getCalledFunction()->getName() << ".";
5284 };
5285
5286 if (CB && EnableVerboseRemarks)
5287 A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark);
5288
5289 LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "
5290 << **SimplifiedValue << "\n");
5291
5292 Changed = ChangeStatus::CHANGED;
5293 }
5294
5295 return Changed;
5296 }
5297
5298 ChangeStatus indicatePessimisticFixpoint() override {
5299 SimplifiedValue = nullptr;
5300 return AAFoldRuntimeCall::indicatePessimisticFixpoint();
5301 }
5302
5303private:
5304 /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
5305 ChangeStatus foldIsSPMDExecMode(Attributor &A) {
5306 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5307
5308 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5309 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5310 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5311 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5312
5313 if (!CallerKernelInfoAA ||
5314 !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5315 return indicatePessimisticFixpoint();
5316
5317 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5318 auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5319 DepClassTy::REQUIRED);
5320
5321 if (!AA || !AA->isValidState()) {
5322 SimplifiedValue = nullptr;
5323 return indicatePessimisticFixpoint();
5324 }
5325
5326 if (AA->SPMDCompatibilityTracker.isAssumed()) {
5327 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5328 ++KnownSPMDCount;
5329 else
5330 ++AssumedSPMDCount;
5331 } else {
5332 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5333 ++KnownNonSPMDCount;
5334 else
5335 ++AssumedNonSPMDCount;
5336 }
5337 }
5338
5339 if ((AssumedSPMDCount + KnownSPMDCount) &&
5340 (AssumedNonSPMDCount + KnownNonSPMDCount))
5341 return indicatePessimisticFixpoint();
5342
5343 auto &Ctx = getAnchorValue().getContext();
5344 if (KnownSPMDCount || AssumedSPMDCount) {
5345 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5346 "Expected only SPMD kernels!");
5347 // All reaching kernels are in SPMD mode. Update all function calls to
5348 // __kmpc_is_spmd_exec_mode to 1.
5349 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
5350 } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
5351 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5352 "Expected only non-SPMD kernels!");
5353 // All reaching kernels are in non-SPMD mode. Update all function
5354 // calls to __kmpc_is_spmd_exec_mode to 0.
5355 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
5356 } else {
5357 // We have empty reaching kernels, therefore we cannot tell if the
5358 // associated call site can be folded. At this moment, SimplifiedValue
5359 // must be none.
5360 assert(!SimplifiedValue && "SimplifiedValue should be none");
5361 }
5362
5363 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5364 : ChangeStatus::CHANGED;
5365 }
5366
5367 /// Fold __kmpc_parallel_level into a constant if possible.
5368 ChangeStatus foldParallelLevel(Attributor &A) {
5369 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5370
5371 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5372 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5373
5374 if (!CallerKernelInfoAA ||
5375 !CallerKernelInfoAA->ParallelLevels.isValidState())
5376 return indicatePessimisticFixpoint();
5377
5378 if (!CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5379 return indicatePessimisticFixpoint();
5380
5381 if (CallerKernelInfoAA->ReachingKernelEntries.empty()) {
5382 assert(!SimplifiedValue &&
5383 "SimplifiedValue should keep none at this point");
5384 return ChangeStatus::UNCHANGED;
5385 }
5386
5387 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5388 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5389 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5390 auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5391 DepClassTy::REQUIRED);
5392 if (!AA || !AA->SPMDCompatibilityTracker.isValidState())
5393 return indicatePessimisticFixpoint();
5394
5395 if (AA->SPMDCompatibilityTracker.isAssumed()) {
5396 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5397 ++KnownSPMDCount;
5398 else
5399 ++AssumedSPMDCount;
5400 } else {
5401 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5402 ++KnownNonSPMDCount;
5403 else
5404 ++AssumedNonSPMDCount;
5405 }
5406 }
5407
5408 if ((AssumedSPMDCount + KnownSPMDCount) &&
5409 (AssumedNonSPMDCount + KnownNonSPMDCount))
5410 return indicatePessimisticFixpoint();
5411
5412 auto &Ctx = getAnchorValue().getContext();
5413 // If the caller can only be reached by SPMD kernel entries, the parallel
5414 // level is 1. Similarly, if the caller can only be reached by non-SPMD
5415 // kernel entries, it is 0.
5416 if (AssumedSPMDCount || KnownSPMDCount) {
5417 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5418 "Expected only SPMD kernels!");
5419 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
5420 } else {
5421 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5422 "Expected only non-SPMD kernels!");
5423 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
5424 }
5425 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5426 : ChangeStatus::CHANGED;
5427 }
5428
5429 ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
5430 // Specialize only if all the calls agree with the attribute constant value
5431 int32_t CurrentAttrValue = -1;
5432 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5433
5434 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5435 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5436
5437 if (!CallerKernelInfoAA ||
5438 !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5439 return indicatePessimisticFixpoint();
5440
5441 // Iterate over the kernels that reach this function
5442 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5443 int32_t NextAttrVal = K->getFnAttributeAsParsedInteger(Attr, -1);
5444
5445 if (NextAttrVal == -1 ||
5446 (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
5447 return indicatePessimisticFixpoint();
5448 CurrentAttrValue = NextAttrVal;
5449 }
5450
5451 if (CurrentAttrValue != -1) {
5452 auto &Ctx = getAnchorValue().getContext();
5453 SimplifiedValue =
5454 ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
5455 }
5456 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5457 : ChangeStatus::CHANGED;
5458 }
5459
5460 /// An optional value the associated value is assumed to fold to. That is, we
5461 /// assume the associated value (which is a call) can be replaced by this
5462 /// simplified value.
5463 std::optional<Value *> SimplifiedValue;
5464
5465 /// The runtime function kind of the callee of the associated call site.
5466 RuntimeFunction RFKind;
5467};
5468
5469} // namespace
5470
5471/// Register folding callsite
5472void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
5473 auto &RFI = OMPInfoCache.RFIs[RF];
5474 RFI.foreachUse(SCC, [&](Use &U, Function &F) {
5475 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
5476 if (!CI)
5477 return false;
5478 A.getOrCreateAAFor<AAFoldRuntimeCall>(
5479 IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
5480 DepClassTy::NONE, /* ForceUpdate */ false,
5481 /* UpdateAfterInit */ false);
5482 return false;
5483 });
5484}
5485
5486void OpenMPOpt::registerAAs(bool IsModulePass) {
5487 if (SCC.empty())
5488 return;
5489
5490 if (IsModulePass) {
5491 // Ensure we create the AAKernelInfo AAs first and without triggering an
5492 // update. This will make sure we register all value simplification
5493 // callbacks before any other AA has the chance to create an AAValueSimplify
5494 // or similar.
5495 auto CreateKernelInfoCB = [&](Use &, Function &Kernel) {
5496 A.getOrCreateAAFor<AAKernelInfo>(
5497 IRPosition::function(Kernel), /* QueryingAA */ nullptr,
5498 DepClassTy::NONE, /* ForceUpdate */ false,
5499 /* UpdateAfterInit */ false);
5500 return false;
5501 };
5502 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
5503 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
5504 InitRFI.foreachUse(SCC, CreateKernelInfoCB);
5505
5506 registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
5507 registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
5508 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
5509 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
5510 }
5511
5512 // Create CallSite AA for all Getters.
5513 if (DeduceICVValues) {
5514 for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
5515 auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
5516
5517 auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
5518
5519 auto CreateAA = [&](Use &U, Function &Caller) {
5520 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
5521 if (!CI)
5522 return false;
5523
5524 auto &CB = cast<CallBase>(*CI);
5525
5527 A.getOrCreateAAFor<AAICVTracker>(CBPos);
5528 return false;
5529 };
5530
5531 GetterRFI.foreachUse(SCC, CreateAA);
5532 }
5533 }
5534
5535 // Create an ExecutionDomain AA for every function and a HeapToStack AA for
5536 // every function if there is a device kernel.
5537 if (!isOpenMPDevice(M))
5538 return;
5539
5540 for (auto *F : SCC) {
5541 if (F->isDeclaration())
5542 continue;
5543
5544 // We look at internal functions only on-demand but if any use is not a
5545 // direct call or outside the current set of analyzed functions, we have
5546 // to do it eagerly.
5547 if (F->hasLocalLinkage()) {
5548 if (llvm::all_of(F->uses(), [this](const Use &U) {
5549 const auto *CB = dyn_cast<CallBase>(U.getUser());
5550 return CB && CB->isCallee(&U) &&
5551 A.isRunOn(const_cast<Function *>(CB->getCaller()));
5552 }))
5553 continue;
5554 }
5555 registerAAsForFunction(A, *F);
5556 }
5557}
5558
5559void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) {
5561 A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
5562 A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(F));
5564 A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(F));
5565 if (F.hasFnAttribute(Attribute::Convergent))
5566 A.getOrCreateAAFor<AANonConvergent>(IRPosition::function(F));
5567
5568 for (auto &I : instructions(F)) {
5569 if (auto *LI = dyn_cast<LoadInst>(&I)) {
5570 bool UsedAssumedInformation = false;
5571 A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
5572 UsedAssumedInformation, AA::Interprocedural);
5573 A.getOrCreateAAFor<AAAddressSpace>(
5574 IRPosition::value(*LI->getPointerOperand()));
5575 continue;
5576 }
5577 if (auto *CI = dyn_cast<CallBase>(&I)) {
5578 if (CI->isIndirectCall())
5579 A.getOrCreateAAFor<AAIndirectCallInfo>(
5581 }
5582 if (auto *SI = dyn_cast<StoreInst>(&I)) {
5583 A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI));
5584 A.getOrCreateAAFor<AAAddressSpace>(
5585 IRPosition::value(*SI->getPointerOperand()));
5586 continue;
5587 }
5588 if (auto *FI = dyn_cast<FenceInst>(&I)) {
5589 A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*FI));
5590 continue;
5591 }
5592 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
5593 if (II->getIntrinsicID() == Intrinsic::assume) {
5594 A.getOrCreateAAFor<AAPotentialValues>(
5595 IRPosition::value(*II->getArgOperand(0)));
5596 continue;
5597 }
5598 }
5599 }
5600}
5601
5602const char AAICVTracker::ID = 0;
5603const char AAKernelInfo::ID = 0;
5604const char AAExecutionDomain::ID = 0;
5605const char AAHeapToShared::ID = 0;
5606const char AAFoldRuntimeCall::ID = 0;
5607
5608AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
5609 Attributor &A) {
5610 AAICVTracker *AA = nullptr;
5611 switch (IRP.getPositionKind()) {
5616 llvm_unreachable("ICVTracker can only be created for function position!");
5618 AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
5619 break;
5621 AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
5622 break;
5624 AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
5625 break;
5627 AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
5628 break;
5629 }
5630
5631 return *AA;
5632}
5633
5635 Attributor &A) {
5636 AAExecutionDomainFunction *AA = nullptr;
5637 switch (IRP.getPositionKind()) {
5646 "AAExecutionDomain can only be created for function position!");
5648 AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
5649 break;
5650 }
5651
5652 return *AA;
5653}
5654
5655AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
5656 Attributor &A) {
5657 AAHeapToSharedFunction *AA = nullptr;
5658 switch (IRP.getPositionKind()) {
5667 "AAHeapToShared can only be created for function position!");
5669 AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
5670 break;
5671 }
5672
5673 return *AA;
5674}
5675
5676AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
5677 Attributor &A) {
5678 AAKernelInfo *AA = nullptr;
5679 switch (IRP.getPositionKind()) {
5686 llvm_unreachable("KernelInfo can only be created for function position!");
5688 AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
5689 break;
5691 AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
5692 break;
5693 }
5694
5695 return *AA;
5696}
5697
5698AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
5699 Attributor &A) {
5700 AAFoldRuntimeCall *AA = nullptr;
5701 switch (IRP.getPositionKind()) {
5709 llvm_unreachable("KernelInfo can only be created for call site position!");
5711 AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
5712 break;
5713 }
5714
5715 return *AA;
5716}
5717
5719 if (!containsOpenMP(M))
5720 return PreservedAnalyses::all();
5722 return PreservedAnalyses::all();
5723
5726 KernelSet Kernels = getDeviceKernels(M);
5727
5729 LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt Module Pass:\n" << M);
5730
5731 auto IsCalled = [&](Function &F) {
5732 if (Kernels.contains(&F))
5733 return true;
5734 for (const User *U : F.users())
5735 if (!isa<BlockAddress>(U))
5736 return true;
5737 return false;
5738 };
5739
5740 auto EmitRemark = [&](Function &F) {
5742 ORE.emit([&]() {
5743 OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
5744 return ORA << "Could not internalize function. "
5745 << "Some optimizations may not be possible. [OMP140]";
5746 });
5747 };
5748
5749 bool Changed = false;
5750
5751 // Create internal copies of each function if this is a kernel Module. This
5752 // allows iterprocedural passes to see every call edge.
5753 DenseMap<Function *, Function *> InternalizedMap;
5754 if (isOpenMPDevice(M)) {
5755 SmallPtrSet<Function *, 16> InternalizeFns;
5756 for (Function &F : M)
5757 if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
5760 InternalizeFns.insert(&F);
5761 } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
5762 EmitRemark(F);
5763 }
5764 }
5765
5766 Changed |=
5767 Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
5768 }
5769
5770 // Look at every function in the Module unless it was internalized.
5771 SetVector<Function *> Functions;
5773 for (Function &F : M)
5774 if (!F.isDeclaration() && !InternalizedMap.lookup(&F)) {
5775 SCC.push_back(&F);
5776 Functions.insert(&F);
5777 }
5778
5779 if (SCC.empty())
5780 return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
5781
5782 AnalysisGetter AG(FAM);
5783
5784 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5786 };
5787
5789 CallGraphUpdater CGUpdater;
5790
5791 bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5793 OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, PostLink);
5794
5795 unsigned MaxFixpointIterations =
5797
5798 AttributorConfig AC(CGUpdater);
5800 AC.IsModulePass = true;
5801 AC.RewriteSignatures = false;
5802 AC.MaxFixpointIterations = MaxFixpointIterations;
5803 AC.OREGetter = OREGetter;
5804 AC.PassName = DEBUG_TYPE;
5805 AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5806 AC.IPOAmendableCB = [](const Function &F) {
5807 return F.hasFnAttribute("kernel");
5808 };
5809
5810 Attributor A(Functions, InfoCache, AC);
5811
5812 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5813 Changed |= OMPOpt.run(true);
5814
5815 // Optionally inline device functions for potentially better performance.
5817 for (Function &F : M)
5818 if (!F.isDeclaration() && !Kernels.contains(&F) &&
5819 !F.hasFnAttribute(Attribute::NoInline))
5820 F.addFnAttr(Attribute::AlwaysInline);
5821
5823 LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M);
5824
5825 if (Changed)
5826 return PreservedAnalyses::none();
5827
5828 return PreservedAnalyses::all();
5829}
5830
5833 LazyCallGraph &CG,
5834 CGSCCUpdateResult &UR) {
5835 if (!containsOpenMP(*C.begin()->getFunction().getParent()))
5836 return PreservedAnalyses::all();
5838 return PreservedAnalyses::all();
5839
5841 // If there are kernels in the module, we have to run on all SCC's.
5842 for (LazyCallGraph::Node &N : C) {
5843 Function *Fn = &N.getFunction();
5844 SCC.push_back(Fn);
5845 }
5846
5847 if (SCC.empty())
5848 return PreservedAnalyses::all();
5849
5850 Module &M = *C.begin()->getFunction().getParent();
5851
5853 LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt CGSCC Pass:\n" << M);
5854
5855 KernelSet Kernels = getDeviceKernels(M);
5856
5858 AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
5859
5860 AnalysisGetter AG(FAM);
5861
5862 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5864 };
5865
5867 CallGraphUpdater CGUpdater;
5868 CGUpdater.initialize(CG, C, AM, UR);
5869
5870 bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5872 SetVector<Function *> Functions(SCC.begin(), SCC.end());
5873 OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
5874 /*CGSCC*/ &Functions, PostLink);
5875
5876 unsigned MaxFixpointIterations =
5878
5879 AttributorConfig AC(CGUpdater);
5881 AC.IsModulePass = false;
5882 AC.RewriteSignatures = false;
5883 AC.MaxFixpointIterations = MaxFixpointIterations;
5884 AC.OREGetter = OREGetter;
5885 AC.PassName = DEBUG_TYPE;
5886 AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5887
5888 Attributor A(Functions, InfoCache, AC);
5889
5890 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5891 bool Changed = OMPOpt.run(false);
5892
5894 LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
5895
5896 if (Changed)
5897 return PreservedAnalyses::none();
5898
5899 return PreservedAnalyses::all();
5900}
5901
5903 return Fn.hasFnAttribute("kernel");
5904}
5905
5907 // TODO: Create a more cross-platform way of determining device kernels.
5908 NamedMDNode *MD = M.getNamedMetadata("nvvm.annotations");
5909 KernelSet Kernels;
5910
5911 if (!MD)
5912 return Kernels;
5913
5914 for (auto *Op : MD->operands()) {
5915 if (Op->getNumOperands() < 2)
5916 continue;
5917 MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
5918 if (!KindID || KindID->getString() != "kernel")
5919 continue;
5920
5921 Function *KernelFn =
5922 mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
5923 if (!KernelFn)
5924 continue;
5925
5926 // We are only interested in OpenMP target regions. Others, such as kernels
5927 // generated by CUDA but linked together, are not interesting to this pass.
5928 if (isOpenMPKernel(*KernelFn)) {
5929 ++NumOpenMPTargetRegionKernels;
5930 Kernels.insert(KernelFn);
5931 } else
5932 ++NumNonOpenMPTargetRegionKernels;
5933 }
5934
5935 return Kernels;
5936}
5937
5939 Metadata *MD = M.getModuleFlag("openmp");
5940 if (!MD)
5941 return false;
5942
5943 return true;
5944}
5945
5947 Metadata *MD = M.getModuleFlag("openmp-device");
5948 if (!MD)
5949 return false;
5950
5951 return true;
5952}
@ Generic
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
static cl::opt< unsigned > SetFixpointIterations("attributor-max-iterations", cl::Hidden, cl::desc("Maximal number of fixpoint iterations."), cl::init(32))
static const Function * getParent(const Value *V)
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
This file provides interfaces used to manipulate a call graph, regardless if it is a "old style" Call...
This file provides interfaces used to build and manipulate a call graph, which is a very useful tool ...
This file contains the declarations for the subclasses of Constant, which represent the different fla...
dxil pretty DXIL Metadata Pretty Printer
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(...)
Definition: Debug.h:106
bool End
Definition: ELF_riscv.cpp:480
This file defines an array type that can be indexed using scoped enum values.
#define DEBUG_TYPE
static void emitRemark(const Function &F, OptimizationRemarkEmitter &ORE, bool Skip)
static bool lookup(const GsymReader &GR, DataExtractor &Data, uint64_t &Offset, uint64_t BaseAddr, uint64_t Addr, SourceLocations &SrcLocs, llvm::Error &Err)
A Lookup helper functions.
Definition: InlineInfo.cpp:108
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
This file provides utility analysis objects describing memory locations.
uint64_t IntrinsicInst * II
This file defines constans and helpers used when dealing with OpenMP.
This file defines constans that will be used by both host and device compilation.
static constexpr auto TAG
Definition: OpenMPOpt.cpp:184
static cl::opt< bool > HideMemoryTransferLatency("openmp-hide-memory-transfer-latency", cl::desc("[WIP] Tries to hide the latency of host to device memory" " transfers"), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptStateMachineRewrite("openmp-opt-disable-state-machine-rewrite", cl::desc("Disable OpenMP optimizations that replace the state machine."), cl::Hidden, cl::init(false))
static cl::opt< bool > EnableParallelRegionMerging("openmp-opt-enable-merging", cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintModuleAfterOptimizations("openmp-opt-print-module-after", cl::desc("Print the current module after OpenMP optimizations."), cl::Hidden, cl::init(false))
#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER)
Definition: OpenMPOpt.cpp:240
#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX)
Definition: OpenMPOpt.cpp:217
#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER)
Definition: OpenMPOpt.cpp:3676
static cl::opt< bool > PrintOpenMPKernels("openmp-print-gpu-kernels", cl::init(false), cl::Hidden)
static cl::opt< bool > DisableOpenMPOptFolding("openmp-opt-disable-folding", cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintModuleBeforeOptimizations("openmp-opt-print-module-before", cl::desc("Print the current module before OpenMP optimizations."), cl::Hidden, cl::init(false))
static cl::opt< unsigned > SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden, cl::desc("Maximal number of attributor iterations."), cl::init(256))
static cl::opt< bool > DisableInternalization("openmp-opt-disable-internalization", cl::desc("Disable function internalization."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintICVValues("openmp-print-icv-values", cl::init(false), cl::Hidden)
static cl::opt< bool > DisableOpenMPOptimizations("openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."), cl::Hidden, cl::init(false))
static cl::opt< unsigned > SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden, cl::desc("Maximum amount of shared memory to use."), cl::init(std::numeric_limits< unsigned >::max()))
static cl::opt< bool > EnableVerboseRemarks("openmp-opt-verbose-remarks", cl::desc("Enables more verbose remarks."), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptDeglobalization("openmp-opt-disable-deglobalization", cl::desc("Disable OpenMP optimizations involving deglobalization."), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptBarrierElimination("openmp-opt-disable-barrier-elimination", cl::desc("Disable OpenMP optimizations that eliminate barriers."), cl::Hidden, cl::init(false))
#define DEBUG_TYPE
Definition: OpenMPOpt.cpp:66
static cl::opt< bool > DeduceICVValues("openmp-deduce-icv-values", cl::init(false), cl::Hidden)
#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX)
Definition: OpenMPOpt.cpp:209
#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE)
Definition: OpenMPOpt.cpp:230
static cl::opt< bool > DisableOpenMPOptSPMDization("openmp-opt-disable-spmdization", cl::desc("Disable OpenMP optimizations involving SPMD-ization."), cl::Hidden, cl::init(false))
static cl::opt< bool > AlwaysInlineDeviceFunctions("openmp-opt-inline-device", cl::desc("Inline all applicable functions on the device."), cl::Hidden, cl::init(false))
#define P(N)
FunctionAnalysisManager FAM
if(PassOpts->AAPipeline)
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
static StringRef getName(Value *V)
R600 Clause Merge
Basic Register Allocator
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static bool contains(SmallPtrSetImpl< ConstantExpr * > &Cache, ConstantExpr *Expr, Constant *C)
Definition: Value.cpp:469
This file implements a set that has insertion order iteration characteristics.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:166
This file contains some functions that are useful when dealing with strings.
static const int BlockSize
Definition: TarWriter.cpp:33
static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, ArrayRef< StringLiteral > StandardNames)
Initialize the set of available library functions based on the specified target triple.
Value * RHS
static cl::opt< unsigned > MaxThreads("xcore-max-threads", cl::Optional, cl::desc("Maximum number of threads (for emulation thread-local storage)"), cl::Hidden, cl::value_desc("number"), cl::init(8))
AbstractCallSite.
This class represents a conversion between pointers from one address space to another.
an instruction to allocate memory on the stack
Definition: Instructions.h:63
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:410
This class represents an incoming formal argument to a Function.
Definition: Argument.h:31
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
size_t size() const
size - Get the array size.
Definition: ArrayRef.h:168
AttributeSet getParamAttrs(unsigned ArgNo) const
The attributes for the argument or parameter at the given index are returned.
LLVM Basic Block Representation.
Definition: BasicBlock.h:61
iterator end()
Definition: BasicBlock.h:461
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:448
const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
Definition: BasicBlock.cpp:416
reverse_iterator rbegin()
Definition: BasicBlock.h:464
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition: BasicBlock.h:212
BasicBlock * splitBasicBlock(iterator I, const Twine &BBName="", bool Before=false)
Split the basic block into two basic blocks at the specified instruction.
Definition: BasicBlock.cpp:577
const BasicBlock * getUniqueSuccessor() const
Return the successor of this block if it has a unique successor.
Definition: BasicBlock.cpp:497
InstListType::reverse_iterator reverse_iterator
Definition: BasicBlock.h:179
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:219
reverse_iterator rend()
Definition: BasicBlock.h:466
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:239
Conditional or Unconditional Branch instruction.
bool isConditional() const
static BranchInst * Create(BasicBlock *IfTrue, InsertPosition InsertBefore=nullptr)
BasicBlock * getSuccessor(unsigned i) const
Value * getCondition() const
Allocate memory in an ever growing pool, as if by bump-pointer.
Definition: Allocator.h:66
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
Definition: InstrTypes.h:1120
void setCallingConv(CallingConv::ID CC)
Definition: InstrTypes.h:1411
bool arg_empty() const
Definition: InstrTypes.h:1291
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1349
bool doesNotAccessMemory(unsigned OpNo) const
Definition: InstrTypes.h:1718
bool isIndirectCall() const
Return true if the callsite is an indirect call.
bool isCallee(Value::const_user_iterator UI) const
Determine whether the passed iterator points to the callee operand's Use.
Definition: InstrTypes.h:1360
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1294
void setArgOperand(unsigned i, Value *v)
Definition: InstrTypes.h:1299
iterator_range< User::op_iterator > args()
Iteration adapter for range-for loops.
Definition: InstrTypes.h:1285
unsigned getArgOperandNo(const Use *U) const
Given a use for a arg operand, get the arg operand number that corresponds to it.
Definition: InstrTypes.h:1325
unsigned arg_size() const
Definition: InstrTypes.h:1292
AttributeList getAttributes() const
Return the attributes for this call.
Definition: InstrTypes.h:1425
void addParamAttr(unsigned ArgNo, Attribute::AttrKind Kind)
Adds the attribute to the indicated argument.
Definition: InstrTypes.h:1502
bool isArgOperand(const Use *U) const
Definition: InstrTypes.h:1314
bool hasOperandBundles() const
Return true if this User has any operand bundles.
Definition: InstrTypes.h:1969
Function * getCaller()
Helper to get the caller (the parent function).
Wrapper to unify "old style" CallGraph and "new style" LazyCallGraph.
void initialize(LazyCallGraph &LCG, LazyCallGraph::SCC &SCC, CGSCCAnalysisManager &AM, CGSCCUpdateResult &UR)
Initializers for usage outside of a CGSCC pass, inside a CGSCC pass in the old and new pass manager (...
This class represents a function call, abstracting a target machine's calling convention.
static CallInst * Create(FunctionType *Ty, Value *F, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
@ ICMP_SLT
signed less than
Definition: InstrTypes.h:702
@ ICMP_EQ
equal
Definition: InstrTypes.h:694
@ ICMP_NE
not equal
Definition: InstrTypes.h:695
static Constant * getPointerCast(Constant *C, Type *Ty)
Create a BitCast, AddrSpaceCast, or a PtrToInt cast constant expression.
Definition: Constants.cpp:2253
static Constant * getPointerBitCastOrAddrSpaceCast(Constant *C, Type *Ty)
Create a BitCast or AddrSpaceCast for a pointer type depending on the address space.
Definition: Constants.cpp:2268
This is the shared class of boolean and integer constants.
Definition: Constants.h:83
IntegerType * getIntegerType() const
Variant of the getType() method to always return an IntegerType, which reduces the amount of casting ...
Definition: Constants.h:187
static ConstantInt * getTrue(LLVMContext &Context)
Definition: Constants.cpp:866
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition: Constants.h:208
int64_t getSExtValue() const
Return the constant as a 64-bit integer value after it has been sign extended as appropriate for the ...
Definition: Constants.h:163
This is an important base class in LLVM.
Definition: Constant.h:42
static Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
Definition: Constants.cpp:373
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:63
A debug info location.
Definition: DebugLoc.h:33
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition: DenseMap.h:194
iterator begin()
Definition: DenseMap.h:75
iterator end()
Definition: DenseMap.h:84
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:211
Implements a dense probed hash-table based set.
Definition: DenseSet.h:278
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:279
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:162
Instruction * findNearestCommonDominator(Instruction *I1, Instruction *I2) const
Find the nearest instruction I that dominates both I1 and I2, in the sense that a result produced bef...
Definition: Dominators.cpp:344
static ErrorSuccess success()
Create a success value.
Definition: Error.h:337
An instruction for ordering other memory operations.
Definition: Instructions.h:424
AtomicOrdering getOrdering() const
Returns the ordering constraint of this fence instruction.
Definition: Instructions.h:449
A proxy from a FunctionAnalysisManager to an SCC.
A handy container for a FunctionType+Callee-pointer pair, which can be passed around as a single enti...
Definition: DerivedTypes.h:170
const BasicBlock & getEntryBlock() const
Definition: Function.h:809
const BasicBlock & front() const
Definition: Function.h:860
LLVMContext & getContext() const
getContext - Return a reference to the LLVMContext associated with this function.
Definition: Function.cpp:369
Argument * getArg(unsigned i) const
Definition: Function.h:886
bool hasFnAttribute(Attribute::AttrKind Kind) const
Return true if the function has the attribute.
Definition: Function.cpp:731
bool isDeclaration() const
Return true if the primary definition of this global value is outside of the current translation unit...
Definition: Globals.cpp:296
bool hasLocalLinkage() const
Definition: GlobalValue.h:528
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:656
@ PrivateLinkage
Like Internal, but omit from symbol table.
Definition: GlobalValue.h:60
@ InternalLinkage
Rename collisions when linking (static functions).
Definition: GlobalValue.h:59
const Constant * getInitializer() const
getInitializer - Return the initializer for this global variable.
void setInitializer(Constant *InitVal)
setInitializer - Sets the initializer for this global variable, removing any existing initializer if ...
Definition: Globals.cpp:492
InsertPoint - A saved insertion point.
Definition: IRBuilder.h:276
BasicBlock * getBlock() const
Definition: IRBuilder.h:291
AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")
Definition: IRBuilder.h:1780
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition: IRBuilder.h:199
Value * CreateAddrSpaceCast(Value *V, Type *DestTy, const Twine &Name="")
Definition: IRBuilder.h:2156
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2704
An analysis over an "outer" IR unit that provides access to an analysis manager over an "inner" IR un...
Definition: PassManager.h:567
bool isLifetimeStartOrEnd() const LLVM_READONLY
Return true if the instruction is a llvm.lifetime.start or llvm.lifetime.end marker.
bool mayWriteToMemory() const LLVM_READONLY
Return true if this instruction may modify memory.
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
Definition: Instruction.h:475
const Module * getModule() const
Return the module owning the function this instruction belongs to or nullptr it the function does not...
Definition: Instruction.cpp:68
const Instruction * getPrevNonDebugInstruction(bool SkipPseudoOp=false) const
Return a pointer to the previous non-debug instruction in the same basic block as 'this',...
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:94
const Function * getFunction() const
Return the function this instruction belongs to.
Definition: Instruction.cpp:72
bool mayHaveSideEffects() const LLVM_READONLY
Return true if the instruction may have side effects.
bool mayReadFromMemory() const LLVM_READONLY
Return true if this instruction may read memory.
const Instruction * getNextNonDebugInstruction(bool SkipPseudoOp=false) const
Return a pointer to the next non-debug instruction in the same basic block as 'this',...
void setDebugLoc(DebugLoc Loc)
Set the debug location information for this instruction.
Definition: Instruction.h:472
void setSuccessor(unsigned Idx, BasicBlock *BB)
Update the specified successor to point at the provided block.
void moveBefore(Instruction *MovePos)
Unlink this instruction from its current basic block and insert it into the basic block that MovePos ...
A node in the call graph.
An SCC of the call graph.
A lazily constructed view of the call graph of a module.
An instruction for reading from memory.
Definition: Instructions.h:176
A single uniqued string.
Definition: Metadata.h:720
StringRef getString() const
Definition: Metadata.cpp:616
void eraseFromParent()
This method unlinks 'this' from the containing function and deletes it.
Root of the metadata hierarchy.
Definition: Metadata.h:62
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
const std::string & getTargetTriple() const
Get the target triple which is a string describing the target host.
Definition: Module.h:298
MutableArrayRef - Represent a mutable reference to an array (0 or more elements consecutively in memo...
Definition: ArrayRef.h:310
A tuple of MDNodes.
Definition: Metadata.h:1733
iterator_range< op_iterator > operands()
Definition: Metadata.h:1829
An interface to create LLVM-IR for OpenMP directives.
Definition: OMPIRBuilder.h:474
static std::pair< int32_t, int32_t > readThreadBoundsForKernel(const Triple &T, Function &Kernel)
}
void addAttributes(omp::RuntimeFunction FnID, Function &Fn)
Add attributes known for FnID to Fn.
IRBuilder<>::InsertPoint InsertPointTy
Type used throughout for insertion points.
Definition: OMPIRBuilder.h:520
static std::pair< int32_t, int32_t > readTeamBoundsForKernel(const Triple &T, Function &Kernel)
Read/write a bounds on teams for Kernel.
PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR)
Definition: OpenMPOpt.cpp:5831
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM)
Definition: OpenMPOpt.cpp:5718
Diagnostic information for optimization analysis remarks.
The optimization diagnostic interface.
Diagnostic information for missed-optimization remarks.
Diagnostic information for applied optimization remarks.
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1878
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Definition: Analysis.h:114
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:117
static ReturnInst * Create(LLVMContext &C, Value *retVal=nullptr, InsertPosition InsertBefore=nullptr)
A vector that has set insertion semantics.
Definition: SetVector.h:57
size_type size() const
Determine the number of elements in the SetVector.
Definition: SetVector.h:98
const value_type & back() const
Return the last element of the SetVector.
Definition: SetVector.h:149
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
Definition: SetVector.h:264
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:162
size_type size() const
Definition: SmallPtrSet.h:94
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:363
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
Definition: SmallPtrSet.h:452
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:384
iterator begin() const
Definition: SmallPtrSet.h:472
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:519
A SetVector that performs no allocations if smaller than a certain size.
Definition: SetVector.h:370
bool empty() const
Definition: SmallVector.h:81
size_t size() const
Definition: SmallVector.h:78
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:573
void assign(size_type NumElts, ValueParamT Elt)
Definition: SmallVector.h:704
reference emplace_back(ArgTypes &&... Args)
Definition: SmallVector.h:937
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:683
void push_back(const T &Elt)
Definition: SmallVector.h:413
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1196
An instruction for storing to memory.
Definition: Instructions.h:292
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:51
bool starts_with(StringRef Prefix) const
Check if this string starts with the given Prefix.
Definition: StringRef.h:265
Triple - Helper class for working with autoconf configuration names.
Definition: Triple.h:44
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
unsigned getPointerAddressSpace() const
Get the address space of this pointer or pointer vector type.
static Type * getVoidTy(LLVMContext &C)
static IntegerType * getInt16Ty(LLVMContext &C)
static IntegerType * getInt8Ty(LLVMContext &C)
static IntegerType * getInt32Ty(LLVMContext &C)
static UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
Definition: Constants.cpp:1859
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Definition: User.cpp:21
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
void setName(const Twine &Name)
Change the name of the value.
Definition: Value.cpp:377
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:534
iterator_range< user_iterator > users()
Definition: Value.h:421
User * user_back()
Definition: Value.h:407
const Value * stripPointerCasts() const
Strip off pointer casts, all-zero GEPs and address space casts.
Definition: Value.cpp:694
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
Definition: ilist_node.h:32
self_iterator getIterator()
Definition: ilist_node.h:132
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition: ilist_node.h:353
A raw_ostream that writes to an std::string.
Definition: raw_ostream.h:661
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
GlobalVariable * getKernelEnvironementGVFromKernelInitCB(CallBase *KernelInitCB)
Definition: OpenMPOpt.cpp:259
ConstantStruct * getKernelEnvironementFromKernelInitCB(CallBase *KernelInitCB)
Definition: OpenMPOpt.cpp:266
bool isValidAtPosition(const ValueAndContext &VAC, InformationCache &InfoCache)
Return true if the value of VAC is a valid at the position of VAC, that is a constant,...
Definition: Attributor.cpp:290
bool isPotentiallyAffectedByBarrier(Attributor &A, const Instruction &I, const AbstractAttribute &QueryingAA)
Return true if I is potentially affected by a barrier.
Definition: Attributor.cpp:889
@ Interprocedural
Definition: Attributor.h:182
bool isNoSyncInst(Attributor &A, const Instruction &I, const AbstractAttribute &QueryingAA)
Return true if I is a nosync instruction.
Definition: Attributor.cpp:205
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
E & operator^=(E &LHS, E RHS)
Definition: BitmaskEnum.h:195
@ Entry
Definition: COFF.h:844
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:443
PointerTypeMap run(const Module &M)
Compute the PointerTypeMap for the module M.
bool isOpenMPDevice(Module &M)
Helper to determine if M is a OpenMP target offloading device module.
Definition: OpenMPOpt.cpp:5946
bool containsOpenMP(Module &M)
Helper to determine if M contains OpenMP.
Definition: OpenMPOpt.cpp:5938
InternalControlVar
IDs for all Internal Control Variables (ICVs).
Definition: OMPConstants.h:26
RuntimeFunction
IDs for all omp runtime library (RTL) functions.
Definition: OMPConstants.h:45
KernelSet getDeviceKernels(Module &M)
Get OpenMP device kernels in M.
Definition: OpenMPOpt.cpp:5906
Function * Kernel
Summary of a kernel (=entry point for target offloading).
Definition: OpenMPOpt.h:21
@ OMP_TGT_EXEC_MODE_GENERIC_SPMD
bool isOpenMPKernel(Function &Fn)
Return true iff Fn is an OpenMP GPU kernel; Fn has the "kernel" attribute.
Definition: OpenMPOpt.cpp:5902
DiagnosticInfoOptimizationBase::Argument NV
const_iterator begin(StringRef path, Style style=Style::native)
Get begin iterator over path.
Definition: Path.cpp:226
const_iterator end(StringRef path)
Get end iterator over path.
Definition: Path.cpp:235
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition: STLExtras.h:329
@ Offset
Definition: DWP.cpp:480
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1739
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition: STLExtras.h:1697
bool succ_empty(const Instruction *I)
Definition: CFG.h:255
std::string to_string(const T &Value)
Definition: ScopedPrinter.h:85
bool operator!=(uint64_t V1, const APInt &V2)
Definition: APInt.h:2082
Value * GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, const DataLayout &DL, bool AllowNonInbounds=true)
Analyze the specified pointer to see if it can be expressed as a base pointer plus a constant offset.
const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)
This method strips off any GEP address adjustments, pointer casts or llvm.threadlocal....
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
@ CGSCC
Definition: Attributor.h:6477
@ FullLTOPostLink
Full LTO postlink (backend compile) phase.
@ ThinLTOPreLink
ThinLTO prelink (summary) phase.
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
void cantFail(Error Err, const char *Msg=nullptr)
Report a fatal error if Err is a failure value.
Definition: Error.h:756
bool operator&=(SparseBitVector< ElementSize > *LHS, const SparseBitVector< ElementSize > &RHS)
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition: STLExtras.h:1938
BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
auto predecessors(const MachineBasicBlock *BB)
ChangeStatus
{
Definition: Attributor.h:489
const char * toString(DWARFSectionKind Kind)
Constant * ConstantFoldInsertValueInstruction(Constant *Agg, Constant *Val, ArrayRef< unsigned > Idxs)
ConstantFoldInsertValueInstruction - Attempt to constant fold an insertvalue instruction with the spe...
Implement std::hash so that hash_code can be used in STL containers.
Definition: BitVector.h:858
#define N
An abstract interface for address space information.
Definition: Attributor.h:6285
An abstract attribute for getting assumption information.
Definition: Attributor.h:6211
An abstract state for querying live call edges.
Definition: Attributor.h:5486
virtual const SetVector< Function * > & getOptimisticEdges() const =0
Get the optimistic edges.
bool isExecutedByInitialThreadOnly(const Instruction &I) const
Check if an instruction is executed only by the initial thread.
Definition: Attributor.h:5646
static AAExecutionDomain & createForPosition(const IRPosition &IRP, Attributor &A)
Create an abstract attribute view for the position IRP.
Definition: OpenMPOpt.cpp:5634
virtual ExecutionDomainTy getFunctionExecutionDomain() const =0
virtual ExecutionDomainTy getExecutionDomain(const BasicBlock &) const =0
virtual bool isExecutedInAlignedRegion(Attributor &A, const Instruction &I) const =0
Check if the instruction I is executed in an aligned region, that is, the synchronizing effects befor...
virtual bool isNoOpFence(const FenceInst &FI) const =0
Helper function to determine if FI is a no-op given the information about its execution from ExecDoma...
static const char ID
Unique ID (due to the unique address)
Definition: Attributor.h:5677
An abstract interface for indirect call information interference.
Definition: Attributor.h:6404
An abstract interface for liveness abstract attribute.
Definition: Attributor.h:3976
An abstract interface for all memory location attributes (readnone/argmemonly/inaccessiblememonly/ina...
Definition: Attributor.h:4704
AccessKind
Simple enum to distinguish read/write/read-write accesses.
Definition: Attributor.h:4840
StateType::base_t MemoryLocationsKind
Definition: Attributor.h:4705
static bool isAlignedBarrier(const CallBase &CB, bool ExecutedAligned)
Helper function to determine if CB is an aligned (GPU) barrier.
An abstract Attribute for determining the necessity of the convergent attribute.
Definition: Attributor.h:5722
An abstract attribute for getting all assumption underlying objects.
Definition: Attributor.h:6243
Base struct for all "concrete attribute" deductions.
Definition: Attributor.h:3284
virtual ChangeStatus manifest(Attributor &A)
Hook for the Attributor to trigger the manifestation of the information represented by the abstract a...
Definition: Attributor.h:3399
virtual void initialize(Attributor &A)
Initialize the state with the information in the Attributor A.
Definition: Attributor.h:3348
virtual const std::string getAsStr(Attributor *A) const =0
This function should return the "summarized" assumed state as string.
virtual ChangeStatus updateImpl(Attributor &A)=0
The actual update/transfer function which has to be implemented by the derived classes.
virtual void trackStatistics() const =0
Hook to enable custom statistic tracking, called after manifest that resulted in a change if statisti...
virtual const char * getIdAddr() const =0
This function should return the address of the ID of the AbstractAttribute.
An interface to query the internal state of an abstract attribute.
Definition: Attributor.h:2604
virtual ChangeStatus indicatePessimisticFixpoint()=0
Indicate that the abstract state should converge to the pessimistic state.
virtual bool isAtFixpoint() const =0
Return if this abstract state is fixed, thus does not need to be updated if information changes as it...
virtual bool isValidState() const =0
Return if this abstract state is in a valid state.
virtual ChangeStatus indicateOptimisticFixpoint()=0
Indicate that the abstract state should converge to the optimistic state.
Wrapper for FunctionAnalysisManager.
Definition: Attributor.h:1127
Configuration for the Attributor.
Definition: Attributor.h:1422
std::function< void(Attributor &A, const Function &F)> InitializationCallback
Callback function to be invoked on internal functions marked live.
Definition: Attributor.h:1453
std::optional< unsigned > MaxFixpointIterations
Maximum number of iterations to run until fixpoint.
Definition: Attributor.h:1469
bool RewriteSignatures
Flag to determine if we rewrite function signatures.
Definition: Attributor.h:1439
const char * PassName
}
Definition: Attributor.h:1479
OptimizationRemarkGetter OREGetter
Definition: Attributor.h:1475
IPOAmendableCBTy IPOAmendableCB
Definition: Attributor.h:1482
bool IsModulePass
Is the user of the Attributor a module pass or not.
Definition: Attributor.h:1433
bool DefaultInitializeLiveInternals
Flag to determine if we want to initialize all default AAs for an internal function marked live.
Definition: Attributor.h:1443
The fixpoint analysis framework that orchestrates the attribute deduction.
Definition: Attributor.h:1516
static bool isInternalizable(Function &F)
Returns true if the function F can be internalized.
std::function< std::optional< Constant * >(const GlobalVariable &, const AbstractAttribute *, bool &)> GlobalVariableSimplifictionCallbackTy
Register CB as a simplification callback.
Definition: Attributor.h:2033
std::function< bool(Attributor &, const AbstractAttribute *)> VirtualUseCallbackTy
Definition: Attributor.h:2063
static bool internalizeFunctions(SmallPtrSetImpl< Function * > &FnSet, DenseMap< Function *, Function * > &FnMap)
Make copies of each function in the set FnSet such that the copied version has internal linkage after...
std::function< std::optional< Value * >(const IRPosition &, const AbstractAttribute *, bool &)> SimplifictionCallbackTy
Register CB as a simplification callback.
Definition: Attributor.h:2017
Simple wrapper for a single bit (boolean) state.
Definition: Attributor.h:2888
Support structure for SCC passes to communicate updates the call graph back to the CGSCC pass manager...
Helper to describe and deal with positions in the LLVM-IR.
Definition: Attributor.h:586
static const IRPosition callsite_returned(const CallBase &CB)
Create a position describing the returned value of CB.
Definition: Attributor.h:654
static const IRPosition returned(const Function &F, const CallBaseContext *CBContext=nullptr)
Create a position describing the returned value of F.
Definition: Attributor.h:636
static const IRPosition value(const Value &V, const CallBaseContext *CBContext=nullptr)
Create a position describing the value of V.
Definition: Attributor.h:610
static const IRPosition inst(const Instruction &I, const CallBaseContext *CBContext=nullptr)
Create a position describing the instruction I.
Definition: Attributor.h:622
@ IRP_ARGUMENT
An attribute for a function argument.
Definition: Attributor.h:600
@ IRP_RETURNED
An attribute for the function return value.
Definition: Attributor.h:596
@ IRP_CALL_SITE
An attribute for a call site (function scope).
Definition: Attributor.h:599
@ IRP_CALL_SITE_RETURNED
An attribute for a call site return value.
Definition: Attributor.h:597
@ IRP_FUNCTION
An attribute for a function (scope).
Definition: Attributor.h:598
@ IRP_FLOAT
A position that is not associated with a spot suitable for attributes.
Definition: Attributor.h:594
@ IRP_CALL_SITE_ARGUMENT
An attribute for a call site argument.
Definition: Attributor.h:601
@ IRP_INVALID
An invalid position.
Definition: Attributor.h:593
static const IRPosition function(const Function &F, const CallBaseContext *CBContext=nullptr)
Create a position describing the function scope of F.
Definition: Attributor.h:629
Kind getPositionKind() const
Return the associated position kind.
Definition: Attributor.h:882
static const IRPosition callsite_function(const CallBase &CB)
Create a position describing the function scope of CB.
Definition: Attributor.h:649
Function * getAnchorScope() const
Return the Function surrounding the anchor value.
Definition: Attributor.h:758
Data structure to hold cached (LLVM-IR) information.
Definition: Attributor.h:1203
bool isValidState() const override
See AbstractState::isValidState() NOTE: For now we simply pretend that the worst possible state is in...
Definition: Attributor.h:2663
ChangeStatus indicatePessimisticFixpoint() override
See AbstractState::indicatePessimisticFixpoint(...)
Definition: Attributor.h:2675
Direction
An enum for the direction of the loop.
Definition: LoopInfo.h:215
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Definition: Alignment.h:117
Description of a LLVM-IR insertion point (IP) and a debug/source location (filename,...
Definition: OMPIRBuilder.h:645
Helper to tie a abstract state implementation to an abstract attribute.
Definition: Attributor.h:3173
StateType & getState() override
See AbstractAttribute::getState(...).
Definition: Attributor.h:3181
Defines various target-specific GPU grid values that must be consistent between host RTL (plugin),...
Definition: OMPGridValues.h:57