LLVM 23.0.0git
OpenMPOpt.cpp
Go to the documentation of this file.
1//===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// OpenMP specific optimizations:
10//
11// - Deduplication of runtime calls, e.g., omp_get_thread_num.
12// - Replacing globalized device memory with stack memory.
13// - Replacing globalized device memory with shared memory.
14// - Parallel region merging.
15// - Transforming generic-mode device kernels to SPMD mode.
16// - Specializing the state machine for generic-mode device kernels.
17//
18//===----------------------------------------------------------------------===//
19
21
22#include "llvm/ADT/DenseSet.h"
25#include "llvm/ADT/SetVector.h"
28#include "llvm/ADT/Statistic.h"
30#include "llvm/ADT/StringRef.h"
38#include "llvm/IR/Assumptions.h"
39#include "llvm/IR/BasicBlock.h"
40#include "llvm/IR/Constants.h"
42#include "llvm/IR/Dominators.h"
43#include "llvm/IR/Function.h"
44#include "llvm/IR/GlobalValue.h"
46#include "llvm/IR/InstrTypes.h"
47#include "llvm/IR/Instruction.h"
50#include "llvm/IR/IntrinsicsAMDGPU.h"
51#include "llvm/IR/IntrinsicsNVPTX.h"
52#include "llvm/IR/LLVMContext.h"
55#include "llvm/Support/Debug.h"
59
60#include <algorithm>
61#include <optional>
62#include <string>
63
64using namespace llvm;
65using namespace omp;
66
67#define DEBUG_TYPE "openmp-opt"
68
70 "openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."),
71 cl::Hidden, cl::init(false));
72
74 "openmp-opt-enable-merging",
75 cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
76 cl::init(false));
77
78static cl::opt<bool>
79 DisableInternalization("openmp-opt-disable-internalization",
80 cl::desc("Disable function internalization."),
81 cl::Hidden, cl::init(false));
82
83static cl::opt<bool> DeduceICVValues("openmp-deduce-icv-values",
84 cl::init(false), cl::Hidden);
85static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
87static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
88 cl::init(false), cl::Hidden);
89
91 "openmp-hide-memory-transfer-latency",
92 cl::desc("[WIP] Tries to hide the latency of host to device memory"
93 " transfers"),
94 cl::Hidden, cl::init(false));
95
97 "openmp-opt-disable-deglobalization",
98 cl::desc("Disable OpenMP optimizations involving deglobalization."),
99 cl::Hidden, cl::init(false));
100
102 "openmp-opt-disable-spmdization",
103 cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
104 cl::Hidden, cl::init(false));
105
107 "openmp-opt-disable-folding",
108 cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
109 cl::init(false));
110
112 "openmp-opt-disable-state-machine-rewrite",
113 cl::desc("Disable OpenMP optimizations that replace the state machine."),
114 cl::Hidden, cl::init(false));
115
117 "openmp-opt-disable-barrier-elimination",
118 cl::desc("Disable OpenMP optimizations that eliminate barriers."),
119 cl::Hidden, cl::init(false));
120
122 "openmp-opt-print-module-after",
123 cl::desc("Print the current module after OpenMP optimizations."),
124 cl::Hidden, cl::init(false));
125
127 "openmp-opt-print-module-before",
128 cl::desc("Print the current module before OpenMP optimizations."),
129 cl::Hidden, cl::init(false));
130
132 "openmp-opt-inline-device",
133 cl::desc("Inline all applicable functions on the device."), cl::Hidden,
134 cl::init(false));
135
136static cl::opt<bool>
137 EnableVerboseRemarks("openmp-opt-verbose-remarks",
138 cl::desc("Enables more verbose remarks."), cl::Hidden,
139 cl::init(false));
140
142 SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden,
143 cl::desc("Maximal number of attributor iterations."),
144 cl::init(256));
145
147 SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden,
148 cl::desc("Maximum amount of shared memory to use."),
149 cl::init(std::numeric_limits<unsigned>::max()));
150
151STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
152 "Number of OpenMP runtime calls deduplicated");
153STATISTIC(NumOpenMPParallelRegionsDeleted,
154 "Number of OpenMP parallel regions deleted");
155STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
156 "Number of OpenMP runtime functions identified");
157STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
158 "Number of OpenMP runtime function uses identified");
159STATISTIC(NumOpenMPTargetRegionKernels,
160 "Number of OpenMP target region entry points (=kernels) identified");
161STATISTIC(NumNonOpenMPTargetRegionKernels,
162 "Number of non-OpenMP target region kernels identified");
163STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
164 "Number of OpenMP target region entry points (=kernels) executed in "
165 "SPMD-mode instead of generic-mode");
166STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
167 "Number of OpenMP target region entry points (=kernels) executed in "
168 "generic-mode without a state machines");
169STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
170 "Number of OpenMP target region entry points (=kernels) executed in "
171 "generic-mode with customized state machines with fallback");
172STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
173 "Number of OpenMP target region entry points (=kernels) executed in "
174 "generic-mode with customized state machines without fallback");
176 NumOpenMPParallelRegionsReplacedInGPUStateMachine,
177 "Number of OpenMP parallel regions replaced with ID in GPU state machines");
178STATISTIC(NumOpenMPParallelRegionsMerged,
179 "Number of OpenMP parallel regions merged");
180STATISTIC(NumBytesMovedToSharedMemory,
181 "Amount of memory pushed to shared memory");
182STATISTIC(NumBarriersEliminated, "Number of redundant barriers eliminated");
183
184#if !defined(NDEBUG)
185static constexpr auto TAG = "[" DEBUG_TYPE "]";
186#endif
187
188namespace KernelInfo {
189
190// struct ConfigurationEnvironmentTy {
191// uint8_t UseGenericStateMachine;
192// uint8_t MayUseNestedParallelism;
193// llvm::omp::OMPTgtExecModeFlags ExecMode;
194// int32_t MinThreads;
195// int32_t MaxThreads;
196// int32_t MinTeams;
197// int32_t MaxTeams;
198// };
199
200// struct DynamicEnvironmentTy {
201// uint16_t DebugIndentionLevel;
202// };
203
204// struct KernelEnvironmentTy {
205// ConfigurationEnvironmentTy Configuration;
206// IdentTy *Ident;
207// DynamicEnvironmentTy *DynamicEnv;
208// };
209
210#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX) \
211 constexpr unsigned MEMBER##Idx = IDX;
212
213KERNEL_ENVIRONMENT_IDX(Configuration, 0)
215
216#undef KERNEL_ENVIRONMENT_IDX
217
218#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX) \
219 constexpr unsigned MEMBER##Idx = IDX;
220
221KERNEL_ENVIRONMENT_CONFIGURATION_IDX(UseGenericStateMachine, 0)
222KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MayUseNestedParallelism, 1)
228
229#undef KERNEL_ENVIRONMENT_CONFIGURATION_IDX
230
231#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE) \
232 RETURNTYPE *get##MEMBER##FromKernelEnvironment(ConstantStruct *KernelEnvC) { \
233 return cast<RETURNTYPE>(KernelEnvC->getAggregateElement(MEMBER##Idx)); \
234 }
235
238
239#undef KERNEL_ENVIRONMENT_GETTER
240
241#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER) \
242 ConstantInt *get##MEMBER##FromKernelEnvironment( \
243 ConstantStruct *KernelEnvC) { \
244 ConstantStruct *ConfigC = \
245 getConfigurationFromKernelEnvironment(KernelEnvC); \
246 return dyn_cast<ConstantInt>(ConfigC->getAggregateElement(MEMBER##Idx)); \
247 }
248
249KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(UseGenericStateMachine)
250KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MayUseNestedParallelism)
256
257#undef KERNEL_ENVIRONMENT_CONFIGURATION_GETTER
258
261 constexpr int InitKernelEnvironmentArgNo = 0;
263 KernelInitCB->getArgOperand(InitKernelEnvironmentArgNo)
265}
266
272} // namespace KernelInfo
273
274namespace {
275
276struct AAHeapToShared;
277
278struct AAICVTracker;
279
280/// OpenMP specific information. For now, stores RFIs and ICVs also needed for
281/// Attributor runs.
282struct OMPInformationCache : public InformationCache {
283 OMPInformationCache(Module &M, AnalysisGetter &AG,
284 BumpPtrAllocator &Allocator, SetVector<Function *> *CGSCC,
285 bool OpenMPPostLink)
286 : InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),
287 OpenMPPostLink(OpenMPPostLink) {
288
289 OMPBuilder.Config.IsTargetDevice = isOpenMPDevice(OMPBuilder.M);
290 const Triple T(OMPBuilder.M.getTargetTriple());
291 switch (T.getArch()) {
295 assert(OMPBuilder.Config.IsTargetDevice &&
296 "OpenMP AMDGPU/NVPTX is only prepared to deal with device code.");
297 OMPBuilder.Config.IsGPU = true;
298 break;
299 default:
300 OMPBuilder.Config.IsGPU = false;
301 break;
302 }
303 OMPBuilder.initialize();
304 initializeRuntimeFunctions(M);
305 initializeInternalControlVars();
306 }
307
308 /// Generic information that describes an internal control variable.
309 struct InternalControlVarInfo {
310 /// The kind, as described by InternalControlVar enum.
312
313 /// The name of the ICV.
314 StringRef Name;
315
316 /// Environment variable associated with this ICV.
317 StringRef EnvVarName;
318
319 /// Initial value kind.
320 ICVInitValue InitKind;
321
322 /// Initial value.
323 ConstantInt *InitValue;
324
325 /// Setter RTL function associated with this ICV.
326 RuntimeFunction Setter;
327
328 /// Getter RTL function associated with this ICV.
329 RuntimeFunction Getter;
330
331 /// RTL Function corresponding to the override clause of this ICV
332 RuntimeFunction Clause;
333 };
334
335 /// Generic information that describes a runtime function
336 struct RuntimeFunctionInfo {
337
338 /// The kind, as described by the RuntimeFunction enum.
339 RuntimeFunction Kind;
340
341 /// The name of the function.
342 StringRef Name;
343
344 /// Flag to indicate a variadic function.
345 bool IsVarArg;
346
347 /// The return type of the function.
348 Type *ReturnType;
349
350 /// The argument types of the function.
351 SmallVector<Type *, 8> ArgumentTypes;
352
353 /// The declaration if available.
354 Function *Declaration = nullptr;
355
356 /// Uses of this runtime function per function containing the use.
357 using UseVector = SmallVector<Use *, 16>;
358
359 /// Clear UsesMap for runtime function.
360 void clearUsesMap() { UsesMap.clear(); }
361
362 /// Boolean conversion that is true if the runtime function was found.
363 operator bool() const { return Declaration; }
364
365 /// Return the vector of uses in function \p F.
366 UseVector &getOrCreateUseVector(Function *F) {
367 std::shared_ptr<UseVector> &UV = UsesMap[F];
368 if (!UV)
369 UV = std::make_shared<UseVector>();
370 return *UV;
371 }
372
373 /// Return the vector of uses in function \p F or `nullptr` if there are
374 /// none.
375 const UseVector *getUseVector(Function &F) const {
376 auto I = UsesMap.find(&F);
377 if (I != UsesMap.end())
378 return I->second.get();
379 return nullptr;
380 }
381
382 /// Return how many functions contain uses of this runtime function.
383 size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
384
385 /// Return the number of arguments (or the minimal number for variadic
386 /// functions).
387 size_t getNumArgs() const { return ArgumentTypes.size(); }
388
389 /// Run the callback \p CB on each use and forget the use if the result is
390 /// true. The callback will be fed the function in which the use was
391 /// encountered as second argument.
392 void foreachUse(SmallVectorImpl<Function *> &SCC,
393 function_ref<bool(Use &, Function &)> CB) {
394 for (Function *F : SCC)
395 foreachUse(CB, F);
396 }
397
398 /// Run the callback \p CB on each use within the function \p F and forget
399 /// the use if the result is true.
400 void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
401 SmallVector<unsigned, 8> ToBeDeleted;
402 ToBeDeleted.clear();
403
404 unsigned Idx = 0;
405 UseVector &UV = getOrCreateUseVector(F);
406
407 for (Use *U : UV) {
408 if (CB(*U, *F))
409 ToBeDeleted.push_back(Idx);
410 ++Idx;
411 }
412
413 // Remove the to-be-deleted indices in reverse order as prior
414 // modifications will not modify the smaller indices.
415 while (!ToBeDeleted.empty()) {
416 unsigned Idx = ToBeDeleted.pop_back_val();
417 UV[Idx] = UV.back();
418 UV.pop_back();
419 }
420 }
421
422 private:
423 /// Map from functions to all uses of this runtime function contained in
424 /// them.
425 DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap;
426
427 public:
428 /// Iterators for the uses of this runtime function.
429 decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
430 decltype(UsesMap)::iterator end() { return UsesMap.end(); }
431 };
432
433 /// An OpenMP-IR-Builder instance
434 OpenMPIRBuilder OMPBuilder;
435
436 /// Map from runtime function kind to the runtime function description.
437 EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
438 RuntimeFunction::OMPRTL___last>
439 RFIs;
440
441 /// Map from function declarations/definitions to their runtime enum type.
442 DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
443
444 /// Map from ICV kind to the ICV description.
445 EnumeratedArray<InternalControlVarInfo, InternalControlVar,
446 InternalControlVar::ICV___last>
447 ICVs;
448
449 /// Helper to initialize all internal control variable information for those
450 /// defined in OMPKinds.def.
451 void initializeInternalControlVars() {
452#define ICV_RT_SET(_Name, RTL) \
453 { \
454 auto &ICV = ICVs[_Name]; \
455 ICV.Setter = RTL; \
456 }
457#define ICV_RT_GET(Name, RTL) \
458 { \
459 auto &ICV = ICVs[Name]; \
460 ICV.Getter = RTL; \
461 }
462#define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \
463 { \
464 auto &ICV = ICVs[Enum]; \
465 ICV.Name = _Name; \
466 ICV.Kind = Enum; \
467 ICV.InitKind = Init; \
468 ICV.EnvVarName = _EnvVarName; \
469 switch (ICV.InitKind) { \
470 case ICV_IMPLEMENTATION_DEFINED: \
471 ICV.InitValue = nullptr; \
472 break; \
473 case ICV_ZERO: \
474 ICV.InitValue = ConstantInt::get( \
475 Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \
476 break; \
477 case ICV_FALSE: \
478 ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \
479 break; \
480 case ICV_LAST: \
481 break; \
482 } \
483 }
484#include "llvm/Frontend/OpenMP/OMPKinds.def"
485 }
486
487 /// Returns true if the function declaration \p F matches the runtime
488 /// function types, that is, return type \p RTFRetType, and argument types
489 /// \p RTFArgTypes.
490 static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
491 SmallVector<Type *, 8> &RTFArgTypes) {
492 // TODO: We should output information to the user (under debug output
493 // and via remarks).
494
495 if (!F)
496 return false;
497 if (F->getReturnType() != RTFRetType)
498 return false;
499 if (F->arg_size() != RTFArgTypes.size())
500 return false;
501
502 auto *RTFTyIt = RTFArgTypes.begin();
503 for (Argument &Arg : F->args()) {
504 if (Arg.getType() != *RTFTyIt)
505 return false;
506
507 ++RTFTyIt;
508 }
509
510 return true;
511 }
512
513 // Helper to collect all uses of the declaration in the UsesMap.
514 unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
515 unsigned NumUses = 0;
516 if (!RFI.Declaration)
517 return NumUses;
518 OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
519
520 if (CollectStats) {
521 NumOpenMPRuntimeFunctionsIdentified += 1;
522 NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
523 }
524
525 // TODO: We directly convert uses into proper calls and unknown uses.
526 for (Use &U : RFI.Declaration->uses()) {
527 if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
528 if (!CGSCC || CGSCC->empty() || CGSCC->contains(UserI->getFunction())) {
529 RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
530 ++NumUses;
531 }
532 } else {
533 RFI.getOrCreateUseVector(nullptr).push_back(&U);
534 ++NumUses;
535 }
536 }
537 return NumUses;
538 }
539
540 // Helper function to recollect uses of a runtime function.
541 void recollectUsesForFunction(RuntimeFunction RTF) {
542 auto &RFI = RFIs[RTF];
543 RFI.clearUsesMap();
544 collectUses(RFI, /*CollectStats*/ false);
545 }
546
547 // Helper function to recollect uses of all runtime functions.
548 void recollectUses() {
549 for (int Idx = 0; Idx < RFIs.size(); ++Idx)
550 recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
551 }
552
553 // Helper function to inherit the calling convention of the function callee.
554 void setCallingConvention(FunctionCallee Callee, CallInst *CI) {
555 if (Function *Fn = dyn_cast<Function>(Callee.getCallee()))
556 CI->setCallingConv(Fn->getCallingConv());
557 }
558
559 // Helper function to determine if it's legal to create a call to the runtime
560 // functions.
561 bool runtimeFnsAvailable(ArrayRef<RuntimeFunction> Fns) {
562 // We can always emit calls if we haven't yet linked in the runtime.
563 if (!OpenMPPostLink)
564 return true;
565
566 // Once the runtime has been already been linked in we cannot emit calls to
567 // any undefined functions.
568 for (RuntimeFunction Fn : Fns) {
569 RuntimeFunctionInfo &RFI = RFIs[Fn];
570
571 if (!RFI.Declaration || RFI.Declaration->isDeclaration())
572 return false;
573 }
574 return true;
575 }
576
577 /// Helper to initialize all runtime function information for those defined
578 /// in OpenMPKinds.def.
579 void initializeRuntimeFunctions(Module &M) {
580
581 // Helper macros for handling __VA_ARGS__ in OMP_RTL
582#define OMP_TYPE(VarName, ...) \
583 Type *VarName = OMPBuilder.VarName; \
584 (void)VarName;
585
586#define OMP_ARRAY_TYPE(VarName, ...) \
587 ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \
588 (void)VarName##Ty; \
589 PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \
590 (void)VarName##PtrTy;
591
592#define OMP_FUNCTION_TYPE(VarName, ...) \
593 FunctionType *VarName = OMPBuilder.VarName; \
594 (void)VarName; \
595 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
596 (void)VarName##Ptr;
597
598#define OMP_STRUCT_TYPE(VarName, ...) \
599 StructType *VarName = OMPBuilder.VarName; \
600 (void)VarName; \
601 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
602 (void)VarName##Ptr;
603
604#define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \
605 { \
606 SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \
607 Function *F = M.getFunction(_Name); \
608 RTLFunctions.insert(F); \
609 if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \
610 RuntimeFunctionIDMap[F] = _Enum; \
611 auto &RFI = RFIs[_Enum]; \
612 RFI.Kind = _Enum; \
613 RFI.Name = _Name; \
614 RFI.IsVarArg = _IsVarArg; \
615 RFI.ReturnType = OMPBuilder._ReturnType; \
616 RFI.ArgumentTypes = std::move(ArgsTypes); \
617 RFI.Declaration = F; \
618 unsigned NumUses = collectUses(RFI); \
619 (void)NumUses; \
620 LLVM_DEBUG({ \
621 dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \
622 << " found\n"; \
623 if (RFI.Declaration) \
624 dbgs() << TAG << "-> got " << NumUses << " uses in " \
625 << RFI.getNumFunctionsWithUses() \
626 << " different functions.\n"; \
627 }); \
628 } \
629 }
630#include "llvm/Frontend/OpenMP/OMPKinds.def"
631
632 // Remove the `noinline` attribute from `__kmpc`, `ompx::` and `omp_`
633 // functions, except if `optnone` is present.
634 if (isOpenMPDevice(M)) {
635 for (Function &F : M) {
636 for (StringRef Prefix : {"__kmpc", "_ZN4ompx", "omp_"})
637 if (F.hasFnAttribute(Attribute::NoInline) &&
638 F.getName().starts_with(Prefix) &&
639 !F.hasFnAttribute(Attribute::OptimizeNone))
640 F.removeFnAttr(Attribute::NoInline);
641 }
642 }
643
644 // TODO: We should attach the attributes defined in OMPKinds.def.
645 }
646
647 /// Collection of known OpenMP runtime functions..
648 DenseSet<const Function *> RTLFunctions;
649
650 /// Indicates if we have already linked in the OpenMP device library.
651 bool OpenMPPostLink = false;
652};
653
654template <typename Ty, bool InsertInvalidates = true>
655struct BooleanStateWithSetVector : public BooleanState {
656 bool contains(const Ty &Elem) const { return Set.contains(Elem); }
657 bool insert(const Ty &Elem) {
658 if (InsertInvalidates)
659 BooleanState::indicatePessimisticFixpoint();
660 return Set.insert(Elem);
661 }
662
663 const Ty &operator[](int Idx) const { return Set[Idx]; }
664 bool operator==(const BooleanStateWithSetVector &RHS) const {
665 return BooleanState::operator==(RHS) && Set == RHS.Set;
666 }
667 bool operator!=(const BooleanStateWithSetVector &RHS) const {
668 return !(*this == RHS);
669 }
670
671 bool empty() const { return Set.empty(); }
672 size_t size() const { return Set.size(); }
673
674 /// "Clamp" this state with \p RHS.
675 BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
676 BooleanState::operator^=(RHS);
677 Set.insert_range(RHS.Set);
678 return *this;
679 }
680
681private:
682 /// A set to keep track of elements.
683 SetVector<Ty> Set;
684
685public:
686 typename decltype(Set)::iterator begin() { return Set.begin(); }
687 typename decltype(Set)::iterator end() { return Set.end(); }
688 typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
689 typename decltype(Set)::const_iterator end() const { return Set.end(); }
690};
691
692template <typename Ty, bool InsertInvalidates = true>
693using BooleanStateWithPtrSetVector =
694 BooleanStateWithSetVector<Ty *, InsertInvalidates>;
695
696struct KernelInfoState : AbstractState {
697 /// Flag to track if we reached a fixpoint.
698 bool IsAtFixpoint = false;
699
700 /// The parallel regions (identified by the outlined parallel functions) that
701 /// can be reached from the associated function.
702 BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
703 ReachedKnownParallelRegions;
704
705 /// State to track what parallel region we might reach.
706 BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
707
708 /// State to track if we are in SPMD-mode, assumed or know, and why we decided
709 /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
710 /// false.
711 BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
712
713 /// The __kmpc_target_init call in this kernel, if any. If we find more than
714 /// one we abort as the kernel is malformed.
715 CallBase *KernelInitCB = nullptr;
716
717 /// The constant kernel environement as taken from and passed to
718 /// __kmpc_target_init.
719 ConstantStruct *KernelEnvC = nullptr;
720
721 /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
722 /// one we abort as the kernel is malformed.
723 CallBase *KernelDeinitCB = nullptr;
724
725 /// Flag to indicate if the associated function is a kernel entry.
726 bool IsKernelEntry = false;
727
728 /// State to track what kernel entries can reach the associated function.
729 BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
730
731 /// State to indicate if we can track parallel level of the associated
732 /// function. We will give up tracking if we encounter unknown caller or the
733 /// caller is __kmpc_parallel_60.
734 BooleanStateWithSetVector<uint8_t> ParallelLevels;
735
736 /// Flag that indicates if the kernel has nested Parallelism
737 bool NestedParallelism = false;
738
739 /// Abstract State interface
740 ///{
741
742 KernelInfoState() = default;
743 KernelInfoState(bool BestState) {
744 if (!BestState)
745 indicatePessimisticFixpoint();
746 }
747
748 /// See AbstractState::isValidState(...)
749 bool isValidState() const override { return true; }
750
751 /// See AbstractState::isAtFixpoint(...)
752 bool isAtFixpoint() const override { return IsAtFixpoint; }
753
754 /// See AbstractState::indicatePessimisticFixpoint(...)
755 ChangeStatus indicatePessimisticFixpoint() override {
756 IsAtFixpoint = true;
757 ParallelLevels.indicatePessimisticFixpoint();
758 ReachingKernelEntries.indicatePessimisticFixpoint();
759 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
760 ReachedKnownParallelRegions.indicatePessimisticFixpoint();
761 ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
762 NestedParallelism = true;
763 return ChangeStatus::CHANGED;
764 }
765
766 /// See AbstractState::indicateOptimisticFixpoint(...)
767 ChangeStatus indicateOptimisticFixpoint() override {
768 IsAtFixpoint = true;
769 ParallelLevels.indicateOptimisticFixpoint();
770 ReachingKernelEntries.indicateOptimisticFixpoint();
771 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
772 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
773 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
774 return ChangeStatus::UNCHANGED;
775 }
776
777 /// Return the assumed state
778 KernelInfoState &getAssumed() { return *this; }
779 const KernelInfoState &getAssumed() const { return *this; }
780
781 bool operator==(const KernelInfoState &RHS) const {
782 if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
783 return false;
784 if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
785 return false;
786 if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
787 return false;
788 if (ReachingKernelEntries != RHS.ReachingKernelEntries)
789 return false;
790 if (ParallelLevels != RHS.ParallelLevels)
791 return false;
792 if (NestedParallelism != RHS.NestedParallelism)
793 return false;
794 return true;
795 }
796
797 /// Returns true if this kernel contains any OpenMP parallel regions.
798 bool mayContainParallelRegion() {
799 return !ReachedKnownParallelRegions.empty() ||
800 !ReachedUnknownParallelRegions.empty();
801 }
802
803 /// Return empty set as the best state of potential values.
804 static KernelInfoState getBestState() { return KernelInfoState(true); }
805
806 static KernelInfoState getBestState(KernelInfoState &KIS) {
807 return getBestState();
808 }
809
810 /// Return full set as the worst state of potential values.
811 static KernelInfoState getWorstState() { return KernelInfoState(false); }
812
813 /// "Clamp" this state with \p KIS.
814 KernelInfoState operator^=(const KernelInfoState &KIS) {
815 // Do not merge two different _init and _deinit call sites.
816 if (KIS.KernelInitCB) {
817 if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
818 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
819 "assumptions.");
820 KernelInitCB = KIS.KernelInitCB;
821 }
822 if (KIS.KernelDeinitCB) {
823 if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
824 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
825 "assumptions.");
826 KernelDeinitCB = KIS.KernelDeinitCB;
827 }
828 if (KIS.KernelEnvC) {
829 if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
830 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
831 "assumptions.");
832 KernelEnvC = KIS.KernelEnvC;
833 }
834 SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
835 ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
836 ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
837 NestedParallelism |= KIS.NestedParallelism;
838 return *this;
839 }
840
841 KernelInfoState operator&=(const KernelInfoState &KIS) {
842 return (*this ^= KIS);
843 }
844
845 ///}
846};
847
848/// Used to map the values physically (in the IR) stored in an offload
849/// array, to a vector in memory.
850struct OffloadArray {
851 /// Physical array (in the IR).
852 AllocaInst *Array = nullptr;
853 /// Mapped values.
854 SmallVector<Value *, 8> StoredValues;
855 /// Last stores made in the offload array.
856 SmallVector<StoreInst *, 8> LastAccesses;
857
858 OffloadArray() = default;
859
860 /// Initializes the OffloadArray with the values stored in \p Array before
861 /// instruction \p Before is reached. Returns false if the initialization
862 /// fails.
863 /// This MUST be used immediately after the construction of the object.
864 bool initialize(AllocaInst &Array, Instruction &Before) {
865 if (!getValues(Array, Before))
866 return false;
867
868 this->Array = &Array;
869 return true;
870 }
871
872 static const unsigned DeviceIDArgNum = 1;
873 static const unsigned BasePtrsArgNum = 3;
874 static const unsigned PtrsArgNum = 4;
875 static const unsigned SizesArgNum = 5;
876
877private:
878 /// Traverses the BasicBlock where \p Array is, collecting the stores made to
879 /// \p Array, leaving StoredValues with the values stored before the
880 /// instruction \p Before is reached.
881 bool getValues(AllocaInst &Array, Instruction &Before) {
882 // Initialize containers.
883 const DataLayout &DL = Array.getDataLayout();
884 std::optional<TypeSize> ArraySize = Array.getAllocationSize(DL);
885 if (!ArraySize || !ArraySize->isFixed())
886 return false;
887 const unsigned int PointerSize = DL.getPointerSize();
888 const uint64_t NumValues = ArraySize->getFixedValue() / PointerSize;
889 StoredValues.assign(NumValues, nullptr);
890 LastAccesses.assign(NumValues, nullptr);
891
892 // TODO: This assumes the instruction \p Before is in the same
893 // BasicBlock as Array. Make it general, for any control flow graph.
894 BasicBlock *BB = Array.getParent();
895 if (BB != Before.getParent())
896 return false;
897
898 for (Instruction &I : *BB) {
899 if (&I == &Before)
900 break;
901
902 if (!isa<StoreInst>(&I))
903 continue;
904
905 auto *S = cast<StoreInst>(&I);
906 int64_t Offset = -1;
907 auto *Dst =
908 GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
909 if (Dst == &Array) {
910 int64_t Idx = Offset / PointerSize;
911 // Ignore updates that must be UB (probably in dead code at runtime)
912 if ((uint64_t)Idx < NumValues) {
913 StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
914 LastAccesses[Idx] = S;
915 }
916 }
917 }
918
919 return isFilled();
920 }
921
922 /// Returns true if all values in StoredValues and
923 /// LastAccesses are not nullptrs.
924 bool isFilled() {
925 const unsigned NumValues = StoredValues.size();
926 for (unsigned I = 0; I < NumValues; ++I) {
927 if (!StoredValues[I] || !LastAccesses[I])
928 return false;
929 }
930
931 return true;
932 }
933};
934
935struct OpenMPOpt {
936
937 using OptimizationRemarkGetter =
938 function_ref<OptimizationRemarkEmitter &(Function *)>;
939
940 OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
941 OptimizationRemarkGetter OREGetter,
942 OMPInformationCache &OMPInfoCache, Attributor &A)
943 : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
944 OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
945
946 /// Check if any remarks are enabled for openmp-opt
947 bool remarksEnabled() {
948 auto &Ctx = M.getContext();
949 return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
950 }
951
952 /// Run all OpenMP optimizations on the underlying SCC.
953 bool run(bool IsModulePass) {
954 if (SCC.empty())
955 return false;
956
957 bool Changed = false;
958
959 LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
960 << " functions\n");
961
962 if (IsModulePass) {
963 Changed |= runAttributor(IsModulePass);
964
965 // Recollect uses, in case Attributor deleted any.
966 OMPInfoCache.recollectUses();
967
968 // TODO: This should be folded into buildCustomStateMachine.
969 Changed |= rewriteDeviceCodeStateMachine();
970
971 if (remarksEnabled())
972 analysisGlobalization();
973 } else {
974 if (PrintICVValues)
975 printICVs();
977 printKernels();
978
979 Changed |= runAttributor(IsModulePass);
980
981 // Recollect uses, in case Attributor deleted any.
982 OMPInfoCache.recollectUses();
983
984 Changed |= deleteParallelRegions();
985
987 Changed |= hideMemTransfersLatency();
988 Changed |= deduplicateRuntimeCalls();
990 if (mergeParallelRegions()) {
991 deduplicateRuntimeCalls();
992 Changed = true;
993 }
994 }
995 }
996
997 if (OMPInfoCache.OpenMPPostLink)
998 Changed |= removeRuntimeSymbols();
999
1000 return Changed;
1001 }
1002
1003 /// Print initial ICV values for testing.
1004 /// FIXME: This should be done from the Attributor once it is added.
1005 void printICVs() const {
1006 InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
1007 ICV_proc_bind};
1008
1009 for (Function *F : SCC) {
1010 for (auto ICV : ICVs) {
1011 auto ICVInfo = OMPInfoCache.ICVs[ICV];
1012 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1013 return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
1014 << " Value: "
1015 << (ICVInfo.InitValue
1016 ? toString(ICVInfo.InitValue->getValue(), 10, true)
1017 : "IMPLEMENTATION_DEFINED");
1018 };
1019
1020 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
1021 }
1022 }
1023 }
1024
1025 /// Print OpenMP GPU kernels for testing.
1026 void printKernels() const {
1027 for (Function *F : SCC) {
1028 if (!omp::isOpenMPKernel(*F))
1029 continue;
1030
1031 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1032 return ORA << "OpenMP GPU kernel "
1033 << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
1034 };
1035
1037 }
1038 }
1039
1040 /// Return the call if \p U is a callee use in a regular call. If \p RFI is
1041 /// given it has to be the callee or a nullptr is returned.
1042 static CallInst *getCallIfRegularCall(
1043 Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1044 CallInst *CI = dyn_cast<CallInst>(U.getUser());
1045 if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
1046 (!RFI ||
1047 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1048 return CI;
1049 return nullptr;
1050 }
1051
1052 /// Return the call if \p V is a regular call. If \p RFI is given it has to be
1053 /// the callee or a nullptr is returned.
1054 static CallInst *getCallIfRegularCall(
1055 Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1056 CallInst *CI = dyn_cast<CallInst>(&V);
1057 if (CI && !CI->hasOperandBundles() &&
1058 (!RFI ||
1059 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1060 return CI;
1061 return nullptr;
1062 }
1063
1064private:
1065 /// Merge parallel regions when it is safe.
1066 bool mergeParallelRegions() {
1067 const unsigned CallbackCalleeOperand = 2;
1068 const unsigned CallbackFirstArgOperand = 3;
1069 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
1070
1071 // Check if there are any __kmpc_fork_call calls to merge.
1072 OMPInformationCache::RuntimeFunctionInfo &RFI =
1073 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1074
1075 if (!RFI.Declaration)
1076 return false;
1077
1078 // Unmergable calls that prevent merging a parallel region.
1079 OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
1080 OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
1081 OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
1082 };
1083
1084 bool Changed = false;
1085 LoopInfo *LI = nullptr;
1086 DominatorTree *DT = nullptr;
1087
1088 SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap;
1089
1090 BasicBlock *StartBB = nullptr, *EndBB = nullptr;
1091 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
1092 ArrayRef<BasicBlock *> DeallocBlocks) {
1093 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1094 BasicBlock *CGEndBB =
1095 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1096 assert(StartBB != nullptr && "StartBB should not be null");
1097 CGStartBB->getTerminator()->setSuccessor(0, StartBB);
1098 assert(EndBB != nullptr && "EndBB should not be null");
1099 EndBB->getTerminator()->setSuccessor(0, CGEndBB);
1100 return Error::success();
1101 };
1102
1103 auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
1104 Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
1105 ReplacementValue = &Inner;
1106 return CodeGenIP;
1107 };
1108
1109 auto FiniCB = [&](InsertPointTy CodeGenIP) { return Error::success(); };
1110
1111 /// Create a sequential execution region within a merged parallel region,
1112 /// encapsulated in a master construct with a barrier for synchronization.
1113 auto CreateSequentialRegion = [&](Function *OuterFn,
1114 BasicBlock *OuterPredBB,
1115 Instruction *SeqStartI,
1116 Instruction *SeqEndI) {
1117 // Isolate the instructions of the sequential region to a separate
1118 // block.
1119 BasicBlock *ParentBB = SeqStartI->getParent();
1120 BasicBlock *SeqEndBB =
1121 SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
1122 BasicBlock *SeqAfterBB =
1123 SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
1124 BasicBlock *SeqStartBB =
1125 SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
1126
1127 assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
1128 "Expected a different CFG");
1129 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
1130 ParentBB->getTerminator()->eraseFromParent();
1131
1132 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
1133 ArrayRef<BasicBlock *> DeallocBlocks) {
1134 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1135 BasicBlock *CGEndBB =
1136 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1137 assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
1138 CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
1139 assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
1140 SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
1141 return Error::success();
1142 };
1143 auto FiniCB = [&](InsertPointTy CodeGenIP) { return Error::success(); };
1144
1145 // Find outputs from the sequential region to outside users and
1146 // broadcast their values to them.
1147 for (Instruction &I : *SeqStartBB) {
1148 SmallPtrSet<Instruction *, 4> OutsideUsers;
1149 for (User *Usr : I.users()) {
1150 Instruction &UsrI = *cast<Instruction>(Usr);
1151 // Ignore outputs to LT intrinsics, code extraction for the merged
1152 // parallel region will fix them.
1153 if (UsrI.isLifetimeStartOrEnd())
1154 continue;
1155
1156 if (UsrI.getParent() != SeqStartBB)
1157 OutsideUsers.insert(&UsrI);
1158 }
1159
1160 if (OutsideUsers.empty())
1161 continue;
1162
1163 // Emit an alloca in the outer region to store the broadcasted
1164 // value.
1165 const DataLayout &DL = M.getDataLayout();
1166 AllocaInst *AllocaI = new AllocaInst(
1167 I.getType(), DL.getAllocaAddrSpace(), nullptr,
1168 I.getName() + ".seq.output.alloc", OuterFn->front().begin());
1169
1170 // Emit a store instruction in the sequential BB to update the
1171 // value.
1172 new StoreInst(&I, AllocaI, SeqStartBB->getTerminator()->getIterator());
1173
1174 // Emit a load instruction and replace the use of the output value
1175 // with it.
1176 for (Instruction *UsrI : OutsideUsers) {
1177 LoadInst *LoadI = new LoadInst(I.getType(), AllocaI,
1178 I.getName() + ".seq.output.load",
1179 UsrI->getIterator());
1180 UsrI->replaceUsesOfWith(&I, LoadI);
1181 }
1182 }
1183
1184 OpenMPIRBuilder::LocationDescription Loc(
1185 InsertPointTy(ParentBB, ParentBB->end()), DL);
1187 OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB));
1188 cantFail(
1189 OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel));
1190
1191 UncondBrInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
1192
1193 LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
1194 << "\n");
1195 };
1196
1197 // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
1198 // contained in BB and only separated by instructions that can be
1199 // redundantly executed in parallel. The block BB is split before the first
1200 // call (in MergableCIs) and after the last so the entire region we merge
1201 // into a single parallel region is contained in a single basic block
1202 // without any other instructions. We use the OpenMPIRBuilder to outline
1203 // that block and call the resulting function via __kmpc_fork_call.
1204 auto Merge = [&](const SmallVectorImpl<CallInst *> &MergableCIs,
1205 BasicBlock *BB) {
1206 // TODO: Change the interface to allow single CIs expanded, e.g, to
1207 // include an outer loop.
1208 assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
1209
1210 auto Remark = [&](OptimizationRemark OR) {
1211 OR << "Parallel region merged with parallel region"
1212 << (MergableCIs.size() > 2 ? "s" : "") << " at ";
1213 for (auto *CI : llvm::drop_begin(MergableCIs)) {
1214 OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
1215 if (CI != MergableCIs.back())
1216 OR << ", ";
1217 }
1218 return OR << ".";
1219 };
1220
1221 emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
1222
1223 Function *OriginalFn = BB->getParent();
1224 LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
1225 << " parallel regions in " << OriginalFn->getName()
1226 << "\n");
1227
1228 // Isolate the calls to merge in a separate block.
1229 EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
1230 BasicBlock *AfterBB =
1231 SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1232 StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
1233 "omp.par.merged");
1234
1235 assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
1236 const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1237 BB->getTerminator()->eraseFromParent();
1238
1239 // Create sequential regions for sequential instructions that are
1240 // in-between mergable parallel regions.
1241 for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1242 It != End; ++It) {
1243 Instruction *ForkCI = *It;
1244 Instruction *NextForkCI = *(It + 1);
1245
1246 // Continue if there are not in-between instructions.
1247 if (ForkCI->getNextNode() == NextForkCI)
1248 continue;
1249
1250 CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1251 NextForkCI->getPrevNode());
1252 }
1253
1254 OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1255 DL);
1256 IRBuilder<>::InsertPoint AllocaIP(
1257 &OriginalFn->getEntryBlock(),
1258 OriginalFn->getEntryBlock().getFirstInsertionPt());
1259 // Create the merged parallel region with default proc binding, to
1260 // avoid overriding binding settings, and without explicit cancellation.
1262 cantFail(OMPInfoCache.OMPBuilder.createParallel(
1263 Loc, AllocaIP, /* DeallocBlocks */ {}, BodyGenCB, PrivCB, FiniCB,
1264 nullptr, nullptr, OMP_PROC_BIND_default,
1265 /* IsCancellable */ false));
1266 UncondBrInst::Create(AfterBB, AfterIP.getBlock());
1267
1268 // Perform the actual outlining.
1269 OMPInfoCache.OMPBuilder.finalize(OriginalFn);
1270
1271 Function *OutlinedFn = MergableCIs.front()->getCaller();
1272
1273 // Replace the __kmpc_fork_call calls with direct calls to the outlined
1274 // callbacks.
1275 SmallVector<Value *, 8> Args;
1276 for (auto *CI : MergableCIs) {
1277 Value *Callee = CI->getArgOperand(CallbackCalleeOperand);
1278 FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask;
1279 Args.clear();
1280 Args.push_back(OutlinedFn->getArg(0));
1281 Args.push_back(OutlinedFn->getArg(1));
1282 for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1283 ++U)
1284 Args.push_back(CI->getArgOperand(U));
1285
1286 CallInst *NewCI =
1287 CallInst::Create(FT, Callee, Args, "", CI->getIterator());
1288 if (CI->getDebugLoc())
1289 NewCI->setDebugLoc(CI->getDebugLoc());
1290
1291 // Forward parameter attributes from the callback to the callee.
1292 for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1293 ++U)
1294 for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
1295 NewCI->addParamAttr(
1296 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1297
1298 // Emit an explicit barrier to replace the implicit fork-join barrier.
1299 if (CI != MergableCIs.back()) {
1300 // TODO: Remove barrier if the merged parallel region includes the
1301 // 'nowait' clause.
1302 cantFail(OMPInfoCache.OMPBuilder.createBarrier(
1303 InsertPointTy(NewCI->getParent(),
1304 NewCI->getNextNode()->getIterator()),
1305 OMPD_parallel));
1306 }
1307
1308 CI->eraseFromParent();
1309 }
1310
1311 assert(OutlinedFn != OriginalFn && "Outlining failed");
1312 CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1313 CGUpdater.reanalyzeFunction(*OriginalFn);
1314
1315 NumOpenMPParallelRegionsMerged += MergableCIs.size();
1316
1317 return true;
1318 };
1319
1320 // Helper function that identifes sequences of
1321 // __kmpc_fork_call uses in a basic block.
1322 auto DetectPRsCB = [&](Use &U, Function &F) {
1323 CallInst *CI = getCallIfRegularCall(U, &RFI);
1324 BB2PRMap[CI->getParent()].insert(CI);
1325
1326 return false;
1327 };
1328
1329 BB2PRMap.clear();
1330 RFI.foreachUse(SCC, DetectPRsCB);
1331 SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1332 // Find mergable parallel regions within a basic block that are
1333 // safe to merge, that is any in-between instructions can safely
1334 // execute in parallel after merging.
1335 // TODO: support merging across basic-blocks.
1336 for (auto &It : BB2PRMap) {
1337 auto &CIs = It.getSecond();
1338 if (CIs.size() < 2)
1339 continue;
1340
1341 BasicBlock *BB = It.getFirst();
1342 SmallVector<CallInst *, 4> MergableCIs;
1343
1344 /// Returns true if the instruction is mergable, false otherwise.
1345 /// A terminator instruction is unmergable by definition since merging
1346 /// works within a BB. Instructions before the mergable region are
1347 /// mergable if they are not calls to OpenMP runtime functions that may
1348 /// set different execution parameters for subsequent parallel regions.
1349 /// Instructions in-between parallel regions are mergable if they are not
1350 /// calls to any non-intrinsic function since that may call a non-mergable
1351 /// OpenMP runtime function.
1352 auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1353 // We do not merge across BBs, hence return false (unmergable) if the
1354 // instruction is a terminator.
1355 if (I.isTerminator())
1356 return false;
1357
1358 if (!isa<CallInst>(&I))
1359 return true;
1360
1361 CallInst *CI = cast<CallInst>(&I);
1362 if (IsBeforeMergableRegion) {
1363 Function *CalledFunction = CI->getCalledFunction();
1364 if (!CalledFunction)
1365 return false;
1366 // Return false (unmergable) if the call before the parallel
1367 // region calls an explicit affinity (proc_bind) or number of
1368 // threads (num_threads) compiler-generated function. Those settings
1369 // may be incompatible with following parallel regions.
1370 // TODO: ICV tracking to detect compatibility.
1371 for (const auto &RFI : UnmergableCallsInfo) {
1372 if (CalledFunction == RFI.Declaration)
1373 return false;
1374 }
1375 } else {
1376 // Return false (unmergable) if there is a call instruction
1377 // in-between parallel regions when it is not an intrinsic. It
1378 // may call an unmergable OpenMP runtime function in its callpath.
1379 // TODO: Keep track of possible OpenMP calls in the callpath.
1380 if (!isa<IntrinsicInst>(CI))
1381 return false;
1382 }
1383
1384 return true;
1385 };
1386 // Find maximal number of parallel region CIs that are safe to merge.
1387 for (auto It = BB->begin(), End = BB->end(); It != End;) {
1388 Instruction &I = *It;
1389 ++It;
1390
1391 if (CIs.count(&I)) {
1392 MergableCIs.push_back(cast<CallInst>(&I));
1393 continue;
1394 }
1395
1396 // Continue expanding if the instruction is mergable.
1397 if (IsMergable(I, MergableCIs.empty()))
1398 continue;
1399
1400 // Forward the instruction iterator to skip the next parallel region
1401 // since there is an unmergable instruction which can affect it.
1402 for (; It != End; ++It) {
1403 Instruction &SkipI = *It;
1404 if (CIs.count(&SkipI)) {
1405 LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1406 << " due to " << I << "\n");
1407 ++It;
1408 break;
1409 }
1410 }
1411
1412 // Store mergable regions found.
1413 if (MergableCIs.size() > 1) {
1414 MergableCIsVector.push_back(MergableCIs);
1415 LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1416 << " parallel regions in block " << BB->getName()
1417 << " of function " << BB->getParent()->getName()
1418 << "\n";);
1419 }
1420
1421 MergableCIs.clear();
1422 }
1423
1424 if (!MergableCIsVector.empty()) {
1425 Changed = true;
1426
1427 for (auto &MergableCIs : MergableCIsVector)
1428 Merge(MergableCIs, BB);
1429 MergableCIsVector.clear();
1430 }
1431 }
1432
1433 if (Changed) {
1434 /// Re-collect use for fork calls, emitted barrier calls, and
1435 /// any emitted master/end_master calls.
1436 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1437 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1438 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1439 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1440 }
1441
1442 return Changed;
1443 }
1444
1445 /// Try to delete parallel regions if possible.
1446 bool deleteParallelRegions() {
1447 const unsigned CallbackCalleeOperand = 2;
1448
1449 OMPInformationCache::RuntimeFunctionInfo &RFI =
1450 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1451
1452 if (!RFI.Declaration)
1453 return false;
1454
1455 bool Changed = false;
1456 auto DeleteCallCB = [&](Use &U, Function &) {
1457 CallInst *CI = getCallIfRegularCall(U);
1458 if (!CI)
1459 return false;
1460 auto *Fn = dyn_cast<Function>(
1461 CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1462 if (!Fn)
1463 return false;
1464 if (!Fn->onlyReadsMemory())
1465 return false;
1466 if (!Fn->hasFnAttribute(Attribute::WillReturn))
1467 return false;
1468
1469 LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1470 << CI->getCaller()->getName() << "\n");
1471
1472 auto Remark = [&](OptimizationRemark OR) {
1473 return OR << "Removing parallel region with no side-effects.";
1474 };
1476
1477 CI->eraseFromParent();
1478 Changed = true;
1479 ++NumOpenMPParallelRegionsDeleted;
1480 return true;
1481 };
1482
1483 RFI.foreachUse(SCC, DeleteCallCB);
1484
1485 return Changed;
1486 }
1487
1488 /// Try to eliminate runtime calls by reusing existing ones.
1489 bool deduplicateRuntimeCalls() {
1490 bool Changed = false;
1491
1492 RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1493 OMPRTL_omp_get_num_threads,
1494 OMPRTL_omp_in_parallel,
1495 OMPRTL_omp_get_cancellation,
1496 OMPRTL_omp_get_supported_active_levels,
1497 OMPRTL_omp_get_level,
1498 OMPRTL_omp_get_ancestor_thread_num,
1499 OMPRTL_omp_get_team_size,
1500 OMPRTL_omp_get_active_level,
1501 OMPRTL_omp_in_final,
1502 OMPRTL_omp_get_proc_bind,
1503 OMPRTL_omp_get_num_places,
1504 OMPRTL_omp_get_num_procs,
1505 OMPRTL_omp_get_place_num,
1506 OMPRTL_omp_get_partition_num_places,
1507 OMPRTL_omp_get_partition_place_nums};
1508
1509 // Global-tid is handled separately.
1510 SmallSetVector<Value *, 16> GTIdArgs;
1511 collectGlobalThreadIdArguments(GTIdArgs);
1512 LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1513 << " global thread ID arguments\n");
1514
1515 for (Function *F : SCC) {
1516 for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1517 Changed |= deduplicateRuntimeCalls(
1518 *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1519
1520 // __kmpc_global_thread_num is special as we can replace it with an
1521 // argument in enough cases to make it worth trying.
1522 Value *GTIdArg = nullptr;
1523 for (Argument &Arg : F->args())
1524 if (GTIdArgs.count(&Arg)) {
1525 GTIdArg = &Arg;
1526 break;
1527 }
1528 Changed |= deduplicateRuntimeCalls(
1529 *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1530 }
1531
1532 return Changed;
1533 }
1534
1535 /// Tries to remove known runtime symbols that are optional from the module.
1536 bool removeRuntimeSymbols() {
1537 // The RPC client symbol is defined in `libc` and indicates that something
1538 // required an RPC server. If its users were all optimized out then we can
1539 // safely remove it.
1540 // TODO: This should be somewhere more common in the future.
1541 if (GlobalVariable *GV = M.getNamedGlobal("__llvm_rpc_client")) {
1542 if (GV->hasNUsesOrMore(1))
1543 return false;
1544
1545 GV->replaceAllUsesWith(PoisonValue::get(GV->getType()));
1546 GV->eraseFromParent();
1547 return true;
1548 }
1549 return false;
1550 }
1551
1552 /// Tries to hide the latency of runtime calls that involve host to
1553 /// device memory transfers by splitting them into their "issue" and "wait"
1554 /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1555 /// moved downards as much as possible. The "issue" issues the memory transfer
1556 /// asynchronously, returning a handle. The "wait" waits in the returned
1557 /// handle for the memory transfer to finish.
1558 bool hideMemTransfersLatency() {
1559 auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1560 bool Changed = false;
1561 auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1562 auto *RTCall = getCallIfRegularCall(U, &RFI);
1563 if (!RTCall)
1564 return false;
1565
1566 OffloadArray OffloadArrays[3];
1567 if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1568 return false;
1569
1570 LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1571
1572 // TODO: Check if can be moved upwards.
1573 bool WasSplit = false;
1574 Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1575 if (WaitMovementPoint)
1576 WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1577
1578 Changed |= WasSplit;
1579 return WasSplit;
1580 };
1581 if (OMPInfoCache.runtimeFnsAvailable(
1582 {OMPRTL___tgt_target_data_begin_mapper_issue,
1583 OMPRTL___tgt_target_data_begin_mapper_wait}))
1584 RFI.foreachUse(SCC, SplitMemTransfers);
1585
1586 return Changed;
1587 }
1588
1589 void analysisGlobalization() {
1590 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1591
1592 auto CheckGlobalization = [&](Use &U, Function &Decl) {
1593 if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1594 auto Remark = [&](OptimizationRemarkMissed ORM) {
1595 return ORM
1596 << "Found thread data sharing on the GPU. "
1597 << "Expect degraded performance due to data globalization.";
1598 };
1600 }
1601
1602 return false;
1603 };
1604
1605 RFI.foreachUse(SCC, CheckGlobalization);
1606 }
1607
1608 /// Maps the values stored in the offload arrays passed as arguments to
1609 /// \p RuntimeCall into the offload arrays in \p OAs.
1610 bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1612 assert(OAs.size() == 3 && "Need space for three offload arrays!");
1613
1614 // A runtime call that involves memory offloading looks something like:
1615 // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1616 // i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1617 // ...)
1618 // So, the idea is to access the allocas that allocate space for these
1619 // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1620 // Therefore:
1621 // i8** %offload_baseptrs.
1622 Value *BasePtrsArg =
1623 RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1624 // i8** %offload_ptrs.
1625 Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1626 // i8** %offload_sizes.
1627 Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1628
1629 // Get values stored in **offload_baseptrs.
1630 auto *V = getUnderlyingObject(BasePtrsArg);
1631 if (!isa<AllocaInst>(V))
1632 return false;
1633 auto *BasePtrsArray = cast<AllocaInst>(V);
1634 if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1635 return false;
1636
1637 // Get values stored in **offload_baseptrs.
1638 V = getUnderlyingObject(PtrsArg);
1639 if (!isa<AllocaInst>(V))
1640 return false;
1641 auto *PtrsArray = cast<AllocaInst>(V);
1642 if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1643 return false;
1644
1645 // Get values stored in **offload_sizes.
1646 V = getUnderlyingObject(SizesArg);
1647 // If it's a [constant] global array don't analyze it.
1648 if (isa<GlobalValue>(V))
1649 return isa<Constant>(V);
1650 if (!isa<AllocaInst>(V))
1651 return false;
1652
1653 auto *SizesArray = cast<AllocaInst>(V);
1654 if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1655 return false;
1656
1657 return true;
1658 }
1659
1660 /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1661 /// For now this is a way to test that the function getValuesInOffloadArrays
1662 /// is working properly.
1663 /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1664 void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1665 assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1666
1667 LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1668 std::string ValuesStr;
1669 raw_string_ostream Printer(ValuesStr);
1670 std::string Separator = " --- ";
1671
1672 for (auto *BP : OAs[0].StoredValues) {
1673 BP->print(Printer);
1674 Printer << Separator;
1675 }
1676 LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << ValuesStr << "\n");
1677 ValuesStr.clear();
1678
1679 for (auto *P : OAs[1].StoredValues) {
1680 P->print(Printer);
1681 Printer << Separator;
1682 }
1683 LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << ValuesStr << "\n");
1684 ValuesStr.clear();
1685
1686 for (auto *S : OAs[2].StoredValues) {
1687 S->print(Printer);
1688 Printer << Separator;
1689 }
1690 LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << ValuesStr << "\n");
1691 }
1692
1693 /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1694 /// moved. Returns nullptr if the movement is not possible, or not worth it.
1695 Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1696 // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1697 // Make it traverse the CFG.
1698
1699 Instruction *CurrentI = &RuntimeCall;
1700 bool IsWorthIt = false;
1701 while ((CurrentI = CurrentI->getNextNode())) {
1702
1703 // TODO: Once we detect the regions to be offloaded we should use the
1704 // alias analysis manager to check if CurrentI may modify one of
1705 // the offloaded regions.
1706 if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1707 if (IsWorthIt)
1708 return CurrentI;
1709
1710 return nullptr;
1711 }
1712
1713 // FIXME: For now if we move it over anything without side effect
1714 // is worth it.
1715 IsWorthIt = true;
1716 }
1717
1718 // Return end of BasicBlock.
1719 return RuntimeCall.getParent()->getTerminator();
1720 }
1721
1722 /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1723 bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1724 Instruction &WaitMovementPoint) {
1725 // Create stack allocated handle (__tgt_async_info) at the beginning of the
1726 // function. Used for storing information of the async transfer, allowing to
1727 // wait on it later.
1728 auto &IRBuilder = OMPInfoCache.OMPBuilder;
1729 Function *F = RuntimeCall.getCaller();
1730 BasicBlock &Entry = F->getEntryBlock();
1731 IRBuilder.Builder.SetInsertPoint(&Entry,
1732 Entry.getFirstNonPHIOrDbgOrAlloca());
1733 Value *Handle = IRBuilder.Builder.CreateAlloca(
1734 IRBuilder.AsyncInfo, /*ArraySize=*/nullptr, "handle");
1735 Handle =
1736 IRBuilder.Builder.CreateAddrSpaceCast(Handle, IRBuilder.AsyncInfoPtr);
1737
1738 // Add "issue" runtime call declaration:
1739 // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1740 // i8**, i8**, i64*, i64*)
1741 FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1742 M, OMPRTL___tgt_target_data_begin_mapper_issue);
1743
1744 // Change RuntimeCall call site for its asynchronous version.
1745 SmallVector<Value *, 16> Args;
1746 for (auto &Arg : RuntimeCall.args())
1747 Args.push_back(Arg.get());
1748 Args.push_back(Handle);
1749
1750 CallInst *IssueCallsite = CallInst::Create(IssueDecl, Args, /*NameStr=*/"",
1751 RuntimeCall.getIterator());
1752 OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);
1753 RuntimeCall.eraseFromParent();
1754
1755 // Add "wait" runtime call declaration:
1756 // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1757 FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1758 M, OMPRTL___tgt_target_data_begin_mapper_wait);
1759
1760 Value *WaitParams[2] = {
1761 IssueCallsite->getArgOperand(
1762 OffloadArray::DeviceIDArgNum), // device_id.
1763 Handle // handle to wait on.
1764 };
1765 CallInst *WaitCallsite = CallInst::Create(
1766 WaitDecl, WaitParams, /*NameStr=*/"", WaitMovementPoint.getIterator());
1767 OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);
1768
1769 return true;
1770 }
1771
1772 static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1773 bool GlobalOnly, bool &SingleChoice) {
1774 if (CurrentIdent == NextIdent)
1775 return CurrentIdent;
1776
1777 // TODO: Figure out how to actually combine multiple debug locations. For
1778 // now we just keep an existing one if there is a single choice.
1779 if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1780 SingleChoice = !CurrentIdent;
1781 return NextIdent;
1782 }
1783 return nullptr;
1784 }
1785
1786 /// Return an `struct ident_t*` value that represents the ones used in the
1787 /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1788 /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1789 /// return value we create one from scratch. We also do not yet combine
1790 /// information, e.g., the source locations, see combinedIdentStruct.
1791 Value *
1792 getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1793 Function &F, bool GlobalOnly) {
1794 bool SingleChoice = true;
1795 Value *Ident = nullptr;
1796 auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1797 CallInst *CI = getCallIfRegularCall(U, &RFI);
1798 if (!CI || &F != &Caller)
1799 return false;
1800 Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1801 /* GlobalOnly */ true, SingleChoice);
1802 return false;
1803 };
1804 RFI.foreachUse(SCC, CombineIdentStruct);
1805
1806 if (!Ident || !SingleChoice) {
1807 // The IRBuilder uses the insertion block to get to the module, this is
1808 // unfortunate but we work around it for now.
1809 if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1810 OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1811 &F.getEntryBlock(), F.getEntryBlock().begin()));
1812 // Create a fallback location if non was found.
1813 // TODO: Use the debug locations of the calls instead.
1814 uint32_t SrcLocStrSize;
1815 Constant *Loc =
1816 OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1817 Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize);
1818 }
1819 return Ident;
1820 }
1821
1822 /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1823 /// \p ReplVal if given.
1824 bool deduplicateRuntimeCalls(Function &F,
1825 OMPInformationCache::RuntimeFunctionInfo &RFI,
1826 Value *ReplVal = nullptr) {
1827 auto *UV = RFI.getUseVector(F);
1828 if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1829 return false;
1830
1831 LLVM_DEBUG(
1832 dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1833 << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1834
1835 assert((!ReplVal || (isa<Argument>(ReplVal) &&
1836 cast<Argument>(ReplVal)->getParent() == &F)) &&
1837 "Unexpected replacement value!");
1838
1839 // TODO: Use dominance to find a good position instead.
1840 auto CanBeMoved = [this](CallBase &CB) {
1841 unsigned NumArgs = CB.arg_size();
1842 if (NumArgs == 0)
1843 return true;
1844 if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1845 return false;
1846 for (unsigned U = 1; U < NumArgs; ++U)
1847 if (isa<Instruction>(CB.getArgOperand(U)))
1848 return false;
1849 return true;
1850 };
1851
1852 if (!ReplVal) {
1853 auto *DT =
1854 OMPInfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(F);
1855 if (!DT)
1856 return false;
1857 Instruction *IP = nullptr;
1858 for (Use *U : *UV) {
1859 if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1860 if (IP)
1861 IP = DT->findNearestCommonDominator(IP, CI);
1862 else
1863 IP = CI;
1864 if (!CanBeMoved(*CI))
1865 continue;
1866 if (!ReplVal)
1867 ReplVal = CI;
1868 }
1869 }
1870 if (!ReplVal)
1871 return false;
1872 assert(IP && "Expected insertion point!");
1873 cast<Instruction>(ReplVal)->moveBefore(IP->getIterator());
1874 }
1875
1876 // If we use a call as a replacement value we need to make sure the ident is
1877 // valid at the new location. For now we just pick a global one, either
1878 // existing and used by one of the calls, or created from scratch.
1879 if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1880 if (!CI->arg_empty() &&
1881 CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1882 Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1883 /* GlobalOnly */ true);
1884 CI->setArgOperand(0, Ident);
1885 }
1886 }
1887
1888 bool Changed = false;
1889 auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1890 CallInst *CI = getCallIfRegularCall(U, &RFI);
1891 if (!CI || CI == ReplVal || &F != &Caller)
1892 return false;
1893 assert(CI->getCaller() == &F && "Unexpected call!");
1894
1895 auto Remark = [&](OptimizationRemark OR) {
1896 return OR << "OpenMP runtime call "
1897 << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1898 };
1899 if (CI->getDebugLoc())
1901 else
1903
1904 CI->replaceAllUsesWith(ReplVal);
1905 CI->eraseFromParent();
1906 ++NumOpenMPRuntimeCallsDeduplicated;
1907 Changed = true;
1908 return true;
1909 };
1910 RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1911
1912 return Changed;
1913 }
1914
1915 /// Collect arguments that represent the global thread id in \p GTIdArgs.
1916 void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1917 // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1918 // initialization. We could define an AbstractAttribute instead and
1919 // run the Attributor here once it can be run as an SCC pass.
1920
1921 // Helper to check the argument \p ArgNo at all call sites of \p F for
1922 // a GTId.
1923 auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1924 if (!F.hasLocalLinkage())
1925 return false;
1926 for (Use &U : F.uses()) {
1927 if (CallInst *CI = getCallIfRegularCall(U)) {
1928 Value *ArgOp = CI->getArgOperand(ArgNo);
1929 if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1930 getCallIfRegularCall(
1931 *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1932 continue;
1933 }
1934 return false;
1935 }
1936 return true;
1937 };
1938
1939 // Helper to identify uses of a GTId as GTId arguments.
1940 auto AddUserArgs = [&](Value &GTId) {
1941 for (Use &U : GTId.uses())
1942 if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1943 if (CI->isArgOperand(&U))
1944 if (Function *Callee = CI->getCalledFunction())
1945 if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1946 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1947 };
1948
1949 // The argument users of __kmpc_global_thread_num calls are GTIds.
1950 OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1951 OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1952
1953 GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1954 if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1955 AddUserArgs(*CI);
1956 return false;
1957 });
1958
1959 // Transitively search for more arguments by looking at the users of the
1960 // ones we know already. During the search the GTIdArgs vector is extended
1961 // so we cannot cache the size nor can we use a range based for.
1962 for (unsigned U = 0; U < GTIdArgs.size(); ++U)
1963 AddUserArgs(*GTIdArgs[U]);
1964 }
1965
1966 /// Kernel (=GPU) optimizations and utility functions
1967 ///
1968 ///{{
1969
1970 /// Cache to remember the unique kernel for a function.
1971 DenseMap<Function *, std::optional<Kernel>> UniqueKernelMap;
1972
1973 /// Find the unique kernel that will execute \p F, if any.
1974 Kernel getUniqueKernelFor(Function &F);
1975
1976 /// Find the unique kernel that will execute \p I, if any.
1977 Kernel getUniqueKernelFor(Instruction &I) {
1978 return getUniqueKernelFor(*I.getFunction());
1979 }
1980
1981 /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1982 /// the cases we can avoid taking the address of a function.
1983 bool rewriteDeviceCodeStateMachine();
1984
1985 ///
1986 ///}}
1987
1988 /// Emit a remark generically
1989 ///
1990 /// This template function can be used to generically emit a remark. The
1991 /// RemarkKind should be one of the following:
1992 /// - OptimizationRemark to indicate a successful optimization attempt
1993 /// - OptimizationRemarkMissed to report a failed optimization attempt
1994 /// - OptimizationRemarkAnalysis to provide additional information about an
1995 /// optimization attempt
1996 ///
1997 /// The remark is built using a callback function provided by the caller that
1998 /// takes a RemarkKind as input and returns a RemarkKind.
1999 template <typename RemarkKind, typename RemarkCallBack>
2000 void emitRemark(Instruction *I, StringRef RemarkName,
2001 RemarkCallBack &&RemarkCB) const {
2002 Function *F = I->getParent()->getParent();
2003 auto &ORE = OREGetter(F);
2004
2005 if (RemarkName.starts_with("OMP"))
2006 ORE.emit([&]() {
2007 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
2008 << " [" << RemarkName << "]";
2009 });
2010 else
2011 ORE.emit(
2012 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
2013 }
2014
2015 /// Emit a remark on a function.
2016 template <typename RemarkKind, typename RemarkCallBack>
2017 void emitRemark(Function *F, StringRef RemarkName,
2018 RemarkCallBack &&RemarkCB) const {
2019 auto &ORE = OREGetter(F);
2020
2021 if (RemarkName.starts_with("OMP"))
2022 ORE.emit([&]() {
2023 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
2024 << " [" << RemarkName << "]";
2025 });
2026 else
2027 ORE.emit(
2028 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
2029 }
2030
2031 /// The underlying module.
2032 Module &M;
2033
2034 /// The SCC we are operating on.
2035 SmallVectorImpl<Function *> &SCC;
2036
2037 /// Callback to update the call graph, the first argument is a removed call,
2038 /// the second an optional replacement call.
2039 CallGraphUpdater &CGUpdater;
2040
2041 /// Callback to get an OptimizationRemarkEmitter from a Function *
2042 OptimizationRemarkGetter OREGetter;
2043
2044 /// OpenMP-specific information cache. Also Used for Attributor runs.
2045 OMPInformationCache &OMPInfoCache;
2046
2047 /// Attributor instance.
2048 Attributor &A;
2049
2050 /// Helper function to run Attributor on SCC.
2051 bool runAttributor(bool IsModulePass) {
2052 if (SCC.empty())
2053 return false;
2054
2055 registerAAs(IsModulePass);
2056
2057 ChangeStatus Changed = A.run();
2058
2059 LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
2060 << " functions, result: " << Changed << ".\n");
2061
2062 if (Changed == ChangeStatus::CHANGED)
2063 OMPInfoCache.invalidateAnalyses();
2064
2065 return Changed == ChangeStatus::CHANGED;
2066 }
2067
2068 void registerFoldRuntimeCall(RuntimeFunction RF);
2069
2070 /// Populate the Attributor with abstract attribute opportunities in the
2071 /// functions.
2072 void registerAAs(bool IsModulePass);
2073
2074public:
2075 /// Callback to register AAs for live functions, including internal functions
2076 /// marked live during the traversal.
2077 static void registerAAsForFunction(Attributor &A, const Function &F);
2078};
2079
2080Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
2081 if (OMPInfoCache.CGSCC && !OMPInfoCache.CGSCC->empty() &&
2082 !OMPInfoCache.CGSCC->contains(&F))
2083 return nullptr;
2084
2085 // Use a scope to keep the lifetime of the CachedKernel short.
2086 {
2087 std::optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
2088 if (CachedKernel)
2089 return *CachedKernel;
2090
2091 // TODO: We should use an AA to create an (optimistic and callback
2092 // call-aware) call graph. For now we stick to simple patterns that
2093 // are less powerful, basically the worst fixpoint.
2094 if (isOpenMPKernel(F)) {
2095 CachedKernel = Kernel(&F);
2096 return *CachedKernel;
2097 }
2098
2099 CachedKernel = nullptr;
2100 if (!F.hasLocalLinkage()) {
2101
2102 // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
2103 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2104 return ORA << "Potentially unknown OpenMP target region caller.";
2105 };
2107
2108 return nullptr;
2109 }
2110 }
2111
2112 auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
2113 if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
2114 // Allow use in equality comparisons.
2115 if (Cmp->isEquality())
2116 return getUniqueKernelFor(*Cmp);
2117 return nullptr;
2118 }
2119 if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
2120 // Allow direct calls.
2121 if (CB->isCallee(&U))
2122 return getUniqueKernelFor(*CB);
2123
2124 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2125 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_60];
2126 // Allow the use in __kmpc_parallel_60 calls.
2127 if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
2128 return getUniqueKernelFor(*CB);
2129 return nullptr;
2130 }
2131 // Disallow every other use.
2132 return nullptr;
2133 };
2134
2135 // TODO: In the future we want to track more than just a unique kernel.
2136 SmallPtrSet<Kernel, 2> PotentialKernels;
2137 OMPInformationCache::foreachUse(F, [&](const Use &U) {
2138 PotentialKernels.insert(GetUniqueKernelForUse(U));
2139 });
2140
2141 Kernel K = nullptr;
2142 if (PotentialKernels.size() == 1)
2143 K = *PotentialKernels.begin();
2144
2145 // Cache the result.
2146 UniqueKernelMap[&F] = K;
2147
2148 return K;
2149}
2150
2151bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
2152 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2153 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_60];
2154
2155 bool Changed = false;
2156 if (!KernelParallelRFI)
2157 return Changed;
2158
2159 // If we have disabled state machine changes, exit
2161 return Changed;
2162
2163 for (Function *F : SCC) {
2164
2165 // Check if the function is a use in a __kmpc_parallel_60 call at
2166 // all.
2167 bool UnknownUse = false;
2168 bool KernelParallelUse = false;
2169 unsigned NumDirectCalls = 0;
2170
2171 SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
2172 OMPInformationCache::foreachUse(*F, [&](Use &U) {
2173 if (auto *CB = dyn_cast<CallBase>(U.getUser()))
2174 if (CB->isCallee(&U)) {
2175 ++NumDirectCalls;
2176 return;
2177 }
2178
2179 if (isa<ICmpInst>(U.getUser())) {
2180 ToBeReplacedStateMachineUses.push_back(&U);
2181 return;
2182 }
2183
2184 // Find wrapper functions that represent parallel kernels.
2185 CallInst *CI =
2186 OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
2187 const unsigned int WrapperFunctionArgNo = 6;
2188 if (!KernelParallelUse && CI &&
2189 CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
2190 KernelParallelUse = true;
2191 ToBeReplacedStateMachineUses.push_back(&U);
2192 return;
2193 }
2194 UnknownUse = true;
2195 });
2196
2197 // Do not emit a remark if we haven't seen a __kmpc_parallel_60
2198 // use.
2199 if (!KernelParallelUse)
2200 continue;
2201
2202 // If this ever hits, we should investigate.
2203 // TODO: Checking the number of uses is not a necessary restriction and
2204 // should be lifted.
2205 if (UnknownUse || NumDirectCalls != 1 ||
2206 ToBeReplacedStateMachineUses.size() > 2) {
2207 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2208 return ORA << "Parallel region is used in "
2209 << (UnknownUse ? "unknown" : "unexpected")
2210 << " ways. Will not attempt to rewrite the state machine.";
2211 };
2213 continue;
2214 }
2215
2216 // Even if we have __kmpc_parallel_60 calls, we (for now) give
2217 // up if the function is not called from a unique kernel.
2218 Kernel K = getUniqueKernelFor(*F);
2219 if (!K) {
2220 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2221 return ORA << "Parallel region is not called from a unique kernel. "
2222 "Will not attempt to rewrite the state machine.";
2223 };
2225 continue;
2226 }
2227
2228 // We now know F is a parallel body function called only from the kernel K.
2229 // We also identified the state machine uses in which we replace the
2230 // function pointer by a new global symbol for identification purposes. This
2231 // ensures only direct calls to the function are left.
2232
2233 Module &M = *F->getParent();
2234 Type *Int8Ty = Type::getInt8Ty(M.getContext());
2235
2236 auto *ID = new GlobalVariable(
2237 M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2238 UndefValue::get(Int8Ty), F->getName() + ".ID");
2239
2240 for (Use *U : ToBeReplacedStateMachineUses)
2242 ID, U->get()->getType()));
2243
2244 ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2245
2246 Changed = true;
2247 }
2248
2249 return Changed;
2250}
2251
2252/// Abstract Attribute for tracking ICV values.
2253struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2254 using Base = StateWrapper<BooleanState, AbstractAttribute>;
2255 AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2256
2257 /// Returns true if value is assumed to be tracked.
2258 bool isAssumedTracked() const { return getAssumed(); }
2259
2260 /// Returns true if value is known to be tracked.
2261 bool isKnownTracked() const { return getAssumed(); }
2262
2263 /// Create an abstract attribute biew for the position \p IRP.
2264 static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2265
2266 /// Return the value with which \p I can be replaced for specific \p ICV.
2267 virtual std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2268 const Instruction *I,
2269 Attributor &A) const {
2270 return std::nullopt;
2271 }
2272
2273 /// Return an assumed unique ICV value if a single candidate is found. If
2274 /// there cannot be one, return a nullptr. If it is not clear yet, return
2275 /// std::nullopt.
2276 virtual std::optional<Value *>
2277 getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2278
2279 // Currently only nthreads is being tracked.
2280 // this array will only grow with time.
2281 InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2282
2283 /// See AbstractAttribute::getName()
2284 StringRef getName() const override { return "AAICVTracker"; }
2285
2286 /// See AbstractAttribute::getIdAddr()
2287 const char *getIdAddr() const override { return &ID; }
2288
2289 /// This function should return true if the type of the \p AA is AAICVTracker
2290 static bool classof(const AbstractAttribute *AA) {
2291 return (AA->getIdAddr() == &ID);
2292 }
2293
2294 static const char ID;
2295};
2296
2297struct AAICVTrackerFunction : public AAICVTracker {
2298 AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2299 : AAICVTracker(IRP, A) {}
2300
2301 // FIXME: come up with better string.
2302 const std::string getAsStr(Attributor *) const override {
2303 return "ICVTrackerFunction";
2304 }
2305
2306 // FIXME: come up with some stats.
2307 void trackStatistics() const override {}
2308
2309 /// We don't manifest anything for this AA.
2310 ChangeStatus manifest(Attributor &A) override {
2311 return ChangeStatus::UNCHANGED;
2312 }
2313
2314 // Map of ICV to their values at specific program point.
2315 EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar,
2316 InternalControlVar::ICV___last>
2317 ICVReplacementValuesMap;
2318
2319 ChangeStatus updateImpl(Attributor &A) override {
2320 ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
2321
2322 Function *F = getAnchorScope();
2323
2324 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2325
2326 for (InternalControlVar ICV : TrackableICVs) {
2327 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2328
2329 auto &ValuesMap = ICVReplacementValuesMap[ICV];
2330 auto TrackValues = [&](Use &U, Function &) {
2331 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2332 if (!CI)
2333 return false;
2334
2335 // FIXME: handle setters with more that 1 arguments.
2336 /// Track new value.
2337 if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2338 HasChanged = ChangeStatus::CHANGED;
2339
2340 return false;
2341 };
2342
2343 auto CallCheck = [&](Instruction &I) {
2344 std::optional<Value *> ReplVal = getValueForCall(A, I, ICV);
2345 if (ReplVal && ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2346 HasChanged = ChangeStatus::CHANGED;
2347
2348 return true;
2349 };
2350
2351 // Track all changes of an ICV.
2352 SetterRFI.foreachUse(TrackValues, F);
2353
2354 bool UsedAssumedInformation = false;
2355 A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2356 UsedAssumedInformation,
2357 /* CheckBBLivenessOnly */ true);
2358
2359 /// TODO: Figure out a way to avoid adding entry in
2360 /// ICVReplacementValuesMap
2361 Instruction *Entry = &F->getEntryBlock().front();
2362 if (HasChanged == ChangeStatus::CHANGED)
2363 ValuesMap.try_emplace(Entry);
2364 }
2365
2366 return HasChanged;
2367 }
2368
2369 /// Helper to check if \p I is a call and get the value for it if it is
2370 /// unique.
2371 std::optional<Value *> getValueForCall(Attributor &A, const Instruction &I,
2372 InternalControlVar &ICV) const {
2373
2374 const auto *CB = dyn_cast<CallBase>(&I);
2375 if (!CB || CB->hasFnAttr("no_openmp") ||
2376 CB->hasFnAttr("no_openmp_routines") ||
2377 CB->hasFnAttr("no_openmp_constructs"))
2378 return std::nullopt;
2379
2380 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2381 auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2382 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2383 Function *CalledFunction = CB->getCalledFunction();
2384
2385 // Indirect call, assume ICV changes.
2386 if (CalledFunction == nullptr)
2387 return nullptr;
2388 if (CalledFunction == GetterRFI.Declaration)
2389 return std::nullopt;
2390 if (CalledFunction == SetterRFI.Declaration) {
2391 if (ICVReplacementValuesMap[ICV].count(&I))
2392 return ICVReplacementValuesMap[ICV].lookup(&I);
2393
2394 return nullptr;
2395 }
2396
2397 // Since we don't know, assume it changes the ICV.
2398 if (CalledFunction->isDeclaration())
2399 return nullptr;
2400
2401 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2402 *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
2403
2404 if (ICVTrackingAA->isAssumedTracked()) {
2405 std::optional<Value *> URV =
2406 ICVTrackingAA->getUniqueReplacementValue(ICV);
2407 if (!URV || (*URV && AA::isValidAtPosition(AA::ValueAndContext(**URV, I),
2408 OMPInfoCache)))
2409 return URV;
2410 }
2411
2412 // If we don't know, assume it changes.
2413 return nullptr;
2414 }
2415
2416 // We don't check unique value for a function, so return std::nullopt.
2417 std::optional<Value *>
2418 getUniqueReplacementValue(InternalControlVar ICV) const override {
2419 return std::nullopt;
2420 }
2421
2422 /// Return the value with which \p I can be replaced for specific \p ICV.
2423 std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2424 const Instruction *I,
2425 Attributor &A) const override {
2426 const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2427 if (ValuesMap.count(I))
2428 return ValuesMap.lookup(I);
2429
2431 SmallPtrSet<const Instruction *, 16> Visited;
2432 Worklist.push_back(I);
2433
2434 std::optional<Value *> ReplVal;
2435
2436 while (!Worklist.empty()) {
2437 const Instruction *CurrInst = Worklist.pop_back_val();
2438 if (!Visited.insert(CurrInst).second)
2439 continue;
2440
2441 const BasicBlock *CurrBB = CurrInst->getParent();
2442
2443 // Go up and look for all potential setters/calls that might change the
2444 // ICV.
2445 while ((CurrInst = CurrInst->getPrevNode())) {
2446 if (ValuesMap.count(CurrInst)) {
2447 std::optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2448 // Unknown value, track new.
2449 if (!ReplVal) {
2450 ReplVal = NewReplVal;
2451 break;
2452 }
2453
2454 // If we found a new value, we can't know the icv value anymore.
2455 if (NewReplVal)
2456 if (ReplVal != NewReplVal)
2457 return nullptr;
2458
2459 break;
2460 }
2461
2462 std::optional<Value *> NewReplVal = getValueForCall(A, *CurrInst, ICV);
2463 if (!NewReplVal)
2464 continue;
2465
2466 // Unknown value, track new.
2467 if (!ReplVal) {
2468 ReplVal = NewReplVal;
2469 break;
2470 }
2471
2472 // if (NewReplVal.hasValue())
2473 // We found a new value, we can't know the icv value anymore.
2474 if (ReplVal != NewReplVal)
2475 return nullptr;
2476 }
2477
2478 // If we are in the same BB and we have a value, we are done.
2479 if (CurrBB == I->getParent() && ReplVal)
2480 return ReplVal;
2481
2482 // Go through all predecessors and add terminators for analysis.
2483 for (const BasicBlock *Pred : predecessors(CurrBB))
2484 if (const Instruction *Terminator = Pred->getTerminator())
2485 Worklist.push_back(Terminator);
2486 }
2487
2488 return ReplVal;
2489 }
2490};
2491
2492struct AAICVTrackerFunctionReturned : AAICVTracker {
2493 AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2494 : AAICVTracker(IRP, A) {}
2495
2496 // FIXME: come up with better string.
2497 const std::string getAsStr(Attributor *) const override {
2498 return "ICVTrackerFunctionReturned";
2499 }
2500
2501 // FIXME: come up with some stats.
2502 void trackStatistics() const override {}
2503
2504 /// We don't manifest anything for this AA.
2505 ChangeStatus manifest(Attributor &A) override {
2506 return ChangeStatus::UNCHANGED;
2507 }
2508
2509 // Map of ICV to their values at specific program point.
2510 EnumeratedArray<std::optional<Value *>, InternalControlVar,
2511 InternalControlVar::ICV___last>
2512 ICVReplacementValuesMap;
2513
2514 /// Return the value with which \p I can be replaced for specific \p ICV.
2515 std::optional<Value *>
2516 getUniqueReplacementValue(InternalControlVar ICV) const override {
2517 return ICVReplacementValuesMap[ICV];
2518 }
2519
2520 ChangeStatus updateImpl(Attributor &A) override {
2521 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2522 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2523 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2524
2525 if (!ICVTrackingAA->isAssumedTracked())
2526 return indicatePessimisticFixpoint();
2527
2528 for (InternalControlVar ICV : TrackableICVs) {
2529 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2530 std::optional<Value *> UniqueICVValue;
2531
2532 auto CheckReturnInst = [&](Instruction &I) {
2533 std::optional<Value *> NewReplVal =
2534 ICVTrackingAA->getReplacementValue(ICV, &I, A);
2535
2536 // If we found a second ICV value there is no unique returned value.
2537 if (UniqueICVValue && UniqueICVValue != NewReplVal)
2538 return false;
2539
2540 UniqueICVValue = NewReplVal;
2541
2542 return true;
2543 };
2544
2545 bool UsedAssumedInformation = false;
2546 if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2547 UsedAssumedInformation,
2548 /* CheckBBLivenessOnly */ true))
2549 UniqueICVValue = nullptr;
2550
2551 if (UniqueICVValue == ReplVal)
2552 continue;
2553
2554 ReplVal = UniqueICVValue;
2555 Changed = ChangeStatus::CHANGED;
2556 }
2557
2558 return Changed;
2559 }
2560};
2561
2562struct AAICVTrackerCallSite : AAICVTracker {
2563 AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2564 : AAICVTracker(IRP, A) {}
2565
2566 void initialize(Attributor &A) override {
2567 assert(getAnchorScope() && "Expected anchor function");
2568
2569 // We only initialize this AA for getters, so we need to know which ICV it
2570 // gets.
2571 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2572 for (InternalControlVar ICV : TrackableICVs) {
2573 auto ICVInfo = OMPInfoCache.ICVs[ICV];
2574 auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2575 if (Getter.Declaration == getAssociatedFunction()) {
2576 AssociatedICV = ICVInfo.Kind;
2577 return;
2578 }
2579 }
2580
2581 /// Unknown ICV.
2582 indicatePessimisticFixpoint();
2583 }
2584
2585 ChangeStatus manifest(Attributor &A) override {
2586 if (!ReplVal || !*ReplVal)
2587 return ChangeStatus::UNCHANGED;
2588
2589 A.changeAfterManifest(IRPosition::inst(*getCtxI()), **ReplVal);
2590 A.deleteAfterManifest(*getCtxI());
2591
2592 return ChangeStatus::CHANGED;
2593 }
2594
2595 // FIXME: come up with better string.
2596 const std::string getAsStr(Attributor *) const override {
2597 return "ICVTrackerCallSite";
2598 }
2599
2600 // FIXME: come up with some stats.
2601 void trackStatistics() const override {}
2602
2603 InternalControlVar AssociatedICV;
2604 std::optional<Value *> ReplVal;
2605
2606 ChangeStatus updateImpl(Attributor &A) override {
2607 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2608 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2609
2610 // We don't have any information, so we assume it changes the ICV.
2611 if (!ICVTrackingAA->isAssumedTracked())
2612 return indicatePessimisticFixpoint();
2613
2614 std::optional<Value *> NewReplVal =
2615 ICVTrackingAA->getReplacementValue(AssociatedICV, getCtxI(), A);
2616
2617 if (ReplVal == NewReplVal)
2618 return ChangeStatus::UNCHANGED;
2619
2620 ReplVal = NewReplVal;
2621 return ChangeStatus::CHANGED;
2622 }
2623
2624 // Return the value with which associated value can be replaced for specific
2625 // \p ICV.
2626 std::optional<Value *>
2627 getUniqueReplacementValue(InternalControlVar ICV) const override {
2628 return ReplVal;
2629 }
2630};
2631
2632struct AAICVTrackerCallSiteReturned : AAICVTracker {
2633 AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2634 : AAICVTracker(IRP, A) {}
2635
2636 // FIXME: come up with better string.
2637 const std::string getAsStr(Attributor *) const override {
2638 return "ICVTrackerCallSiteReturned";
2639 }
2640
2641 // FIXME: come up with some stats.
2642 void trackStatistics() const override {}
2643
2644 /// We don't manifest anything for this AA.
2645 ChangeStatus manifest(Attributor &A) override {
2646 return ChangeStatus::UNCHANGED;
2647 }
2648
2649 // Map of ICV to their values at specific program point.
2650 EnumeratedArray<std::optional<Value *>, InternalControlVar,
2651 InternalControlVar::ICV___last>
2652 ICVReplacementValuesMap;
2653
2654 /// Return the value with which associated value can be replaced for specific
2655 /// \p ICV.
2656 std::optional<Value *>
2657 getUniqueReplacementValue(InternalControlVar ICV) const override {
2658 return ICVReplacementValuesMap[ICV];
2659 }
2660
2661 ChangeStatus updateImpl(Attributor &A) override {
2662 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2663 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2664 *this, IRPosition::returned(*getAssociatedFunction()),
2665 DepClassTy::REQUIRED);
2666
2667 // We don't have any information, so we assume it changes the ICV.
2668 if (!ICVTrackingAA->isAssumedTracked())
2669 return indicatePessimisticFixpoint();
2670
2671 for (InternalControlVar ICV : TrackableICVs) {
2672 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2673 std::optional<Value *> NewReplVal =
2674 ICVTrackingAA->getUniqueReplacementValue(ICV);
2675
2676 if (ReplVal == NewReplVal)
2677 continue;
2678
2679 ReplVal = NewReplVal;
2680 Changed = ChangeStatus::CHANGED;
2681 }
2682 return Changed;
2683 }
2684};
2685
2686/// Determines if \p BB exits the function unconditionally itself or reaches a
2687/// block that does through only unique successors.
2688static bool hasFunctionEndAsUniqueSuccessor(const BasicBlock *BB) {
2689 if (succ_empty(BB))
2690 return true;
2691 const BasicBlock *const Successor = BB->getUniqueSuccessor();
2692 if (!Successor)
2693 return false;
2694 return hasFunctionEndAsUniqueSuccessor(Successor);
2695}
2696
2697struct AAExecutionDomainFunction : public AAExecutionDomain {
2698 AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2699 : AAExecutionDomain(IRP, A) {}
2700
2701 ~AAExecutionDomainFunction() override { delete RPOT; }
2702
2703 void initialize(Attributor &A) override {
2704 Function *F = getAnchorScope();
2705 assert(F && "Expected anchor function");
2706 RPOT = new ReversePostOrderTraversal<Function *>(F);
2707 }
2708
2709 const std::string getAsStr(Attributor *) const override {
2710 unsigned TotalBlocks = 0, InitialThreadBlocks = 0, AlignedBlocks = 0;
2711 for (auto &It : BEDMap) {
2712 if (!It.getFirst())
2713 continue;
2714 TotalBlocks++;
2715 InitialThreadBlocks += It.getSecond().IsExecutedByInitialThreadOnly;
2716 AlignedBlocks += It.getSecond().IsReachedFromAlignedBarrierOnly &&
2717 It.getSecond().IsReachingAlignedBarrierOnly;
2718 }
2719 return "[AAExecutionDomain] " + std::to_string(InitialThreadBlocks) + "/" +
2720 std::to_string(AlignedBlocks) + " of " +
2721 std::to_string(TotalBlocks) +
2722 " executed by initial thread / aligned";
2723 }
2724
2725 /// See AbstractAttribute::trackStatistics().
2726 void trackStatistics() const override {}
2727
2728 ChangeStatus manifest(Attributor &A) override {
2729 LLVM_DEBUG({
2730 for (const BasicBlock &BB : *getAnchorScope()) {
2731 if (!isExecutedByInitialThreadOnly(BB))
2732 continue;
2733 dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2734 << BB.getName() << " is executed by a single thread.\n";
2735 }
2736 });
2737
2738 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2739
2741 return Changed;
2742
2743 SmallPtrSet<CallBase *, 16> DeletedBarriers;
2744 auto HandleAlignedBarrier = [&](CallBase *CB) {
2745 const ExecutionDomainTy &ED = CB ? CEDMap[{CB, PRE}] : BEDMap[nullptr];
2746 if (!ED.IsReachedFromAlignedBarrierOnly ||
2747 ED.EncounteredNonLocalSideEffect)
2748 return;
2749 if (!ED.EncounteredAssumes.empty() && !A.isModulePass())
2750 return;
2751
2752 // We can remove this barrier, if it is one, or aligned barriers reaching
2753 // the kernel end (if CB is nullptr). Aligned barriers reaching the kernel
2754 // end should only be removed if the kernel end is their unique successor;
2755 // otherwise, they may have side-effects that aren't accounted for in the
2756 // kernel end in their other successors. If those barriers have other
2757 // barriers reaching them, those can be transitively removed as well as
2758 // long as the kernel end is also their unique successor.
2759 if (CB) {
2760 DeletedBarriers.insert(CB);
2761 A.deleteAfterManifest(*CB);
2762 ++NumBarriersEliminated;
2763 Changed = ChangeStatus::CHANGED;
2764 } else if (!ED.AlignedBarriers.empty()) {
2765 Changed = ChangeStatus::CHANGED;
2766 SmallVector<CallBase *> Worklist(ED.AlignedBarriers.begin(),
2767 ED.AlignedBarriers.end());
2768 SmallSetVector<CallBase *, 16> Visited;
2769 while (!Worklist.empty()) {
2770 CallBase *LastCB = Worklist.pop_back_val();
2771 if (!Visited.insert(LastCB))
2772 continue;
2773 if (LastCB->getFunction() != getAnchorScope())
2774 continue;
2775 if (!hasFunctionEndAsUniqueSuccessor(LastCB->getParent()))
2776 continue;
2777 if (!DeletedBarriers.count(LastCB)) {
2778 ++NumBarriersEliminated;
2779 A.deleteAfterManifest(*LastCB);
2780 continue;
2781 }
2782 // The final aligned barrier (LastCB) reaching the kernel end was
2783 // removed already. This means we can go one step further and remove
2784 // the barriers encoutered last before (LastCB).
2785 const ExecutionDomainTy &LastED = CEDMap[{LastCB, PRE}];
2786 Worklist.append(LastED.AlignedBarriers.begin(),
2787 LastED.AlignedBarriers.end());
2788 }
2789 }
2790
2791 // If we actually eliminated a barrier we need to eliminate the associated
2792 // llvm.assumes as well to avoid creating UB.
2793 if (!ED.EncounteredAssumes.empty() && (CB || !ED.AlignedBarriers.empty()))
2794 for (auto *AssumeCB : ED.EncounteredAssumes)
2795 A.deleteAfterManifest(*AssumeCB);
2796 };
2797
2798 for (auto *CB : AlignedBarriers)
2799 HandleAlignedBarrier(CB);
2800
2801 // Handle the "kernel end barrier" for kernels too.
2802 if (omp::isOpenMPKernel(*getAnchorScope()))
2803 HandleAlignedBarrier(nullptr);
2804
2805 return Changed;
2806 }
2807
2808 bool isNoOpFence(const FenceInst &FI) const override {
2809 return getState().isValidState() && !NonNoOpFences.count(&FI);
2810 }
2811
2812 /// Merge barrier and assumption information from \p PredED into the successor
2813 /// \p ED.
2814 void
2815 mergeInPredecessorBarriersAndAssumptions(Attributor &A, ExecutionDomainTy &ED,
2816 const ExecutionDomainTy &PredED);
2817
2818 /// Merge all information from \p PredED into the successor \p ED. If
2819 /// \p InitialEdgeOnly is set, only the initial edge will enter the block
2820 /// represented by \p ED from this predecessor.
2821 bool mergeInPredecessor(Attributor &A, ExecutionDomainTy &ED,
2822 const ExecutionDomainTy &PredED,
2823 bool InitialEdgeOnly = false);
2824
2825 /// Accumulate information for the entry block in \p EntryBBED.
2826 bool handleCallees(Attributor &A, ExecutionDomainTy &EntryBBED);
2827
2828 /// See AbstractAttribute::updateImpl.
2829 ChangeStatus updateImpl(Attributor &A) override;
2830
2831 /// Query interface, see AAExecutionDomain
2832 ///{
2833 bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2834 if (!isValidState())
2835 return false;
2836 assert(BB.getParent() == getAnchorScope() && "Block is out of scope!");
2837 return BEDMap.lookup(&BB).IsExecutedByInitialThreadOnly;
2838 }
2839
2840 bool isExecutedInAlignedRegion(Attributor &A,
2841 const Instruction &I) const override {
2842 assert(I.getFunction() == getAnchorScope() &&
2843 "Instruction is out of scope!");
2844 if (!isValidState())
2845 return false;
2846
2847 bool ForwardIsOk = true;
2848 const Instruction *CurI;
2849
2850 // Check forward until a call or the block end is reached.
2851 CurI = &I;
2852 do {
2853 auto *CB = dyn_cast<CallBase>(CurI);
2854 if (!CB)
2855 continue;
2856 if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2857 return true;
2858 const auto &It = CEDMap.find({CB, PRE});
2859 if (It == CEDMap.end())
2860 continue;
2861 if (!It->getSecond().IsReachingAlignedBarrierOnly)
2862 ForwardIsOk = false;
2863 break;
2864 } while ((CurI = CurI->getNextNode()));
2865
2866 if (!CurI && !BEDMap.lookup(I.getParent()).IsReachingAlignedBarrierOnly)
2867 ForwardIsOk = false;
2868
2869 // Check backward until a call or the block beginning is reached.
2870 CurI = &I;
2871 do {
2872 auto *CB = dyn_cast<CallBase>(CurI);
2873 if (!CB)
2874 continue;
2875 if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2876 return true;
2877 const auto &It = CEDMap.find({CB, POST});
2878 if (It == CEDMap.end())
2879 continue;
2880 if (It->getSecond().IsReachedFromAlignedBarrierOnly)
2881 break;
2882 return false;
2883 } while ((CurI = CurI->getPrevNode()));
2884
2885 // Delayed decision on the forward pass to allow aligned barrier detection
2886 // in the backwards traversal.
2887 if (!ForwardIsOk)
2888 return false;
2889
2890 if (!CurI) {
2891 const BasicBlock *BB = I.getParent();
2892 if (BB == &BB->getParent()->getEntryBlock())
2893 return BEDMap.lookup(nullptr).IsReachedFromAlignedBarrierOnly;
2894 if (!llvm::all_of(predecessors(BB), [&](const BasicBlock *PredBB) {
2895 return BEDMap.lookup(PredBB).IsReachedFromAlignedBarrierOnly;
2896 })) {
2897 return false;
2898 }
2899 }
2900
2901 // On neither traversal we found a anything but aligned barriers.
2902 return true;
2903 }
2904
2905 ExecutionDomainTy getExecutionDomain(const BasicBlock &BB) const override {
2906 assert(isValidState() &&
2907 "No request should be made against an invalid state!");
2908 return BEDMap.lookup(&BB);
2909 }
2910 std::pair<ExecutionDomainTy, ExecutionDomainTy>
2911 getExecutionDomain(const CallBase &CB) const override {
2912 assert(isValidState() &&
2913 "No request should be made against an invalid state!");
2914 return {CEDMap.lookup({&CB, PRE}), CEDMap.lookup({&CB, POST})};
2915 }
2916 ExecutionDomainTy getFunctionExecutionDomain() const override {
2917 assert(isValidState() &&
2918 "No request should be made against an invalid state!");
2919 return InterProceduralED;
2920 }
2921 ///}
2922
2923 // Check if the edge into the successor block contains a condition that only
2924 // lets the main thread execute it.
2925 static bool isInitialThreadOnlyEdge(Attributor &A, CondBrInst *Edge,
2926 BasicBlock &SuccessorBB) {
2927 if (!Edge)
2928 return false;
2929 if (Edge->getSuccessor(0) != &SuccessorBB)
2930 return false;
2931
2932 auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2933 if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2934 return false;
2935
2936 ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2937 if (!C)
2938 return false;
2939
2940 // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
2941 if (C->isAllOnesValue()) {
2942 auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2943 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2944 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2945 CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2946 if (!CB)
2947 return false;
2948 ConstantStruct *KernelEnvC =
2950 ConstantInt *ExecModeC =
2951 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
2952 return ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC;
2953 }
2954
2955 if (C->isZero()) {
2956 // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x()
2957 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2958 if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2959 return true;
2960
2961 // Match: 0 == llvm.amdgcn.workitem.id.x()
2962 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2963 if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2964 return true;
2965 }
2966
2967 return false;
2968 };
2969
2970 /// Mapping containing information about the function for other AAs.
2971 ExecutionDomainTy InterProceduralED;
2972
2973 enum Direction { PRE = 0, POST = 1 };
2974 /// Mapping containing information per block.
2975 DenseMap<const BasicBlock *, ExecutionDomainTy> BEDMap;
2976 DenseMap<PointerIntPair<const CallBase *, 1, Direction>, ExecutionDomainTy>
2977 CEDMap;
2978 SmallSetVector<CallBase *, 16> AlignedBarriers;
2979
2980 ReversePostOrderTraversal<Function *> *RPOT = nullptr;
2981
2982 /// Set \p R to \V and report true if that changed \p R.
2983 static bool setAndRecord(bool &R, bool V) {
2984 bool Eq = (R == V);
2985 R = V;
2986 return !Eq;
2987 }
2988
2989 /// Collection of fences known to be non-no-opt. All fences not in this set
2990 /// can be assumed no-opt.
2991 SmallPtrSet<const FenceInst *, 8> NonNoOpFences;
2992};
2993
2994void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions(
2995 Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED) {
2996 for (auto *EA : PredED.EncounteredAssumes)
2997 ED.addAssumeInst(A, *EA);
2998
2999 for (auto *AB : PredED.AlignedBarriers)
3000 ED.addAlignedBarrier(A, *AB);
3001}
3002
3003bool AAExecutionDomainFunction::mergeInPredecessor(
3004 Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED,
3005 bool InitialEdgeOnly) {
3006
3007 bool Changed = false;
3008 Changed |=
3009 setAndRecord(ED.IsExecutedByInitialThreadOnly,
3010 InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly &&
3011 ED.IsExecutedByInitialThreadOnly));
3012
3013 Changed |= setAndRecord(ED.IsReachedFromAlignedBarrierOnly,
3014 ED.IsReachedFromAlignedBarrierOnly &&
3015 PredED.IsReachedFromAlignedBarrierOnly);
3016 Changed |= setAndRecord(ED.EncounteredNonLocalSideEffect,
3017 ED.EncounteredNonLocalSideEffect |
3018 PredED.EncounteredNonLocalSideEffect);
3019 // Do not track assumptions and barriers as part of Changed.
3020 if (ED.IsReachedFromAlignedBarrierOnly)
3021 mergeInPredecessorBarriersAndAssumptions(A, ED, PredED);
3022 else
3023 ED.clearAssumeInstAndAlignedBarriers();
3024 return Changed;
3025}
3026
3027bool AAExecutionDomainFunction::handleCallees(Attributor &A,
3028 ExecutionDomainTy &EntryBBED) {
3030 auto PredForCallSite = [&](AbstractCallSite ACS) {
3031 const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3032 *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
3033 DepClassTy::OPTIONAL);
3034 if (!EDAA || !EDAA->getState().isValidState())
3035 return false;
3036 CallSiteEDs.emplace_back(
3037 EDAA->getExecutionDomain(*cast<CallBase>(ACS.getInstruction())));
3038 return true;
3039 };
3040
3041 ExecutionDomainTy ExitED;
3042 bool AllCallSitesKnown;
3043 if (A.checkForAllCallSites(PredForCallSite, *this,
3044 /* RequiresAllCallSites */ true,
3045 AllCallSitesKnown)) {
3046 for (const auto &[CSInED, CSOutED] : CallSiteEDs) {
3047 mergeInPredecessor(A, EntryBBED, CSInED);
3048 ExitED.IsReachingAlignedBarrierOnly &=
3049 CSOutED.IsReachingAlignedBarrierOnly;
3050 }
3051
3052 } else {
3053 // We could not find all predecessors, so this is either a kernel or a
3054 // function with external linkage (or with some other weird uses).
3055 if (omp::isOpenMPKernel(*getAnchorScope())) {
3056 EntryBBED.IsExecutedByInitialThreadOnly = false;
3057 EntryBBED.IsReachedFromAlignedBarrierOnly = true;
3058 EntryBBED.EncounteredNonLocalSideEffect = false;
3059 ExitED.IsReachingAlignedBarrierOnly = false;
3060 } else {
3061 EntryBBED.IsExecutedByInitialThreadOnly = false;
3062 EntryBBED.IsReachedFromAlignedBarrierOnly = false;
3063 EntryBBED.EncounteredNonLocalSideEffect = true;
3064 ExitED.IsReachingAlignedBarrierOnly = false;
3065 }
3066 }
3067
3068 bool Changed = false;
3069 auto &FnED = BEDMap[nullptr];
3070 Changed |= setAndRecord(FnED.IsReachedFromAlignedBarrierOnly,
3071 FnED.IsReachedFromAlignedBarrierOnly &
3072 EntryBBED.IsReachedFromAlignedBarrierOnly);
3073 Changed |= setAndRecord(FnED.IsReachingAlignedBarrierOnly,
3074 FnED.IsReachingAlignedBarrierOnly &
3075 ExitED.IsReachingAlignedBarrierOnly);
3076 Changed |= setAndRecord(FnED.IsExecutedByInitialThreadOnly,
3077 EntryBBED.IsExecutedByInitialThreadOnly);
3078 return Changed;
3079}
3080
3081ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
3082
3083 bool Changed = false;
3084
3085 // Helper to deal with an aligned barrier encountered during the forward
3086 // traversal. \p CB is the aligned barrier, \p ED is the execution domain when
3087 // it was encountered.
3088 auto HandleAlignedBarrier = [&](CallBase &CB, ExecutionDomainTy &ED) {
3089 Changed |= AlignedBarriers.insert(&CB);
3090 // First, update the barrier ED kept in the separate CEDMap.
3091 auto &CallInED = CEDMap[{&CB, PRE}];
3092 Changed |= mergeInPredecessor(A, CallInED, ED);
3093 CallInED.IsReachingAlignedBarrierOnly = true;
3094 // Next adjust the ED we use for the traversal.
3095 ED.EncounteredNonLocalSideEffect = false;
3096 ED.IsReachedFromAlignedBarrierOnly = true;
3097 // Aligned barrier collection has to come last.
3098 ED.clearAssumeInstAndAlignedBarriers();
3099 ED.addAlignedBarrier(A, CB);
3100 auto &CallOutED = CEDMap[{&CB, POST}];
3101 Changed |= mergeInPredecessor(A, CallOutED, ED);
3102 };
3103
3104 auto *LivenessAA =
3105 A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL);
3106
3107 Function *F = getAnchorScope();
3108 BasicBlock &EntryBB = F->getEntryBlock();
3109 bool IsKernel = omp::isOpenMPKernel(*F);
3110
3111 SmallVector<Instruction *> SyncInstWorklist;
3112 for (auto &RIt : *RPOT) {
3113 BasicBlock &BB = *RIt;
3114
3115 bool IsEntryBB = &BB == &EntryBB;
3116 // TODO: We use local reasoning since we don't have a divergence analysis
3117 // running as well. We could basically allow uniform branches here.
3118 bool AlignedBarrierLastInBlock = IsEntryBB && IsKernel;
3119 bool IsExplicitlyAligned = IsEntryBB && IsKernel;
3120 ExecutionDomainTy ED;
3121 // Propagate "incoming edges" into information about this block.
3122 if (IsEntryBB) {
3123 Changed |= handleCallees(A, ED);
3124 } else {
3125 // For live non-entry blocks we only propagate
3126 // information via live edges.
3127 if (LivenessAA && LivenessAA->isAssumedDead(&BB))
3128 continue;
3129
3130 for (auto *PredBB : predecessors(&BB)) {
3131 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, &BB))
3132 continue;
3133 bool InitialEdgeOnly = isInitialThreadOnlyEdge(
3134 A, dyn_cast<CondBrInst>(PredBB->getTerminator()), BB);
3135 mergeInPredecessor(A, ED, BEDMap[PredBB], InitialEdgeOnly);
3136 }
3137 }
3138
3139 // Now we traverse the block, accumulate effects in ED and attach
3140 // information to calls.
3141 for (Instruction &I : BB) {
3142 bool UsedAssumedInformation;
3143 if (A.isAssumedDead(I, *this, LivenessAA, UsedAssumedInformation,
3144 /* CheckBBLivenessOnly */ false, DepClassTy::OPTIONAL,
3145 /* CheckForDeadStore */ true))
3146 continue;
3147
3148 // Asummes and "assume-like" (dbg, lifetime, ...) are handled first, the
3149 // former is collected the latter is ignored.
3150 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
3151 if (auto *AI = dyn_cast_or_null<AssumeInst>(II)) {
3152 ED.addAssumeInst(A, *AI);
3153 continue;
3154 }
3155 // TODO: Should we also collect and delete lifetime markers?
3156 if (II->isAssumeLikeIntrinsic())
3157 continue;
3158 }
3159
3160 if (auto *FI = dyn_cast<FenceInst>(&I)) {
3161 if (!ED.EncounteredNonLocalSideEffect) {
3162 // An aligned fence without non-local side-effects is a no-op.
3163 if (ED.IsReachedFromAlignedBarrierOnly)
3164 continue;
3165 // A non-aligned fence without non-local side-effects is a no-op
3166 // if the ordering only publishes non-local side-effects (or less).
3167 switch (FI->getOrdering()) {
3168 case AtomicOrdering::NotAtomic:
3169 continue;
3170 case AtomicOrdering::Unordered:
3171 continue;
3172 case AtomicOrdering::Monotonic:
3173 continue;
3174 case AtomicOrdering::Acquire:
3175 break;
3176 case AtomicOrdering::Release:
3177 continue;
3178 case AtomicOrdering::AcquireRelease:
3179 break;
3180 case AtomicOrdering::SequentiallyConsistent:
3181 break;
3182 };
3183 }
3184 NonNoOpFences.insert(FI);
3185 }
3186
3187 auto *CB = dyn_cast<CallBase>(&I);
3188 bool IsNoSync = AA::isNoSyncInst(A, I, *this);
3189 bool IsAlignedBarrier =
3190 !IsNoSync && CB &&
3191 AANoSync::isAlignedBarrier(*CB, AlignedBarrierLastInBlock);
3192
3193 AlignedBarrierLastInBlock &= IsNoSync;
3194 IsExplicitlyAligned &= IsNoSync;
3195
3196 // Next we check for calls. Aligned barriers are handled
3197 // explicitly, everything else is kept for the backward traversal and will
3198 // also affect our state.
3199 if (CB) {
3200 if (IsAlignedBarrier) {
3201 HandleAlignedBarrier(*CB, ED);
3202 AlignedBarrierLastInBlock = true;
3203 IsExplicitlyAligned = true;
3204 continue;
3205 }
3206
3207 // Check the pointer(s) of a memory intrinsic explicitly.
3208 if (isa<MemIntrinsic>(&I)) {
3209 if (!ED.EncounteredNonLocalSideEffect &&
3211 ED.EncounteredNonLocalSideEffect = true;
3212 if (!IsNoSync) {
3213 ED.IsReachedFromAlignedBarrierOnly = false;
3214 SyncInstWorklist.push_back(&I);
3215 }
3216 continue;
3217 }
3218
3219 // Record how we entered the call, then accumulate the effect of the
3220 // call in ED for potential use by the callee.
3221 auto &CallInED = CEDMap[{CB, PRE}];
3222 Changed |= mergeInPredecessor(A, CallInED, ED);
3223
3224 // If we have a sync-definition we can check if it starts/ends in an
3225 // aligned barrier. If we are unsure we assume any sync breaks
3226 // alignment.
3228 if (!IsNoSync && Callee && !Callee->isDeclaration()) {
3229 const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3230 *this, IRPosition::function(*Callee), DepClassTy::OPTIONAL);
3231 if (EDAA && EDAA->getState().isValidState()) {
3232 const auto &CalleeED = EDAA->getFunctionExecutionDomain();
3233 ED.IsReachedFromAlignedBarrierOnly =
3234 CalleeED.IsReachedFromAlignedBarrierOnly;
3235 AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly;
3236 if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly)
3237 ED.EncounteredNonLocalSideEffect |=
3238 CalleeED.EncounteredNonLocalSideEffect;
3239 else
3240 ED.EncounteredNonLocalSideEffect =
3241 CalleeED.EncounteredNonLocalSideEffect;
3242 if (!CalleeED.IsReachingAlignedBarrierOnly) {
3243 Changed |=
3244 setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3245 SyncInstWorklist.push_back(&I);
3246 }
3247 if (CalleeED.IsReachedFromAlignedBarrierOnly)
3248 mergeInPredecessorBarriersAndAssumptions(A, ED, CalleeED);
3249 auto &CallOutED = CEDMap[{CB, POST}];
3250 Changed |= mergeInPredecessor(A, CallOutED, ED);
3251 continue;
3252 }
3253 }
3254 if (!IsNoSync) {
3255 ED.IsReachedFromAlignedBarrierOnly = false;
3256 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3257 SyncInstWorklist.push_back(&I);
3258 }
3259 AlignedBarrierLastInBlock &= ED.IsReachedFromAlignedBarrierOnly;
3260 ED.EncounteredNonLocalSideEffect |= !CB->doesNotAccessMemory();
3261 auto &CallOutED = CEDMap[{CB, POST}];
3262 Changed |= mergeInPredecessor(A, CallOutED, ED);
3263 }
3264
3265 if (!I.mayHaveSideEffects() && !I.mayReadFromMemory())
3266 continue;
3267
3268 // If we have a callee we try to use fine-grained information to
3269 // determine local side-effects.
3270 if (CB) {
3271 const auto *MemAA = A.getAAFor<AAMemoryLocation>(
3272 *this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL);
3273
3274 auto AccessPred = [&](const Instruction *I, const Value *Ptr,
3277 return !AA::isPotentiallyAffectedByBarrier(A, {Ptr}, *this, I);
3278 };
3279 if (MemAA && MemAA->getState().isValidState() &&
3280 MemAA->checkForAllAccessesToMemoryKind(
3282 continue;
3283 }
3284
3285 auto &InfoCache = A.getInfoCache();
3286 if (!I.mayHaveSideEffects() && InfoCache.isOnlyUsedByAssume(I))
3287 continue;
3288
3289 if (auto *LI = dyn_cast<LoadInst>(&I))
3290 if (LI->hasMetadata(LLVMContext::MD_invariant_load))
3291 continue;
3292
3293 if (!ED.EncounteredNonLocalSideEffect &&
3295 ED.EncounteredNonLocalSideEffect = true;
3296 }
3297
3298 bool IsEndAndNotReachingAlignedBarriersOnly = false;
3299 if (!isa<UnreachableInst>(BB.getTerminator()) &&
3300 !BB.getTerminator()->getNumSuccessors()) {
3301
3302 Changed |= mergeInPredecessor(A, InterProceduralED, ED);
3303
3304 auto &FnED = BEDMap[nullptr];
3305 if (IsKernel && !IsExplicitlyAligned)
3306 FnED.IsReachingAlignedBarrierOnly = false;
3307 Changed |= mergeInPredecessor(A, FnED, ED);
3308
3309 if (!FnED.IsReachingAlignedBarrierOnly) {
3310 IsEndAndNotReachingAlignedBarriersOnly = true;
3311 SyncInstWorklist.push_back(BB.getTerminator());
3312 auto &BBED = BEDMap[&BB];
3313 Changed |= setAndRecord(BBED.IsReachingAlignedBarrierOnly, false);
3314 }
3315 }
3316
3317 ExecutionDomainTy &StoredED = BEDMap[&BB];
3318 ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly &
3319 !IsEndAndNotReachingAlignedBarriersOnly;
3320
3321 // Check if we computed anything different as part of the forward
3322 // traversal. We do not take assumptions and aligned barriers into account
3323 // as they do not influence the state we iterate. Backward traversal values
3324 // are handled later on.
3325 if (ED.IsExecutedByInitialThreadOnly !=
3326 StoredED.IsExecutedByInitialThreadOnly ||
3327 ED.IsReachedFromAlignedBarrierOnly !=
3328 StoredED.IsReachedFromAlignedBarrierOnly ||
3329 ED.EncounteredNonLocalSideEffect !=
3330 StoredED.EncounteredNonLocalSideEffect)
3331 Changed = true;
3332
3333 // Update the state with the new value.
3334 StoredED = std::move(ED);
3335 }
3336
3337 // Propagate (non-aligned) sync instruction effects backwards until the
3338 // entry is hit or an aligned barrier.
3339 SmallSetVector<BasicBlock *, 16> Visited;
3340 while (!SyncInstWorklist.empty()) {
3341 Instruction *SyncInst = SyncInstWorklist.pop_back_val();
3342 Instruction *CurInst = SyncInst;
3343 bool HitAlignedBarrierOrKnownEnd = false;
3344 while ((CurInst = CurInst->getPrevNode())) {
3345 auto *CB = dyn_cast<CallBase>(CurInst);
3346 if (!CB)
3347 continue;
3348 auto &CallOutED = CEDMap[{CB, POST}];
3349 Changed |= setAndRecord(CallOutED.IsReachingAlignedBarrierOnly, false);
3350 auto &CallInED = CEDMap[{CB, PRE}];
3351 HitAlignedBarrierOrKnownEnd =
3352 AlignedBarriers.count(CB) || !CallInED.IsReachingAlignedBarrierOnly;
3353 if (HitAlignedBarrierOrKnownEnd)
3354 break;
3355 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3356 }
3357 if (HitAlignedBarrierOrKnownEnd)
3358 continue;
3359 BasicBlock *SyncBB = SyncInst->getParent();
3360 for (auto *PredBB : predecessors(SyncBB)) {
3361 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, SyncBB))
3362 continue;
3363 if (!Visited.insert(PredBB))
3364 continue;
3365 auto &PredED = BEDMap[PredBB];
3366 if (setAndRecord(PredED.IsReachingAlignedBarrierOnly, false)) {
3367 Changed = true;
3368 SyncInstWorklist.push_back(PredBB->getTerminator());
3369 }
3370 }
3371 if (SyncBB != &EntryBB)
3372 continue;
3373 Changed |=
3374 setAndRecord(InterProceduralED.IsReachingAlignedBarrierOnly, false);
3375 }
3376
3377 return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
3378}
3379
3380/// Try to replace memory allocation calls called by a single thread with a
3381/// static buffer of shared memory.
3382struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
3383 using Base = StateWrapper<BooleanState, AbstractAttribute>;
3384 AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3385
3386 /// Create an abstract attribute view for the position \p IRP.
3387 static AAHeapToShared &createForPosition(const IRPosition &IRP,
3388 Attributor &A);
3389
3390 /// Returns true if HeapToShared conversion is assumed to be possible.
3391 virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
3392
3393 /// Returns true if HeapToShared conversion is assumed and the CB is a
3394 /// callsite to a free operation to be removed.
3395 virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
3396
3397 /// See AbstractAttribute::getName().
3398 StringRef getName() const override { return "AAHeapToShared"; }
3399
3400 /// See AbstractAttribute::getIdAddr().
3401 const char *getIdAddr() const override { return &ID; }
3402
3403 /// This function should return true if the type of the \p AA is
3404 /// AAHeapToShared.
3405 static bool classof(const AbstractAttribute *AA) {
3406 return (AA->getIdAddr() == &ID);
3407 }
3408
3409 /// Unique ID (due to the unique address)
3410 static const char ID;
3411};
3412
3413struct AAHeapToSharedFunction : public AAHeapToShared {
3414 AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
3415 : AAHeapToShared(IRP, A) {}
3416
3417 const std::string getAsStr(Attributor *) const override {
3418 return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
3419 " malloc calls eligible.";
3420 }
3421
3422 /// See AbstractAttribute::trackStatistics().
3423 void trackStatistics() const override {}
3424
3425 /// This functions finds free calls that will be removed by the
3426 /// HeapToShared transformation.
3427 void findPotentialRemovedFreeCalls(Attributor &A) {
3428 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3429 auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3430
3431 PotentialRemovedFreeCalls.clear();
3432 // Update free call users of found malloc calls.
3433 for (CallBase *CB : MallocCalls) {
3435 for (auto *U : CB->users()) {
3436 CallBase *C = dyn_cast<CallBase>(U);
3437 if (C && C->getCalledFunction() == FreeRFI.Declaration)
3438 FreeCalls.push_back(C);
3439 }
3440
3441 if (FreeCalls.size() != 1)
3442 continue;
3443
3444 PotentialRemovedFreeCalls.insert(FreeCalls.front());
3445 }
3446 }
3447
3448 void initialize(Attributor &A) override {
3450 indicatePessimisticFixpoint();
3451 return;
3452 }
3453
3454 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3455 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3456 if (!RFI.Declaration)
3457 return;
3458
3460 [](const IRPosition &, const AbstractAttribute *,
3461 bool &) -> std::optional<Value *> { return nullptr; };
3462
3463 Function *F = getAnchorScope();
3464 for (User *U : RFI.Declaration->users())
3465 if (CallBase *CB = dyn_cast<CallBase>(U)) {
3466 if (CB->getFunction() != F)
3467 continue;
3468 MallocCalls.insert(CB);
3469 A.registerSimplificationCallback(IRPosition::callsite_returned(*CB),
3470 SCB);
3471 }
3472
3473 findPotentialRemovedFreeCalls(A);
3474 }
3475
3476 bool isAssumedHeapToShared(CallBase &CB) const override {
3477 return isValidState() && MallocCalls.count(&CB);
3478 }
3479
3480 bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
3481 return isValidState() && PotentialRemovedFreeCalls.count(&CB);
3482 }
3483
3484 ChangeStatus manifest(Attributor &A) override {
3485 if (MallocCalls.empty())
3486 return ChangeStatus::UNCHANGED;
3487
3488 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3489 auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3490
3491 Function *F = getAnchorScope();
3492 auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
3493 DepClassTy::OPTIONAL);
3494
3495 ChangeStatus Changed = ChangeStatus::UNCHANGED;
3496 for (CallBase *CB : MallocCalls) {
3497 // Skip replacing this if HeapToStack has already claimed it.
3498 if (HS && HS->isAssumedHeapToStack(*CB))
3499 continue;
3500
3501 // Find the unique free call to remove it.
3503 for (auto *U : CB->users()) {
3504 CallBase *C = dyn_cast<CallBase>(U);
3505 if (C && C->getCalledFunction() == FreeCall.Declaration)
3506 FreeCalls.push_back(C);
3507 }
3508 if (FreeCalls.size() != 1)
3509 continue;
3510
3511 auto *AllocSize = cast<ConstantInt>(CB->getArgOperand(0));
3512
3513 if (AllocSize->getZExtValue() + SharedMemoryUsed > SharedMemoryLimit) {
3514 LLVM_DEBUG(dbgs() << TAG << "Cannot replace call " << *CB
3515 << " with shared memory."
3516 << " Shared memory usage is limited to "
3517 << SharedMemoryLimit << " bytes\n");
3518 continue;
3519 }
3520
3521 LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
3522 << " with " << AllocSize->getZExtValue()
3523 << " bytes of shared memory\n");
3524
3525 // Create a new shared memory buffer of the same size as the allocation
3526 // and replace all the uses of the original allocation with it.
3527 Module *M = CB->getModule();
3528 Type *Int8Ty = Type::getInt8Ty(M->getContext());
3529 Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
3530 auto *SharedMem = new GlobalVariable(
3531 *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
3532 PoisonValue::get(Int8ArrTy), CB->getName() + "_shared", nullptr,
3534 static_cast<unsigned>(AddressSpace::Shared));
3535 auto *NewBuffer = ConstantExpr::getPointerCast(
3536 SharedMem, PointerType::getUnqual(M->getContext()));
3537
3538 auto Remark = [&](OptimizationRemark OR) {
3539 return OR << "Replaced globalized variable with "
3540 << ore::NV("SharedMemory", AllocSize->getZExtValue())
3541 << (AllocSize->isOne() ? " byte " : " bytes ")
3542 << "of shared memory.";
3543 };
3544 A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
3545
3546 MaybeAlign Alignment = CB->getRetAlign();
3547 assert(Alignment &&
3548 "HeapToShared on allocation without alignment attribute");
3549 SharedMem->setAlignment(*Alignment);
3550
3551 A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewBuffer);
3552 A.deleteAfterManifest(*CB);
3553 A.deleteAfterManifest(*FreeCalls.front());
3554
3555 SharedMemoryUsed += AllocSize->getZExtValue();
3556 NumBytesMovedToSharedMemory = SharedMemoryUsed;
3557 Changed = ChangeStatus::CHANGED;
3558 }
3559
3560 return Changed;
3561 }
3562
3563 ChangeStatus updateImpl(Attributor &A) override {
3564 if (MallocCalls.empty())
3565 return indicatePessimisticFixpoint();
3566 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3567 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3568 if (!RFI.Declaration)
3569 return ChangeStatus::UNCHANGED;
3570
3571 Function *F = getAnchorScope();
3572
3573 auto NumMallocCalls = MallocCalls.size();
3574
3575 // Only consider malloc calls executed by a single thread with a constant.
3576 for (User *U : RFI.Declaration->users()) {
3577 if (CallBase *CB = dyn_cast<CallBase>(U)) {
3578 if (CB->getCaller() != F)
3579 continue;
3580 if (!MallocCalls.count(CB))
3581 continue;
3582 if (!isa<ConstantInt>(CB->getArgOperand(0))) {
3583 MallocCalls.remove(CB);
3584 continue;
3585 }
3586 const auto *ED = A.getAAFor<AAExecutionDomain>(
3587 *this, IRPosition::function(*F), DepClassTy::REQUIRED);
3588 if (!ED || !ED->isExecutedByInitialThreadOnly(*CB))
3589 MallocCalls.remove(CB);
3590 }
3591 }
3592
3593 findPotentialRemovedFreeCalls(A);
3594
3595 if (NumMallocCalls != MallocCalls.size())
3596 return ChangeStatus::CHANGED;
3597
3598 return ChangeStatus::UNCHANGED;
3599 }
3600
3601 /// Collection of all malloc calls in a function.
3602 SmallSetVector<CallBase *, 4> MallocCalls;
3603 /// Collection of potentially removed free calls in a function.
3604 SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
3605 /// The total amount of shared memory that has been used for HeapToShared.
3606 unsigned SharedMemoryUsed = 0;
3607};
3608
3609struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
3610 using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
3611 AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3612
3613 /// The callee value is tracked beyond a simple stripPointerCasts, so we allow
3614 /// unknown callees.
3615 static bool requiresCalleeForCallBase() { return false; }
3616
3617 /// Statistics are tracked as part of manifest for now.
3618 void trackStatistics() const override {}
3619
3620 /// See AbstractAttribute::getAsStr()
3621 const std::string getAsStr(Attributor *) const override {
3622 if (!isValidState())
3623 return "<invalid>";
3624 return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
3625 : "generic") +
3626 std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
3627 : "") +
3628 std::string(" #PRs: ") +
3629 (ReachedKnownParallelRegions.isValidState()
3630 ? std::to_string(ReachedKnownParallelRegions.size())
3631 : "<invalid>") +
3632 ", #Unknown PRs: " +
3633 (ReachedUnknownParallelRegions.isValidState()
3634 ? std::to_string(ReachedUnknownParallelRegions.size())
3635 : "<invalid>") +
3636 ", #Reaching Kernels: " +
3637 (ReachingKernelEntries.isValidState()
3638 ? std::to_string(ReachingKernelEntries.size())
3639 : "<invalid>") +
3640 ", #ParLevels: " +
3641 (ParallelLevels.isValidState()
3642 ? std::to_string(ParallelLevels.size())
3643 : "<invalid>") +
3644 ", NestedPar: " + (NestedParallelism ? "yes" : "no");
3645 }
3646
3647 /// Create an abstract attribute biew for the position \p IRP.
3648 static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
3649
3650 /// See AbstractAttribute::getName()
3651 StringRef getName() const override { return "AAKernelInfo"; }
3652
3653 /// See AbstractAttribute::getIdAddr()
3654 const char *getIdAddr() const override { return &ID; }
3655
3656 /// This function should return true if the type of the \p AA is AAKernelInfo
3657 static bool classof(const AbstractAttribute *AA) {
3658 return (AA->getIdAddr() == &ID);
3659 }
3660
3661 static const char ID;
3662};
3663
3664/// The function kernel info abstract attribute, basically, what can we say
3665/// about a function with regards to the KernelInfoState.
3666struct AAKernelInfoFunction : AAKernelInfo {
3667 AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
3668 : AAKernelInfo(IRP, A) {}
3669
3670 SmallPtrSet<Instruction *, 4> GuardedInstructions;
3671
3672 SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
3673 return GuardedInstructions;
3674 }
3675
3676 void setConfigurationOfKernelEnvironment(ConstantStruct *ConfigC) {
3678 KernelEnvC, ConfigC, {KernelInfo::ConfigurationIdx});
3679 assert(NewKernelEnvC && "Failed to create new kernel environment");
3680 KernelEnvC = cast<ConstantStruct>(NewKernelEnvC);
3681 }
3682
3683#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER) \
3684 void set##MEMBER##OfKernelEnvironment(ConstantInt *NewVal) { \
3685 ConstantStruct *ConfigC = \
3686 KernelInfo::getConfigurationFromKernelEnvironment(KernelEnvC); \
3687 Constant *NewConfigC = ConstantFoldInsertValueInstruction( \
3688 ConfigC, NewVal, {KernelInfo::MEMBER##Idx}); \
3689 assert(NewConfigC && "Failed to create new configuration environment"); \
3690 setConfigurationOfKernelEnvironment(cast<ConstantStruct>(NewConfigC)); \
3691 }
3692
3693 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(UseGenericStateMachine)
3694 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MayUseNestedParallelism)
3700
3701#undef KERNEL_ENVIRONMENT_CONFIGURATION_SETTER
3702
3703 /// See AbstractAttribute::initialize(...).
3704 void initialize(Attributor &A) override {
3705 // This is a high-level transform that might change the constant arguments
3706 // of the init and dinit calls. We need to tell the Attributor about this
3707 // to avoid other parts using the current constant value for simpliication.
3708 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3709
3710 Function *Fn = getAnchorScope();
3711
3712 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
3713 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
3714 OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
3715 OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
3716
3717 // For kernels we perform more initialization work, first we find the init
3718 // and deinit calls.
3719 auto StoreCallBase = [](Use &U,
3720 OMPInformationCache::RuntimeFunctionInfo &RFI,
3721 CallBase *&Storage) {
3722 CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
3723 assert(CB &&
3724 "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
3725 assert(!Storage &&
3726 "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
3727 Storage = CB;
3728 return false;
3729 };
3730 InitRFI.foreachUse(
3731 [&](Use &U, Function &) {
3732 StoreCallBase(U, InitRFI, KernelInitCB);
3733 return false;
3734 },
3735 Fn);
3736 DeinitRFI.foreachUse(
3737 [&](Use &U, Function &) {
3738 StoreCallBase(U, DeinitRFI, KernelDeinitCB);
3739 return false;
3740 },
3741 Fn);
3742
3743 // Ignore kernels without initializers such as global constructors.
3744 if (!KernelInitCB || !KernelDeinitCB)
3745 return;
3746
3747 // Add itself to the reaching kernel and set IsKernelEntry.
3748 ReachingKernelEntries.insert(Fn);
3749 IsKernelEntry = true;
3750
3751 KernelEnvC =
3753 GlobalVariable *KernelEnvGV =
3755
3757 KernelConfigurationSimplifyCB =
3758 [&](const GlobalVariable &GV, const AbstractAttribute *AA,
3759 bool &UsedAssumedInformation) -> std::optional<Constant *> {
3760 if (!isAtFixpoint()) {
3761 if (!AA)
3762 return nullptr;
3763 UsedAssumedInformation = true;
3764 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3765 }
3766 return KernelEnvC;
3767 };
3768
3769 A.registerGlobalVariableSimplificationCallback(
3770 *KernelEnvGV, KernelConfigurationSimplifyCB);
3771
3772 // We cannot change to SPMD mode if the runtime functions aren't availible.
3773 bool CanChangeToSPMD = OMPInfoCache.runtimeFnsAvailable(
3774 {OMPRTL___kmpc_get_hardware_thread_id_in_block,
3775 OMPRTL___kmpc_barrier_simple_spmd});
3776
3777 // Check if we know we are in SPMD-mode already.
3778 ConstantInt *ExecModeC =
3779 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
3780 ConstantInt *AssumedExecModeC = ConstantInt::get(
3781 ExecModeC->getIntegerType(),
3783 if (ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)
3784 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3785 else if (DisableOpenMPOptSPMDization || !CanChangeToSPMD)
3786 // This is a generic region but SPMDization is disabled so stop
3787 // tracking.
3788 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3789 else
3790 setExecModeOfKernelEnvironment(AssumedExecModeC);
3791
3792 const Triple T(Fn->getParent()->getTargetTriple());
3793 auto *Int32Ty = Type::getInt32Ty(Fn->getContext());
3794 auto [MinThreads, MaxThreads] =
3796 if (MinThreads)
3797 setMinThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinThreads));
3798 if (MaxThreads)
3799 setMaxThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxThreads));
3800 auto [MinTeams, MaxTeams] =
3802 if (MinTeams)
3803 setMinTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinTeams));
3804 if (MaxTeams)
3805 setMaxTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxTeams));
3806
3807 ConstantInt *MayUseNestedParallelismC =
3808 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(KernelEnvC);
3809 ConstantInt *AssumedMayUseNestedParallelismC = ConstantInt::get(
3810 MayUseNestedParallelismC->getIntegerType(), NestedParallelism);
3811 setMayUseNestedParallelismOfKernelEnvironment(
3812 AssumedMayUseNestedParallelismC);
3813
3815 ConstantInt *UseGenericStateMachineC =
3816 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3817 KernelEnvC);
3818 ConstantInt *AssumedUseGenericStateMachineC =
3819 ConstantInt::get(UseGenericStateMachineC->getIntegerType(), false);
3820 setUseGenericStateMachineOfKernelEnvironment(
3821 AssumedUseGenericStateMachineC);
3822 }
3823
3824 // Register virtual uses of functions we might need to preserve.
3825 auto RegisterVirtualUse = [&](RuntimeFunction RFKind,
3827 if (!OMPInfoCache.RFIs[RFKind].Declaration)
3828 return;
3829 A.registerVirtualUseCallback(*OMPInfoCache.RFIs[RFKind].Declaration, CB);
3830 };
3831
3832 // Add a dependence to ensure updates if the state changes.
3833 auto AddDependence = [](Attributor &A, const AAKernelInfo *KI,
3834 const AbstractAttribute *QueryingAA) {
3835 if (QueryingAA) {
3836 A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);
3837 }
3838 return true;
3839 };
3840
3841 Attributor::VirtualUseCallbackTy CustomStateMachineUseCB =
3842 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3843 // Whenever we create a custom state machine we will insert calls to
3844 // __kmpc_get_hardware_num_threads_in_block,
3845 // __kmpc_get_warp_size,
3846 // __kmpc_barrier_simple_generic,
3847 // __kmpc_kernel_parallel, and
3848 // __kmpc_kernel_end_parallel.
3849 // Not needed if we are on track for SPMDzation.
3850 if (SPMDCompatibilityTracker.isValidState())
3851 return AddDependence(A, this, QueryingAA);
3852 // Not needed if we can't rewrite due to an invalid state.
3853 if (!ReachedKnownParallelRegions.isValidState())
3854 return AddDependence(A, this, QueryingAA);
3855 return false;
3856 };
3857
3858 // Not needed if we are pre-runtime merge.
3859 if (!KernelInitCB->getCalledFunction()->isDeclaration()) {
3860 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_num_threads_in_block,
3861 CustomStateMachineUseCB);
3862 RegisterVirtualUse(OMPRTL___kmpc_get_warp_size, CustomStateMachineUseCB);
3863 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_generic,
3864 CustomStateMachineUseCB);
3865 RegisterVirtualUse(OMPRTL___kmpc_kernel_parallel,
3866 CustomStateMachineUseCB);
3867 RegisterVirtualUse(OMPRTL___kmpc_kernel_end_parallel,
3868 CustomStateMachineUseCB);
3869 }
3870
3871 // If we do not perform SPMDzation we do not need the virtual uses below.
3872 if (SPMDCompatibilityTracker.isAtFixpoint())
3873 return;
3874
3875 Attributor::VirtualUseCallbackTy HWThreadIdUseCB =
3876 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3877 // Whenever we perform SPMDzation we will insert
3878 // __kmpc_get_hardware_thread_id_in_block calls.
3879 if (!SPMDCompatibilityTracker.isValidState())
3880 return AddDependence(A, this, QueryingAA);
3881 return false;
3882 };
3883 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_thread_id_in_block,
3884 HWThreadIdUseCB);
3885
3886 Attributor::VirtualUseCallbackTy SPMDBarrierUseCB =
3887 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3888 // Whenever we perform SPMDzation with guarding we will insert
3889 // __kmpc_simple_barrier_spmd calls. If SPMDzation failed, there is
3890 // nothing to guard, or there are no parallel regions, we don't need
3891 // the calls.
3892 if (!SPMDCompatibilityTracker.isValidState())
3893 return AddDependence(A, this, QueryingAA);
3894 if (SPMDCompatibilityTracker.empty())
3895 return AddDependence(A, this, QueryingAA);
3896 if (!mayContainParallelRegion())
3897 return AddDependence(A, this, QueryingAA);
3898 return false;
3899 };
3900 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_spmd, SPMDBarrierUseCB);
3901 }
3902
3903 /// Sanitize the string \p S such that it is a suitable global symbol name.
3904 static std::string sanitizeForGlobalName(std::string S) {
3905 std::replace_if(
3906 S.begin(), S.end(),
3907 [](const char C) {
3908 return !((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') ||
3909 (C >= '0' && C <= '9') || C == '_');
3910 },
3911 '.');
3912 return S;
3913 }
3914
3915 /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
3916 /// finished now.
3917 ChangeStatus manifest(Attributor &A) override {
3918 // If we are not looking at a kernel with __kmpc_target_init and
3919 // __kmpc_target_deinit call we cannot actually manifest the information.
3920 if (!KernelInitCB || !KernelDeinitCB)
3921 return ChangeStatus::UNCHANGED;
3922
3923 ChangeStatus Changed = ChangeStatus::UNCHANGED;
3924
3925 bool HasBuiltStateMachine = true;
3926 if (!changeToSPMDMode(A, Changed)) {
3927 if (!KernelInitCB->getCalledFunction()->isDeclaration())
3928 HasBuiltStateMachine = buildCustomStateMachine(A, Changed);
3929 else
3930 HasBuiltStateMachine = false;
3931 }
3932
3933 // We need to reset KernelEnvC if specific rewriting is not done.
3934 ConstantStruct *ExistingKernelEnvC =
3936 ConstantInt *OldUseGenericStateMachineVal =
3937 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3938 ExistingKernelEnvC);
3939 if (!HasBuiltStateMachine)
3940 setUseGenericStateMachineOfKernelEnvironment(
3941 OldUseGenericStateMachineVal);
3942
3943 // At last, update the KernelEnvc
3944 GlobalVariable *KernelEnvGV =
3946 if (KernelEnvGV->getInitializer() != KernelEnvC) {
3947 KernelEnvGV->setInitializer(KernelEnvC);
3948 Changed = ChangeStatus::CHANGED;
3949 }
3950
3951 return Changed;
3952 }
3953
3954 void insertInstructionGuardsHelper(Attributor &A) {
3955 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3956
3957 auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3958 Instruction *RegionEndI) {
3959 LoopInfo *LI = nullptr;
3960 DominatorTree *DT = nullptr;
3961 MemorySSAUpdater *MSU = nullptr;
3962 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3963
3964 BasicBlock *ParentBB = RegionStartI->getParent();
3965 Function *Fn = ParentBB->getParent();
3966 Module &M = *Fn->getParent();
3967
3968 // Create all the blocks and logic.
3969 // ParentBB:
3970 // goto RegionCheckTidBB
3971 // RegionCheckTidBB:
3972 // Tid = __kmpc_hardware_thread_id()
3973 // if (Tid != 0)
3974 // goto RegionBarrierBB
3975 // RegionStartBB:
3976 // <execute instructions guarded>
3977 // goto RegionEndBB
3978 // RegionEndBB:
3979 // <store escaping values to shared mem>
3980 // goto RegionBarrierBB
3981 // RegionBarrierBB:
3982 // __kmpc_simple_barrier_spmd()
3983 // // second barrier is omitted if lacking escaping values.
3984 // <load escaping values from shared mem>
3985 // __kmpc_simple_barrier_spmd()
3986 // goto RegionExitBB
3987 // RegionExitBB:
3988 // <execute rest of instructions>
3989
3990 BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3991 DT, LI, MSU, "region.guarded.end");
3992 BasicBlock *RegionBarrierBB =
3993 SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3994 MSU, "region.barrier");
3995 BasicBlock *RegionExitBB =
3996 SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3997 DT, LI, MSU, "region.exit");
3998 BasicBlock *RegionStartBB =
3999 SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
4000
4001 assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
4002 "Expected a different CFG");
4003
4004 BasicBlock *RegionCheckTidBB = SplitBlock(
4005 ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
4006
4007 // Register basic blocks with the Attributor.
4008 A.registerManifestAddedBasicBlock(*RegionEndBB);
4009 A.registerManifestAddedBasicBlock(*RegionBarrierBB);
4010 A.registerManifestAddedBasicBlock(*RegionExitBB);
4011 A.registerManifestAddedBasicBlock(*RegionStartBB);
4012 A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
4013
4014 bool HasBroadcastValues = false;
4015 // Find escaping outputs from the guarded region to outside users and
4016 // broadcast their values to them.
4017 for (Instruction &I : *RegionStartBB) {
4018 SmallVector<Use *, 4> OutsideUses;
4019 for (Use &U : I.uses()) {
4020 Instruction &UsrI = *cast<Instruction>(U.getUser());
4021 if (UsrI.getParent() != RegionStartBB)
4022 OutsideUses.push_back(&U);
4023 }
4024
4025 if (OutsideUses.empty())
4026 continue;
4027
4028 HasBroadcastValues = true;
4029
4030 // Emit a global variable in shared memory to store the broadcasted
4031 // value.
4032 auto *SharedMem = new GlobalVariable(
4033 M, I.getType(), /* IsConstant */ false,
4035 sanitizeForGlobalName(
4036 (I.getName() + ".guarded.output.alloc").str()),
4038 static_cast<unsigned>(AddressSpace::Shared));
4039
4040 // Emit a store instruction to update the value.
4041 new StoreInst(&I, SharedMem,
4042 RegionEndBB->getTerminator()->getIterator());
4043
4044 LoadInst *LoadI = new LoadInst(
4045 I.getType(), SharedMem, I.getName() + ".guarded.output.load",
4046 RegionBarrierBB->getTerminator()->getIterator());
4047
4048 // Emit a load instruction and replace uses of the output value.
4049 for (Use *U : OutsideUses)
4050 A.changeUseAfterManifest(*U, *LoadI);
4051 }
4052
4053 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4054
4055 // Go to tid check BB in ParentBB.
4056 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
4057 ParentBB->getTerminator()->eraseFromParent();
4058 OpenMPIRBuilder::LocationDescription Loc(
4059 InsertPointTy(ParentBB, ParentBB->end()), DL);
4060 OMPInfoCache.OMPBuilder.updateToLocation(Loc);
4061 uint32_t SrcLocStrSize;
4062 auto *SrcLocStr =
4063 OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4064 Value *Ident =
4065 OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4066 UncondBrInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
4067
4068 // Add check for Tid in RegionCheckTidBB
4069 RegionCheckTidBB->getTerminator()->eraseFromParent();
4070 OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
4071 InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
4072 OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
4073 FunctionCallee HardwareTidFn =
4074 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4075 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4076 CallInst *Tid =
4077 OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
4078 Tid->setDebugLoc(DL);
4079 OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);
4080 Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
4081 OMPInfoCache.OMPBuilder.Builder
4082 .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
4083 ->setDebugLoc(DL);
4084
4085 // First barrier for synchronization, ensures main thread has updated
4086 // values.
4087 FunctionCallee BarrierFn =
4088 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4089 M, OMPRTL___kmpc_barrier_simple_spmd);
4090 OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
4091 RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
4092 CallInst *Barrier =
4093 OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});
4094 Barrier->setDebugLoc(DL);
4095 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4096
4097 // Second barrier ensures workers have read broadcast values.
4098 if (HasBroadcastValues) {
4099 CallInst *Barrier =
4100 CallInst::Create(BarrierFn, {Ident, Tid}, "",
4101 RegionBarrierBB->getTerminator()->getIterator());
4102 Barrier->setDebugLoc(DL);
4103 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4104 }
4105 };
4106
4107 auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4108 SmallPtrSet<BasicBlock *, 8> Visited;
4109 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4110 BasicBlock *BB = GuardedI->getParent();
4111 if (!Visited.insert(BB).second)
4112 continue;
4113
4115 Instruction *LastEffect = nullptr;
4116 BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();
4117 while (++IP != IPEnd) {
4118 if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
4119 continue;
4120 Instruction *I = &*IP;
4121 if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI))
4122 continue;
4123 if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) {
4124 LastEffect = nullptr;
4125 continue;
4126 }
4127 if (LastEffect)
4128 Reorders.push_back({I, LastEffect});
4129 LastEffect = &*IP;
4130 }
4131 for (auto &Reorder : Reorders)
4132 Reorder.first->moveBefore(Reorder.second->getIterator());
4133 }
4134
4136
4137 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4138 BasicBlock *BB = GuardedI->getParent();
4139 auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
4140 IRPosition::function(*GuardedI->getFunction()), nullptr,
4141 DepClassTy::NONE);
4142 assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
4143 auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
4144 // Continue if instruction is already guarded.
4145 if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
4146 continue;
4147
4148 Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
4149 for (Instruction &I : *BB) {
4150 // If instruction I needs to be guarded update the guarded region
4151 // bounds.
4152 if (SPMDCompatibilityTracker.contains(&I)) {
4153 CalleeAAFunction.getGuardedInstructions().insert(&I);
4154 if (GuardedRegionStart)
4155 GuardedRegionEnd = &I;
4156 else
4157 GuardedRegionStart = GuardedRegionEnd = &I;
4158
4159 continue;
4160 }
4161
4162 // Instruction I does not need guarding, store
4163 // any region found and reset bounds.
4164 if (GuardedRegionStart) {
4165 GuardedRegions.push_back(
4166 std::make_pair(GuardedRegionStart, GuardedRegionEnd));
4167 GuardedRegionStart = nullptr;
4168 GuardedRegionEnd = nullptr;
4169 }
4170 }
4171 }
4172
4173 for (auto &GR : GuardedRegions)
4174 CreateGuardedRegion(GR.first, GR.second);
4175 }
4176
4177 void forceSingleThreadPerWorkgroupHelper(Attributor &A) {
4178 // Only allow 1 thread per workgroup to continue executing the user code.
4179 //
4180 // InitCB = __kmpc_target_init(...)
4181 // ThreadIdInBlock = __kmpc_get_hardware_thread_id_in_block();
4182 // if (ThreadIdInBlock != 0) return;
4183 // UserCode:
4184 // // user code
4185 //
4186 auto &Ctx = getAnchorValue().getContext();
4187 Function *Kernel = getAssociatedFunction();
4188 assert(Kernel && "Expected an associated function!");
4189
4190 // Create block for user code to branch to from initial block.
4191 BasicBlock *InitBB = KernelInitCB->getParent();
4192 BasicBlock *UserCodeBB = InitBB->splitBasicBlock(
4193 KernelInitCB->getNextNode(), "main.thread.user_code");
4194 BasicBlock *ReturnBB =
4195 BasicBlock::Create(Ctx, "exit.threads", Kernel, UserCodeBB);
4196
4197 // Register blocks with attributor:
4198 A.registerManifestAddedBasicBlock(*InitBB);
4199 A.registerManifestAddedBasicBlock(*UserCodeBB);
4200 A.registerManifestAddedBasicBlock(*ReturnBB);
4201
4202 // Debug location:
4203 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4204 ReturnInst::Create(Ctx, ReturnBB)->setDebugLoc(DLoc);
4205 InitBB->getTerminator()->eraseFromParent();
4206
4207 // Prepare call to OMPRTL___kmpc_get_hardware_thread_id_in_block.
4208 Module &M = *Kernel->getParent();
4209 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4210 FunctionCallee ThreadIdInBlockFn =
4211 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4212 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4213
4214 // Get thread ID in block.
4215 CallInst *ThreadIdInBlock =
4216 CallInst::Create(ThreadIdInBlockFn, "thread_id.in.block", InitBB);
4217 OMPInfoCache.setCallingConvention(ThreadIdInBlockFn, ThreadIdInBlock);
4218 ThreadIdInBlock->setDebugLoc(DLoc);
4219
4220 // Eliminate all threads in the block with ID not equal to 0:
4221 Instruction *IsMainThread =
4222 ICmpInst::Create(ICmpInst::ICmp, CmpInst::ICMP_NE, ThreadIdInBlock,
4223 ConstantInt::get(ThreadIdInBlock->getType(), 0),
4224 "thread.is_main", InitBB);
4225 IsMainThread->setDebugLoc(DLoc);
4226 CondBrInst::Create(IsMainThread, ReturnBB, UserCodeBB, InitBB);
4227 }
4228
4229 bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {
4230 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4231
4232 if (!SPMDCompatibilityTracker.isAssumed()) {
4233 for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
4234 if (!NonCompatibleI)
4235 continue;
4236
4237 // Skip diagnostics on calls to known OpenMP runtime functions for now.
4238 if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
4239 if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
4240 continue;
4241
4242 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4243 ORA << "Value has potential side effects preventing SPMD-mode "
4244 "execution";
4245 if (isa<CallBase>(NonCompatibleI)) {
4246 ORA << ". Add `[[omp::assume(\"ompx_spmd_amenable\")]]` to "
4247 "the called function to override";
4248 }
4249 return ORA << ".";
4250 };
4251 A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
4252 Remark);
4253
4254 LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
4255 << *NonCompatibleI << "\n");
4256 }
4257
4258 return false;
4259 }
4260
4261 // Get the actual kernel, could be the caller of the anchor scope if we have
4262 // a debug wrapper.
4263 Function *Kernel = getAnchorScope();
4264 if (Kernel->hasLocalLinkage()) {
4265 assert(Kernel->hasOneUse() && "Unexpected use of debug kernel wrapper.");
4266 auto *CB = cast<CallBase>(Kernel->user_back());
4267 Kernel = CB->getCaller();
4268 }
4269 assert(omp::isOpenMPKernel(*Kernel) && "Expected kernel function!");
4270
4271 // Check if the kernel is already in SPMD mode, if so, return success.
4272 ConstantStruct *ExistingKernelEnvC =
4274 auto *ExecModeC =
4275 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4276 const int8_t ExecModeVal = ExecModeC->getSExtValue();
4277 if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC)
4278 return true;
4279
4280 // We will now unconditionally modify the IR, indicate a change.
4281 Changed = ChangeStatus::CHANGED;
4282
4283 // Do not use instruction guards when no parallel is present inside
4284 // the target region.
4285 if (mayContainParallelRegion())
4286 insertInstructionGuardsHelper(A);
4287 else
4288 forceSingleThreadPerWorkgroupHelper(A);
4289
4290 // Adjust the global exec mode flag that tells the runtime what mode this
4291 // kernel is executed in.
4292 assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
4293 "Initially non-SPMD kernel has SPMD exec mode!");
4294 setExecModeOfKernelEnvironment(
4295 ConstantInt::get(ExecModeC->getIntegerType(),
4296 ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
4297
4298 ++NumOpenMPTargetRegionKernelsSPMD;
4299
4300 auto Remark = [&](OptimizationRemark OR) {
4301 return OR << "Transformed generic-mode kernel to SPMD-mode.";
4302 };
4303 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
4304 return true;
4305 };
4306
4307 bool buildCustomStateMachine(Attributor &A, ChangeStatus &Changed) {
4308 // If we have disabled state machine rewrites, don't make a custom one
4310 return false;
4311
4312 // Don't rewrite the state machine if we are not in a valid state.
4313 if (!ReachedKnownParallelRegions.isValidState())
4314 return false;
4315
4316 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4317 if (!OMPInfoCache.runtimeFnsAvailable(
4318 {OMPRTL___kmpc_get_hardware_num_threads_in_block,
4319 OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
4320 OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
4321 return false;
4322
4323 ConstantStruct *ExistingKernelEnvC =
4325
4326 // Check if the current configuration is non-SPMD and generic state machine.
4327 // If we already have SPMD mode or a custom state machine we do not need to
4328 // go any further. If it is anything but a constant something is weird and
4329 // we give up.
4330 ConstantInt *UseStateMachineC =
4331 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4332 ExistingKernelEnvC);
4333 ConstantInt *ModeC =
4334 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4335
4336 // If we are stuck with generic mode, try to create a custom device (=GPU)
4337 // state machine which is specialized for the parallel regions that are
4338 // reachable by the kernel.
4339 if (UseStateMachineC->isZero() ||
4341 return false;
4342
4343 Changed = ChangeStatus::CHANGED;
4344
4345 // If not SPMD mode, indicate we use a custom state machine now.
4346 setUseGenericStateMachineOfKernelEnvironment(
4347 ConstantInt::get(UseStateMachineC->getIntegerType(), false));
4348
4349 // If we don't actually need a state machine we are done here. This can
4350 // happen if there simply are no parallel regions. In the resulting kernel
4351 // all worker threads will simply exit right away, leaving the main thread
4352 // to do the work alone.
4353 if (!mayContainParallelRegion()) {
4354 ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
4355
4356 auto Remark = [&](OptimizationRemark OR) {
4357 return OR << "Removing unused state machine from generic-mode kernel.";
4358 };
4359 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
4360
4361 return true;
4362 }
4363
4364 // Keep track in the statistics of our new shiny custom state machine.
4365 if (ReachedUnknownParallelRegions.empty()) {
4366 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
4367
4368 auto Remark = [&](OptimizationRemark OR) {
4369 return OR << "Rewriting generic-mode kernel with a customized state "
4370 "machine.";
4371 };
4372 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
4373 } else {
4374 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
4375
4376 auto Remark = [&](OptimizationRemarkAnalysis OR) {
4377 return OR << "Generic-mode kernel is executed with a customized state "
4378 "machine that requires a fallback.";
4379 };
4380 A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
4381
4382 // Tell the user why we ended up with a fallback.
4383 for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
4384 if (!UnknownParallelRegionCB)
4385 continue;
4386 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4387 return ORA << "Call may contain unknown parallel regions. Use "
4388 << "`[[omp::assume(\"omp_no_parallelism\")]]` to "
4389 "override.";
4390 };
4391 A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
4392 "OMP133", Remark);
4393 }
4394 }
4395
4396 // Create all the blocks:
4397 //
4398 // InitCB = __kmpc_target_init(...)
4399 // BlockHwSize =
4400 // __kmpc_get_hardware_num_threads_in_block();
4401 // WarpSize = __kmpc_get_warp_size();
4402 // BlockSize = BlockHwSize - WarpSize;
4403 // IsWorkerCheckBB: bool IsWorker = InitCB != -1;
4404 // if (IsWorker) {
4405 // if (InitCB >= BlockSize) return;
4406 // SMBeginBB: __kmpc_barrier_simple_generic(...);
4407 // void *WorkFn;
4408 // bool Active = __kmpc_kernel_parallel(&WorkFn);
4409 // if (!WorkFn) return;
4410 // SMIsActiveCheckBB: if (Active) {
4411 // SMIfCascadeCurrentBB: if (WorkFn == <ParFn0>)
4412 // ParFn0(...);
4413 // SMIfCascadeCurrentBB: else if (WorkFn == <ParFn1>)
4414 // ParFn1(...);
4415 // ...
4416 // SMIfCascadeCurrentBB: else
4417 // ((WorkFnTy*)WorkFn)(...);
4418 // SMEndParallelBB: __kmpc_kernel_end_parallel(...);
4419 // }
4420 // SMDoneBB: __kmpc_barrier_simple_generic(...);
4421 // goto SMBeginBB;
4422 // }
4423 // UserCodeEntryBB: // user code
4424 // __kmpc_target_deinit(...)
4425 //
4426 auto &Ctx = getAnchorValue().getContext();
4427 Function *Kernel = getAssociatedFunction();
4428 assert(Kernel && "Expected an associated function!");
4429
4430 BasicBlock *InitBB = KernelInitCB->getParent();
4431 BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
4432 KernelInitCB->getNextNode(), "thread.user_code.check");
4433 BasicBlock *IsWorkerCheckBB =
4434 BasicBlock::Create(Ctx, "is_worker_check", Kernel, UserCodeEntryBB);
4435 BasicBlock *StateMachineBeginBB = BasicBlock::Create(
4436 Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
4437 BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
4438 Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
4439 BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
4440 Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
4441 BasicBlock *StateMachineIfCascadeCurrentBB =
4442 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4443 Kernel, UserCodeEntryBB);
4444 BasicBlock *StateMachineEndParallelBB =
4445 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
4446 Kernel, UserCodeEntryBB);
4447 BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
4448 Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
4449 A.registerManifestAddedBasicBlock(*InitBB);
4450 A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
4451 A.registerManifestAddedBasicBlock(*IsWorkerCheckBB);
4452 A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
4453 A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
4454 A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
4455 A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
4456 A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
4457 A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
4458
4459 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4460 ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
4461 InitBB->getTerminator()->eraseFromParent();
4462
4463 Instruction *IsWorker =
4464 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
4465 ConstantInt::getAllOnesValue(KernelInitCB->getType()),
4466 "thread.is_worker", InitBB);
4467 IsWorker->setDebugLoc(DLoc);
4468 CondBrInst::Create(IsWorker, IsWorkerCheckBB, UserCodeEntryBB, InitBB);
4469
4470 Module &M = *Kernel->getParent();
4471 FunctionCallee BlockHwSizeFn =
4472 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4473 M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
4474 FunctionCallee WarpSizeFn =
4475 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4476 M, OMPRTL___kmpc_get_warp_size);
4477 CallInst *BlockHwSize =
4478 CallInst::Create(BlockHwSizeFn, "block.hw_size", IsWorkerCheckBB);
4479 OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);
4480 BlockHwSize->setDebugLoc(DLoc);
4481 CallInst *WarpSize =
4482 CallInst::Create(WarpSizeFn, "warp.size", IsWorkerCheckBB);
4483 OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);
4484 WarpSize->setDebugLoc(DLoc);
4485 Instruction *BlockSize = BinaryOperator::CreateSub(
4486 BlockHwSize, WarpSize, "block.size", IsWorkerCheckBB);
4487 BlockSize->setDebugLoc(DLoc);
4488 Instruction *IsMainOrWorker = ICmpInst::Create(
4489 ICmpInst::ICmp, llvm::CmpInst::ICMP_SLT, KernelInitCB, BlockSize,
4490 "thread.is_main_or_worker", IsWorkerCheckBB);
4491 IsMainOrWorker->setDebugLoc(DLoc);
4492 CondBrInst::Create(IsMainOrWorker, StateMachineBeginBB,
4493 StateMachineFinishedBB, IsWorkerCheckBB);
4494
4495 // Create local storage for the work function pointer.
4496 const DataLayout &DL = M.getDataLayout();
4497 Type *VoidPtrTy = PointerType::getUnqual(Ctx);
4498 Instruction *WorkFnAI =
4499 new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
4500 "worker.work_fn.addr", Kernel->getEntryBlock().begin());
4501 WorkFnAI->setDebugLoc(DLoc);
4502
4503 OMPInfoCache.OMPBuilder.updateToLocation(
4504 OpenMPIRBuilder::LocationDescription(
4505 IRBuilder<>::InsertPoint(StateMachineBeginBB,
4506 StateMachineBeginBB->end()),
4507 DLoc));
4508
4509 Value *Ident = KernelInfo::getIdentFromKernelEnvironment(KernelEnvC);
4510 Value *GTid = KernelInitCB;
4511
4512 FunctionCallee BarrierFn =
4513 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4514 M, OMPRTL___kmpc_barrier_simple_generic);
4515 CallInst *Barrier =
4516 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB);
4517 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4518 Barrier->setDebugLoc(DLoc);
4519
4520 if (WorkFnAI->getType()->getPointerAddressSpace() !=
4521 (unsigned int)AddressSpace::Generic) {
4522 WorkFnAI = new AddrSpaceCastInst(
4523 WorkFnAI, PointerType::get(Ctx, (unsigned int)AddressSpace::Generic),
4524 WorkFnAI->getName() + ".generic", StateMachineBeginBB);
4525 WorkFnAI->setDebugLoc(DLoc);
4526 }
4527
4528 FunctionCallee KernelParallelFn =
4529 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4530 M, OMPRTL___kmpc_kernel_parallel);
4531 CallInst *IsActiveWorker = CallInst::Create(
4532 KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
4533 OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);
4534 IsActiveWorker->setDebugLoc(DLoc);
4535 Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
4536 StateMachineBeginBB);
4537 WorkFn->setDebugLoc(DLoc);
4538
4539 FunctionType *ParallelRegionFnTy = FunctionType::get(
4540 Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)},
4541 false);
4542
4543 Instruction *IsDone =
4544 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
4545 Constant::getNullValue(VoidPtrTy), "worker.is_done",
4546 StateMachineBeginBB);
4547 IsDone->setDebugLoc(DLoc);
4548 CondBrInst::Create(IsDone, StateMachineFinishedBB,
4549 StateMachineIsActiveCheckBB, StateMachineBeginBB)
4550 ->setDebugLoc(DLoc);
4551
4552 CondBrInst::Create(IsActiveWorker, StateMachineIfCascadeCurrentBB,
4553 StateMachineDoneBarrierBB, StateMachineIsActiveCheckBB)
4554 ->setDebugLoc(DLoc);
4555
4556 Value *ZeroArg =
4557 Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
4558
4559 const unsigned int WrapperFunctionArgNo = 6;
4560
4561 // Now that we have most of the CFG skeleton it is time for the if-cascade
4562 // that checks the function pointer we got from the runtime against the
4563 // parallel regions we expect, if there are any.
4564 for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) {
4565 auto *CB = ReachedKnownParallelRegions[I];
4566 auto *ParallelRegion = dyn_cast<Function>(
4567 CB->getArgOperand(WrapperFunctionArgNo)->stripPointerCasts());
4568 BasicBlock *PRExecuteBB = BasicBlock::Create(
4569 Ctx, "worker_state_machine.parallel_region.execute", Kernel,
4570 StateMachineEndParallelBB);
4571 CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
4572 ->setDebugLoc(DLoc);
4573 UncondBrInst::Create(StateMachineEndParallelBB, PRExecuteBB)
4574 ->setDebugLoc(DLoc);
4575
4576 BasicBlock *PRNextBB =
4577 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4578 Kernel, StateMachineEndParallelBB);
4579 A.registerManifestAddedBasicBlock(*PRExecuteBB);
4580 A.registerManifestAddedBasicBlock(*PRNextBB);
4581
4582 // Check if we need to compare the pointer at all or if we can just
4583 // call the parallel region function.
4584 Value *IsPR;
4585 if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) {
4586 Instruction *CmpI = ICmpInst::Create(
4587 ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn, ParallelRegion,
4588 "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
4589 CmpI->setDebugLoc(DLoc);
4590 IsPR = CmpI;
4591 } else {
4592 IsPR = ConstantInt::getTrue(Ctx);
4593 }
4594
4595 CondBrInst::Create(IsPR, PRExecuteBB, PRNextBB,
4596 StateMachineIfCascadeCurrentBB)
4597 ->setDebugLoc(DLoc);
4598 StateMachineIfCascadeCurrentBB = PRNextBB;
4599 }
4600
4601 // At the end of the if-cascade we place the indirect function pointer call
4602 // in case we might need it, that is if there can be parallel regions we
4603 // have not handled in the if-cascade above.
4604 if (!ReachedUnknownParallelRegions.empty()) {
4605 StateMachineIfCascadeCurrentBB->setName(
4606 "worker_state_machine.parallel_region.fallback.execute");
4607 CallInst::Create(ParallelRegionFnTy, WorkFn, {ZeroArg, GTid}, "",
4608 StateMachineIfCascadeCurrentBB)
4609 ->setDebugLoc(DLoc);
4610 }
4611 UncondBrInst::Create(StateMachineEndParallelBB,
4612 StateMachineIfCascadeCurrentBB)
4613 ->setDebugLoc(DLoc);
4614
4615 FunctionCallee EndParallelFn =
4616 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4617 M, OMPRTL___kmpc_kernel_end_parallel);
4618 CallInst *EndParallel =
4619 CallInst::Create(EndParallelFn, {}, "", StateMachineEndParallelBB);
4620 OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);
4621 EndParallel->setDebugLoc(DLoc);
4622 UncondBrInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
4623 ->setDebugLoc(DLoc);
4624
4625 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
4626 ->setDebugLoc(DLoc);
4627 UncondBrInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
4628 ->setDebugLoc(DLoc);
4629
4630 return true;
4631 }
4632
4633 /// Fixpoint iteration update function. Will be called every time a dependence
4634 /// changed its state (and in the beginning).
4635 ChangeStatus updateImpl(Attributor &A) override {
4636 KernelInfoState StateBefore = getState();
4637
4638 // When we leave this function this RAII will make sure the member
4639 // KernelEnvC is updated properly depending on the state. That member is
4640 // used for simplification of values and needs to be up to date at all
4641 // times.
4642 struct UpdateKernelEnvCRAII {
4643 AAKernelInfoFunction &AA;
4644
4645 UpdateKernelEnvCRAII(AAKernelInfoFunction &AA) : AA(AA) {}
4646
4647 ~UpdateKernelEnvCRAII() {
4648 if (!AA.KernelEnvC)
4649 return;
4650
4651 ConstantStruct *ExistingKernelEnvC =
4653
4654 if (!AA.isValidState()) {
4655 AA.KernelEnvC = ExistingKernelEnvC;
4656 return;
4657 }
4658
4659 if (!AA.ReachedKnownParallelRegions.isValidState())
4660 AA.setUseGenericStateMachineOfKernelEnvironment(
4661 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4662 ExistingKernelEnvC));
4663
4664 if (!AA.SPMDCompatibilityTracker.isValidState())
4665 AA.setExecModeOfKernelEnvironment(
4666 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC));
4667
4668 ConstantInt *MayUseNestedParallelismC =
4669 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(
4670 AA.KernelEnvC);
4671 ConstantInt *NewMayUseNestedParallelismC = ConstantInt::get(
4672 MayUseNestedParallelismC->getIntegerType(), AA.NestedParallelism);
4673 AA.setMayUseNestedParallelismOfKernelEnvironment(
4674 NewMayUseNestedParallelismC);
4675 }
4676 } RAII(*this);
4677
4678 // Callback to check a read/write instruction.
4679 auto CheckRWInst = [&](Instruction &I) {
4680 // We handle calls later.
4681 if (isa<CallBase>(I))
4682 return true;
4683 // We only care about write effects.
4684 if (!I.mayWriteToMemory())
4685 return true;
4686 if (auto *SI = dyn_cast<StoreInst>(&I)) {
4687 const auto *UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>(
4688 *this, IRPosition::value(*SI->getPointerOperand()),
4689 DepClassTy::OPTIONAL);
4690 auto *HS = A.getAAFor<AAHeapToStack>(
4691 *this, IRPosition::function(*I.getFunction()),
4692 DepClassTy::OPTIONAL);
4693 if (UnderlyingObjsAA &&
4694 UnderlyingObjsAA->forallUnderlyingObjects([&](Value &Obj) {
4695 if (AA::isAssumedThreadLocalObject(A, Obj, *this))
4696 return true;
4697 // Check for AAHeapToStack moved objects which must not be
4698 // guarded.
4699 auto *CB = dyn_cast<CallBase>(&Obj);
4700 return CB && HS && HS->isAssumedHeapToStack(*CB);
4701 }))
4702 return true;
4703 }
4704
4705 // Insert instruction that needs guarding.
4706 SPMDCompatibilityTracker.insert(&I);
4707 return true;
4708 };
4709
4710 bool UsedAssumedInformationInCheckRWInst = false;
4711 if (!SPMDCompatibilityTracker.isAtFixpoint())
4712 if (!A.checkForAllReadWriteInstructions(
4713 CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
4714 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4715
4716 bool UsedAssumedInformationFromReachingKernels = false;
4717 if (!IsKernelEntry) {
4718 updateParallelLevels(A);
4719
4720 bool AllReachingKernelsKnown = true;
4721 updateReachingKernelEntries(A, AllReachingKernelsKnown);
4722 UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
4723
4724 if (!SPMDCompatibilityTracker.empty()) {
4725 if (!ParallelLevels.isValidState())
4726 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4727 else if (!ReachingKernelEntries.isValidState())
4728 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4729 else {
4730 // Check if all reaching kernels agree on the mode as we can otherwise
4731 // not guard instructions. We might not be sure about the mode so we
4732 // we cannot fix the internal spmd-zation state either.
4733 int SPMD = 0, Generic = 0;
4734 for (auto *Kernel : ReachingKernelEntries) {
4735 auto *CBAA = A.getAAFor<AAKernelInfo>(
4736 *this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL);
4737 if (CBAA && CBAA->SPMDCompatibilityTracker.isValidState() &&
4738 CBAA->SPMDCompatibilityTracker.isAssumed())
4739 ++SPMD;
4740 else
4741 ++Generic;
4742 if (!CBAA || !CBAA->SPMDCompatibilityTracker.isAtFixpoint())
4743 UsedAssumedInformationFromReachingKernels = true;
4744 }
4745 if (SPMD != 0 && Generic != 0)
4746 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4747 }
4748 }
4749 }
4750
4751 // Callback to check a call instruction.
4752 bool AllParallelRegionStatesWereFixed = true;
4753 bool AllSPMDStatesWereFixed = true;
4754 auto CheckCallInst = [&](Instruction &I) {
4755 auto &CB = cast<CallBase>(I);
4756 auto *CBAA = A.getAAFor<AAKernelInfo>(
4757 *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4758 if (!CBAA)
4759 return false;
4760 getState() ^= CBAA->getState();
4761 AllSPMDStatesWereFixed &= CBAA->SPMDCompatibilityTracker.isAtFixpoint();
4762 AllParallelRegionStatesWereFixed &=
4763 CBAA->ReachedKnownParallelRegions.isAtFixpoint();
4764 AllParallelRegionStatesWereFixed &=
4765 CBAA->ReachedUnknownParallelRegions.isAtFixpoint();
4766 return true;
4767 };
4768
4769 bool UsedAssumedInformationInCheckCallInst = false;
4770 if (!A.checkForAllCallLikeInstructions(
4771 CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) {
4772 LLVM_DEBUG(dbgs() << TAG
4773 << "Failed to visit all call-like instructions!\n";);
4774 return indicatePessimisticFixpoint();
4775 }
4776
4777 // If we haven't used any assumed information for the reached parallel
4778 // region states we can fix it.
4779 if (!UsedAssumedInformationInCheckCallInst &&
4780 AllParallelRegionStatesWereFixed) {
4781 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
4782 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
4783 }
4784
4785 // If we haven't used any assumed information for the SPMD state we can fix
4786 // it.
4787 if (!UsedAssumedInformationInCheckRWInst &&
4788 !UsedAssumedInformationInCheckCallInst &&
4789 !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
4790 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4791
4792 return StateBefore == getState() ? ChangeStatus::UNCHANGED
4793 : ChangeStatus::CHANGED;
4794 }
4795
4796private:
4797 /// Update info regarding reaching kernels.
4798 void updateReachingKernelEntries(Attributor &A,
4799 bool &AllReachingKernelsKnown) {
4800 auto PredCallSite = [&](AbstractCallSite ACS) {
4801 Function *Caller = ACS.getInstruction()->getFunction();
4802
4803 assert(Caller && "Caller is nullptr");
4804
4805 auto *CAA = A.getOrCreateAAFor<AAKernelInfo>(
4806 IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
4807 if (CAA && CAA->ReachingKernelEntries.isValidState()) {
4808 ReachingKernelEntries ^= CAA->ReachingKernelEntries;
4809 return true;
4810 }
4811
4812 // We lost track of the caller of the associated function, any kernel
4813 // could reach now.
4814 ReachingKernelEntries.indicatePessimisticFixpoint();
4815
4816 return true;
4817 };
4818
4819 if (!A.checkForAllCallSites(PredCallSite, *this,
4820 true /* RequireAllCallSites */,
4821 AllReachingKernelsKnown))
4822 ReachingKernelEntries.indicatePessimisticFixpoint();
4823 }
4824
4825 /// Update info regarding parallel levels.
4826 void updateParallelLevels(Attributor &A) {
4827 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4828 OMPInformationCache::RuntimeFunctionInfo &Parallel60RFI =
4829 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_60];
4830
4831 auto PredCallSite = [&](AbstractCallSite ACS) {
4832 Function *Caller = ACS.getInstruction()->getFunction();
4833
4834 assert(Caller && "Caller is nullptr");
4835
4836 auto *CAA =
4837 A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
4838 if (CAA && CAA->ParallelLevels.isValidState()) {
4839 // Any function that is called by `__kmpc_parallel_60` will not be
4840 // folded as the parallel level in the function is updated. In order to
4841 // get it right, all the analysis would depend on the implentation. That
4842 // said, if in the future any change to the implementation, the analysis
4843 // could be wrong. As a consequence, we are just conservative here.
4844 if (Caller == Parallel60RFI.Declaration) {
4845 ParallelLevels.indicatePessimisticFixpoint();
4846 return true;
4847 }
4848
4849 ParallelLevels ^= CAA->ParallelLevels;
4850
4851 return true;
4852 }
4853
4854 // We lost track of the caller of the associated function, any kernel
4855 // could reach now.
4856 ParallelLevels.indicatePessimisticFixpoint();
4857
4858 return true;
4859 };
4860
4861 bool AllCallSitesKnown = true;
4862 if (!A.checkForAllCallSites(PredCallSite, *this,
4863 true /* RequireAllCallSites */,
4864 AllCallSitesKnown))
4865 ParallelLevels.indicatePessimisticFixpoint();
4866 }
4867};
4868
4869/// The call site kernel info abstract attribute, basically, what can we say
4870/// about a call site with regards to the KernelInfoState. For now this simply
4871/// forwards the information from the callee.
4872struct AAKernelInfoCallSite : AAKernelInfo {
4873 AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
4874 : AAKernelInfo(IRP, A) {}
4875
4876 /// See AbstractAttribute::initialize(...).
4877 void initialize(Attributor &A) override {
4878 AAKernelInfo::initialize(A);
4879
4880 CallBase &CB = cast<CallBase>(getAssociatedValue());
4881 auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>(
4882 *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4883
4884 // Check for SPMD-mode assumptions.
4885 if (AssumptionAA && AssumptionAA->hasAssumption("ompx_spmd_amenable")) {
4886 indicateOptimisticFixpoint();
4887 return;
4888 }
4889
4890 // First weed out calls we do not care about, that is readonly/readnone
4891 // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
4892 // parallel region or anything else we are looking for.
4893 if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
4894 indicateOptimisticFixpoint();
4895 return;
4896 }
4897
4898 // Next we check if we know the callee. If it is a known OpenMP function
4899 // we will handle them explicitly in the switch below. If it is not, we
4900 // will use an AAKernelInfo object on the callee to gather information and
4901 // merge that into the current state. The latter happens in the updateImpl.
4902 auto CheckCallee = [&](Function *Callee, unsigned NumCallees) {
4903 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4904 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4905 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4906 // Unknown caller or declarations are not analyzable, we give up.
4907 if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
4908
4909 // Unknown callees might contain parallel regions, except if they have
4910 // an appropriate assumption attached.
4911 if (!AssumptionAA ||
4912 !(AssumptionAA->hasAssumption("omp_no_openmp") ||
4913 AssumptionAA->hasAssumption("omp_no_parallelism")))
4914 ReachedUnknownParallelRegions.insert(&CB);
4915
4916 // If SPMDCompatibilityTracker is not fixed, we need to give up on the
4917 // idea we can run something unknown in SPMD-mode.
4918 if (!SPMDCompatibilityTracker.isAtFixpoint()) {
4919 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4920 SPMDCompatibilityTracker.insert(&CB);
4921 }
4922
4923 // We have updated the state for this unknown call properly, there
4924 // won't be any change so we indicate a fixpoint.
4925 indicateOptimisticFixpoint();
4926 }
4927 // If the callee is known and can be used in IPO, we will update the
4928 // state based on the callee state in updateImpl.
4929 return;
4930 }
4931 if (NumCallees > 1) {
4932 indicatePessimisticFixpoint();
4933 return;
4934 }
4935
4936 RuntimeFunction RF = It->getSecond();
4937 switch (RF) {
4938 // All the functions we know are compatible with SPMD mode.
4939 case OMPRTL___kmpc_is_spmd_exec_mode:
4940 case OMPRTL___kmpc_distribute_static_fini:
4941 case OMPRTL___kmpc_for_static_fini:
4942 case OMPRTL___kmpc_global_thread_num:
4943 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4944 case OMPRTL___kmpc_get_hardware_num_blocks:
4945 case OMPRTL___kmpc_single:
4946 case OMPRTL___kmpc_end_single:
4947 case OMPRTL___kmpc_master:
4948 case OMPRTL___kmpc_end_master:
4949 case OMPRTL___kmpc_barrier:
4950 case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
4951 case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
4952 case OMPRTL___kmpc_error:
4953 case OMPRTL___kmpc_flush:
4954 case OMPRTL___kmpc_get_hardware_thread_id_in_block:
4955 case OMPRTL___kmpc_get_warp_size:
4956 case OMPRTL_omp_get_thread_num:
4957 case OMPRTL_omp_get_num_threads:
4958 case OMPRTL_omp_get_max_threads:
4959 case OMPRTL_omp_in_parallel:
4960 case OMPRTL_omp_get_dynamic:
4961 case OMPRTL_omp_get_cancellation:
4962 case OMPRTL_omp_get_nested:
4963 case OMPRTL_omp_get_schedule:
4964 case OMPRTL_omp_get_thread_limit:
4965 case OMPRTL_omp_get_supported_active_levels:
4966 case OMPRTL_omp_get_max_active_levels:
4967 case OMPRTL_omp_get_level:
4968 case OMPRTL_omp_get_ancestor_thread_num:
4969 case OMPRTL_omp_get_team_size:
4970 case OMPRTL_omp_get_active_level:
4971 case OMPRTL_omp_in_final:
4972 case OMPRTL_omp_get_proc_bind:
4973 case OMPRTL_omp_get_num_places:
4974 case OMPRTL_omp_get_num_procs:
4975 case OMPRTL_omp_get_place_proc_ids:
4976 case OMPRTL_omp_get_place_num:
4977 case OMPRTL_omp_get_partition_num_places:
4978 case OMPRTL_omp_get_partition_place_nums:
4979 case OMPRTL_omp_get_wtime:
4980 break;
4981 case OMPRTL___kmpc_distribute_static_init_4:
4982 case OMPRTL___kmpc_distribute_static_init_4u:
4983 case OMPRTL___kmpc_distribute_static_init_8:
4984 case OMPRTL___kmpc_distribute_static_init_8u:
4985 case OMPRTL___kmpc_for_static_init_4:
4986 case OMPRTL___kmpc_for_static_init_4u:
4987 case OMPRTL___kmpc_for_static_init_8:
4988 case OMPRTL___kmpc_for_static_init_8u: {
4989 // Check the schedule and allow static schedule in SPMD mode.
4990 unsigned ScheduleArgOpNo = 2;
4991 auto *ScheduleTypeCI =
4992 dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
4993 unsigned ScheduleTypeVal =
4994 ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
4995 switch (OMPScheduleType(ScheduleTypeVal)) {
4996 case OMPScheduleType::UnorderedStatic:
4997 case OMPScheduleType::UnorderedStaticChunked:
4998 case OMPScheduleType::OrderedDistribute:
4999 case OMPScheduleType::OrderedDistributeChunked:
5000 break;
5001 default:
5002 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5003 SPMDCompatibilityTracker.insert(&CB);
5004 break;
5005 };
5006 } break;
5007 case OMPRTL___kmpc_target_init:
5008 KernelInitCB = &CB;
5009 break;
5010 case OMPRTL___kmpc_target_deinit:
5011 KernelDeinitCB = &CB;
5012 break;
5013 case OMPRTL___kmpc_parallel_60:
5014 if (!handleParallel60(A, CB))
5015 indicatePessimisticFixpoint();
5016 return;
5017 case OMPRTL___kmpc_omp_task:
5018 // We do not look into tasks right now, just give up.
5019 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5020 SPMDCompatibilityTracker.insert(&CB);
5021 ReachedUnknownParallelRegions.insert(&CB);
5022 break;
5023 case OMPRTL___kmpc_alloc_shared:
5024 case OMPRTL___kmpc_free_shared:
5025 // Return without setting a fixpoint, to be resolved in updateImpl.
5026 return;
5027 case OMPRTL___kmpc_distribute_static_loop_4:
5028 case OMPRTL___kmpc_distribute_static_loop_4u:
5029 case OMPRTL___kmpc_distribute_static_loop_8:
5030 case OMPRTL___kmpc_distribute_static_loop_8u:
5031 case OMPRTL___kmpc_distribute_for_static_loop_4:
5032 case OMPRTL___kmpc_distribute_for_static_loop_4u:
5033 case OMPRTL___kmpc_distribute_for_static_loop_8:
5034 case OMPRTL___kmpc_distribute_for_static_loop_8u:
5035 case OMPRTL___kmpc_for_static_loop_4:
5036 case OMPRTL___kmpc_for_static_loop_4u:
5037 case OMPRTL___kmpc_for_static_loop_8:
5038 case OMPRTL___kmpc_for_static_loop_8u:
5039 // Parallel regions might be reached by these calls, as they take a
5040 // callback argument potentially containing arbitrary user-provided
5041 // code.
5042 ReachedUnknownParallelRegions.insert(&CB);
5043 // TODO: The presence of these calls on their own does not prevent a
5044 // kernel from being SPMD-izable. We mark it as such because we need
5045 // further changes in order to also consider the contents of the
5046 // callbacks passed to them.
5047 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5048 SPMDCompatibilityTracker.insert(&CB);
5049 break;
5050 default:
5051 // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
5052 // generally. However, they do not hide parallel regions.
5053 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5054 SPMDCompatibilityTracker.insert(&CB);
5055 break;
5056 }
5057 // All other OpenMP runtime calls will not reach parallel regions so they
5058 // can be safely ignored for now. Since it is a known OpenMP runtime call
5059 // we have now modeled all effects and there is no need for any update.
5060 indicateOptimisticFixpoint();
5061 };
5062
5063 const auto *AACE =
5064 A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5065 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5066 CheckCallee(getAssociatedFunction(), 1);
5067 return;
5068 }
5069 const auto &OptimisticEdges = AACE->getOptimisticEdges();
5070 for (auto *Callee : OptimisticEdges) {
5071 CheckCallee(Callee, OptimisticEdges.size());
5072 if (isAtFixpoint())
5073 break;
5074 }
5075 }
5076
5077 ChangeStatus updateImpl(Attributor &A) override {
5078 // TODO: Once we have call site specific value information we can provide
5079 // call site specific liveness information and then it makes
5080 // sense to specialize attributes for call sites arguments instead of
5081 // redirecting requests to the callee argument.
5082 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5083 KernelInfoState StateBefore = getState();
5084
5085 auto CheckCallee = [&](Function *F, int NumCallees) {
5086 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
5087
5088 // If F is not a runtime function, propagate the AAKernelInfo of the
5089 // callee.
5090 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
5091 const IRPosition &FnPos = IRPosition::function(*F);
5092 auto *FnAA =
5093 A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
5094 if (!FnAA)
5095 return indicatePessimisticFixpoint();
5096 if (getState() == FnAA->getState())
5097 return ChangeStatus::UNCHANGED;
5098 getState() = FnAA->getState();
5099 return ChangeStatus::CHANGED;
5100 }
5101 if (NumCallees > 1)
5102 return indicatePessimisticFixpoint();
5103
5104 CallBase &CB = cast<CallBase>(getAssociatedValue());
5105 if (It->getSecond() == OMPRTL___kmpc_parallel_60) {
5106 if (!handleParallel60(A, CB))
5107 return indicatePessimisticFixpoint();
5108 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5109 : ChangeStatus::CHANGED;
5110 }
5111
5112 // F is a runtime function that allocates or frees memory, check
5113 // AAHeapToStack and AAHeapToShared.
5114 assert(
5115 (It->getSecond() == OMPRTL___kmpc_alloc_shared ||
5116 It->getSecond() == OMPRTL___kmpc_free_shared) &&
5117 "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
5118
5119 auto *HeapToStackAA = A.getAAFor<AAHeapToStack>(
5120 *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5121 auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>(
5122 *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5123
5124 RuntimeFunction RF = It->getSecond();
5125
5126 switch (RF) {
5127 // If neither HeapToStack nor HeapToShared assume the call is removed,
5128 // assume SPMD incompatibility.
5129 case OMPRTL___kmpc_alloc_shared:
5130 if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&
5131 (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))
5132 SPMDCompatibilityTracker.insert(&CB);
5133 break;
5134 case OMPRTL___kmpc_free_shared:
5135 if ((!HeapToStackAA ||
5136 !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&
5137 (!HeapToSharedAA ||
5138 !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))
5139 SPMDCompatibilityTracker.insert(&CB);
5140 break;
5141 default:
5142 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5143 SPMDCompatibilityTracker.insert(&CB);
5144 }
5145 return ChangeStatus::CHANGED;
5146 };
5147
5148 const auto *AACE =
5149 A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5150 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5151 if (Function *F = getAssociatedFunction())
5152 CheckCallee(F, /*NumCallees=*/1);
5153 } else {
5154 const auto &OptimisticEdges = AACE->getOptimisticEdges();
5155 for (auto *Callee : OptimisticEdges) {
5156 CheckCallee(Callee, OptimisticEdges.size());
5157 if (isAtFixpoint())
5158 break;
5159 }
5160 }
5161
5162 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5163 : ChangeStatus::CHANGED;
5164 }
5165
5166 /// Deal with a __kmpc_parallel_60 call (\p CB). Returns true if the call was
5167 /// handled, if a problem occurred, false is returned.
5168 bool handleParallel60(Attributor &A, CallBase &CB) {
5169 const unsigned int NonWrapperFunctionArgNo = 5;
5170 const unsigned int WrapperFunctionArgNo = 6;
5171 auto ParallelRegionOpArgNo = SPMDCompatibilityTracker.isAssumed()
5172 ? NonWrapperFunctionArgNo
5173 : WrapperFunctionArgNo;
5174
5175 auto *ParallelRegion = dyn_cast<Function>(
5176 CB.getArgOperand(ParallelRegionOpArgNo)->stripPointerCasts());
5177 if (!ParallelRegion)
5178 return false;
5179
5180 ReachedKnownParallelRegions.insert(&CB);
5181 /// Check nested parallelism
5182 auto *FnAA = A.getAAFor<AAKernelInfo>(
5183 *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL);
5184 NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
5185 !FnAA->ReachedKnownParallelRegions.empty() ||
5186 !FnAA->ReachedKnownParallelRegions.isValidState() ||
5187 !FnAA->ReachedUnknownParallelRegions.isValidState() ||
5188 !FnAA->ReachedUnknownParallelRegions.empty();
5189 return true;
5190 }
5191};
5192
5193struct AAFoldRuntimeCall
5194 : public StateWrapper<BooleanState, AbstractAttribute> {
5195 using Base = StateWrapper<BooleanState, AbstractAttribute>;
5196
5197 AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
5198
5199 /// Statistics are tracked as part of manifest for now.
5200 void trackStatistics() const override {}
5201
5202 /// Create an abstract attribute biew for the position \p IRP.
5203 static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
5204 Attributor &A);
5205
5206 /// See AbstractAttribute::getName()
5207 StringRef getName() const override { return "AAFoldRuntimeCall"; }
5208
5209 /// See AbstractAttribute::getIdAddr()
5210 const char *getIdAddr() const override { return &ID; }
5211
5212 /// This function should return true if the type of the \p AA is
5213 /// AAFoldRuntimeCall
5214 static bool classof(const AbstractAttribute *AA) {
5215 return (AA->getIdAddr() == &ID);
5216 }
5217
5218 static const char ID;
5219};
5220
5221struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
5222 AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
5223 : AAFoldRuntimeCall(IRP, A) {}
5224
5225 /// See AbstractAttribute::getAsStr()
5226 const std::string getAsStr(Attributor *) const override {
5227 if (!isValidState())
5228 return "<invalid>";
5229
5230 std::string Str("simplified value: ");
5231
5232 if (!SimplifiedValue)
5233 return Str + std::string("none");
5234
5235 if (!*SimplifiedValue)
5236 return Str + std::string("nullptr");
5237
5238 if (ConstantInt *CI = dyn_cast<ConstantInt>(*SimplifiedValue))
5239 return Str + std::to_string(CI->getSExtValue());
5240
5241 return Str + std::string("unknown");
5242 }
5243
5244 void initialize(Attributor &A) override {
5246 indicatePessimisticFixpoint();
5247
5248 Function *Callee = getAssociatedFunction();
5249
5250 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5251 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
5252 assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
5253 "Expected a known OpenMP runtime function");
5254
5255 RFKind = It->getSecond();
5256
5257 CallBase &CB = cast<CallBase>(getAssociatedValue());
5258 A.registerSimplificationCallback(
5260 [&](const IRPosition &IRP, const AbstractAttribute *AA,
5261 bool &UsedAssumedInformation) -> std::optional<Value *> {
5262 assert((isValidState() || SimplifiedValue == nullptr) &&
5263 "Unexpected invalid state!");
5264
5265 if (!isAtFixpoint()) {
5266 UsedAssumedInformation = true;
5267 if (AA)
5268 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
5269 }
5270 return SimplifiedValue;
5271 });
5272 }
5273
5274 ChangeStatus updateImpl(Attributor &A) override {
5275 ChangeStatus Changed = ChangeStatus::UNCHANGED;
5276 switch (RFKind) {
5277 case OMPRTL___kmpc_is_spmd_exec_mode:
5278 Changed |= foldIsSPMDExecMode(A);
5279 break;
5280 case OMPRTL___kmpc_parallel_level:
5281 Changed |= foldParallelLevel(A);
5282 break;
5283 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
5284 Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
5285 break;
5286 case OMPRTL___kmpc_get_hardware_num_blocks:
5287 Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
5288 break;
5289 default:
5290 llvm_unreachable("Unhandled OpenMP runtime function!");
5291 }
5292
5293 return Changed;
5294 }
5295
5296 ChangeStatus manifest(Attributor &A) override {
5297 ChangeStatus Changed = ChangeStatus::UNCHANGED;
5298
5299 if (SimplifiedValue && *SimplifiedValue) {
5300 Instruction &I = *getCtxI();
5301 A.changeAfterManifest(IRPosition::inst(I), **SimplifiedValue);
5302 A.deleteAfterManifest(I);
5303
5304 CallBase *CB = dyn_cast<CallBase>(&I);
5305 auto Remark = [&](OptimizationRemark OR) {
5306 if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue))
5307 return OR << "Replacing OpenMP runtime call "
5308 << CB->getCalledFunction()->getName() << " with "
5309 << ore::NV("FoldedValue", C->getZExtValue()) << ".";
5310 return OR << "Replacing OpenMP runtime call "
5311 << CB->getCalledFunction()->getName() << ".";
5312 };
5313
5314 if (CB && EnableVerboseRemarks)
5315 A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark);
5316
5317 LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "
5318 << **SimplifiedValue << "\n");
5319
5320 Changed = ChangeStatus::CHANGED;
5321 }
5322
5323 return Changed;
5324 }
5325
5326 ChangeStatus indicatePessimisticFixpoint() override {
5327 SimplifiedValue = nullptr;
5328 return AAFoldRuntimeCall::indicatePessimisticFixpoint();
5329 }
5330
5331private:
5332 /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
5333 ChangeStatus foldIsSPMDExecMode(Attributor &A) {
5334 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5335
5336 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5337 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5338 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5339 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5340
5341 if (!CallerKernelInfoAA ||
5342 !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5343 return indicatePessimisticFixpoint();
5344
5345 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5346 auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5347 DepClassTy::REQUIRED);
5348
5349 if (!AA || !AA->isValidState()) {
5350 SimplifiedValue = nullptr;
5351 return indicatePessimisticFixpoint();
5352 }
5353
5354 if (AA->SPMDCompatibilityTracker.isAssumed()) {
5355 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5356 ++KnownSPMDCount;
5357 else
5358 ++AssumedSPMDCount;
5359 } else {
5360 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5361 ++KnownNonSPMDCount;
5362 else
5363 ++AssumedNonSPMDCount;
5364 }
5365 }
5366
5367 if ((AssumedSPMDCount + KnownSPMDCount) &&
5368 (AssumedNonSPMDCount + KnownNonSPMDCount))
5369 return indicatePessimisticFixpoint();
5370
5371 auto &Ctx = getAnchorValue().getContext();
5372 if (KnownSPMDCount || AssumedSPMDCount) {
5373 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5374 "Expected only SPMD kernels!");
5375 // All reaching kernels are in SPMD mode. Update all function calls to
5376 // __kmpc_is_spmd_exec_mode to 1.
5377 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
5378 } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
5379 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5380 "Expected only non-SPMD kernels!");
5381 // All reaching kernels are in non-SPMD mode. Update all function
5382 // calls to __kmpc_is_spmd_exec_mode to 0.
5383 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
5384 } else {
5385 // We have empty reaching kernels, therefore we cannot tell if the
5386 // associated call site can be folded. At this moment, SimplifiedValue
5387 // must be none.
5388 assert(!SimplifiedValue && "SimplifiedValue should be none");
5389 }
5390
5391 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5392 : ChangeStatus::CHANGED;
5393 }
5394
5395 /// Fold __kmpc_parallel_level into a constant if possible.
5396 ChangeStatus foldParallelLevel(Attributor &A) {
5397 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5398
5399 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5400 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5401
5402 if (!CallerKernelInfoAA ||
5403 !CallerKernelInfoAA->ParallelLevels.isValidState())
5404 return indicatePessimisticFixpoint();
5405
5406 if (!CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5407 return indicatePessimisticFixpoint();
5408
5409 if (CallerKernelInfoAA->ReachingKernelEntries.empty()) {
5410 assert(!SimplifiedValue &&
5411 "SimplifiedValue should keep none at this point");
5412 return ChangeStatus::UNCHANGED;
5413 }
5414
5415 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5416 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5417 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5418 auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5419 DepClassTy::REQUIRED);
5420 if (!AA || !AA->SPMDCompatibilityTracker.isValidState())
5421 return indicatePessimisticFixpoint();
5422
5423 if (AA->SPMDCompatibilityTracker.isAssumed()) {
5424 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5425 ++KnownSPMDCount;
5426 else
5427 ++AssumedSPMDCount;
5428 } else {
5429 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5430 ++KnownNonSPMDCount;
5431 else
5432 ++AssumedNonSPMDCount;
5433 }
5434 }
5435
5436 if ((AssumedSPMDCount + KnownSPMDCount) &&
5437 (AssumedNonSPMDCount + KnownNonSPMDCount))
5438 return indicatePessimisticFixpoint();
5439
5440 auto &Ctx = getAnchorValue().getContext();
5441 // If the caller can only be reached by SPMD kernel entries, the parallel
5442 // level is 1. Similarly, if the caller can only be reached by non-SPMD
5443 // kernel entries, it is 0.
5444 if (AssumedSPMDCount || KnownSPMDCount) {
5445 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5446 "Expected only SPMD kernels!");
5447 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
5448 } else {
5449 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5450 "Expected only non-SPMD kernels!");
5451 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
5452 }
5453 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5454 : ChangeStatus::CHANGED;
5455 }
5456
5457 ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
5458 // Specialize only if all the calls agree with the attribute constant value
5459 int32_t CurrentAttrValue = -1;
5460 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5461
5462 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5463 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5464
5465 if (!CallerKernelInfoAA ||
5466 !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5467 return indicatePessimisticFixpoint();
5468
5469 // Iterate over the kernels that reach this function
5470 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5471 int32_t NextAttrVal = K->getFnAttributeAsParsedInteger(Attr, -1);
5472
5473 if (NextAttrVal == -1 ||
5474 (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
5475 return indicatePessimisticFixpoint();
5476 CurrentAttrValue = NextAttrVal;
5477 }
5478
5479 if (CurrentAttrValue != -1) {
5480 auto &Ctx = getAnchorValue().getContext();
5481 SimplifiedValue =
5482 ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
5483 }
5484 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5485 : ChangeStatus::CHANGED;
5486 }
5487
5488 /// An optional value the associated value is assumed to fold to. That is, we
5489 /// assume the associated value (which is a call) can be replaced by this
5490 /// simplified value.
5491 std::optional<Value *> SimplifiedValue;
5492
5493 /// The runtime function kind of the callee of the associated call site.
5494 RuntimeFunction RFKind;
5495};
5496
5497} // namespace
5498
5499/// Register folding callsite
5500void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
5501 auto &RFI = OMPInfoCache.RFIs[RF];
5502 RFI.foreachUse(SCC, [&](Use &U, Function &F) {
5503 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
5504 if (!CI)
5505 return false;
5506 A.getOrCreateAAFor<AAFoldRuntimeCall>(
5507 IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
5508 DepClassTy::NONE, /* ForceUpdate */ false,
5509 /* UpdateAfterInit */ false);
5510 return false;
5511 });
5512}
5513
5514void OpenMPOpt::registerAAs(bool IsModulePass) {
5515 if (SCC.empty())
5516 return;
5517
5518 if (IsModulePass) {
5519 // Ensure we create the AAKernelInfo AAs first and without triggering an
5520 // update. This will make sure we register all value simplification
5521 // callbacks before any other AA has the chance to create an AAValueSimplify
5522 // or similar.
5523 auto CreateKernelInfoCB = [&](Use &, Function &Kernel) {
5524 A.getOrCreateAAFor<AAKernelInfo>(
5525 IRPosition::function(Kernel), /* QueryingAA */ nullptr,
5526 DepClassTy::NONE, /* ForceUpdate */ false,
5527 /* UpdateAfterInit */ false);
5528 return false;
5529 };
5530 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
5531 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
5532 InitRFI.foreachUse(SCC, CreateKernelInfoCB);
5533
5534 registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
5535 registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
5536 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
5537 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
5538 }
5539
5540 // Create CallSite AA for all Getters.
5541 if (DeduceICVValues) {
5542 for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
5543 auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
5544
5545 auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
5546
5547 auto CreateAA = [&](Use &U, Function &Caller) {
5548 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
5549 if (!CI)
5550 return false;
5551
5552 auto &CB = cast<CallBase>(*CI);
5553
5554 IRPosition CBPos = IRPosition::callsite_function(CB);
5555 A.getOrCreateAAFor<AAICVTracker>(CBPos);
5556 return false;
5557 };
5558
5559 GetterRFI.foreachUse(SCC, CreateAA);
5560 }
5561 }
5562
5563 // Create an ExecutionDomain AA for every function and a HeapToStack AA for
5564 // every function if there is a device kernel.
5565 if (!isOpenMPDevice(M))
5566 return;
5567
5568 for (auto *F : SCC) {
5569 if (F->isDeclaration())
5570 continue;
5571
5572 // We look at internal functions only on-demand but if any use is not a
5573 // direct call or outside the current set of analyzed functions, we have
5574 // to do it eagerly.
5575 if (F->hasLocalLinkage()) {
5576 if (llvm::all_of(F->uses(), [this](const Use &U) {
5577 const auto *CB = dyn_cast<CallBase>(U.getUser());
5578 return CB && CB->isCallee(&U) &&
5579 A.isRunOn(const_cast<Function *>(CB->getCaller()));
5580 }))
5581 continue;
5582 }
5583 registerAAsForFunction(A, *F);
5584 }
5585}
5586
5587void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) {
5589 A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
5590 A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(F));
5592 A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(F));
5593 if (F.hasFnAttribute(Attribute::Convergent))
5594 A.getOrCreateAAFor<AANonConvergent>(IRPosition::function(F));
5595
5596 for (auto &I : instructions(F)) {
5597 if (auto *LI = dyn_cast<LoadInst>(&I)) {
5598 bool UsedAssumedInformation = false;
5599 A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
5600 UsedAssumedInformation, AA::Interprocedural);
5601 A.getOrCreateAAFor<AAAddressSpace>(
5602 IRPosition::value(*LI->getPointerOperand()));
5603 continue;
5604 }
5605 if (auto *CI = dyn_cast<CallBase>(&I)) {
5606 if (CI->isIndirectCall())
5607 A.getOrCreateAAFor<AAIndirectCallInfo>(
5609 }
5610 if (auto *SI = dyn_cast<StoreInst>(&I)) {
5611 A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI));
5612 A.getOrCreateAAFor<AAAddressSpace>(
5613 IRPosition::value(*SI->getPointerOperand()));
5614 continue;
5615 }
5616 if (auto *FI = dyn_cast<FenceInst>(&I)) {
5617 A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*FI));
5618 continue;
5619 }
5620 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
5621 if (II->getIntrinsicID() == Intrinsic::assume) {
5622 A.getOrCreateAAFor<AAPotentialValues>(
5623 IRPosition::value(*II->getArgOperand(0)));
5624 continue;
5625 }
5626 }
5627 }
5628}
5629
5630const char AAICVTracker::ID = 0;
5631const char AAKernelInfo::ID = 0;
5632const char AAExecutionDomain::ID = 0;
5633const char AAHeapToShared::ID = 0;
5634const char AAFoldRuntimeCall::ID = 0;
5635
5636AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
5637 Attributor &A) {
5638 AAICVTracker *AA = nullptr;
5639 switch (IRP.getPositionKind()) {
5644 llvm_unreachable("ICVTracker can only be created for function position!");
5646 AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
5647 break;
5649 AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
5650 break;
5652 AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
5653 break;
5655 AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
5656 break;
5657 }
5658
5659 return *AA;
5660}
5661
5663 Attributor &A) {
5664 AAExecutionDomainFunction *AA = nullptr;
5665 switch (IRP.getPositionKind()) {
5674 "AAExecutionDomain can only be created for function position!");
5676 AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
5677 break;
5678 }
5679
5680 return *AA;
5681}
5682
5683AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
5684 Attributor &A) {
5685 AAHeapToSharedFunction *AA = nullptr;
5686 switch (IRP.getPositionKind()) {
5695 "AAHeapToShared can only be created for function position!");
5697 AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
5698 break;
5699 }
5700
5701 return *AA;
5702}
5703
5704AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
5705 Attributor &A) {
5706 AAKernelInfo *AA = nullptr;
5707 switch (IRP.getPositionKind()) {
5714 llvm_unreachable("KernelInfo can only be created for function position!");
5716 AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
5717 break;
5719 AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
5720 break;
5721 }
5722
5723 return *AA;
5724}
5725
5726AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
5727 Attributor &A) {
5728 AAFoldRuntimeCall *AA = nullptr;
5729 switch (IRP.getPositionKind()) {
5737 llvm_unreachable("KernelInfo can only be created for call site position!");
5739 AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
5740 break;
5741 }
5742
5743 return *AA;
5744}
5745
5747 if (!containsOpenMP(M))
5748 return PreservedAnalyses::all();
5750 return PreservedAnalyses::all();
5751
5754 KernelSet Kernels = getDeviceKernels(M);
5755
5757 LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt Module Pass:\n" << M);
5758
5759 auto IsCalled = [&](Function &F) {
5760 if (Kernels.contains(&F))
5761 return true;
5762 return !F.use_empty();
5763 };
5764
5765 auto EmitRemark = [&](Function &F) {
5766 auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
5767 ORE.emit([&]() {
5768 OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
5769 return ORA << "Could not internalize function. "
5770 << "Some optimizations may not be possible. [OMP140]";
5771 });
5772 };
5773
5774 bool Changed = false;
5775
5776 // Create internal copies of each function if this is a kernel Module. This
5777 // allows iterprocedural passes to see every call edge.
5778 DenseMap<Function *, Function *> InternalizedMap;
5779 if (isOpenMPDevice(M)) {
5780 SmallPtrSet<Function *, 16> InternalizeFns;
5781 for (Function &F : M)
5782 if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
5785 InternalizeFns.insert(&F);
5786 } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
5787 EmitRemark(F);
5788 }
5789 }
5790
5791 Changed |=
5792 Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
5793 }
5794
5795 // Look at every function in the Module unless it was internalized.
5796 SetVector<Function *> Functions;
5798 for (Function &F : M)
5799 if (!F.isDeclaration() && !InternalizedMap.lookup(&F)) {
5800 SCC.push_back(&F);
5801 Functions.insert(&F);
5802 }
5803
5804 if (SCC.empty())
5806
5807 AnalysisGetter AG(FAM);
5808
5809 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5810 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
5811 };
5812
5813 BumpPtrAllocator Allocator;
5814 CallGraphUpdater CGUpdater;
5815
5816 bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5819 OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, PostLink);
5820
5821 unsigned MaxFixpointIterations =
5823
5824 AttributorConfig AC(CGUpdater);
5826 AC.IsModulePass = true;
5827 AC.RewriteSignatures = false;
5828 AC.MaxFixpointIterations = MaxFixpointIterations;
5829 AC.OREGetter = OREGetter;
5830 AC.PassName = DEBUG_TYPE;
5831 AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5832 AC.IPOAmendableCB = [](const Function &F) {
5833 return F.hasFnAttribute("kernel");
5834 };
5835
5836 Attributor A(Functions, InfoCache, AC);
5837
5838 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5839 Changed |= OMPOpt.run(true);
5840
5841 // Optionally inline device functions for potentially better performance.
5843 for (Function &F : M)
5844 if (!F.isDeclaration() && !Kernels.contains(&F) &&
5845 !F.hasFnAttribute(Attribute::NoInline))
5846 F.addFnAttr(Attribute::AlwaysInline);
5847
5849 LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M);
5850
5851 if (Changed)
5852 return PreservedAnalyses::none();
5853
5854 return PreservedAnalyses::all();
5855}
5856
5859 LazyCallGraph &CG,
5860 CGSCCUpdateResult &UR) {
5861 if (!containsOpenMP(*C.begin()->getFunction().getParent()))
5862 return PreservedAnalyses::all();
5864 return PreservedAnalyses::all();
5865
5867 // If there are kernels in the module, we have to run on all SCC's.
5868 for (LazyCallGraph::Node &N : C) {
5869 Function *Fn = &N.getFunction();
5870 SCC.push_back(Fn);
5871 }
5872
5873 if (SCC.empty())
5874 return PreservedAnalyses::all();
5875
5876 Module &M = *C.begin()->getFunction().getParent();
5877
5879 LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt CGSCC Pass:\n" << M);
5880
5882 AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
5883
5884 AnalysisGetter AG(FAM);
5885
5886 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5887 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
5888 };
5889
5890 BumpPtrAllocator Allocator;
5891 CallGraphUpdater CGUpdater;
5892 CGUpdater.initialize(CG, C, AM, UR);
5893
5894 bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5898 OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
5899 /*CGSCC*/ &Functions, PostLink);
5900
5901 unsigned MaxFixpointIterations =
5903
5904 AttributorConfig AC(CGUpdater);
5906 AC.IsModulePass = false;
5907 AC.RewriteSignatures = false;
5908 AC.MaxFixpointIterations = MaxFixpointIterations;
5909 AC.OREGetter = OREGetter;
5910 AC.PassName = DEBUG_TYPE;
5911 AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5912
5913 Attributor A(Functions, InfoCache, AC);
5914
5915 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5916 bool Changed = OMPOpt.run(false);
5917
5919 LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
5920
5921 if (Changed)
5922 return PreservedAnalyses::none();
5923
5924 return PreservedAnalyses::all();
5925}
5926
5928 return Fn.hasFnAttribute("kernel");
5929}
5930
5932 KernelSet Kernels;
5933
5934 for (Function &F : M)
5935 if (F.hasKernelCallingConv()) {
5936 // We are only interested in OpenMP target regions. Others, such as
5937 // kernels generated by CUDA but linked together, are not interesting to
5938 // this pass.
5939 if (isOpenMPKernel(F)) {
5940 ++NumOpenMPTargetRegionKernels;
5941 Kernels.insert(&F);
5942 } else
5943 ++NumNonOpenMPTargetRegionKernels;
5944 }
5945
5946 return Kernels;
5947}
5948
5950 Metadata *MD = M.getModuleFlag("openmp");
5951 if (!MD)
5952 return false;
5953
5954 return true;
5955}
5956
5958 Metadata *MD = M.getModuleFlag("openmp-device");
5959 if (!MD)
5960 return false;
5961
5962 return true;
5963}
@ Generic
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
amdgpu next use AMDGPU Next Use Analysis Printer
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
static cl::opt< unsigned > SetFixpointIterations("attributor-max-iterations", cl::Hidden, cl::desc("Maximal number of fixpoint iterations."), cl::init(32))
static const Function * getParent(const Value *V)
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
This file provides interfaces used to manipulate a call graph, regardless if it is a "old style" Call...
This file provides interfaces used to build and manipulate a call graph, which is a very useful tool ...
This file contains the declarations for the subclasses of Constant, which represent the different fla...
This file defines the DenseSet and SmallDenseSet classes.
This file defines an array type that can be indexed using scoped enum values.
#define DEBUG_TYPE
static void emitRemark(const Function &F, OptimizationRemarkEmitter &ORE, bool Skip)
Loop::LoopBounds::Direction Direction
Definition LoopInfo.cpp:253
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
Machine Check Debug Module
This file provides utility analysis objects describing memory locations.
#define T
uint64_t IntrinsicInst * II
This file defines constans and helpers used when dealing with OpenMP.
This file defines constans that will be used by both host and device compilation.
static constexpr auto TAG
static cl::opt< bool > HideMemoryTransferLatency("openmp-hide-memory-transfer-latency", cl::desc("[WIP] Tries to hide the latency of host to device memory" " transfers"), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptStateMachineRewrite("openmp-opt-disable-state-machine-rewrite", cl::desc("Disable OpenMP optimizations that replace the state machine."), cl::Hidden, cl::init(false))
static cl::opt< bool > EnableParallelRegionMerging("openmp-opt-enable-merging", cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintModuleAfterOptimizations("openmp-opt-print-module-after", cl::desc("Print the current module after OpenMP optimizations."), cl::Hidden, cl::init(false))
#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER)
#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX)
#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER)
static cl::opt< bool > PrintOpenMPKernels("openmp-print-gpu-kernels", cl::init(false), cl::Hidden)
static cl::opt< bool > DisableOpenMPOptFolding("openmp-opt-disable-folding", cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintModuleBeforeOptimizations("openmp-opt-print-module-before", cl::desc("Print the current module before OpenMP optimizations."), cl::Hidden, cl::init(false))
static cl::opt< unsigned > SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden, cl::desc("Maximal number of attributor iterations."), cl::init(256))
static cl::opt< bool > DisableInternalization("openmp-opt-disable-internalization", cl::desc("Disable function internalization."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintICVValues("openmp-print-icv-values", cl::init(false), cl::Hidden)
static cl::opt< bool > DisableOpenMPOptimizations("openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."), cl::Hidden, cl::init(false))
static cl::opt< unsigned > SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden, cl::desc("Maximum amount of shared memory to use."), cl::init(std::numeric_limits< unsigned >::max()))
static cl::opt< bool > EnableVerboseRemarks("openmp-opt-verbose-remarks", cl::desc("Enables more verbose remarks."), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptDeglobalization("openmp-opt-disable-deglobalization", cl::desc("Disable OpenMP optimizations involving deglobalization."), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptBarrierElimination("openmp-opt-disable-barrier-elimination", cl::desc("Disable OpenMP optimizations that eliminate barriers."), cl::Hidden, cl::init(false))
#define DEBUG_TYPE
Definition OpenMPOpt.cpp:67
static cl::opt< bool > DeduceICVValues("openmp-deduce-icv-values", cl::init(false), cl::Hidden)
#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX)
#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE)
static cl::opt< bool > DisableOpenMPOptSPMDization("openmp-opt-disable-spmdization", cl::desc("Disable OpenMP optimizations involving SPMD-ization."), cl::Hidden, cl::init(false))
static cl::opt< bool > AlwaysInlineDeviceFunctions("openmp-opt-inline-device", cl::desc("Inline all applicable functions on the device."), cl::Hidden, cl::init(false))
#define P(N)
FunctionAnalysisManager FAM
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
static StringRef getName(Value *V)
R600 Clause Merge
Basic Register Allocator
std::pair< BasicBlock *, BasicBlock * > Edge
static bool contains(SmallPtrSetImpl< ConstantExpr * > &Cache, ConstantExpr *Expr, Constant *C)
Definition Value.cpp:483
This file implements a set that has insertion order iteration characteristics.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition Debug.h:114
static const int BlockSize
Definition TarWriter.cpp:33
static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, const llvm::StringTable &StandardNames, VectorLibrary VecLib)
Initialize the set of available library functions based on the specified target triple.
Value * RHS
static cl::opt< unsigned > MaxThreads("xcore-max-threads", cl::Optional, cl::desc("Maximum number of threads (for emulation thread-local storage)"), cl::Hidden, cl::value_desc("number"), cl::init(8))
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
size_t size() const
size - Get the array size.
Definition ArrayRef.h:142
iterator end()
Definition BasicBlock.h:474
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:461
LLVM_ABI const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
LLVM_ABI BasicBlock * splitBasicBlock(iterator I, const Twine &BBName="")
Split the basic block into two basic blocks at the specified instruction.
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
reverse_iterator rbegin()
Definition BasicBlock.h:477
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition BasicBlock.h:206
LLVM_ABI const BasicBlock * getUniqueSuccessor() const
Return the successor of this block if it has a unique successor.
InstListType::reverse_iterator reverse_iterator
Definition BasicBlock.h:172
reverse_iterator rend()
Definition BasicBlock.h:479
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction; assumes that the block is well-formed.
Definition BasicBlock.h:237
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
void setCallingConv(CallingConv::ID CC)
bool arg_empty() const
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
bool doesNotAccessMemory(unsigned OpNo) const
LLVM_ABI bool isIndirectCall() const
Return true if the callsite is an indirect call.
bool isCallee(Value::const_user_iterator UI) const
Determine whether the passed iterator points to the callee operand's Use.
Value * getArgOperand(unsigned i) const
void setArgOperand(unsigned i, Value *v)
iterator_range< User::op_iterator > args()
Iteration adapter for range-for loops.
unsigned getArgOperandNo(const Use *U) const
Given a use for a arg operand, get the arg operand number that corresponds to it.
unsigned arg_size() const
AttributeList getAttributes() const
Return the attributes for this call.
void addParamAttr(unsigned ArgNo, Attribute::AttrKind Kind)
Adds the attribute to the indicated argument.
bool isArgOperand(const Use *U) const
bool hasOperandBundles() const
Return true if this User has any operand bundles.
LLVM_ABI Function * getCaller()
Helper to get the caller (the parent function).
Wrapper to unify "old style" CallGraph and "new style" LazyCallGraph.
void initialize(LazyCallGraph &LCG, LazyCallGraph::SCC &SCC, CGSCCAnalysisManager &AM, CGSCCUpdateResult &UR)
Initializers for usage outside of a CGSCC pass, inside a CGSCC pass in the old and new pass manager (...
static CallInst * Create(FunctionType *Ty, Value *F, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
@ ICMP_SLT
signed less than
Definition InstrTypes.h:705
@ ICMP_NE
not equal
Definition InstrTypes.h:698
static CondBrInst * Create(Value *Cond, BasicBlock *IfTrue, BasicBlock *IfFalse, InsertPosition InsertBefore=nullptr)
static LLVM_ABI Constant * getPointerCast(Constant *C, Type *Ty)
Create a BitCast, AddrSpaceCast, or a PtrToInt cast constant expression.
static LLVM_ABI Constant * getPointerBitCastOrAddrSpaceCast(Constant *C, Type *Ty)
Create a BitCast or AddrSpaceCast for a pointer type depending on the address space.
This is the shared class of boolean and integer constants.
Definition Constants.h:87
IntegerType * getIntegerType() const
Variant of the getType() method to always return an IntegerType, which reduces the amount of casting ...
Definition Constants.h:198
static LLVM_ABI ConstantInt * getTrue(LLVMContext &Context)
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition Constants.h:219
int64_t getSExtValue() const
Return the constant as a 64-bit integer value after it has been sign extended as appropriate for the ...
Definition Constants.h:174
This is an important base class in LLVM.
Definition Constant.h:43
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition DenseMap.h:205
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:241
LLVM_ABI Instruction * findNearestCommonDominator(Instruction *I1, Instruction *I2) const
Find the nearest instruction I that dominates both I1 and I2, in the sense that a result produced bef...
static ErrorSuccess success()
Create a success value.
Definition Error.h:336
AtomicOrdering getOrdering() const
Returns the ordering constraint of this fence instruction.
A proxy from a FunctionAnalysisManager to an SCC.
const BasicBlock & getEntryBlock() const
Definition Function.h:809
const BasicBlock & front() const
Definition Function.h:860
LLVMContext & getContext() const
getContext - Return a reference to the LLVMContext associated with this function.
Definition Function.cpp:358
Argument * getArg(unsigned i) const
Definition Function.h:886
bool hasFnAttribute(Attribute::AttrKind Kind) const
Return true if the function has the attribute.
Definition Function.cpp:728
LLVM_ABI bool isDeclaration() const
Return true if the primary definition of this global value is outside of the current translation unit...
Definition Globals.cpp:337
bool hasLocalLinkage() const
Module * getParent()
Get the module that this global value is contained inside of...
@ PrivateLinkage
Like Internal, but omit from symbol table.
Definition GlobalValue.h:61
@ InternalLinkage
Rename collisions when linking (static functions).
Definition GlobalValue.h:60
const Constant * getInitializer() const
getInitializer - Return the initializer for this global variable.
LLVM_ABI void setInitializer(Constant *InitVal)
setInitializer - Sets the initializer for this global variable, removing any existing initializer if ...
Definition Globals.cpp:542
BasicBlock * getBlock() const
Definition IRBuilder.h:313
CondBrInst * CreateCondBr(Value *Cond, BasicBlock *True, BasicBlock *False, MDNode *BranchWeights=nullptr, MDNode *Unpredictable=nullptr)
Create a conditional 'br Cond, TrueDest, FalseDest' instruction.
Definition IRBuilder.h:1237
CallInst * CreateCall(FunctionType *FTy, Value *Callee, ArrayRef< Value * > Args={}, const Twine &Name="", MDNode *FPMathTag=nullptr)
Definition IRBuilder.h:2546
Value * CreateIsNull(Value *Arg, const Twine &Name="")
Return a boolean value testing if Arg == 0.
Definition IRBuilder.h:2695
LLVM_ABI bool isLifetimeStartOrEnd() const LLVM_READONLY
Return true if the instruction is a llvm.lifetime.start or llvm.lifetime.end marker.
LLVM_ABI bool mayWriteToMemory() const LLVM_READONLY
Return true if this instruction may modify memory.
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
LLVM_ABI const Module * getModule() const
Return the module owning the function this instruction belongs to or nullptr it the function does not...
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
LLVM_ABI const Function * getFunction() const
Return the function this instruction belongs to.
LLVM_ABI bool mayHaveSideEffects() const LLVM_READONLY
Return true if the instruction may have side effects.
LLVM_ABI bool mayReadFromMemory() const LLVM_READONLY
Return true if this instruction may read memory.
void setDebugLoc(DebugLoc Loc)
Set the debug location information for this instruction.
LLVM_ABI void setSuccessor(unsigned Idx, BasicBlock *BB)
Update the specified successor to point at the provided block.
A node in the call graph.
An SCC of the call graph.
A lazily constructed view of the call graph of a module.
LLVM_ABI void eraseFromParent()
This method unlinks 'this' from the containing function and deletes it.
LLVM_ABI StringRef getName() const
Return the name of the corresponding LLVM basic block, or an empty string.
Root of the metadata hierarchy.
Definition Metadata.h:64
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
const Triple & getTargetTriple() const
Get the target triple which is a string describing the target host.
Definition Module.h:281
LLVM_ABI Constant * getOrCreateIdent(Constant *SrcLocStr, uint32_t SrcLocStrSize, omp::IdentFlag Flags=omp::IdentFlag(0), unsigned Reserve2Flags=0)
Return an ident_t* encoding the source location SrcLocStr and Flags.
LLVM_ABI FunctionCallee getOrCreateRuntimeFunction(Module &M, omp::RuntimeFunction FnID)
Return the function declaration for the runtime function with FnID.
static LLVM_ABI std::pair< int32_t, int32_t > readThreadBoundsForKernel(const Triple &T, Function &Kernel)
}
LLVM_ABI Constant * getOrCreateSrcLocStr(StringRef LocStr, uint32_t &SrcLocStrSize)
Return the (LLVM-IR) string describing the source location LocStr.
IRBuilder<>::InsertPoint InsertPointTy
Type used throughout for insertion points.
IRBuilder Builder
The LLVM-IR Builder used to create IR.
static LLVM_ABI std::pair< int32_t, int32_t > readTeamBoundsForKernel(const Triple &T, Function &Kernel)
Read/write a bounds on teams for Kernel.
bool updateToLocation(const LocationDescription &Loc)
Update the internal location to Loc.
PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR)
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM)
Diagnostic information for optimization analysis remarks.
The optimization diagnostic interface.
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Definition Analysis.h:115
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
static ReturnInst * Create(LLVMContext &C, Value *retVal=nullptr, InsertPosition InsertBefore=nullptr)
A vector that has set insertion semantics.
Definition SetVector.h:57
size_type size() const
Determine the number of elements in the SetVector.
Definition SetVector.h:103
size_type count(const_arg_type key) const
Count the number of elements of a given key in the SetVector.
Definition SetVector.h:262
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition SetVector.h:151
size_type size() const
Definition SmallPtrSet.h:99
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
iterator begin() const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
reference emplace_back(ArgTypes &&... Args)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
bool starts_with(StringRef Prefix) const
Check if this string starts with the given Prefix.
Definition StringRef.h:258
Triple - Helper class for working with autoconf configuration names.
Definition Triple.h:47
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
Definition Type.cpp:313
LLVM_ABI unsigned getPointerAddressSpace() const
Get the address space of this pointer or pointer vector type.
static UncondBrInst * Create(BasicBlock *Target, InsertPosition InsertBefore=nullptr)
static LLVM_ABI UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
LLVM_ABI bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Definition User.cpp:25
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:255
LLVM_ABI void setName(const Twine &Name)
Change the name of the value.
Definition Value.cpp:393
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition Value.h:439
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition Value.cpp:549
iterator_range< user_iterator > users()
Definition Value.h:426
User * user_back()
Definition Value.h:412
LLVM_ABI const Value * stripPointerCasts() const
Strip off pointer casts, all-zero GEPs and address space casts.
Definition Value.cpp:709
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:318
const ParentTy * getParent() const
Definition ilist_node.h:34
self_iterator getIterator()
Definition ilist_node.h:123
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition ilist_node.h:348
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
GlobalVariable * getKernelEnvironementGVFromKernelInitCB(CallBase *KernelInitCB)
ConstantStruct * getKernelEnvironementFromKernelInitCB(CallBase *KernelInitCB)
Abstract Attribute helper functions.
Definition Attributor.h:165
LLVM_ABI bool isValidAtPosition(const ValueAndContext &VAC, InformationCache &InfoCache)
Return true if the value of VAC is a valid at the position of VAC, that is a constant,...
LLVM_ABI bool isPotentiallyAffectedByBarrier(Attributor &A, const Instruction &I, const AbstractAttribute &QueryingAA)
Return true if I is potentially affected by a barrier.
@ Interprocedural
Definition Attributor.h:196
LLVM_ABI bool isNoSyncInst(Attributor &A, const Instruction &I, const AbstractAttribute &QueryingAA)
Return true if I is a nosync instruction.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
E & operator^=(E &LHS, E RHS)
@ Entry
Definition COFF.h:862
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
initializer< Ty > init(const Ty &Val)
DXILDebugInfoMap run(Module &M)
bool isOpenMPDevice(Module &M)
Helper to determine if M is a OpenMP target offloading device module.
bool containsOpenMP(Module &M)
Helper to determine if M contains OpenMP.
InternalControlVar
IDs for all Internal Control Variables (ICVs).
RuntimeFunction
IDs for all omp runtime library (RTL) functions.
KernelSet getDeviceKernels(Module &M)
Get OpenMP device kernels in M.
@ OMP_TGT_EXEC_MODE_GENERIC_SPMD
SetVector< Kernel > KernelSet
Set of kernels in the module.
Definition OpenMPOpt.h:24
Function * Kernel
Summary of a kernel (=entry point for target offloading).
Definition OpenMPOpt.h:21
bool isOpenMPKernel(Function &Fn)
Return true iff Fn is an OpenMP GPU kernel; Fn has the "kernel" attribute.
DiagnosticInfoOptimizationBase::Argument NV
NodeAddr< UseNode * > Use
Definition RDFGraph.h:385
bool empty() const
Definition BasicBlock.h:101
iterator end() const
Definition BasicBlock.h:89
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
LLVM_ABI iterator begin() const
This is an optimization pass for GlobalISel generic memory operations.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:315
@ Offset
Definition DWP.cpp:557
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1738
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition STLExtras.h:1668
bool succ_empty(const Instruction *I)
Definition CFG.h:153
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
FunctionAddr VTableAddr uintptr_t uintptr_t Int32Ty
Definition InstrProf.h:328
bool operator!=(uint64_t V1, const APInt &V2)
Definition APInt.h:2142
constexpr from_range_t from_range
Value * GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, const DataLayout &DL, bool AllowNonInbounds=true)
Analyze the specified pointer to see if it can be expressed as a base pointer plus a constant offset.
InnerAnalysisManagerProxy< FunctionAnalysisManager, Module > FunctionAnalysisManagerModuleProxy
Provide the FunctionAnalysisManager to Module proxy.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
AnalysisManager< LazyCallGraph::SCC, LazyCallGraph & > CGSCCAnalysisManager
The CGSCC analysis manager.
@ ThinLTOPostLink
ThinLTO postlink (backend compile) phase.
Definition Pass.h:83
@ FullLTOPostLink
Full LTO postlink (backend compile) phase.
Definition Pass.h:87
@ ThinLTOPreLink
ThinLTO prelink (summary) phase.
Definition Pass.h:81
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:753
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
MutableArrayRef(T &OneElt) -> MutableArrayRef< T >
void cantFail(Error Err, const char *Msg=nullptr)
Report a fatal error if Err is a failure value.
Definition Error.h:769
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
bool operator&=(SparseBitVector< ElementSize > *LHS, const SparseBitVector< ElementSize > &RHS)
LLVM_ABI BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="")
Split the specified block at the specified instruction.
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition STLExtras.h:2011
ArrayRef(const T &OneElt) -> ArrayRef< T >
std::string toString(const APInt &I, unsigned Radix, bool Signed, bool formatAsCLiteral=false, bool UpperCase=true, bool InsertSeparators=false)
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
auto predecessors(const MachineBasicBlock *BB)
ChangeStatus
{
Definition Attributor.h:508
LLVM_ABI Constant * ConstantFoldInsertValueInstruction(Constant *Agg, Constant *Val, ArrayRef< unsigned > Idxs)
Attempt to constant fold an insertvalue instruction with the specified operands and indices.
@ OPTIONAL
The target may be valid if the source is not.
Definition Attributor.h:520
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
BumpPtrAllocatorImpl<> BumpPtrAllocator
The standard BumpPtrAllocator which just uses the default template parameters.
Definition Allocator.h:383
LLVM_ABI const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=MaxLookupSearchDepth)
This method strips off any GEP address adjustments, pointer casts or llvm.threadlocal....
AnalysisManager< Module > ModuleAnalysisManager
Convenience typedef for the Module analysis manager.
Definition MIRParser.h:39
#define N
static LLVM_ABI AAExecutionDomain & createForPosition(const IRPosition &IRP, Attributor &A)
Create an abstract attribute view for the position IRP.
AAExecutionDomain(const IRPosition &IRP, Attributor &A)
static LLVM_ABI const char ID
Unique ID (due to the unique address)
AccessKind
Simple enum to distinguish read/write/read-write accesses.
StateType::base_t MemoryLocationsKind
static LLVM_ABI bool isAlignedBarrier(const CallBase &CB, bool ExecutedAligned)
Helper function to determine if CB is an aligned (GPU) barrier.
Base struct for all "concrete attribute" deductions.
virtual const char * getIdAddr() const =0
This function should return the address of the ID of the AbstractAttribute.
An interface to query the internal state of an abstract attribute.
Wrapper for FunctionAnalysisManager.
Configuration for the Attributor.
std::function< void(Attributor &A, const Function &F)> InitializationCallback
Callback function to be invoked on internal functions marked live.
std::optional< unsigned > MaxFixpointIterations
Maximum number of iterations to run until fixpoint.
bool RewriteSignatures
Flag to determine if we rewrite function signatures.
const char * PassName
}
OptimizationRemarkGetter OREGetter
IPOAmendableCBTy IPOAmendableCB
bool IsModulePass
Is the user of the Attributor a module pass or not.
bool DefaultInitializeLiveInternals
Flag to determine if we want to initialize all default AAs for an internal function marked live.
The fixpoint analysis framework that orchestrates the attribute deduction.
static LLVM_ABI bool isInternalizable(Function &F)
Returns true if the function F can be internalized.
std::function< std::optional< Value * >( const IRPosition &, const AbstractAttribute *, bool &)> SimplifictionCallbackTy
Register CB as a simplification callback.
std::function< std::optional< Constant * >( const GlobalVariable &, const AbstractAttribute *, bool &)> GlobalVariableSimplifictionCallbackTy
Register CB as a simplification callback.
std::function< bool(Attributor &, const AbstractAttribute *)> VirtualUseCallbackTy
static LLVM_ABI bool internalizeFunctions(SmallPtrSetImpl< Function * > &FnSet, DenseMap< Function *, Function * > &FnMap)
Make copies of each function in the set FnSet such that the copied version has internal linkage after...
Simple wrapper for a single bit (boolean) state.
Support structure for SCC passes to communicate updates the call graph back to the CGSCC pass manager...
Helper to describe and deal with positions in the LLVM-IR.
Definition Attributor.h:605
static const IRPosition callsite_returned(const CallBase &CB)
Create a position describing the returned value of CB.
Definition Attributor.h:673
static const IRPosition returned(const Function &F, const CallBaseContext *CBContext=nullptr)
Create a position describing the returned value of F.
Definition Attributor.h:655
static const IRPosition value(const Value &V, const CallBaseContext *CBContext=nullptr)
Create a position describing the value of V.
Definition Attributor.h:629
static const IRPosition inst(const Instruction &I, const CallBaseContext *CBContext=nullptr)
Create a position describing the instruction I.
Definition Attributor.h:641
@ IRP_ARGUMENT
An attribute for a function argument.
Definition Attributor.h:619
@ IRP_RETURNED
An attribute for the function return value.
Definition Attributor.h:615
@ IRP_CALL_SITE
An attribute for a call site (function scope).
Definition Attributor.h:618
@ IRP_CALL_SITE_RETURNED
An attribute for a call site return value.
Definition Attributor.h:616
@ IRP_FUNCTION
An attribute for a function (scope).
Definition Attributor.h:617
@ IRP_FLOAT
A position that is not associated with a spot suitable for attributes.
Definition Attributor.h:613
@ IRP_CALL_SITE_ARGUMENT
An attribute for a call site argument.
Definition Attributor.h:620
@ IRP_INVALID
An invalid position.
Definition Attributor.h:612
static const IRPosition function(const Function &F, const CallBaseContext *CBContext=nullptr)
Create a position describing the function scope of F.
Definition Attributor.h:648
Kind getPositionKind() const
Return the associated position kind.
Definition Attributor.h:901
static const IRPosition callsite_function(const CallBase &CB)
Create a position describing the function scope of CB.
Definition Attributor.h:668
Data structure to hold cached (LLVM-IR) information.
Defines various target-specific GPU grid values that must be consistent between host RTL (plugin),...