LLVM 23.0.0git
OpenMPOpt.cpp
Go to the documentation of this file.
1//===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// OpenMP specific optimizations:
10//
11// - Deduplication of runtime calls, e.g., omp_get_thread_num.
12// - Replacing globalized device memory with stack memory.
13// - Replacing globalized device memory with shared memory.
14// - Parallel region merging.
15// - Transforming generic-mode device kernels to SPMD mode.
16// - Specializing the state machine for generic-mode device kernels.
17//
18//===----------------------------------------------------------------------===//
19
21
22#include "llvm/ADT/DenseSet.h"
25#include "llvm/ADT/SetVector.h"
28#include "llvm/ADT/Statistic.h"
30#include "llvm/ADT/StringRef.h"
38#include "llvm/IR/Assumptions.h"
39#include "llvm/IR/BasicBlock.h"
40#include "llvm/IR/Constants.h"
42#include "llvm/IR/Dominators.h"
43#include "llvm/IR/Function.h"
44#include "llvm/IR/GlobalValue.h"
46#include "llvm/IR/InstrTypes.h"
47#include "llvm/IR/Instruction.h"
50#include "llvm/IR/IntrinsicsAMDGPU.h"
51#include "llvm/IR/IntrinsicsNVPTX.h"
52#include "llvm/IR/LLVMContext.h"
55#include "llvm/Support/Debug.h"
59
60#include <algorithm>
61#include <optional>
62#include <string>
63
64using namespace llvm;
65using namespace omp;
66
67#define DEBUG_TYPE "openmp-opt"
68
70 "openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."),
71 cl::Hidden, cl::init(false));
72
74 "openmp-opt-enable-merging",
75 cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
76 cl::init(false));
77
78static cl::opt<bool>
79 DisableInternalization("openmp-opt-disable-internalization",
80 cl::desc("Disable function internalization."),
81 cl::Hidden, cl::init(false));
82
83static cl::opt<bool> DeduceICVValues("openmp-deduce-icv-values",
84 cl::init(false), cl::Hidden);
85static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
87static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
88 cl::init(false), cl::Hidden);
89
91 "openmp-hide-memory-transfer-latency",
92 cl::desc("[WIP] Tries to hide the latency of host to device memory"
93 " transfers"),
94 cl::Hidden, cl::init(false));
95
97 "openmp-opt-disable-deglobalization",
98 cl::desc("Disable OpenMP optimizations involving deglobalization."),
99 cl::Hidden, cl::init(false));
100
102 "openmp-opt-disable-spmdization",
103 cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
104 cl::Hidden, cl::init(false));
105
107 "openmp-opt-disable-folding",
108 cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
109 cl::init(false));
110
112 "openmp-opt-disable-state-machine-rewrite",
113 cl::desc("Disable OpenMP optimizations that replace the state machine."),
114 cl::Hidden, cl::init(false));
115
117 "openmp-opt-disable-barrier-elimination",
118 cl::desc("Disable OpenMP optimizations that eliminate barriers."),
119 cl::Hidden, cl::init(false));
120
122 "openmp-opt-print-module-after",
123 cl::desc("Print the current module after OpenMP optimizations."),
124 cl::Hidden, cl::init(false));
125
127 "openmp-opt-print-module-before",
128 cl::desc("Print the current module before OpenMP optimizations."),
129 cl::Hidden, cl::init(false));
130
132 "openmp-opt-inline-device",
133 cl::desc("Inline all applicable functions on the device."), cl::Hidden,
134 cl::init(false));
135
136static cl::opt<bool>
137 EnableVerboseRemarks("openmp-opt-verbose-remarks",
138 cl::desc("Enables more verbose remarks."), cl::Hidden,
139 cl::init(false));
140
142 SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden,
143 cl::desc("Maximal number of attributor iterations."),
144 cl::init(256));
145
147 SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden,
148 cl::desc("Maximum amount of shared memory to use."),
149 cl::init(std::numeric_limits<unsigned>::max()));
150
151STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
152 "Number of OpenMP runtime calls deduplicated");
153STATISTIC(NumOpenMPParallelRegionsDeleted,
154 "Number of OpenMP parallel regions deleted");
155STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
156 "Number of OpenMP runtime functions identified");
157STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
158 "Number of OpenMP runtime function uses identified");
159STATISTIC(NumOpenMPTargetRegionKernels,
160 "Number of OpenMP target region entry points (=kernels) identified");
161STATISTIC(NumNonOpenMPTargetRegionKernels,
162 "Number of non-OpenMP target region kernels identified");
163STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
164 "Number of OpenMP target region entry points (=kernels) executed in "
165 "SPMD-mode instead of generic-mode");
166STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
167 "Number of OpenMP target region entry points (=kernels) executed in "
168 "generic-mode without a state machines");
169STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
170 "Number of OpenMP target region entry points (=kernels) executed in "
171 "generic-mode with customized state machines with fallback");
172STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
173 "Number of OpenMP target region entry points (=kernels) executed in "
174 "generic-mode with customized state machines without fallback");
176 NumOpenMPParallelRegionsReplacedInGPUStateMachine,
177 "Number of OpenMP parallel regions replaced with ID in GPU state machines");
178STATISTIC(NumOpenMPParallelRegionsMerged,
179 "Number of OpenMP parallel regions merged");
180STATISTIC(NumBytesMovedToSharedMemory,
181 "Amount of memory pushed to shared memory");
182STATISTIC(NumBarriersEliminated, "Number of redundant barriers eliminated");
183
184#if !defined(NDEBUG)
185static constexpr auto TAG = "[" DEBUG_TYPE "]";
186#endif
187
188namespace KernelInfo {
189
190// struct ConfigurationEnvironmentTy {
191// uint8_t UseGenericStateMachine;
192// uint8_t MayUseNestedParallelism;
193// llvm::omp::OMPTgtExecModeFlags ExecMode;
194// int32_t MinThreads;
195// int32_t MaxThreads;
196// int32_t MinTeams;
197// int32_t MaxTeams;
198// };
199
200// struct DynamicEnvironmentTy {
201// uint16_t DebugIndentionLevel;
202// };
203
204// struct KernelEnvironmentTy {
205// ConfigurationEnvironmentTy Configuration;
206// IdentTy *Ident;
207// DynamicEnvironmentTy *DynamicEnv;
208// };
209
210#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX) \
211 constexpr unsigned MEMBER##Idx = IDX;
212
213KERNEL_ENVIRONMENT_IDX(Configuration, 0)
215
216#undef KERNEL_ENVIRONMENT_IDX
217
218#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX) \
219 constexpr unsigned MEMBER##Idx = IDX;
220
221KERNEL_ENVIRONMENT_CONFIGURATION_IDX(UseGenericStateMachine, 0)
222KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MayUseNestedParallelism, 1)
228
229#undef KERNEL_ENVIRONMENT_CONFIGURATION_IDX
230
231#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE) \
232 RETURNTYPE *get##MEMBER##FromKernelEnvironment(ConstantStruct *KernelEnvC) { \
233 return cast<RETURNTYPE>(KernelEnvC->getAggregateElement(MEMBER##Idx)); \
234 }
235
238
239#undef KERNEL_ENVIRONMENT_GETTER
240
241#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER) \
242 ConstantInt *get##MEMBER##FromKernelEnvironment( \
243 ConstantStruct *KernelEnvC) { \
244 ConstantStruct *ConfigC = \
245 getConfigurationFromKernelEnvironment(KernelEnvC); \
246 return dyn_cast<ConstantInt>(ConfigC->getAggregateElement(MEMBER##Idx)); \
247 }
248
249KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(UseGenericStateMachine)
250KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MayUseNestedParallelism)
256
257#undef KERNEL_ENVIRONMENT_CONFIGURATION_GETTER
258
261 constexpr int InitKernelEnvironmentArgNo = 0;
263 KernelInitCB->getArgOperand(InitKernelEnvironmentArgNo)
265}
266
272} // namespace KernelInfo
273
274namespace {
275
276struct AAHeapToShared;
277
278struct AAICVTracker;
279
280/// OpenMP specific information. For now, stores RFIs and ICVs also needed for
281/// Attributor runs.
282struct OMPInformationCache : public InformationCache {
283 OMPInformationCache(Module &M, AnalysisGetter &AG,
284 BumpPtrAllocator &Allocator, SetVector<Function *> *CGSCC,
285 bool OpenMPPostLink)
286 : InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),
287 OpenMPPostLink(OpenMPPostLink) {
288
289 OMPBuilder.Config.IsTargetDevice = isOpenMPDevice(OMPBuilder.M);
290 const Triple T(OMPBuilder.M.getTargetTriple());
291 switch (T.getArch()) {
295 assert(OMPBuilder.Config.IsTargetDevice &&
296 "OpenMP AMDGPU/NVPTX is only prepared to deal with device code.");
297 OMPBuilder.Config.IsGPU = true;
298 break;
299 default:
300 OMPBuilder.Config.IsGPU = false;
301 break;
302 }
303 OMPBuilder.initialize();
304 initializeRuntimeFunctions(M);
305 initializeInternalControlVars();
306 }
307
308 /// Generic information that describes an internal control variable.
309 struct InternalControlVarInfo {
310 /// The kind, as described by InternalControlVar enum.
312
313 /// The name of the ICV.
314 StringRef Name;
315
316 /// Environment variable associated with this ICV.
317 StringRef EnvVarName;
318
319 /// Initial value kind.
320 ICVInitValue InitKind;
321
322 /// Initial value.
323 ConstantInt *InitValue;
324
325 /// Setter RTL function associated with this ICV.
326 RuntimeFunction Setter;
327
328 /// Getter RTL function associated with this ICV.
329 RuntimeFunction Getter;
330
331 /// RTL Function corresponding to the override clause of this ICV
332 RuntimeFunction Clause;
333 };
334
335 /// Generic information that describes a runtime function
336 struct RuntimeFunctionInfo {
337
338 /// The kind, as described by the RuntimeFunction enum.
339 RuntimeFunction Kind;
340
341 /// The name of the function.
342 StringRef Name;
343
344 /// Flag to indicate a variadic function.
345 bool IsVarArg;
346
347 /// The return type of the function.
348 Type *ReturnType;
349
350 /// The argument types of the function.
351 SmallVector<Type *, 8> ArgumentTypes;
352
353 /// The declaration if available.
354 Function *Declaration = nullptr;
355
356 /// Uses of this runtime function per function containing the use.
357 using UseVector = SmallVector<Use *, 16>;
358
359 /// Clear UsesMap for runtime function.
360 void clearUsesMap() { UsesMap.clear(); }
361
362 /// Boolean conversion that is true if the runtime function was found.
363 operator bool() const { return Declaration; }
364
365 /// Return the vector of uses in function \p F.
366 UseVector &getOrCreateUseVector(Function *F) {
367 std::shared_ptr<UseVector> &UV = UsesMap[F];
368 if (!UV)
369 UV = std::make_shared<UseVector>();
370 return *UV;
371 }
372
373 /// Return the vector of uses in function \p F or `nullptr` if there are
374 /// none.
375 const UseVector *getUseVector(Function &F) const {
376 auto I = UsesMap.find(&F);
377 if (I != UsesMap.end())
378 return I->second.get();
379 return nullptr;
380 }
381
382 /// Return how many functions contain uses of this runtime function.
383 size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
384
385 /// Return the number of arguments (or the minimal number for variadic
386 /// functions).
387 size_t getNumArgs() const { return ArgumentTypes.size(); }
388
389 /// Run the callback \p CB on each use and forget the use if the result is
390 /// true. The callback will be fed the function in which the use was
391 /// encountered as second argument.
392 void foreachUse(SmallVectorImpl<Function *> &SCC,
393 function_ref<bool(Use &, Function &)> CB) {
394 for (Function *F : SCC)
395 foreachUse(CB, F);
396 }
397
398 /// Run the callback \p CB on each use within the function \p F and forget
399 /// the use if the result is true.
400 void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
401 SmallVector<unsigned, 8> ToBeDeleted;
402 ToBeDeleted.clear();
403
404 unsigned Idx = 0;
405 UseVector &UV = getOrCreateUseVector(F);
406
407 for (Use *U : UV) {
408 if (CB(*U, *F))
409 ToBeDeleted.push_back(Idx);
410 ++Idx;
411 }
412
413 // Remove the to-be-deleted indices in reverse order as prior
414 // modifications will not modify the smaller indices.
415 while (!ToBeDeleted.empty()) {
416 unsigned Idx = ToBeDeleted.pop_back_val();
417 UV[Idx] = UV.back();
418 UV.pop_back();
419 }
420 }
421
422 private:
423 /// Map from functions to all uses of this runtime function contained in
424 /// them.
425 DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap;
426
427 public:
428 /// Iterators for the uses of this runtime function.
429 decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
430 decltype(UsesMap)::iterator end() { return UsesMap.end(); }
431 };
432
433 /// An OpenMP-IR-Builder instance
434 OpenMPIRBuilder OMPBuilder;
435
436 /// Map from runtime function kind to the runtime function description.
437 EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
438 RuntimeFunction::OMPRTL___last>
439 RFIs;
440
441 /// Map from function declarations/definitions to their runtime enum type.
442 DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
443
444 /// Map from ICV kind to the ICV description.
445 EnumeratedArray<InternalControlVarInfo, InternalControlVar,
446 InternalControlVar::ICV___last>
447 ICVs;
448
449 /// Helper to initialize all internal control variable information for those
450 /// defined in OMPKinds.def.
451 void initializeInternalControlVars() {
452#define ICV_RT_SET(_Name, RTL) \
453 { \
454 auto &ICV = ICVs[_Name]; \
455 ICV.Setter = RTL; \
456 }
457#define ICV_RT_GET(Name, RTL) \
458 { \
459 auto &ICV = ICVs[Name]; \
460 ICV.Getter = RTL; \
461 }
462#define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \
463 { \
464 auto &ICV = ICVs[Enum]; \
465 ICV.Name = _Name; \
466 ICV.Kind = Enum; \
467 ICV.InitKind = Init; \
468 ICV.EnvVarName = _EnvVarName; \
469 switch (ICV.InitKind) { \
470 case ICV_IMPLEMENTATION_DEFINED: \
471 ICV.InitValue = nullptr; \
472 break; \
473 case ICV_ZERO: \
474 ICV.InitValue = ConstantInt::get( \
475 Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \
476 break; \
477 case ICV_FALSE: \
478 ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \
479 break; \
480 case ICV_LAST: \
481 break; \
482 } \
483 }
484#include "llvm/Frontend/OpenMP/OMPKinds.def"
485 }
486
487 /// Returns true if the function declaration \p F matches the runtime
488 /// function types, that is, return type \p RTFRetType, and argument types
489 /// \p RTFArgTypes.
490 static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
491 SmallVector<Type *, 8> &RTFArgTypes) {
492 // TODO: We should output information to the user (under debug output
493 // and via remarks).
494
495 if (!F)
496 return false;
497 if (F->getReturnType() != RTFRetType)
498 return false;
499 if (F->arg_size() != RTFArgTypes.size())
500 return false;
501
502 auto *RTFTyIt = RTFArgTypes.begin();
503 for (Argument &Arg : F->args()) {
504 if (Arg.getType() != *RTFTyIt)
505 return false;
506
507 ++RTFTyIt;
508 }
509
510 return true;
511 }
512
513 // Helper to collect all uses of the declaration in the UsesMap.
514 unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
515 unsigned NumUses = 0;
516 if (!RFI.Declaration)
517 return NumUses;
518 OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
519
520 if (CollectStats) {
521 NumOpenMPRuntimeFunctionsIdentified += 1;
522 NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
523 }
524
525 // TODO: We directly convert uses into proper calls and unknown uses.
526 for (Use &U : RFI.Declaration->uses()) {
527 if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
528 if (!CGSCC || CGSCC->empty() || CGSCC->contains(UserI->getFunction())) {
529 RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
530 ++NumUses;
531 }
532 } else {
533 RFI.getOrCreateUseVector(nullptr).push_back(&U);
534 ++NumUses;
535 }
536 }
537 return NumUses;
538 }
539
540 // Helper function to recollect uses of a runtime function.
541 void recollectUsesForFunction(RuntimeFunction RTF) {
542 auto &RFI = RFIs[RTF];
543 RFI.clearUsesMap();
544 collectUses(RFI, /*CollectStats*/ false);
545 }
546
547 // Helper function to recollect uses of all runtime functions.
548 void recollectUses() {
549 for (int Idx = 0; Idx < RFIs.size(); ++Idx)
550 recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
551 }
552
553 // Helper function to inherit the calling convention of the function callee.
554 void setCallingConvention(FunctionCallee Callee, CallInst *CI) {
555 if (Function *Fn = dyn_cast<Function>(Callee.getCallee()))
556 CI->setCallingConv(Fn->getCallingConv());
557 }
558
559 // Helper function to determine if it's legal to create a call to the runtime
560 // functions.
561 bool runtimeFnsAvailable(ArrayRef<RuntimeFunction> Fns) {
562 // We can always emit calls if we haven't yet linked in the runtime.
563 if (!OpenMPPostLink)
564 return true;
565
566 // Once the runtime has been already been linked in we cannot emit calls to
567 // any undefined functions.
568 for (RuntimeFunction Fn : Fns) {
569 RuntimeFunctionInfo &RFI = RFIs[Fn];
570
571 if (!RFI.Declaration || RFI.Declaration->isDeclaration())
572 return false;
573 }
574 return true;
575 }
576
577 /// Helper to initialize all runtime function information for those defined
578 /// in OpenMPKinds.def.
579 void initializeRuntimeFunctions(Module &M) {
580
581 // Helper macros for handling __VA_ARGS__ in OMP_RTL
582#define OMP_TYPE(VarName, ...) \
583 Type *VarName = OMPBuilder.VarName; \
584 (void)VarName;
585
586#define OMP_ARRAY_TYPE(VarName, ...) \
587 ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \
588 (void)VarName##Ty; \
589 PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \
590 (void)VarName##PtrTy;
591
592#define OMP_FUNCTION_TYPE(VarName, ...) \
593 FunctionType *VarName = OMPBuilder.VarName; \
594 (void)VarName; \
595 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
596 (void)VarName##Ptr;
597
598#define OMP_STRUCT_TYPE(VarName, ...) \
599 StructType *VarName = OMPBuilder.VarName; \
600 (void)VarName; \
601 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
602 (void)VarName##Ptr;
603
604#define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \
605 { \
606 SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \
607 Function *F = M.getFunction(_Name); \
608 RTLFunctions.insert(F); \
609 if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \
610 RuntimeFunctionIDMap[F] = _Enum; \
611 auto &RFI = RFIs[_Enum]; \
612 RFI.Kind = _Enum; \
613 RFI.Name = _Name; \
614 RFI.IsVarArg = _IsVarArg; \
615 RFI.ReturnType = OMPBuilder._ReturnType; \
616 RFI.ArgumentTypes = std::move(ArgsTypes); \
617 RFI.Declaration = F; \
618 unsigned NumUses = collectUses(RFI); \
619 (void)NumUses; \
620 LLVM_DEBUG({ \
621 dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \
622 << " found\n"; \
623 if (RFI.Declaration) \
624 dbgs() << TAG << "-> got " << NumUses << " uses in " \
625 << RFI.getNumFunctionsWithUses() \
626 << " different functions.\n"; \
627 }); \
628 } \
629 }
630#include "llvm/Frontend/OpenMP/OMPKinds.def"
631
632 // Remove the `noinline` attribute from `__kmpc`, `ompx::` and `omp_`
633 // functions, except if `optnone` is present.
634 if (isOpenMPDevice(M)) {
635 for (Function &F : M) {
636 for (StringRef Prefix : {"__kmpc", "_ZN4ompx", "omp_"})
637 if (F.hasFnAttribute(Attribute::NoInline) &&
638 F.getName().starts_with(Prefix) &&
639 !F.hasFnAttribute(Attribute::OptimizeNone))
640 F.removeFnAttr(Attribute::NoInline);
641 }
642 }
643
644 // TODO: We should attach the attributes defined in OMPKinds.def.
645 }
646
647 /// Collection of known OpenMP runtime functions..
648 DenseSet<const Function *> RTLFunctions;
649
650 /// Indicates if we have already linked in the OpenMP device library.
651 bool OpenMPPostLink = false;
652};
653
654template <typename Ty, bool InsertInvalidates = true>
655struct BooleanStateWithSetVector : public BooleanState {
656 bool contains(const Ty &Elem) const { return Set.contains(Elem); }
657 bool insert(const Ty &Elem) {
658 if (InsertInvalidates)
659 BooleanState::indicatePessimisticFixpoint();
660 return Set.insert(Elem);
661 }
662
663 const Ty &operator[](int Idx) const { return Set[Idx]; }
664 bool operator==(const BooleanStateWithSetVector &RHS) const {
665 return BooleanState::operator==(RHS) && Set == RHS.Set;
666 }
667 bool operator!=(const BooleanStateWithSetVector &RHS) const {
668 return !(*this == RHS);
669 }
670
671 bool empty() const { return Set.empty(); }
672 size_t size() const { return Set.size(); }
673
674 /// "Clamp" this state with \p RHS.
675 BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
676 BooleanState::operator^=(RHS);
677 Set.insert_range(RHS.Set);
678 return *this;
679 }
680
681private:
682 /// A set to keep track of elements.
683 SetVector<Ty> Set;
684
685public:
686 typename decltype(Set)::iterator begin() { return Set.begin(); }
687 typename decltype(Set)::iterator end() { return Set.end(); }
688 typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
689 typename decltype(Set)::const_iterator end() const { return Set.end(); }
690};
691
692template <typename Ty, bool InsertInvalidates = true>
693using BooleanStateWithPtrSetVector =
694 BooleanStateWithSetVector<Ty *, InsertInvalidates>;
695
696struct KernelInfoState : AbstractState {
697 /// Flag to track if we reached a fixpoint.
698 bool IsAtFixpoint = false;
699
700 /// The parallel regions (identified by the outlined parallel functions) that
701 /// can be reached from the associated function.
702 BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
703 ReachedKnownParallelRegions;
704
705 /// State to track what parallel region we might reach.
706 BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
707
708 /// State to track if we are in SPMD-mode, assumed or know, and why we decided
709 /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
710 /// false.
711 BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
712
713 /// The __kmpc_target_init call in this kernel, if any. If we find more than
714 /// one we abort as the kernel is malformed.
715 CallBase *KernelInitCB = nullptr;
716
717 /// The constant kernel environement as taken from and passed to
718 /// __kmpc_target_init.
719 ConstantStruct *KernelEnvC = nullptr;
720
721 /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
722 /// one we abort as the kernel is malformed.
723 CallBase *KernelDeinitCB = nullptr;
724
725 /// Flag to indicate if the associated function is a kernel entry.
726 bool IsKernelEntry = false;
727
728 /// State to track what kernel entries can reach the associated function.
729 BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
730
731 /// State to indicate if we can track parallel level of the associated
732 /// function. We will give up tracking if we encounter unknown caller or the
733 /// caller is __kmpc_parallel_60.
734 BooleanStateWithSetVector<uint8_t> ParallelLevels;
735
736 /// Flag that indicates if the kernel has nested Parallelism
737 bool NestedParallelism = false;
738
739 /// Abstract State interface
740 ///{
741
742 KernelInfoState() = default;
743 KernelInfoState(bool BestState) {
744 if (!BestState)
745 indicatePessimisticFixpoint();
746 }
747
748 /// See AbstractState::isValidState(...)
749 bool isValidState() const override { return true; }
750
751 /// See AbstractState::isAtFixpoint(...)
752 bool isAtFixpoint() const override { return IsAtFixpoint; }
753
754 /// See AbstractState::indicatePessimisticFixpoint(...)
755 ChangeStatus indicatePessimisticFixpoint() override {
756 IsAtFixpoint = true;
757 ParallelLevels.indicatePessimisticFixpoint();
758 ReachingKernelEntries.indicatePessimisticFixpoint();
759 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
760 ReachedKnownParallelRegions.indicatePessimisticFixpoint();
761 ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
762 NestedParallelism = true;
763 return ChangeStatus::CHANGED;
764 }
765
766 /// See AbstractState::indicateOptimisticFixpoint(...)
767 ChangeStatus indicateOptimisticFixpoint() override {
768 IsAtFixpoint = true;
769 ParallelLevels.indicateOptimisticFixpoint();
770 ReachingKernelEntries.indicateOptimisticFixpoint();
771 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
772 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
773 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
774 return ChangeStatus::UNCHANGED;
775 }
776
777 /// Return the assumed state
778 KernelInfoState &getAssumed() { return *this; }
779 const KernelInfoState &getAssumed() const { return *this; }
780
781 bool operator==(const KernelInfoState &RHS) const {
782 if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
783 return false;
784 if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
785 return false;
786 if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
787 return false;
788 if (ReachingKernelEntries != RHS.ReachingKernelEntries)
789 return false;
790 if (ParallelLevels != RHS.ParallelLevels)
791 return false;
792 if (NestedParallelism != RHS.NestedParallelism)
793 return false;
794 return true;
795 }
796
797 /// Returns true if this kernel contains any OpenMP parallel regions.
798 bool mayContainParallelRegion() {
799 return !ReachedKnownParallelRegions.empty() ||
800 !ReachedUnknownParallelRegions.empty();
801 }
802
803 /// Return empty set as the best state of potential values.
804 static KernelInfoState getBestState() { return KernelInfoState(true); }
805
806 static KernelInfoState getBestState(KernelInfoState &KIS) {
807 return getBestState();
808 }
809
810 /// Return full set as the worst state of potential values.
811 static KernelInfoState getWorstState() { return KernelInfoState(false); }
812
813 /// "Clamp" this state with \p KIS.
814 KernelInfoState operator^=(const KernelInfoState &KIS) {
815 // Do not merge two different _init and _deinit call sites.
816 if (KIS.KernelInitCB) {
817 if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
818 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
819 "assumptions.");
820 KernelInitCB = KIS.KernelInitCB;
821 }
822 if (KIS.KernelDeinitCB) {
823 if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
824 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
825 "assumptions.");
826 KernelDeinitCB = KIS.KernelDeinitCB;
827 }
828 if (KIS.KernelEnvC) {
829 if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
830 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
831 "assumptions.");
832 KernelEnvC = KIS.KernelEnvC;
833 }
834 SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
835 ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
836 ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
837 NestedParallelism |= KIS.NestedParallelism;
838 return *this;
839 }
840
841 KernelInfoState operator&=(const KernelInfoState &KIS) {
842 return (*this ^= KIS);
843 }
844
845 ///}
846};
847
848/// Used to map the values physically (in the IR) stored in an offload
849/// array, to a vector in memory.
850struct OffloadArray {
851 /// Physical array (in the IR).
852 AllocaInst *Array = nullptr;
853 /// Mapped values.
854 SmallVector<Value *, 8> StoredValues;
855 /// Last stores made in the offload array.
856 SmallVector<StoreInst *, 8> LastAccesses;
857
858 OffloadArray() = default;
859
860 /// Initializes the OffloadArray with the values stored in \p Array before
861 /// instruction \p Before is reached. Returns false if the initialization
862 /// fails.
863 /// This MUST be used immediately after the construction of the object.
864 bool initialize(AllocaInst &Array, Instruction &Before) {
865 if (!getValues(Array, Before))
866 return false;
867
868 this->Array = &Array;
869 return true;
870 }
871
872 static const unsigned DeviceIDArgNum = 1;
873 static const unsigned BasePtrsArgNum = 3;
874 static const unsigned PtrsArgNum = 4;
875 static const unsigned SizesArgNum = 5;
876
877private:
878 /// Traverses the BasicBlock where \p Array is, collecting the stores made to
879 /// \p Array, leaving StoredValues with the values stored before the
880 /// instruction \p Before is reached.
881 bool getValues(AllocaInst &Array, Instruction &Before) {
882 // Initialize containers.
883 const DataLayout &DL = Array.getDataLayout();
884 std::optional<TypeSize> ArraySize = Array.getAllocationSize(DL);
885 if (!ArraySize || !ArraySize->isFixed())
886 return false;
887 const unsigned int PointerSize = DL.getPointerSize();
888 const uint64_t NumValues = ArraySize->getFixedValue() / PointerSize;
889 StoredValues.assign(NumValues, nullptr);
890 LastAccesses.assign(NumValues, nullptr);
891
892 // TODO: This assumes the instruction \p Before is in the same
893 // BasicBlock as Array. Make it general, for any control flow graph.
894 BasicBlock *BB = Array.getParent();
895 if (BB != Before.getParent())
896 return false;
897
898 for (Instruction &I : *BB) {
899 if (&I == &Before)
900 break;
901
902 if (!isa<StoreInst>(&I))
903 continue;
904
905 auto *S = cast<StoreInst>(&I);
906 int64_t Offset = -1;
907 auto *Dst =
908 GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
909 if (Dst == &Array) {
910 int64_t Idx = Offset / PointerSize;
911 // Ignore updates that must be UB (probably in dead code at runtime)
912 if ((uint64_t)Idx < NumValues) {
913 StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
914 LastAccesses[Idx] = S;
915 }
916 }
917 }
918
919 return isFilled();
920 }
921
922 /// Returns true if all values in StoredValues and
923 /// LastAccesses are not nullptrs.
924 bool isFilled() {
925 const unsigned NumValues = StoredValues.size();
926 for (unsigned I = 0; I < NumValues; ++I) {
927 if (!StoredValues[I] || !LastAccesses[I])
928 return false;
929 }
930
931 return true;
932 }
933};
934
935struct OpenMPOpt {
936
937 using OptimizationRemarkGetter =
938 function_ref<OptimizationRemarkEmitter &(Function *)>;
939
940 OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
941 OptimizationRemarkGetter OREGetter,
942 OMPInformationCache &OMPInfoCache, Attributor &A)
943 : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
944 OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
945
946 /// Check if any remarks are enabled for openmp-opt
947 bool remarksEnabled() {
948 auto &Ctx = M.getContext();
949 return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
950 }
951
952 /// Run all OpenMP optimizations on the underlying SCC.
953 bool run(bool IsModulePass) {
954 if (SCC.empty())
955 return false;
956
957 bool Changed = false;
958
959 LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
960 << " functions\n");
961
962 if (IsModulePass) {
963 Changed |= runAttributor(IsModulePass);
964
965 // Recollect uses, in case Attributor deleted any.
966 OMPInfoCache.recollectUses();
967
968 // TODO: This should be folded into buildCustomStateMachine.
969 Changed |= rewriteDeviceCodeStateMachine();
970
971 if (remarksEnabled())
972 analysisGlobalization();
973 } else {
974 if (PrintICVValues)
975 printICVs();
977 printKernels();
978
979 Changed |= runAttributor(IsModulePass);
980
981 // Recollect uses, in case Attributor deleted any.
982 OMPInfoCache.recollectUses();
983
984 Changed |= deleteParallelRegions();
985
987 Changed |= hideMemTransfersLatency();
988 Changed |= deduplicateRuntimeCalls();
990 if (mergeParallelRegions()) {
991 deduplicateRuntimeCalls();
992 Changed = true;
993 }
994 }
995 }
996
997 if (OMPInfoCache.OpenMPPostLink)
998 Changed |= removeRuntimeSymbols();
999
1000 return Changed;
1001 }
1002
1003 /// Print initial ICV values for testing.
1004 /// FIXME: This should be done from the Attributor once it is added.
1005 void printICVs() const {
1006 InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
1007 ICV_proc_bind};
1008
1009 for (Function *F : SCC) {
1010 for (auto ICV : ICVs) {
1011 auto ICVInfo = OMPInfoCache.ICVs[ICV];
1012 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1013 return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
1014 << " Value: "
1015 << (ICVInfo.InitValue
1016 ? toString(ICVInfo.InitValue->getValue(), 10, true)
1017 : "IMPLEMENTATION_DEFINED");
1018 };
1019
1020 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
1021 }
1022 }
1023 }
1024
1025 /// Print OpenMP GPU kernels for testing.
1026 void printKernels() const {
1027 for (Function *F : SCC) {
1028 if (!omp::isOpenMPKernel(*F))
1029 continue;
1030
1031 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1032 return ORA << "OpenMP GPU kernel "
1033 << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
1034 };
1035
1037 }
1038 }
1039
1040 /// Return the call if \p U is a callee use in a regular call. If \p RFI is
1041 /// given it has to be the callee or a nullptr is returned.
1042 static CallInst *getCallIfRegularCall(
1043 Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1044 CallInst *CI = dyn_cast<CallInst>(U.getUser());
1045 if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
1046 (!RFI ||
1047 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1048 return CI;
1049 return nullptr;
1050 }
1051
1052 /// Return the call if \p V is a regular call. If \p RFI is given it has to be
1053 /// the callee or a nullptr is returned.
1054 static CallInst *getCallIfRegularCall(
1055 Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1056 CallInst *CI = dyn_cast<CallInst>(&V);
1057 if (CI && !CI->hasOperandBundles() &&
1058 (!RFI ||
1059 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1060 return CI;
1061 return nullptr;
1062 }
1063
1064private:
1065 /// Merge parallel regions when it is safe.
1066 bool mergeParallelRegions() {
1067 const unsigned CallbackCalleeOperand = 2;
1068 const unsigned CallbackFirstArgOperand = 3;
1069 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
1070
1071 // Check if there are any __kmpc_fork_call calls to merge.
1072 OMPInformationCache::RuntimeFunctionInfo &RFI =
1073 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1074
1075 if (!RFI.Declaration)
1076 return false;
1077
1078 // Unmergable calls that prevent merging a parallel region.
1079 OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
1080 OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
1081 OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
1082 };
1083
1084 bool Changed = false;
1085 LoopInfo *LI = nullptr;
1086 DominatorTree *DT = nullptr;
1087
1088 SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap;
1089
1090 BasicBlock *StartBB = nullptr, *EndBB = nullptr;
1091 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1092 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1093 BasicBlock *CGEndBB =
1094 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1095 assert(StartBB != nullptr && "StartBB should not be null");
1096 CGStartBB->getTerminator()->setSuccessor(0, StartBB);
1097 assert(EndBB != nullptr && "EndBB should not be null");
1098 EndBB->getTerminator()->setSuccessor(0, CGEndBB);
1099 return Error::success();
1100 };
1101
1102 auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
1103 Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
1104 ReplacementValue = &Inner;
1105 return CodeGenIP;
1106 };
1107
1108 auto FiniCB = [&](InsertPointTy CodeGenIP) { return Error::success(); };
1109
1110 /// Create a sequential execution region within a merged parallel region,
1111 /// encapsulated in a master construct with a barrier for synchronization.
1112 auto CreateSequentialRegion = [&](Function *OuterFn,
1113 BasicBlock *OuterPredBB,
1114 Instruction *SeqStartI,
1115 Instruction *SeqEndI) {
1116 // Isolate the instructions of the sequential region to a separate
1117 // block.
1118 BasicBlock *ParentBB = SeqStartI->getParent();
1119 BasicBlock *SeqEndBB =
1120 SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
1121 BasicBlock *SeqAfterBB =
1122 SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
1123 BasicBlock *SeqStartBB =
1124 SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
1125
1126 assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
1127 "Expected a different CFG");
1128 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
1129 ParentBB->getTerminator()->eraseFromParent();
1130
1131 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1132 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1133 BasicBlock *CGEndBB =
1134 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1135 assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
1136 CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
1137 assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
1138 SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
1139 return Error::success();
1140 };
1141 auto FiniCB = [&](InsertPointTy CodeGenIP) { return Error::success(); };
1142
1143 // Find outputs from the sequential region to outside users and
1144 // broadcast their values to them.
1145 for (Instruction &I : *SeqStartBB) {
1146 SmallPtrSet<Instruction *, 4> OutsideUsers;
1147 for (User *Usr : I.users()) {
1148 Instruction &UsrI = *cast<Instruction>(Usr);
1149 // Ignore outputs to LT intrinsics, code extraction for the merged
1150 // parallel region will fix them.
1151 if (UsrI.isLifetimeStartOrEnd())
1152 continue;
1153
1154 if (UsrI.getParent() != SeqStartBB)
1155 OutsideUsers.insert(&UsrI);
1156 }
1157
1158 if (OutsideUsers.empty())
1159 continue;
1160
1161 // Emit an alloca in the outer region to store the broadcasted
1162 // value.
1163 const DataLayout &DL = M.getDataLayout();
1164 AllocaInst *AllocaI = new AllocaInst(
1165 I.getType(), DL.getAllocaAddrSpace(), nullptr,
1166 I.getName() + ".seq.output.alloc", OuterFn->front().begin());
1167
1168 // Emit a store instruction in the sequential BB to update the
1169 // value.
1170 new StoreInst(&I, AllocaI, SeqStartBB->getTerminator()->getIterator());
1171
1172 // Emit a load instruction and replace the use of the output value
1173 // with it.
1174 for (Instruction *UsrI : OutsideUsers) {
1175 LoadInst *LoadI = new LoadInst(I.getType(), AllocaI,
1176 I.getName() + ".seq.output.load",
1177 UsrI->getIterator());
1178 UsrI->replaceUsesOfWith(&I, LoadI);
1179 }
1180 }
1181
1182 OpenMPIRBuilder::LocationDescription Loc(
1183 InsertPointTy(ParentBB, ParentBB->end()), DL);
1185 OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB));
1186 cantFail(
1187 OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel));
1188
1189 UncondBrInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
1190
1191 LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
1192 << "\n");
1193 };
1194
1195 // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
1196 // contained in BB and only separated by instructions that can be
1197 // redundantly executed in parallel. The block BB is split before the first
1198 // call (in MergableCIs) and after the last so the entire region we merge
1199 // into a single parallel region is contained in a single basic block
1200 // without any other instructions. We use the OpenMPIRBuilder to outline
1201 // that block and call the resulting function via __kmpc_fork_call.
1202 auto Merge = [&](const SmallVectorImpl<CallInst *> &MergableCIs,
1203 BasicBlock *BB) {
1204 // TODO: Change the interface to allow single CIs expanded, e.g, to
1205 // include an outer loop.
1206 assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
1207
1208 auto Remark = [&](OptimizationRemark OR) {
1209 OR << "Parallel region merged with parallel region"
1210 << (MergableCIs.size() > 2 ? "s" : "") << " at ";
1211 for (auto *CI : llvm::drop_begin(MergableCIs)) {
1212 OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
1213 if (CI != MergableCIs.back())
1214 OR << ", ";
1215 }
1216 return OR << ".";
1217 };
1218
1219 emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
1220
1221 Function *OriginalFn = BB->getParent();
1222 LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
1223 << " parallel regions in " << OriginalFn->getName()
1224 << "\n");
1225
1226 // Isolate the calls to merge in a separate block.
1227 EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
1228 BasicBlock *AfterBB =
1229 SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1230 StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
1231 "omp.par.merged");
1232
1233 assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
1234 const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1235 BB->getTerminator()->eraseFromParent();
1236
1237 // Create sequential regions for sequential instructions that are
1238 // in-between mergable parallel regions.
1239 for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1240 It != End; ++It) {
1241 Instruction *ForkCI = *It;
1242 Instruction *NextForkCI = *(It + 1);
1243
1244 // Continue if there are not in-between instructions.
1245 if (ForkCI->getNextNode() == NextForkCI)
1246 continue;
1247
1248 CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1249 NextForkCI->getPrevNode());
1250 }
1251
1252 OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1253 DL);
1254 IRBuilder<>::InsertPoint AllocaIP(
1255 &OriginalFn->getEntryBlock(),
1256 OriginalFn->getEntryBlock().getFirstInsertionPt());
1257 // Create the merged parallel region with default proc binding, to
1258 // avoid overriding binding settings, and without explicit cancellation.
1260 cantFail(OMPInfoCache.OMPBuilder.createParallel(
1261 Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
1262 OMP_PROC_BIND_default, /* IsCancellable */ false));
1263 UncondBrInst::Create(AfterBB, AfterIP.getBlock());
1264
1265 // Perform the actual outlining.
1266 OMPInfoCache.OMPBuilder.finalize(OriginalFn);
1267
1268 Function *OutlinedFn = MergableCIs.front()->getCaller();
1269
1270 // Replace the __kmpc_fork_call calls with direct calls to the outlined
1271 // callbacks.
1272 SmallVector<Value *, 8> Args;
1273 for (auto *CI : MergableCIs) {
1274 Value *Callee = CI->getArgOperand(CallbackCalleeOperand);
1275 FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask;
1276 Args.clear();
1277 Args.push_back(OutlinedFn->getArg(0));
1278 Args.push_back(OutlinedFn->getArg(1));
1279 for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1280 ++U)
1281 Args.push_back(CI->getArgOperand(U));
1282
1283 CallInst *NewCI =
1284 CallInst::Create(FT, Callee, Args, "", CI->getIterator());
1285 if (CI->getDebugLoc())
1286 NewCI->setDebugLoc(CI->getDebugLoc());
1287
1288 // Forward parameter attributes from the callback to the callee.
1289 for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1290 ++U)
1291 for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
1292 NewCI->addParamAttr(
1293 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1294
1295 // Emit an explicit barrier to replace the implicit fork-join barrier.
1296 if (CI != MergableCIs.back()) {
1297 // TODO: Remove barrier if the merged parallel region includes the
1298 // 'nowait' clause.
1299 cantFail(OMPInfoCache.OMPBuilder.createBarrier(
1300 InsertPointTy(NewCI->getParent(),
1301 NewCI->getNextNode()->getIterator()),
1302 OMPD_parallel));
1303 }
1304
1305 CI->eraseFromParent();
1306 }
1307
1308 assert(OutlinedFn != OriginalFn && "Outlining failed");
1309 CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1310 CGUpdater.reanalyzeFunction(*OriginalFn);
1311
1312 NumOpenMPParallelRegionsMerged += MergableCIs.size();
1313
1314 return true;
1315 };
1316
1317 // Helper function that identifes sequences of
1318 // __kmpc_fork_call uses in a basic block.
1319 auto DetectPRsCB = [&](Use &U, Function &F) {
1320 CallInst *CI = getCallIfRegularCall(U, &RFI);
1321 BB2PRMap[CI->getParent()].insert(CI);
1322
1323 return false;
1324 };
1325
1326 BB2PRMap.clear();
1327 RFI.foreachUse(SCC, DetectPRsCB);
1328 SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1329 // Find mergable parallel regions within a basic block that are
1330 // safe to merge, that is any in-between instructions can safely
1331 // execute in parallel after merging.
1332 // TODO: support merging across basic-blocks.
1333 for (auto &It : BB2PRMap) {
1334 auto &CIs = It.getSecond();
1335 if (CIs.size() < 2)
1336 continue;
1337
1338 BasicBlock *BB = It.getFirst();
1339 SmallVector<CallInst *, 4> MergableCIs;
1340
1341 /// Returns true if the instruction is mergable, false otherwise.
1342 /// A terminator instruction is unmergable by definition since merging
1343 /// works within a BB. Instructions before the mergable region are
1344 /// mergable if they are not calls to OpenMP runtime functions that may
1345 /// set different execution parameters for subsequent parallel regions.
1346 /// Instructions in-between parallel regions are mergable if they are not
1347 /// calls to any non-intrinsic function since that may call a non-mergable
1348 /// OpenMP runtime function.
1349 auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1350 // We do not merge across BBs, hence return false (unmergable) if the
1351 // instruction is a terminator.
1352 if (I.isTerminator())
1353 return false;
1354
1355 if (!isa<CallInst>(&I))
1356 return true;
1357
1358 CallInst *CI = cast<CallInst>(&I);
1359 if (IsBeforeMergableRegion) {
1360 Function *CalledFunction = CI->getCalledFunction();
1361 if (!CalledFunction)
1362 return false;
1363 // Return false (unmergable) if the call before the parallel
1364 // region calls an explicit affinity (proc_bind) or number of
1365 // threads (num_threads) compiler-generated function. Those settings
1366 // may be incompatible with following parallel regions.
1367 // TODO: ICV tracking to detect compatibility.
1368 for (const auto &RFI : UnmergableCallsInfo) {
1369 if (CalledFunction == RFI.Declaration)
1370 return false;
1371 }
1372 } else {
1373 // Return false (unmergable) if there is a call instruction
1374 // in-between parallel regions when it is not an intrinsic. It
1375 // may call an unmergable OpenMP runtime function in its callpath.
1376 // TODO: Keep track of possible OpenMP calls in the callpath.
1377 if (!isa<IntrinsicInst>(CI))
1378 return false;
1379 }
1380
1381 return true;
1382 };
1383 // Find maximal number of parallel region CIs that are safe to merge.
1384 for (auto It = BB->begin(), End = BB->end(); It != End;) {
1385 Instruction &I = *It;
1386 ++It;
1387
1388 if (CIs.count(&I)) {
1389 MergableCIs.push_back(cast<CallInst>(&I));
1390 continue;
1391 }
1392
1393 // Continue expanding if the instruction is mergable.
1394 if (IsMergable(I, MergableCIs.empty()))
1395 continue;
1396
1397 // Forward the instruction iterator to skip the next parallel region
1398 // since there is an unmergable instruction which can affect it.
1399 for (; It != End; ++It) {
1400 Instruction &SkipI = *It;
1401 if (CIs.count(&SkipI)) {
1402 LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1403 << " due to " << I << "\n");
1404 ++It;
1405 break;
1406 }
1407 }
1408
1409 // Store mergable regions found.
1410 if (MergableCIs.size() > 1) {
1411 MergableCIsVector.push_back(MergableCIs);
1412 LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1413 << " parallel regions in block " << BB->getName()
1414 << " of function " << BB->getParent()->getName()
1415 << "\n";);
1416 }
1417
1418 MergableCIs.clear();
1419 }
1420
1421 if (!MergableCIsVector.empty()) {
1422 Changed = true;
1423
1424 for (auto &MergableCIs : MergableCIsVector)
1425 Merge(MergableCIs, BB);
1426 MergableCIsVector.clear();
1427 }
1428 }
1429
1430 if (Changed) {
1431 /// Re-collect use for fork calls, emitted barrier calls, and
1432 /// any emitted master/end_master calls.
1433 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1434 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1435 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1436 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1437 }
1438
1439 return Changed;
1440 }
1441
1442 /// Try to delete parallel regions if possible.
1443 bool deleteParallelRegions() {
1444 const unsigned CallbackCalleeOperand = 2;
1445
1446 OMPInformationCache::RuntimeFunctionInfo &RFI =
1447 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1448
1449 if (!RFI.Declaration)
1450 return false;
1451
1452 bool Changed = false;
1453 auto DeleteCallCB = [&](Use &U, Function &) {
1454 CallInst *CI = getCallIfRegularCall(U);
1455 if (!CI)
1456 return false;
1457 auto *Fn = dyn_cast<Function>(
1458 CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1459 if (!Fn)
1460 return false;
1461 if (!Fn->onlyReadsMemory())
1462 return false;
1463 if (!Fn->hasFnAttribute(Attribute::WillReturn))
1464 return false;
1465
1466 LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1467 << CI->getCaller()->getName() << "\n");
1468
1469 auto Remark = [&](OptimizationRemark OR) {
1470 return OR << "Removing parallel region with no side-effects.";
1471 };
1473
1474 CI->eraseFromParent();
1475 Changed = true;
1476 ++NumOpenMPParallelRegionsDeleted;
1477 return true;
1478 };
1479
1480 RFI.foreachUse(SCC, DeleteCallCB);
1481
1482 return Changed;
1483 }
1484
1485 /// Try to eliminate runtime calls by reusing existing ones.
1486 bool deduplicateRuntimeCalls() {
1487 bool Changed = false;
1488
1489 RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1490 OMPRTL_omp_get_num_threads,
1491 OMPRTL_omp_in_parallel,
1492 OMPRTL_omp_get_cancellation,
1493 OMPRTL_omp_get_supported_active_levels,
1494 OMPRTL_omp_get_level,
1495 OMPRTL_omp_get_ancestor_thread_num,
1496 OMPRTL_omp_get_team_size,
1497 OMPRTL_omp_get_active_level,
1498 OMPRTL_omp_in_final,
1499 OMPRTL_omp_get_proc_bind,
1500 OMPRTL_omp_get_num_places,
1501 OMPRTL_omp_get_num_procs,
1502 OMPRTL_omp_get_place_num,
1503 OMPRTL_omp_get_partition_num_places,
1504 OMPRTL_omp_get_partition_place_nums};
1505
1506 // Global-tid is handled separately.
1507 SmallSetVector<Value *, 16> GTIdArgs;
1508 collectGlobalThreadIdArguments(GTIdArgs);
1509 LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1510 << " global thread ID arguments\n");
1511
1512 for (Function *F : SCC) {
1513 for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1514 Changed |= deduplicateRuntimeCalls(
1515 *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1516
1517 // __kmpc_global_thread_num is special as we can replace it with an
1518 // argument in enough cases to make it worth trying.
1519 Value *GTIdArg = nullptr;
1520 for (Argument &Arg : F->args())
1521 if (GTIdArgs.count(&Arg)) {
1522 GTIdArg = &Arg;
1523 break;
1524 }
1525 Changed |= deduplicateRuntimeCalls(
1526 *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1527 }
1528
1529 return Changed;
1530 }
1531
1532 /// Tries to remove known runtime symbols that are optional from the module.
1533 bool removeRuntimeSymbols() {
1534 // The RPC client symbol is defined in `libc` and indicates that something
1535 // required an RPC server. If its users were all optimized out then we can
1536 // safely remove it.
1537 // TODO: This should be somewhere more common in the future.
1538 if (GlobalVariable *GV = M.getNamedGlobal("__llvm_rpc_client")) {
1539 if (GV->hasNUsesOrMore(1))
1540 return false;
1541
1542 GV->replaceAllUsesWith(PoisonValue::get(GV->getType()));
1543 GV->eraseFromParent();
1544 return true;
1545 }
1546 return false;
1547 }
1548
1549 /// Tries to hide the latency of runtime calls that involve host to
1550 /// device memory transfers by splitting them into their "issue" and "wait"
1551 /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1552 /// moved downards as much as possible. The "issue" issues the memory transfer
1553 /// asynchronously, returning a handle. The "wait" waits in the returned
1554 /// handle for the memory transfer to finish.
1555 bool hideMemTransfersLatency() {
1556 auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1557 bool Changed = false;
1558 auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1559 auto *RTCall = getCallIfRegularCall(U, &RFI);
1560 if (!RTCall)
1561 return false;
1562
1563 OffloadArray OffloadArrays[3];
1564 if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1565 return false;
1566
1567 LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1568
1569 // TODO: Check if can be moved upwards.
1570 bool WasSplit = false;
1571 Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1572 if (WaitMovementPoint)
1573 WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1574
1575 Changed |= WasSplit;
1576 return WasSplit;
1577 };
1578 if (OMPInfoCache.runtimeFnsAvailable(
1579 {OMPRTL___tgt_target_data_begin_mapper_issue,
1580 OMPRTL___tgt_target_data_begin_mapper_wait}))
1581 RFI.foreachUse(SCC, SplitMemTransfers);
1582
1583 return Changed;
1584 }
1585
1586 void analysisGlobalization() {
1587 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1588
1589 auto CheckGlobalization = [&](Use &U, Function &Decl) {
1590 if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1591 auto Remark = [&](OptimizationRemarkMissed ORM) {
1592 return ORM
1593 << "Found thread data sharing on the GPU. "
1594 << "Expect degraded performance due to data globalization.";
1595 };
1597 }
1598
1599 return false;
1600 };
1601
1602 RFI.foreachUse(SCC, CheckGlobalization);
1603 }
1604
1605 /// Maps the values stored in the offload arrays passed as arguments to
1606 /// \p RuntimeCall into the offload arrays in \p OAs.
1607 bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1609 assert(OAs.size() == 3 && "Need space for three offload arrays!");
1610
1611 // A runtime call that involves memory offloading looks something like:
1612 // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1613 // i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1614 // ...)
1615 // So, the idea is to access the allocas that allocate space for these
1616 // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1617 // Therefore:
1618 // i8** %offload_baseptrs.
1619 Value *BasePtrsArg =
1620 RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1621 // i8** %offload_ptrs.
1622 Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1623 // i8** %offload_sizes.
1624 Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1625
1626 // Get values stored in **offload_baseptrs.
1627 auto *V = getUnderlyingObject(BasePtrsArg);
1628 if (!isa<AllocaInst>(V))
1629 return false;
1630 auto *BasePtrsArray = cast<AllocaInst>(V);
1631 if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1632 return false;
1633
1634 // Get values stored in **offload_baseptrs.
1635 V = getUnderlyingObject(PtrsArg);
1636 if (!isa<AllocaInst>(V))
1637 return false;
1638 auto *PtrsArray = cast<AllocaInst>(V);
1639 if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1640 return false;
1641
1642 // Get values stored in **offload_sizes.
1643 V = getUnderlyingObject(SizesArg);
1644 // If it's a [constant] global array don't analyze it.
1645 if (isa<GlobalValue>(V))
1646 return isa<Constant>(V);
1647 if (!isa<AllocaInst>(V))
1648 return false;
1649
1650 auto *SizesArray = cast<AllocaInst>(V);
1651 if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1652 return false;
1653
1654 return true;
1655 }
1656
1657 /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1658 /// For now this is a way to test that the function getValuesInOffloadArrays
1659 /// is working properly.
1660 /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1661 void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1662 assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1663
1664 LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1665 std::string ValuesStr;
1666 raw_string_ostream Printer(ValuesStr);
1667 std::string Separator = " --- ";
1668
1669 for (auto *BP : OAs[0].StoredValues) {
1670 BP->print(Printer);
1671 Printer << Separator;
1672 }
1673 LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << ValuesStr << "\n");
1674 ValuesStr.clear();
1675
1676 for (auto *P : OAs[1].StoredValues) {
1677 P->print(Printer);
1678 Printer << Separator;
1679 }
1680 LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << ValuesStr << "\n");
1681 ValuesStr.clear();
1682
1683 for (auto *S : OAs[2].StoredValues) {
1684 S->print(Printer);
1685 Printer << Separator;
1686 }
1687 LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << ValuesStr << "\n");
1688 }
1689
1690 /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1691 /// moved. Returns nullptr if the movement is not possible, or not worth it.
1692 Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1693 // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1694 // Make it traverse the CFG.
1695
1696 Instruction *CurrentI = &RuntimeCall;
1697 bool IsWorthIt = false;
1698 while ((CurrentI = CurrentI->getNextNode())) {
1699
1700 // TODO: Once we detect the regions to be offloaded we should use the
1701 // alias analysis manager to check if CurrentI may modify one of
1702 // the offloaded regions.
1703 if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1704 if (IsWorthIt)
1705 return CurrentI;
1706
1707 return nullptr;
1708 }
1709
1710 // FIXME: For now if we move it over anything without side effect
1711 // is worth it.
1712 IsWorthIt = true;
1713 }
1714
1715 // Return end of BasicBlock.
1716 return RuntimeCall.getParent()->getTerminator();
1717 }
1718
1719 /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1720 bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1721 Instruction &WaitMovementPoint) {
1722 // Create stack allocated handle (__tgt_async_info) at the beginning of the
1723 // function. Used for storing information of the async transfer, allowing to
1724 // wait on it later.
1725 auto &IRBuilder = OMPInfoCache.OMPBuilder;
1726 Function *F = RuntimeCall.getCaller();
1727 BasicBlock &Entry = F->getEntryBlock();
1728 IRBuilder.Builder.SetInsertPoint(&Entry,
1729 Entry.getFirstNonPHIOrDbgOrAlloca());
1730 Value *Handle = IRBuilder.Builder.CreateAlloca(
1731 IRBuilder.AsyncInfo, /*ArraySize=*/nullptr, "handle");
1732 Handle =
1733 IRBuilder.Builder.CreateAddrSpaceCast(Handle, IRBuilder.AsyncInfoPtr);
1734
1735 // Add "issue" runtime call declaration:
1736 // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1737 // i8**, i8**, i64*, i64*)
1738 FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1739 M, OMPRTL___tgt_target_data_begin_mapper_issue);
1740
1741 // Change RuntimeCall call site for its asynchronous version.
1742 SmallVector<Value *, 16> Args;
1743 for (auto &Arg : RuntimeCall.args())
1744 Args.push_back(Arg.get());
1745 Args.push_back(Handle);
1746
1747 CallInst *IssueCallsite = CallInst::Create(IssueDecl, Args, /*NameStr=*/"",
1748 RuntimeCall.getIterator());
1749 OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);
1750 RuntimeCall.eraseFromParent();
1751
1752 // Add "wait" runtime call declaration:
1753 // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1754 FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1755 M, OMPRTL___tgt_target_data_begin_mapper_wait);
1756
1757 Value *WaitParams[2] = {
1758 IssueCallsite->getArgOperand(
1759 OffloadArray::DeviceIDArgNum), // device_id.
1760 Handle // handle to wait on.
1761 };
1762 CallInst *WaitCallsite = CallInst::Create(
1763 WaitDecl, WaitParams, /*NameStr=*/"", WaitMovementPoint.getIterator());
1764 OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);
1765
1766 return true;
1767 }
1768
1769 static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1770 bool GlobalOnly, bool &SingleChoice) {
1771 if (CurrentIdent == NextIdent)
1772 return CurrentIdent;
1773
1774 // TODO: Figure out how to actually combine multiple debug locations. For
1775 // now we just keep an existing one if there is a single choice.
1776 if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1777 SingleChoice = !CurrentIdent;
1778 return NextIdent;
1779 }
1780 return nullptr;
1781 }
1782
1783 /// Return an `struct ident_t*` value that represents the ones used in the
1784 /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1785 /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1786 /// return value we create one from scratch. We also do not yet combine
1787 /// information, e.g., the source locations, see combinedIdentStruct.
1788 Value *
1789 getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1790 Function &F, bool GlobalOnly) {
1791 bool SingleChoice = true;
1792 Value *Ident = nullptr;
1793 auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1794 CallInst *CI = getCallIfRegularCall(U, &RFI);
1795 if (!CI || &F != &Caller)
1796 return false;
1797 Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1798 /* GlobalOnly */ true, SingleChoice);
1799 return false;
1800 };
1801 RFI.foreachUse(SCC, CombineIdentStruct);
1802
1803 if (!Ident || !SingleChoice) {
1804 // The IRBuilder uses the insertion block to get to the module, this is
1805 // unfortunate but we work around it for now.
1806 if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1807 OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1808 &F.getEntryBlock(), F.getEntryBlock().begin()));
1809 // Create a fallback location if non was found.
1810 // TODO: Use the debug locations of the calls instead.
1811 uint32_t SrcLocStrSize;
1812 Constant *Loc =
1813 OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1814 Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize);
1815 }
1816 return Ident;
1817 }
1818
1819 /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1820 /// \p ReplVal if given.
1821 bool deduplicateRuntimeCalls(Function &F,
1822 OMPInformationCache::RuntimeFunctionInfo &RFI,
1823 Value *ReplVal = nullptr) {
1824 auto *UV = RFI.getUseVector(F);
1825 if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1826 return false;
1827
1828 LLVM_DEBUG(
1829 dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1830 << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1831
1832 assert((!ReplVal || (isa<Argument>(ReplVal) &&
1833 cast<Argument>(ReplVal)->getParent() == &F)) &&
1834 "Unexpected replacement value!");
1835
1836 // TODO: Use dominance to find a good position instead.
1837 auto CanBeMoved = [this](CallBase &CB) {
1838 unsigned NumArgs = CB.arg_size();
1839 if (NumArgs == 0)
1840 return true;
1841 if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1842 return false;
1843 for (unsigned U = 1; U < NumArgs; ++U)
1844 if (isa<Instruction>(CB.getArgOperand(U)))
1845 return false;
1846 return true;
1847 };
1848
1849 if (!ReplVal) {
1850 auto *DT =
1851 OMPInfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(F);
1852 if (!DT)
1853 return false;
1854 Instruction *IP = nullptr;
1855 for (Use *U : *UV) {
1856 if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1857 if (IP)
1858 IP = DT->findNearestCommonDominator(IP, CI);
1859 else
1860 IP = CI;
1861 if (!CanBeMoved(*CI))
1862 continue;
1863 if (!ReplVal)
1864 ReplVal = CI;
1865 }
1866 }
1867 if (!ReplVal)
1868 return false;
1869 assert(IP && "Expected insertion point!");
1870 cast<Instruction>(ReplVal)->moveBefore(IP->getIterator());
1871 }
1872
1873 // If we use a call as a replacement value we need to make sure the ident is
1874 // valid at the new location. For now we just pick a global one, either
1875 // existing and used by one of the calls, or created from scratch.
1876 if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1877 if (!CI->arg_empty() &&
1878 CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1879 Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1880 /* GlobalOnly */ true);
1881 CI->setArgOperand(0, Ident);
1882 }
1883 }
1884
1885 bool Changed = false;
1886 auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1887 CallInst *CI = getCallIfRegularCall(U, &RFI);
1888 if (!CI || CI == ReplVal || &F != &Caller)
1889 return false;
1890 assert(CI->getCaller() == &F && "Unexpected call!");
1891
1892 auto Remark = [&](OptimizationRemark OR) {
1893 return OR << "OpenMP runtime call "
1894 << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1895 };
1896 if (CI->getDebugLoc())
1898 else
1900
1901 CI->replaceAllUsesWith(ReplVal);
1902 CI->eraseFromParent();
1903 ++NumOpenMPRuntimeCallsDeduplicated;
1904 Changed = true;
1905 return true;
1906 };
1907 RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1908
1909 return Changed;
1910 }
1911
1912 /// Collect arguments that represent the global thread id in \p GTIdArgs.
1913 void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1914 // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1915 // initialization. We could define an AbstractAttribute instead and
1916 // run the Attributor here once it can be run as an SCC pass.
1917
1918 // Helper to check the argument \p ArgNo at all call sites of \p F for
1919 // a GTId.
1920 auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1921 if (!F.hasLocalLinkage())
1922 return false;
1923 for (Use &U : F.uses()) {
1924 if (CallInst *CI = getCallIfRegularCall(U)) {
1925 Value *ArgOp = CI->getArgOperand(ArgNo);
1926 if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1927 getCallIfRegularCall(
1928 *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1929 continue;
1930 }
1931 return false;
1932 }
1933 return true;
1934 };
1935
1936 // Helper to identify uses of a GTId as GTId arguments.
1937 auto AddUserArgs = [&](Value &GTId) {
1938 for (Use &U : GTId.uses())
1939 if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1940 if (CI->isArgOperand(&U))
1941 if (Function *Callee = CI->getCalledFunction())
1942 if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1943 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1944 };
1945
1946 // The argument users of __kmpc_global_thread_num calls are GTIds.
1947 OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1948 OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1949
1950 GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1951 if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1952 AddUserArgs(*CI);
1953 return false;
1954 });
1955
1956 // Transitively search for more arguments by looking at the users of the
1957 // ones we know already. During the search the GTIdArgs vector is extended
1958 // so we cannot cache the size nor can we use a range based for.
1959 for (unsigned U = 0; U < GTIdArgs.size(); ++U)
1960 AddUserArgs(*GTIdArgs[U]);
1961 }
1962
1963 /// Kernel (=GPU) optimizations and utility functions
1964 ///
1965 ///{{
1966
1967 /// Cache to remember the unique kernel for a function.
1968 DenseMap<Function *, std::optional<Kernel>> UniqueKernelMap;
1969
1970 /// Find the unique kernel that will execute \p F, if any.
1971 Kernel getUniqueKernelFor(Function &F);
1972
1973 /// Find the unique kernel that will execute \p I, if any.
1974 Kernel getUniqueKernelFor(Instruction &I) {
1975 return getUniqueKernelFor(*I.getFunction());
1976 }
1977
1978 /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1979 /// the cases we can avoid taking the address of a function.
1980 bool rewriteDeviceCodeStateMachine();
1981
1982 ///
1983 ///}}
1984
1985 /// Emit a remark generically
1986 ///
1987 /// This template function can be used to generically emit a remark. The
1988 /// RemarkKind should be one of the following:
1989 /// - OptimizationRemark to indicate a successful optimization attempt
1990 /// - OptimizationRemarkMissed to report a failed optimization attempt
1991 /// - OptimizationRemarkAnalysis to provide additional information about an
1992 /// optimization attempt
1993 ///
1994 /// The remark is built using a callback function provided by the caller that
1995 /// takes a RemarkKind as input and returns a RemarkKind.
1996 template <typename RemarkKind, typename RemarkCallBack>
1997 void emitRemark(Instruction *I, StringRef RemarkName,
1998 RemarkCallBack &&RemarkCB) const {
1999 Function *F = I->getParent()->getParent();
2000 auto &ORE = OREGetter(F);
2001
2002 if (RemarkName.starts_with("OMP"))
2003 ORE.emit([&]() {
2004 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
2005 << " [" << RemarkName << "]";
2006 });
2007 else
2008 ORE.emit(
2009 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
2010 }
2011
2012 /// Emit a remark on a function.
2013 template <typename RemarkKind, typename RemarkCallBack>
2014 void emitRemark(Function *F, StringRef RemarkName,
2015 RemarkCallBack &&RemarkCB) const {
2016 auto &ORE = OREGetter(F);
2017
2018 if (RemarkName.starts_with("OMP"))
2019 ORE.emit([&]() {
2020 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
2021 << " [" << RemarkName << "]";
2022 });
2023 else
2024 ORE.emit(
2025 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
2026 }
2027
2028 /// The underlying module.
2029 Module &M;
2030
2031 /// The SCC we are operating on.
2032 SmallVectorImpl<Function *> &SCC;
2033
2034 /// Callback to update the call graph, the first argument is a removed call,
2035 /// the second an optional replacement call.
2036 CallGraphUpdater &CGUpdater;
2037
2038 /// Callback to get an OptimizationRemarkEmitter from a Function *
2039 OptimizationRemarkGetter OREGetter;
2040
2041 /// OpenMP-specific information cache. Also Used for Attributor runs.
2042 OMPInformationCache &OMPInfoCache;
2043
2044 /// Attributor instance.
2045 Attributor &A;
2046
2047 /// Helper function to run Attributor on SCC.
2048 bool runAttributor(bool IsModulePass) {
2049 if (SCC.empty())
2050 return false;
2051
2052 registerAAs(IsModulePass);
2053
2054 ChangeStatus Changed = A.run();
2055
2056 LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
2057 << " functions, result: " << Changed << ".\n");
2058
2059 if (Changed == ChangeStatus::CHANGED)
2060 OMPInfoCache.invalidateAnalyses();
2061
2062 return Changed == ChangeStatus::CHANGED;
2063 }
2064
2065 void registerFoldRuntimeCall(RuntimeFunction RF);
2066
2067 /// Populate the Attributor with abstract attribute opportunities in the
2068 /// functions.
2069 void registerAAs(bool IsModulePass);
2070
2071public:
2072 /// Callback to register AAs for live functions, including internal functions
2073 /// marked live during the traversal.
2074 static void registerAAsForFunction(Attributor &A, const Function &F);
2075};
2076
2077Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
2078 if (OMPInfoCache.CGSCC && !OMPInfoCache.CGSCC->empty() &&
2079 !OMPInfoCache.CGSCC->contains(&F))
2080 return nullptr;
2081
2082 // Use a scope to keep the lifetime of the CachedKernel short.
2083 {
2084 std::optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
2085 if (CachedKernel)
2086 return *CachedKernel;
2087
2088 // TODO: We should use an AA to create an (optimistic and callback
2089 // call-aware) call graph. For now we stick to simple patterns that
2090 // are less powerful, basically the worst fixpoint.
2091 if (isOpenMPKernel(F)) {
2092 CachedKernel = Kernel(&F);
2093 return *CachedKernel;
2094 }
2095
2096 CachedKernel = nullptr;
2097 if (!F.hasLocalLinkage()) {
2098
2099 // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
2100 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2101 return ORA << "Potentially unknown OpenMP target region caller.";
2102 };
2104
2105 return nullptr;
2106 }
2107 }
2108
2109 auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
2110 if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
2111 // Allow use in equality comparisons.
2112 if (Cmp->isEquality())
2113 return getUniqueKernelFor(*Cmp);
2114 return nullptr;
2115 }
2116 if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
2117 // Allow direct calls.
2118 if (CB->isCallee(&U))
2119 return getUniqueKernelFor(*CB);
2120
2121 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2122 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_60];
2123 // Allow the use in __kmpc_parallel_60 calls.
2124 if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
2125 return getUniqueKernelFor(*CB);
2126 return nullptr;
2127 }
2128 // Disallow every other use.
2129 return nullptr;
2130 };
2131
2132 // TODO: In the future we want to track more than just a unique kernel.
2133 SmallPtrSet<Kernel, 2> PotentialKernels;
2134 OMPInformationCache::foreachUse(F, [&](const Use &U) {
2135 PotentialKernels.insert(GetUniqueKernelForUse(U));
2136 });
2137
2138 Kernel K = nullptr;
2139 if (PotentialKernels.size() == 1)
2140 K = *PotentialKernels.begin();
2141
2142 // Cache the result.
2143 UniqueKernelMap[&F] = K;
2144
2145 return K;
2146}
2147
2148bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
2149 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2150 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_60];
2151
2152 bool Changed = false;
2153 if (!KernelParallelRFI)
2154 return Changed;
2155
2156 // If we have disabled state machine changes, exit
2158 return Changed;
2159
2160 for (Function *F : SCC) {
2161
2162 // Check if the function is a use in a __kmpc_parallel_60 call at
2163 // all.
2164 bool UnknownUse = false;
2165 bool KernelParallelUse = false;
2166 unsigned NumDirectCalls = 0;
2167
2168 SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
2169 OMPInformationCache::foreachUse(*F, [&](Use &U) {
2170 if (auto *CB = dyn_cast<CallBase>(U.getUser()))
2171 if (CB->isCallee(&U)) {
2172 ++NumDirectCalls;
2173 return;
2174 }
2175
2176 if (isa<ICmpInst>(U.getUser())) {
2177 ToBeReplacedStateMachineUses.push_back(&U);
2178 return;
2179 }
2180
2181 // Find wrapper functions that represent parallel kernels.
2182 CallInst *CI =
2183 OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
2184 const unsigned int WrapperFunctionArgNo = 6;
2185 if (!KernelParallelUse && CI &&
2186 CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
2187 KernelParallelUse = true;
2188 ToBeReplacedStateMachineUses.push_back(&U);
2189 return;
2190 }
2191 UnknownUse = true;
2192 });
2193
2194 // Do not emit a remark if we haven't seen a __kmpc_parallel_60
2195 // use.
2196 if (!KernelParallelUse)
2197 continue;
2198
2199 // If this ever hits, we should investigate.
2200 // TODO: Checking the number of uses is not a necessary restriction and
2201 // should be lifted.
2202 if (UnknownUse || NumDirectCalls != 1 ||
2203 ToBeReplacedStateMachineUses.size() > 2) {
2204 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2205 return ORA << "Parallel region is used in "
2206 << (UnknownUse ? "unknown" : "unexpected")
2207 << " ways. Will not attempt to rewrite the state machine.";
2208 };
2210 continue;
2211 }
2212
2213 // Even if we have __kmpc_parallel_60 calls, we (for now) give
2214 // up if the function is not called from a unique kernel.
2215 Kernel K = getUniqueKernelFor(*F);
2216 if (!K) {
2217 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2218 return ORA << "Parallel region is not called from a unique kernel. "
2219 "Will not attempt to rewrite the state machine.";
2220 };
2222 continue;
2223 }
2224
2225 // We now know F is a parallel body function called only from the kernel K.
2226 // We also identified the state machine uses in which we replace the
2227 // function pointer by a new global symbol for identification purposes. This
2228 // ensures only direct calls to the function are left.
2229
2230 Module &M = *F->getParent();
2231 Type *Int8Ty = Type::getInt8Ty(M.getContext());
2232
2233 auto *ID = new GlobalVariable(
2234 M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2235 UndefValue::get(Int8Ty), F->getName() + ".ID");
2236
2237 for (Use *U : ToBeReplacedStateMachineUses)
2239 ID, U->get()->getType()));
2240
2241 ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2242
2243 Changed = true;
2244 }
2245
2246 return Changed;
2247}
2248
2249/// Abstract Attribute for tracking ICV values.
2250struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2251 using Base = StateWrapper<BooleanState, AbstractAttribute>;
2252 AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2253
2254 /// Returns true if value is assumed to be tracked.
2255 bool isAssumedTracked() const { return getAssumed(); }
2256
2257 /// Returns true if value is known to be tracked.
2258 bool isKnownTracked() const { return getAssumed(); }
2259
2260 /// Create an abstract attribute biew for the position \p IRP.
2261 static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2262
2263 /// Return the value with which \p I can be replaced for specific \p ICV.
2264 virtual std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2265 const Instruction *I,
2266 Attributor &A) const {
2267 return std::nullopt;
2268 }
2269
2270 /// Return an assumed unique ICV value if a single candidate is found. If
2271 /// there cannot be one, return a nullptr. If it is not clear yet, return
2272 /// std::nullopt.
2273 virtual std::optional<Value *>
2274 getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2275
2276 // Currently only nthreads is being tracked.
2277 // this array will only grow with time.
2278 InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2279
2280 /// See AbstractAttribute::getName()
2281 StringRef getName() const override { return "AAICVTracker"; }
2282
2283 /// See AbstractAttribute::getIdAddr()
2284 const char *getIdAddr() const override { return &ID; }
2285
2286 /// This function should return true if the type of the \p AA is AAICVTracker
2287 static bool classof(const AbstractAttribute *AA) {
2288 return (AA->getIdAddr() == &ID);
2289 }
2290
2291 static const char ID;
2292};
2293
2294struct AAICVTrackerFunction : public AAICVTracker {
2295 AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2296 : AAICVTracker(IRP, A) {}
2297
2298 // FIXME: come up with better string.
2299 const std::string getAsStr(Attributor *) const override {
2300 return "ICVTrackerFunction";
2301 }
2302
2303 // FIXME: come up with some stats.
2304 void trackStatistics() const override {}
2305
2306 /// We don't manifest anything for this AA.
2307 ChangeStatus manifest(Attributor &A) override {
2308 return ChangeStatus::UNCHANGED;
2309 }
2310
2311 // Map of ICV to their values at specific program point.
2312 EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar,
2313 InternalControlVar::ICV___last>
2314 ICVReplacementValuesMap;
2315
2316 ChangeStatus updateImpl(Attributor &A) override {
2317 ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
2318
2319 Function *F = getAnchorScope();
2320
2321 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2322
2323 for (InternalControlVar ICV : TrackableICVs) {
2324 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2325
2326 auto &ValuesMap = ICVReplacementValuesMap[ICV];
2327 auto TrackValues = [&](Use &U, Function &) {
2328 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2329 if (!CI)
2330 return false;
2331
2332 // FIXME: handle setters with more that 1 arguments.
2333 /// Track new value.
2334 if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2335 HasChanged = ChangeStatus::CHANGED;
2336
2337 return false;
2338 };
2339
2340 auto CallCheck = [&](Instruction &I) {
2341 std::optional<Value *> ReplVal = getValueForCall(A, I, ICV);
2342 if (ReplVal && ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2343 HasChanged = ChangeStatus::CHANGED;
2344
2345 return true;
2346 };
2347
2348 // Track all changes of an ICV.
2349 SetterRFI.foreachUse(TrackValues, F);
2350
2351 bool UsedAssumedInformation = false;
2352 A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2353 UsedAssumedInformation,
2354 /* CheckBBLivenessOnly */ true);
2355
2356 /// TODO: Figure out a way to avoid adding entry in
2357 /// ICVReplacementValuesMap
2358 Instruction *Entry = &F->getEntryBlock().front();
2359 if (HasChanged == ChangeStatus::CHANGED)
2360 ValuesMap.try_emplace(Entry);
2361 }
2362
2363 return HasChanged;
2364 }
2365
2366 /// Helper to check if \p I is a call and get the value for it if it is
2367 /// unique.
2368 std::optional<Value *> getValueForCall(Attributor &A, const Instruction &I,
2369 InternalControlVar &ICV) const {
2370
2371 const auto *CB = dyn_cast<CallBase>(&I);
2372 if (!CB || CB->hasFnAttr("no_openmp") ||
2373 CB->hasFnAttr("no_openmp_routines") ||
2374 CB->hasFnAttr("no_openmp_constructs"))
2375 return std::nullopt;
2376
2377 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2378 auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2379 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2380 Function *CalledFunction = CB->getCalledFunction();
2381
2382 // Indirect call, assume ICV changes.
2383 if (CalledFunction == nullptr)
2384 return nullptr;
2385 if (CalledFunction == GetterRFI.Declaration)
2386 return std::nullopt;
2387 if (CalledFunction == SetterRFI.Declaration) {
2388 if (ICVReplacementValuesMap[ICV].count(&I))
2389 return ICVReplacementValuesMap[ICV].lookup(&I);
2390
2391 return nullptr;
2392 }
2393
2394 // Since we don't know, assume it changes the ICV.
2395 if (CalledFunction->isDeclaration())
2396 return nullptr;
2397
2398 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2399 *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
2400
2401 if (ICVTrackingAA->isAssumedTracked()) {
2402 std::optional<Value *> URV =
2403 ICVTrackingAA->getUniqueReplacementValue(ICV);
2404 if (!URV || (*URV && AA::isValidAtPosition(AA::ValueAndContext(**URV, I),
2405 OMPInfoCache)))
2406 return URV;
2407 }
2408
2409 // If we don't know, assume it changes.
2410 return nullptr;
2411 }
2412
2413 // We don't check unique value for a function, so return std::nullopt.
2414 std::optional<Value *>
2415 getUniqueReplacementValue(InternalControlVar ICV) const override {
2416 return std::nullopt;
2417 }
2418
2419 /// Return the value with which \p I can be replaced for specific \p ICV.
2420 std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2421 const Instruction *I,
2422 Attributor &A) const override {
2423 const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2424 if (ValuesMap.count(I))
2425 return ValuesMap.lookup(I);
2426
2428 SmallPtrSet<const Instruction *, 16> Visited;
2429 Worklist.push_back(I);
2430
2431 std::optional<Value *> ReplVal;
2432
2433 while (!Worklist.empty()) {
2434 const Instruction *CurrInst = Worklist.pop_back_val();
2435 if (!Visited.insert(CurrInst).second)
2436 continue;
2437
2438 const BasicBlock *CurrBB = CurrInst->getParent();
2439
2440 // Go up and look for all potential setters/calls that might change the
2441 // ICV.
2442 while ((CurrInst = CurrInst->getPrevNode())) {
2443 if (ValuesMap.count(CurrInst)) {
2444 std::optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2445 // Unknown value, track new.
2446 if (!ReplVal) {
2447 ReplVal = NewReplVal;
2448 break;
2449 }
2450
2451 // If we found a new value, we can't know the icv value anymore.
2452 if (NewReplVal)
2453 if (ReplVal != NewReplVal)
2454 return nullptr;
2455
2456 break;
2457 }
2458
2459 std::optional<Value *> NewReplVal = getValueForCall(A, *CurrInst, ICV);
2460 if (!NewReplVal)
2461 continue;
2462
2463 // Unknown value, track new.
2464 if (!ReplVal) {
2465 ReplVal = NewReplVal;
2466 break;
2467 }
2468
2469 // if (NewReplVal.hasValue())
2470 // We found a new value, we can't know the icv value anymore.
2471 if (ReplVal != NewReplVal)
2472 return nullptr;
2473 }
2474
2475 // If we are in the same BB and we have a value, we are done.
2476 if (CurrBB == I->getParent() && ReplVal)
2477 return ReplVal;
2478
2479 // Go through all predecessors and add terminators for analysis.
2480 for (const BasicBlock *Pred : predecessors(CurrBB))
2481 if (const Instruction *Terminator = Pred->getTerminator())
2482 Worklist.push_back(Terminator);
2483 }
2484
2485 return ReplVal;
2486 }
2487};
2488
2489struct AAICVTrackerFunctionReturned : AAICVTracker {
2490 AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2491 : AAICVTracker(IRP, A) {}
2492
2493 // FIXME: come up with better string.
2494 const std::string getAsStr(Attributor *) const override {
2495 return "ICVTrackerFunctionReturned";
2496 }
2497
2498 // FIXME: come up with some stats.
2499 void trackStatistics() const override {}
2500
2501 /// We don't manifest anything for this AA.
2502 ChangeStatus manifest(Attributor &A) override {
2503 return ChangeStatus::UNCHANGED;
2504 }
2505
2506 // Map of ICV to their values at specific program point.
2507 EnumeratedArray<std::optional<Value *>, InternalControlVar,
2508 InternalControlVar::ICV___last>
2509 ICVReplacementValuesMap;
2510
2511 /// Return the value with which \p I can be replaced for specific \p ICV.
2512 std::optional<Value *>
2513 getUniqueReplacementValue(InternalControlVar ICV) const override {
2514 return ICVReplacementValuesMap[ICV];
2515 }
2516
2517 ChangeStatus updateImpl(Attributor &A) override {
2518 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2519 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2520 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2521
2522 if (!ICVTrackingAA->isAssumedTracked())
2523 return indicatePessimisticFixpoint();
2524
2525 for (InternalControlVar ICV : TrackableICVs) {
2526 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2527 std::optional<Value *> UniqueICVValue;
2528
2529 auto CheckReturnInst = [&](Instruction &I) {
2530 std::optional<Value *> NewReplVal =
2531 ICVTrackingAA->getReplacementValue(ICV, &I, A);
2532
2533 // If we found a second ICV value there is no unique returned value.
2534 if (UniqueICVValue && UniqueICVValue != NewReplVal)
2535 return false;
2536
2537 UniqueICVValue = NewReplVal;
2538
2539 return true;
2540 };
2541
2542 bool UsedAssumedInformation = false;
2543 if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2544 UsedAssumedInformation,
2545 /* CheckBBLivenessOnly */ true))
2546 UniqueICVValue = nullptr;
2547
2548 if (UniqueICVValue == ReplVal)
2549 continue;
2550
2551 ReplVal = UniqueICVValue;
2552 Changed = ChangeStatus::CHANGED;
2553 }
2554
2555 return Changed;
2556 }
2557};
2558
2559struct AAICVTrackerCallSite : AAICVTracker {
2560 AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2561 : AAICVTracker(IRP, A) {}
2562
2563 void initialize(Attributor &A) override {
2564 assert(getAnchorScope() && "Expected anchor function");
2565
2566 // We only initialize this AA for getters, so we need to know which ICV it
2567 // gets.
2568 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2569 for (InternalControlVar ICV : TrackableICVs) {
2570 auto ICVInfo = OMPInfoCache.ICVs[ICV];
2571 auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2572 if (Getter.Declaration == getAssociatedFunction()) {
2573 AssociatedICV = ICVInfo.Kind;
2574 return;
2575 }
2576 }
2577
2578 /// Unknown ICV.
2579 indicatePessimisticFixpoint();
2580 }
2581
2582 ChangeStatus manifest(Attributor &A) override {
2583 if (!ReplVal || !*ReplVal)
2584 return ChangeStatus::UNCHANGED;
2585
2586 A.changeAfterManifest(IRPosition::inst(*getCtxI()), **ReplVal);
2587 A.deleteAfterManifest(*getCtxI());
2588
2589 return ChangeStatus::CHANGED;
2590 }
2591
2592 // FIXME: come up with better string.
2593 const std::string getAsStr(Attributor *) const override {
2594 return "ICVTrackerCallSite";
2595 }
2596
2597 // FIXME: come up with some stats.
2598 void trackStatistics() const override {}
2599
2600 InternalControlVar AssociatedICV;
2601 std::optional<Value *> ReplVal;
2602
2603 ChangeStatus updateImpl(Attributor &A) override {
2604 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2605 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2606
2607 // We don't have any information, so we assume it changes the ICV.
2608 if (!ICVTrackingAA->isAssumedTracked())
2609 return indicatePessimisticFixpoint();
2610
2611 std::optional<Value *> NewReplVal =
2612 ICVTrackingAA->getReplacementValue(AssociatedICV, getCtxI(), A);
2613
2614 if (ReplVal == NewReplVal)
2615 return ChangeStatus::UNCHANGED;
2616
2617 ReplVal = NewReplVal;
2618 return ChangeStatus::CHANGED;
2619 }
2620
2621 // Return the value with which associated value can be replaced for specific
2622 // \p ICV.
2623 std::optional<Value *>
2624 getUniqueReplacementValue(InternalControlVar ICV) const override {
2625 return ReplVal;
2626 }
2627};
2628
2629struct AAICVTrackerCallSiteReturned : AAICVTracker {
2630 AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2631 : AAICVTracker(IRP, A) {}
2632
2633 // FIXME: come up with better string.
2634 const std::string getAsStr(Attributor *) const override {
2635 return "ICVTrackerCallSiteReturned";
2636 }
2637
2638 // FIXME: come up with some stats.
2639 void trackStatistics() const override {}
2640
2641 /// We don't manifest anything for this AA.
2642 ChangeStatus manifest(Attributor &A) override {
2643 return ChangeStatus::UNCHANGED;
2644 }
2645
2646 // Map of ICV to their values at specific program point.
2647 EnumeratedArray<std::optional<Value *>, InternalControlVar,
2648 InternalControlVar::ICV___last>
2649 ICVReplacementValuesMap;
2650
2651 /// Return the value with which associated value can be replaced for specific
2652 /// \p ICV.
2653 std::optional<Value *>
2654 getUniqueReplacementValue(InternalControlVar ICV) const override {
2655 return ICVReplacementValuesMap[ICV];
2656 }
2657
2658 ChangeStatus updateImpl(Attributor &A) override {
2659 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2660 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2661 *this, IRPosition::returned(*getAssociatedFunction()),
2662 DepClassTy::REQUIRED);
2663
2664 // We don't have any information, so we assume it changes the ICV.
2665 if (!ICVTrackingAA->isAssumedTracked())
2666 return indicatePessimisticFixpoint();
2667
2668 for (InternalControlVar ICV : TrackableICVs) {
2669 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2670 std::optional<Value *> NewReplVal =
2671 ICVTrackingAA->getUniqueReplacementValue(ICV);
2672
2673 if (ReplVal == NewReplVal)
2674 continue;
2675
2676 ReplVal = NewReplVal;
2677 Changed = ChangeStatus::CHANGED;
2678 }
2679 return Changed;
2680 }
2681};
2682
2683/// Determines if \p BB exits the function unconditionally itself or reaches a
2684/// block that does through only unique successors.
2685static bool hasFunctionEndAsUniqueSuccessor(const BasicBlock *BB) {
2686 if (succ_empty(BB))
2687 return true;
2688 const BasicBlock *const Successor = BB->getUniqueSuccessor();
2689 if (!Successor)
2690 return false;
2691 return hasFunctionEndAsUniqueSuccessor(Successor);
2692}
2693
2694struct AAExecutionDomainFunction : public AAExecutionDomain {
2695 AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2696 : AAExecutionDomain(IRP, A) {}
2697
2698 ~AAExecutionDomainFunction() override { delete RPOT; }
2699
2700 void initialize(Attributor &A) override {
2701 Function *F = getAnchorScope();
2702 assert(F && "Expected anchor function");
2703 RPOT = new ReversePostOrderTraversal<Function *>(F);
2704 }
2705
2706 const std::string getAsStr(Attributor *) const override {
2707 unsigned TotalBlocks = 0, InitialThreadBlocks = 0, AlignedBlocks = 0;
2708 for (auto &It : BEDMap) {
2709 if (!It.getFirst())
2710 continue;
2711 TotalBlocks++;
2712 InitialThreadBlocks += It.getSecond().IsExecutedByInitialThreadOnly;
2713 AlignedBlocks += It.getSecond().IsReachedFromAlignedBarrierOnly &&
2714 It.getSecond().IsReachingAlignedBarrierOnly;
2715 }
2716 return "[AAExecutionDomain] " + std::to_string(InitialThreadBlocks) + "/" +
2717 std::to_string(AlignedBlocks) + " of " +
2718 std::to_string(TotalBlocks) +
2719 " executed by initial thread / aligned";
2720 }
2721
2722 /// See AbstractAttribute::trackStatistics().
2723 void trackStatistics() const override {}
2724
2725 ChangeStatus manifest(Attributor &A) override {
2726 LLVM_DEBUG({
2727 for (const BasicBlock &BB : *getAnchorScope()) {
2728 if (!isExecutedByInitialThreadOnly(BB))
2729 continue;
2730 dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2731 << BB.getName() << " is executed by a single thread.\n";
2732 }
2733 });
2734
2735 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2736
2738 return Changed;
2739
2740 SmallPtrSet<CallBase *, 16> DeletedBarriers;
2741 auto HandleAlignedBarrier = [&](CallBase *CB) {
2742 const ExecutionDomainTy &ED = CB ? CEDMap[{CB, PRE}] : BEDMap[nullptr];
2743 if (!ED.IsReachedFromAlignedBarrierOnly ||
2744 ED.EncounteredNonLocalSideEffect)
2745 return;
2746 if (!ED.EncounteredAssumes.empty() && !A.isModulePass())
2747 return;
2748
2749 // We can remove this barrier, if it is one, or aligned barriers reaching
2750 // the kernel end (if CB is nullptr). Aligned barriers reaching the kernel
2751 // end should only be removed if the kernel end is their unique successor;
2752 // otherwise, they may have side-effects that aren't accounted for in the
2753 // kernel end in their other successors. If those barriers have other
2754 // barriers reaching them, those can be transitively removed as well as
2755 // long as the kernel end is also their unique successor.
2756 if (CB) {
2757 DeletedBarriers.insert(CB);
2758 A.deleteAfterManifest(*CB);
2759 ++NumBarriersEliminated;
2760 Changed = ChangeStatus::CHANGED;
2761 } else if (!ED.AlignedBarriers.empty()) {
2762 Changed = ChangeStatus::CHANGED;
2763 SmallVector<CallBase *> Worklist(ED.AlignedBarriers.begin(),
2764 ED.AlignedBarriers.end());
2765 SmallSetVector<CallBase *, 16> Visited;
2766 while (!Worklist.empty()) {
2767 CallBase *LastCB = Worklist.pop_back_val();
2768 if (!Visited.insert(LastCB))
2769 continue;
2770 if (LastCB->getFunction() != getAnchorScope())
2771 continue;
2772 if (!hasFunctionEndAsUniqueSuccessor(LastCB->getParent()))
2773 continue;
2774 if (!DeletedBarriers.count(LastCB)) {
2775 ++NumBarriersEliminated;
2776 A.deleteAfterManifest(*LastCB);
2777 continue;
2778 }
2779 // The final aligned barrier (LastCB) reaching the kernel end was
2780 // removed already. This means we can go one step further and remove
2781 // the barriers encoutered last before (LastCB).
2782 const ExecutionDomainTy &LastED = CEDMap[{LastCB, PRE}];
2783 Worklist.append(LastED.AlignedBarriers.begin(),
2784 LastED.AlignedBarriers.end());
2785 }
2786 }
2787
2788 // If we actually eliminated a barrier we need to eliminate the associated
2789 // llvm.assumes as well to avoid creating UB.
2790 if (!ED.EncounteredAssumes.empty() && (CB || !ED.AlignedBarriers.empty()))
2791 for (auto *AssumeCB : ED.EncounteredAssumes)
2792 A.deleteAfterManifest(*AssumeCB);
2793 };
2794
2795 for (auto *CB : AlignedBarriers)
2796 HandleAlignedBarrier(CB);
2797
2798 // Handle the "kernel end barrier" for kernels too.
2799 if (omp::isOpenMPKernel(*getAnchorScope()))
2800 HandleAlignedBarrier(nullptr);
2801
2802 return Changed;
2803 }
2804
2805 bool isNoOpFence(const FenceInst &FI) const override {
2806 return getState().isValidState() && !NonNoOpFences.count(&FI);
2807 }
2808
2809 /// Merge barrier and assumption information from \p PredED into the successor
2810 /// \p ED.
2811 void
2812 mergeInPredecessorBarriersAndAssumptions(Attributor &A, ExecutionDomainTy &ED,
2813 const ExecutionDomainTy &PredED);
2814
2815 /// Merge all information from \p PredED into the successor \p ED. If
2816 /// \p InitialEdgeOnly is set, only the initial edge will enter the block
2817 /// represented by \p ED from this predecessor.
2818 bool mergeInPredecessor(Attributor &A, ExecutionDomainTy &ED,
2819 const ExecutionDomainTy &PredED,
2820 bool InitialEdgeOnly = false);
2821
2822 /// Accumulate information for the entry block in \p EntryBBED.
2823 bool handleCallees(Attributor &A, ExecutionDomainTy &EntryBBED);
2824
2825 /// See AbstractAttribute::updateImpl.
2826 ChangeStatus updateImpl(Attributor &A) override;
2827
2828 /// Query interface, see AAExecutionDomain
2829 ///{
2830 bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2831 if (!isValidState())
2832 return false;
2833 assert(BB.getParent() == getAnchorScope() && "Block is out of scope!");
2834 return BEDMap.lookup(&BB).IsExecutedByInitialThreadOnly;
2835 }
2836
2837 bool isExecutedInAlignedRegion(Attributor &A,
2838 const Instruction &I) const override {
2839 assert(I.getFunction() == getAnchorScope() &&
2840 "Instruction is out of scope!");
2841 if (!isValidState())
2842 return false;
2843
2844 bool ForwardIsOk = true;
2845 const Instruction *CurI;
2846
2847 // Check forward until a call or the block end is reached.
2848 CurI = &I;
2849 do {
2850 auto *CB = dyn_cast<CallBase>(CurI);
2851 if (!CB)
2852 continue;
2853 if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2854 return true;
2855 const auto &It = CEDMap.find({CB, PRE});
2856 if (It == CEDMap.end())
2857 continue;
2858 if (!It->getSecond().IsReachingAlignedBarrierOnly)
2859 ForwardIsOk = false;
2860 break;
2861 } while ((CurI = CurI->getNextNode()));
2862
2863 if (!CurI && !BEDMap.lookup(I.getParent()).IsReachingAlignedBarrierOnly)
2864 ForwardIsOk = false;
2865
2866 // Check backward until a call or the block beginning is reached.
2867 CurI = &I;
2868 do {
2869 auto *CB = dyn_cast<CallBase>(CurI);
2870 if (!CB)
2871 continue;
2872 if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2873 return true;
2874 const auto &It = CEDMap.find({CB, POST});
2875 if (It == CEDMap.end())
2876 continue;
2877 if (It->getSecond().IsReachedFromAlignedBarrierOnly)
2878 break;
2879 return false;
2880 } while ((CurI = CurI->getPrevNode()));
2881
2882 // Delayed decision on the forward pass to allow aligned barrier detection
2883 // in the backwards traversal.
2884 if (!ForwardIsOk)
2885 return false;
2886
2887 if (!CurI) {
2888 const BasicBlock *BB = I.getParent();
2889 if (BB == &BB->getParent()->getEntryBlock())
2890 return BEDMap.lookup(nullptr).IsReachedFromAlignedBarrierOnly;
2891 if (!llvm::all_of(predecessors(BB), [&](const BasicBlock *PredBB) {
2892 return BEDMap.lookup(PredBB).IsReachedFromAlignedBarrierOnly;
2893 })) {
2894 return false;
2895 }
2896 }
2897
2898 // On neither traversal we found a anything but aligned barriers.
2899 return true;
2900 }
2901
2902 ExecutionDomainTy getExecutionDomain(const BasicBlock &BB) const override {
2903 assert(isValidState() &&
2904 "No request should be made against an invalid state!");
2905 return BEDMap.lookup(&BB);
2906 }
2907 std::pair<ExecutionDomainTy, ExecutionDomainTy>
2908 getExecutionDomain(const CallBase &CB) const override {
2909 assert(isValidState() &&
2910 "No request should be made against an invalid state!");
2911 return {CEDMap.lookup({&CB, PRE}), CEDMap.lookup({&CB, POST})};
2912 }
2913 ExecutionDomainTy getFunctionExecutionDomain() const override {
2914 assert(isValidState() &&
2915 "No request should be made against an invalid state!");
2916 return InterProceduralED;
2917 }
2918 ///}
2919
2920 // Check if the edge into the successor block contains a condition that only
2921 // lets the main thread execute it.
2922 static bool isInitialThreadOnlyEdge(Attributor &A, CondBrInst *Edge,
2923 BasicBlock &SuccessorBB) {
2924 if (!Edge)
2925 return false;
2926 if (Edge->getSuccessor(0) != &SuccessorBB)
2927 return false;
2928
2929 auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2930 if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2931 return false;
2932
2933 ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2934 if (!C)
2935 return false;
2936
2937 // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
2938 if (C->isAllOnesValue()) {
2939 auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2940 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2941 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2942 CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2943 if (!CB)
2944 return false;
2945 ConstantStruct *KernelEnvC =
2947 ConstantInt *ExecModeC =
2948 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
2949 return ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC;
2950 }
2951
2952 if (C->isZero()) {
2953 // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x()
2954 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2955 if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2956 return true;
2957
2958 // Match: 0 == llvm.amdgcn.workitem.id.x()
2959 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2960 if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2961 return true;
2962 }
2963
2964 return false;
2965 };
2966
2967 /// Mapping containing information about the function for other AAs.
2968 ExecutionDomainTy InterProceduralED;
2969
2970 enum Direction { PRE = 0, POST = 1 };
2971 /// Mapping containing information per block.
2972 DenseMap<const BasicBlock *, ExecutionDomainTy> BEDMap;
2973 DenseMap<PointerIntPair<const CallBase *, 1, Direction>, ExecutionDomainTy>
2974 CEDMap;
2975 SmallSetVector<CallBase *, 16> AlignedBarriers;
2976
2977 ReversePostOrderTraversal<Function *> *RPOT = nullptr;
2978
2979 /// Set \p R to \V and report true if that changed \p R.
2980 static bool setAndRecord(bool &R, bool V) {
2981 bool Eq = (R == V);
2982 R = V;
2983 return !Eq;
2984 }
2985
2986 /// Collection of fences known to be non-no-opt. All fences not in this set
2987 /// can be assumed no-opt.
2988 SmallPtrSet<const FenceInst *, 8> NonNoOpFences;
2989};
2990
2991void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions(
2992 Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED) {
2993 for (auto *EA : PredED.EncounteredAssumes)
2994 ED.addAssumeInst(A, *EA);
2995
2996 for (auto *AB : PredED.AlignedBarriers)
2997 ED.addAlignedBarrier(A, *AB);
2998}
2999
3000bool AAExecutionDomainFunction::mergeInPredecessor(
3001 Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED,
3002 bool InitialEdgeOnly) {
3003
3004 bool Changed = false;
3005 Changed |=
3006 setAndRecord(ED.IsExecutedByInitialThreadOnly,
3007 InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly &&
3008 ED.IsExecutedByInitialThreadOnly));
3009
3010 Changed |= setAndRecord(ED.IsReachedFromAlignedBarrierOnly,
3011 ED.IsReachedFromAlignedBarrierOnly &&
3012 PredED.IsReachedFromAlignedBarrierOnly);
3013 Changed |= setAndRecord(ED.EncounteredNonLocalSideEffect,
3014 ED.EncounteredNonLocalSideEffect |
3015 PredED.EncounteredNonLocalSideEffect);
3016 // Do not track assumptions and barriers as part of Changed.
3017 if (ED.IsReachedFromAlignedBarrierOnly)
3018 mergeInPredecessorBarriersAndAssumptions(A, ED, PredED);
3019 else
3020 ED.clearAssumeInstAndAlignedBarriers();
3021 return Changed;
3022}
3023
3024bool AAExecutionDomainFunction::handleCallees(Attributor &A,
3025 ExecutionDomainTy &EntryBBED) {
3027 auto PredForCallSite = [&](AbstractCallSite ACS) {
3028 const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3029 *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
3030 DepClassTy::OPTIONAL);
3031 if (!EDAA || !EDAA->getState().isValidState())
3032 return false;
3033 CallSiteEDs.emplace_back(
3034 EDAA->getExecutionDomain(*cast<CallBase>(ACS.getInstruction())));
3035 return true;
3036 };
3037
3038 ExecutionDomainTy ExitED;
3039 bool AllCallSitesKnown;
3040 if (A.checkForAllCallSites(PredForCallSite, *this,
3041 /* RequiresAllCallSites */ true,
3042 AllCallSitesKnown)) {
3043 for (const auto &[CSInED, CSOutED] : CallSiteEDs) {
3044 mergeInPredecessor(A, EntryBBED, CSInED);
3045 ExitED.IsReachingAlignedBarrierOnly &=
3046 CSOutED.IsReachingAlignedBarrierOnly;
3047 }
3048
3049 } else {
3050 // We could not find all predecessors, so this is either a kernel or a
3051 // function with external linkage (or with some other weird uses).
3052 if (omp::isOpenMPKernel(*getAnchorScope())) {
3053 EntryBBED.IsExecutedByInitialThreadOnly = false;
3054 EntryBBED.IsReachedFromAlignedBarrierOnly = true;
3055 EntryBBED.EncounteredNonLocalSideEffect = false;
3056 ExitED.IsReachingAlignedBarrierOnly = false;
3057 } else {
3058 EntryBBED.IsExecutedByInitialThreadOnly = false;
3059 EntryBBED.IsReachedFromAlignedBarrierOnly = false;
3060 EntryBBED.EncounteredNonLocalSideEffect = true;
3061 ExitED.IsReachingAlignedBarrierOnly = false;
3062 }
3063 }
3064
3065 bool Changed = false;
3066 auto &FnED = BEDMap[nullptr];
3067 Changed |= setAndRecord(FnED.IsReachedFromAlignedBarrierOnly,
3068 FnED.IsReachedFromAlignedBarrierOnly &
3069 EntryBBED.IsReachedFromAlignedBarrierOnly);
3070 Changed |= setAndRecord(FnED.IsReachingAlignedBarrierOnly,
3071 FnED.IsReachingAlignedBarrierOnly &
3072 ExitED.IsReachingAlignedBarrierOnly);
3073 Changed |= setAndRecord(FnED.IsExecutedByInitialThreadOnly,
3074 EntryBBED.IsExecutedByInitialThreadOnly);
3075 return Changed;
3076}
3077
3078ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
3079
3080 bool Changed = false;
3081
3082 // Helper to deal with an aligned barrier encountered during the forward
3083 // traversal. \p CB is the aligned barrier, \p ED is the execution domain when
3084 // it was encountered.
3085 auto HandleAlignedBarrier = [&](CallBase &CB, ExecutionDomainTy &ED) {
3086 Changed |= AlignedBarriers.insert(&CB);
3087 // First, update the barrier ED kept in the separate CEDMap.
3088 auto &CallInED = CEDMap[{&CB, PRE}];
3089 Changed |= mergeInPredecessor(A, CallInED, ED);
3090 CallInED.IsReachingAlignedBarrierOnly = true;
3091 // Next adjust the ED we use for the traversal.
3092 ED.EncounteredNonLocalSideEffect = false;
3093 ED.IsReachedFromAlignedBarrierOnly = true;
3094 // Aligned barrier collection has to come last.
3095 ED.clearAssumeInstAndAlignedBarriers();
3096 ED.addAlignedBarrier(A, CB);
3097 auto &CallOutED = CEDMap[{&CB, POST}];
3098 Changed |= mergeInPredecessor(A, CallOutED, ED);
3099 };
3100
3101 auto *LivenessAA =
3102 A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL);
3103
3104 Function *F = getAnchorScope();
3105 BasicBlock &EntryBB = F->getEntryBlock();
3106 bool IsKernel = omp::isOpenMPKernel(*F);
3107
3108 SmallVector<Instruction *> SyncInstWorklist;
3109 for (auto &RIt : *RPOT) {
3110 BasicBlock &BB = *RIt;
3111
3112 bool IsEntryBB = &BB == &EntryBB;
3113 // TODO: We use local reasoning since we don't have a divergence analysis
3114 // running as well. We could basically allow uniform branches here.
3115 bool AlignedBarrierLastInBlock = IsEntryBB && IsKernel;
3116 bool IsExplicitlyAligned = IsEntryBB && IsKernel;
3117 ExecutionDomainTy ED;
3118 // Propagate "incoming edges" into information about this block.
3119 if (IsEntryBB) {
3120 Changed |= handleCallees(A, ED);
3121 } else {
3122 // For live non-entry blocks we only propagate
3123 // information via live edges.
3124 if (LivenessAA && LivenessAA->isAssumedDead(&BB))
3125 continue;
3126
3127 for (auto *PredBB : predecessors(&BB)) {
3128 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, &BB))
3129 continue;
3130 bool InitialEdgeOnly = isInitialThreadOnlyEdge(
3131 A, dyn_cast<CondBrInst>(PredBB->getTerminator()), BB);
3132 mergeInPredecessor(A, ED, BEDMap[PredBB], InitialEdgeOnly);
3133 }
3134 }
3135
3136 // Now we traverse the block, accumulate effects in ED and attach
3137 // information to calls.
3138 for (Instruction &I : BB) {
3139 bool UsedAssumedInformation;
3140 if (A.isAssumedDead(I, *this, LivenessAA, UsedAssumedInformation,
3141 /* CheckBBLivenessOnly */ false, DepClassTy::OPTIONAL,
3142 /* CheckForDeadStore */ true))
3143 continue;
3144
3145 // Asummes and "assume-like" (dbg, lifetime, ...) are handled first, the
3146 // former is collected the latter is ignored.
3147 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
3148 if (auto *AI = dyn_cast_or_null<AssumeInst>(II)) {
3149 ED.addAssumeInst(A, *AI);
3150 continue;
3151 }
3152 // TODO: Should we also collect and delete lifetime markers?
3153 if (II->isAssumeLikeIntrinsic())
3154 continue;
3155 }
3156
3157 if (auto *FI = dyn_cast<FenceInst>(&I)) {
3158 if (!ED.EncounteredNonLocalSideEffect) {
3159 // An aligned fence without non-local side-effects is a no-op.
3160 if (ED.IsReachedFromAlignedBarrierOnly)
3161 continue;
3162 // A non-aligned fence without non-local side-effects is a no-op
3163 // if the ordering only publishes non-local side-effects (or less).
3164 switch (FI->getOrdering()) {
3165 case AtomicOrdering::NotAtomic:
3166 continue;
3167 case AtomicOrdering::Unordered:
3168 continue;
3169 case AtomicOrdering::Monotonic:
3170 continue;
3171 case AtomicOrdering::Acquire:
3172 break;
3173 case AtomicOrdering::Release:
3174 continue;
3175 case AtomicOrdering::AcquireRelease:
3176 break;
3177 case AtomicOrdering::SequentiallyConsistent:
3178 break;
3179 };
3180 }
3181 NonNoOpFences.insert(FI);
3182 }
3183
3184 auto *CB = dyn_cast<CallBase>(&I);
3185 bool IsNoSync = AA::isNoSyncInst(A, I, *this);
3186 bool IsAlignedBarrier =
3187 !IsNoSync && CB &&
3188 AANoSync::isAlignedBarrier(*CB, AlignedBarrierLastInBlock);
3189
3190 AlignedBarrierLastInBlock &= IsNoSync;
3191 IsExplicitlyAligned &= IsNoSync;
3192
3193 // Next we check for calls. Aligned barriers are handled
3194 // explicitly, everything else is kept for the backward traversal and will
3195 // also affect our state.
3196 if (CB) {
3197 if (IsAlignedBarrier) {
3198 HandleAlignedBarrier(*CB, ED);
3199 AlignedBarrierLastInBlock = true;
3200 IsExplicitlyAligned = true;
3201 continue;
3202 }
3203
3204 // Check the pointer(s) of a memory intrinsic explicitly.
3205 if (isa<MemIntrinsic>(&I)) {
3206 if (!ED.EncounteredNonLocalSideEffect &&
3208 ED.EncounteredNonLocalSideEffect = true;
3209 if (!IsNoSync) {
3210 ED.IsReachedFromAlignedBarrierOnly = false;
3211 SyncInstWorklist.push_back(&I);
3212 }
3213 continue;
3214 }
3215
3216 // Record how we entered the call, then accumulate the effect of the
3217 // call in ED for potential use by the callee.
3218 auto &CallInED = CEDMap[{CB, PRE}];
3219 Changed |= mergeInPredecessor(A, CallInED, ED);
3220
3221 // If we have a sync-definition we can check if it starts/ends in an
3222 // aligned barrier. If we are unsure we assume any sync breaks
3223 // alignment.
3225 if (!IsNoSync && Callee && !Callee->isDeclaration()) {
3226 const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3227 *this, IRPosition::function(*Callee), DepClassTy::OPTIONAL);
3228 if (EDAA && EDAA->getState().isValidState()) {
3229 const auto &CalleeED = EDAA->getFunctionExecutionDomain();
3230 ED.IsReachedFromAlignedBarrierOnly =
3231 CalleeED.IsReachedFromAlignedBarrierOnly;
3232 AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly;
3233 if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly)
3234 ED.EncounteredNonLocalSideEffect |=
3235 CalleeED.EncounteredNonLocalSideEffect;
3236 else
3237 ED.EncounteredNonLocalSideEffect =
3238 CalleeED.EncounteredNonLocalSideEffect;
3239 if (!CalleeED.IsReachingAlignedBarrierOnly) {
3240 Changed |=
3241 setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3242 SyncInstWorklist.push_back(&I);
3243 }
3244 if (CalleeED.IsReachedFromAlignedBarrierOnly)
3245 mergeInPredecessorBarriersAndAssumptions(A, ED, CalleeED);
3246 auto &CallOutED = CEDMap[{CB, POST}];
3247 Changed |= mergeInPredecessor(A, CallOutED, ED);
3248 continue;
3249 }
3250 }
3251 if (!IsNoSync) {
3252 ED.IsReachedFromAlignedBarrierOnly = false;
3253 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3254 SyncInstWorklist.push_back(&I);
3255 }
3256 AlignedBarrierLastInBlock &= ED.IsReachedFromAlignedBarrierOnly;
3257 ED.EncounteredNonLocalSideEffect |= !CB->doesNotAccessMemory();
3258 auto &CallOutED = CEDMap[{CB, POST}];
3259 Changed |= mergeInPredecessor(A, CallOutED, ED);
3260 }
3261
3262 if (!I.mayHaveSideEffects() && !I.mayReadFromMemory())
3263 continue;
3264
3265 // If we have a callee we try to use fine-grained information to
3266 // determine local side-effects.
3267 if (CB) {
3268 const auto *MemAA = A.getAAFor<AAMemoryLocation>(
3269 *this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL);
3270
3271 auto AccessPred = [&](const Instruction *I, const Value *Ptr,
3274 return !AA::isPotentiallyAffectedByBarrier(A, {Ptr}, *this, I);
3275 };
3276 if (MemAA && MemAA->getState().isValidState() &&
3277 MemAA->checkForAllAccessesToMemoryKind(
3279 continue;
3280 }
3281
3282 auto &InfoCache = A.getInfoCache();
3283 if (!I.mayHaveSideEffects() && InfoCache.isOnlyUsedByAssume(I))
3284 continue;
3285
3286 if (auto *LI = dyn_cast<LoadInst>(&I))
3287 if (LI->hasMetadata(LLVMContext::MD_invariant_load))
3288 continue;
3289
3290 if (!ED.EncounteredNonLocalSideEffect &&
3292 ED.EncounteredNonLocalSideEffect = true;
3293 }
3294
3295 bool IsEndAndNotReachingAlignedBarriersOnly = false;
3296 if (!isa<UnreachableInst>(BB.getTerminator()) &&
3297 !BB.getTerminator()->getNumSuccessors()) {
3298
3299 Changed |= mergeInPredecessor(A, InterProceduralED, ED);
3300
3301 auto &FnED = BEDMap[nullptr];
3302 if (IsKernel && !IsExplicitlyAligned)
3303 FnED.IsReachingAlignedBarrierOnly = false;
3304 Changed |= mergeInPredecessor(A, FnED, ED);
3305
3306 if (!FnED.IsReachingAlignedBarrierOnly) {
3307 IsEndAndNotReachingAlignedBarriersOnly = true;
3308 SyncInstWorklist.push_back(BB.getTerminator());
3309 auto &BBED = BEDMap[&BB];
3310 Changed |= setAndRecord(BBED.IsReachingAlignedBarrierOnly, false);
3311 }
3312 }
3313
3314 ExecutionDomainTy &StoredED = BEDMap[&BB];
3315 ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly &
3316 !IsEndAndNotReachingAlignedBarriersOnly;
3317
3318 // Check if we computed anything different as part of the forward
3319 // traversal. We do not take assumptions and aligned barriers into account
3320 // as they do not influence the state we iterate. Backward traversal values
3321 // are handled later on.
3322 if (ED.IsExecutedByInitialThreadOnly !=
3323 StoredED.IsExecutedByInitialThreadOnly ||
3324 ED.IsReachedFromAlignedBarrierOnly !=
3325 StoredED.IsReachedFromAlignedBarrierOnly ||
3326 ED.EncounteredNonLocalSideEffect !=
3327 StoredED.EncounteredNonLocalSideEffect)
3328 Changed = true;
3329
3330 // Update the state with the new value.
3331 StoredED = std::move(ED);
3332 }
3333
3334 // Propagate (non-aligned) sync instruction effects backwards until the
3335 // entry is hit or an aligned barrier.
3336 SmallSetVector<BasicBlock *, 16> Visited;
3337 while (!SyncInstWorklist.empty()) {
3338 Instruction *SyncInst = SyncInstWorklist.pop_back_val();
3339 Instruction *CurInst = SyncInst;
3340 bool HitAlignedBarrierOrKnownEnd = false;
3341 while ((CurInst = CurInst->getPrevNode())) {
3342 auto *CB = dyn_cast<CallBase>(CurInst);
3343 if (!CB)
3344 continue;
3345 auto &CallOutED = CEDMap[{CB, POST}];
3346 Changed |= setAndRecord(CallOutED.IsReachingAlignedBarrierOnly, false);
3347 auto &CallInED = CEDMap[{CB, PRE}];
3348 HitAlignedBarrierOrKnownEnd =
3349 AlignedBarriers.count(CB) || !CallInED.IsReachingAlignedBarrierOnly;
3350 if (HitAlignedBarrierOrKnownEnd)
3351 break;
3352 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3353 }
3354 if (HitAlignedBarrierOrKnownEnd)
3355 continue;
3356 BasicBlock *SyncBB = SyncInst->getParent();
3357 for (auto *PredBB : predecessors(SyncBB)) {
3358 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, SyncBB))
3359 continue;
3360 if (!Visited.insert(PredBB))
3361 continue;
3362 auto &PredED = BEDMap[PredBB];
3363 if (setAndRecord(PredED.IsReachingAlignedBarrierOnly, false)) {
3364 Changed = true;
3365 SyncInstWorklist.push_back(PredBB->getTerminator());
3366 }
3367 }
3368 if (SyncBB != &EntryBB)
3369 continue;
3370 Changed |=
3371 setAndRecord(InterProceduralED.IsReachingAlignedBarrierOnly, false);
3372 }
3373
3374 return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
3375}
3376
3377/// Try to replace memory allocation calls called by a single thread with a
3378/// static buffer of shared memory.
3379struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
3380 using Base = StateWrapper<BooleanState, AbstractAttribute>;
3381 AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3382
3383 /// Create an abstract attribute view for the position \p IRP.
3384 static AAHeapToShared &createForPosition(const IRPosition &IRP,
3385 Attributor &A);
3386
3387 /// Returns true if HeapToShared conversion is assumed to be possible.
3388 virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
3389
3390 /// Returns true if HeapToShared conversion is assumed and the CB is a
3391 /// callsite to a free operation to be removed.
3392 virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
3393
3394 /// See AbstractAttribute::getName().
3395 StringRef getName() const override { return "AAHeapToShared"; }
3396
3397 /// See AbstractAttribute::getIdAddr().
3398 const char *getIdAddr() const override { return &ID; }
3399
3400 /// This function should return true if the type of the \p AA is
3401 /// AAHeapToShared.
3402 static bool classof(const AbstractAttribute *AA) {
3403 return (AA->getIdAddr() == &ID);
3404 }
3405
3406 /// Unique ID (due to the unique address)
3407 static const char ID;
3408};
3409
3410struct AAHeapToSharedFunction : public AAHeapToShared {
3411 AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
3412 : AAHeapToShared(IRP, A) {}
3413
3414 const std::string getAsStr(Attributor *) const override {
3415 return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
3416 " malloc calls eligible.";
3417 }
3418
3419 /// See AbstractAttribute::trackStatistics().
3420 void trackStatistics() const override {}
3421
3422 /// This functions finds free calls that will be removed by the
3423 /// HeapToShared transformation.
3424 void findPotentialRemovedFreeCalls(Attributor &A) {
3425 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3426 auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3427
3428 PotentialRemovedFreeCalls.clear();
3429 // Update free call users of found malloc calls.
3430 for (CallBase *CB : MallocCalls) {
3432 for (auto *U : CB->users()) {
3433 CallBase *C = dyn_cast<CallBase>(U);
3434 if (C && C->getCalledFunction() == FreeRFI.Declaration)
3435 FreeCalls.push_back(C);
3436 }
3437
3438 if (FreeCalls.size() != 1)
3439 continue;
3440
3441 PotentialRemovedFreeCalls.insert(FreeCalls.front());
3442 }
3443 }
3444
3445 void initialize(Attributor &A) override {
3447 indicatePessimisticFixpoint();
3448 return;
3449 }
3450
3451 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3452 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3453 if (!RFI.Declaration)
3454 return;
3455
3457 [](const IRPosition &, const AbstractAttribute *,
3458 bool &) -> std::optional<Value *> { return nullptr; };
3459
3460 Function *F = getAnchorScope();
3461 for (User *U : RFI.Declaration->users())
3462 if (CallBase *CB = dyn_cast<CallBase>(U)) {
3463 if (CB->getFunction() != F)
3464 continue;
3465 MallocCalls.insert(CB);
3466 A.registerSimplificationCallback(IRPosition::callsite_returned(*CB),
3467 SCB);
3468 }
3469
3470 findPotentialRemovedFreeCalls(A);
3471 }
3472
3473 bool isAssumedHeapToShared(CallBase &CB) const override {
3474 return isValidState() && MallocCalls.count(&CB);
3475 }
3476
3477 bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
3478 return isValidState() && PotentialRemovedFreeCalls.count(&CB);
3479 }
3480
3481 ChangeStatus manifest(Attributor &A) override {
3482 if (MallocCalls.empty())
3483 return ChangeStatus::UNCHANGED;
3484
3485 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3486 auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3487
3488 Function *F = getAnchorScope();
3489 auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
3490 DepClassTy::OPTIONAL);
3491
3492 ChangeStatus Changed = ChangeStatus::UNCHANGED;
3493 for (CallBase *CB : MallocCalls) {
3494 // Skip replacing this if HeapToStack has already claimed it.
3495 if (HS && HS->isAssumedHeapToStack(*CB))
3496 continue;
3497
3498 // Find the unique free call to remove it.
3500 for (auto *U : CB->users()) {
3501 CallBase *C = dyn_cast<CallBase>(U);
3502 if (C && C->getCalledFunction() == FreeCall.Declaration)
3503 FreeCalls.push_back(C);
3504 }
3505 if (FreeCalls.size() != 1)
3506 continue;
3507
3508 auto *AllocSize = cast<ConstantInt>(CB->getArgOperand(0));
3509
3510 if (AllocSize->getZExtValue() + SharedMemoryUsed > SharedMemoryLimit) {
3511 LLVM_DEBUG(dbgs() << TAG << "Cannot replace call " << *CB
3512 << " with shared memory."
3513 << " Shared memory usage is limited to "
3514 << SharedMemoryLimit << " bytes\n");
3515 continue;
3516 }
3517
3518 LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
3519 << " with " << AllocSize->getZExtValue()
3520 << " bytes of shared memory\n");
3521
3522 // Create a new shared memory buffer of the same size as the allocation
3523 // and replace all the uses of the original allocation with it.
3524 Module *M = CB->getModule();
3525 Type *Int8Ty = Type::getInt8Ty(M->getContext());
3526 Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
3527 auto *SharedMem = new GlobalVariable(
3528 *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
3529 PoisonValue::get(Int8ArrTy), CB->getName() + "_shared", nullptr,
3531 static_cast<unsigned>(AddressSpace::Shared));
3532 auto *NewBuffer = ConstantExpr::getPointerCast(
3533 SharedMem, PointerType::getUnqual(M->getContext()));
3534
3535 auto Remark = [&](OptimizationRemark OR) {
3536 return OR << "Replaced globalized variable with "
3537 << ore::NV("SharedMemory", AllocSize->getZExtValue())
3538 << (AllocSize->isOne() ? " byte " : " bytes ")
3539 << "of shared memory.";
3540 };
3541 A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
3542
3543 MaybeAlign Alignment = CB->getRetAlign();
3544 assert(Alignment &&
3545 "HeapToShared on allocation without alignment attribute");
3546 SharedMem->setAlignment(*Alignment);
3547
3548 A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewBuffer);
3549 A.deleteAfterManifest(*CB);
3550 A.deleteAfterManifest(*FreeCalls.front());
3551
3552 SharedMemoryUsed += AllocSize->getZExtValue();
3553 NumBytesMovedToSharedMemory = SharedMemoryUsed;
3554 Changed = ChangeStatus::CHANGED;
3555 }
3556
3557 return Changed;
3558 }
3559
3560 ChangeStatus updateImpl(Attributor &A) override {
3561 if (MallocCalls.empty())
3562 return indicatePessimisticFixpoint();
3563 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3564 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3565 if (!RFI.Declaration)
3566 return ChangeStatus::UNCHANGED;
3567
3568 Function *F = getAnchorScope();
3569
3570 auto NumMallocCalls = MallocCalls.size();
3571
3572 // Only consider malloc calls executed by a single thread with a constant.
3573 for (User *U : RFI.Declaration->users()) {
3574 if (CallBase *CB = dyn_cast<CallBase>(U)) {
3575 if (CB->getCaller() != F)
3576 continue;
3577 if (!MallocCalls.count(CB))
3578 continue;
3579 if (!isa<ConstantInt>(CB->getArgOperand(0))) {
3580 MallocCalls.remove(CB);
3581 continue;
3582 }
3583 const auto *ED = A.getAAFor<AAExecutionDomain>(
3584 *this, IRPosition::function(*F), DepClassTy::REQUIRED);
3585 if (!ED || !ED->isExecutedByInitialThreadOnly(*CB))
3586 MallocCalls.remove(CB);
3587 }
3588 }
3589
3590 findPotentialRemovedFreeCalls(A);
3591
3592 if (NumMallocCalls != MallocCalls.size())
3593 return ChangeStatus::CHANGED;
3594
3595 return ChangeStatus::UNCHANGED;
3596 }
3597
3598 /// Collection of all malloc calls in a function.
3599 SmallSetVector<CallBase *, 4> MallocCalls;
3600 /// Collection of potentially removed free calls in a function.
3601 SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
3602 /// The total amount of shared memory that has been used for HeapToShared.
3603 unsigned SharedMemoryUsed = 0;
3604};
3605
3606struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
3607 using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
3608 AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3609
3610 /// The callee value is tracked beyond a simple stripPointerCasts, so we allow
3611 /// unknown callees.
3612 static bool requiresCalleeForCallBase() { return false; }
3613
3614 /// Statistics are tracked as part of manifest for now.
3615 void trackStatistics() const override {}
3616
3617 /// See AbstractAttribute::getAsStr()
3618 const std::string getAsStr(Attributor *) const override {
3619 if (!isValidState())
3620 return "<invalid>";
3621 return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
3622 : "generic") +
3623 std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
3624 : "") +
3625 std::string(" #PRs: ") +
3626 (ReachedKnownParallelRegions.isValidState()
3627 ? std::to_string(ReachedKnownParallelRegions.size())
3628 : "<invalid>") +
3629 ", #Unknown PRs: " +
3630 (ReachedUnknownParallelRegions.isValidState()
3631 ? std::to_string(ReachedUnknownParallelRegions.size())
3632 : "<invalid>") +
3633 ", #Reaching Kernels: " +
3634 (ReachingKernelEntries.isValidState()
3635 ? std::to_string(ReachingKernelEntries.size())
3636 : "<invalid>") +
3637 ", #ParLevels: " +
3638 (ParallelLevels.isValidState()
3639 ? std::to_string(ParallelLevels.size())
3640 : "<invalid>") +
3641 ", NestedPar: " + (NestedParallelism ? "yes" : "no");
3642 }
3643
3644 /// Create an abstract attribute biew for the position \p IRP.
3645 static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
3646
3647 /// See AbstractAttribute::getName()
3648 StringRef getName() const override { return "AAKernelInfo"; }
3649
3650 /// See AbstractAttribute::getIdAddr()
3651 const char *getIdAddr() const override { return &ID; }
3652
3653 /// This function should return true if the type of the \p AA is AAKernelInfo
3654 static bool classof(const AbstractAttribute *AA) {
3655 return (AA->getIdAddr() == &ID);
3656 }
3657
3658 static const char ID;
3659};
3660
3661/// The function kernel info abstract attribute, basically, what can we say
3662/// about a function with regards to the KernelInfoState.
3663struct AAKernelInfoFunction : AAKernelInfo {
3664 AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
3665 : AAKernelInfo(IRP, A) {}
3666
3667 SmallPtrSet<Instruction *, 4> GuardedInstructions;
3668
3669 SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
3670 return GuardedInstructions;
3671 }
3672
3673 void setConfigurationOfKernelEnvironment(ConstantStruct *ConfigC) {
3675 KernelEnvC, ConfigC, {KernelInfo::ConfigurationIdx});
3676 assert(NewKernelEnvC && "Failed to create new kernel environment");
3677 KernelEnvC = cast<ConstantStruct>(NewKernelEnvC);
3678 }
3679
3680#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER) \
3681 void set##MEMBER##OfKernelEnvironment(ConstantInt *NewVal) { \
3682 ConstantStruct *ConfigC = \
3683 KernelInfo::getConfigurationFromKernelEnvironment(KernelEnvC); \
3684 Constant *NewConfigC = ConstantFoldInsertValueInstruction( \
3685 ConfigC, NewVal, {KernelInfo::MEMBER##Idx}); \
3686 assert(NewConfigC && "Failed to create new configuration environment"); \
3687 setConfigurationOfKernelEnvironment(cast<ConstantStruct>(NewConfigC)); \
3688 }
3689
3690 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(UseGenericStateMachine)
3691 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MayUseNestedParallelism)
3697
3698#undef KERNEL_ENVIRONMENT_CONFIGURATION_SETTER
3699
3700 /// See AbstractAttribute::initialize(...).
3701 void initialize(Attributor &A) override {
3702 // This is a high-level transform that might change the constant arguments
3703 // of the init and dinit calls. We need to tell the Attributor about this
3704 // to avoid other parts using the current constant value for simpliication.
3705 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3706
3707 Function *Fn = getAnchorScope();
3708
3709 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
3710 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
3711 OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
3712 OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
3713
3714 // For kernels we perform more initialization work, first we find the init
3715 // and deinit calls.
3716 auto StoreCallBase = [](Use &U,
3717 OMPInformationCache::RuntimeFunctionInfo &RFI,
3718 CallBase *&Storage) {
3719 CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
3720 assert(CB &&
3721 "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
3722 assert(!Storage &&
3723 "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
3724 Storage = CB;
3725 return false;
3726 };
3727 InitRFI.foreachUse(
3728 [&](Use &U, Function &) {
3729 StoreCallBase(U, InitRFI, KernelInitCB);
3730 return false;
3731 },
3732 Fn);
3733 DeinitRFI.foreachUse(
3734 [&](Use &U, Function &) {
3735 StoreCallBase(U, DeinitRFI, KernelDeinitCB);
3736 return false;
3737 },
3738 Fn);
3739
3740 // Ignore kernels without initializers such as global constructors.
3741 if (!KernelInitCB || !KernelDeinitCB)
3742 return;
3743
3744 // Add itself to the reaching kernel and set IsKernelEntry.
3745 ReachingKernelEntries.insert(Fn);
3746 IsKernelEntry = true;
3747
3748 KernelEnvC =
3750 GlobalVariable *KernelEnvGV =
3752
3754 KernelConfigurationSimplifyCB =
3755 [&](const GlobalVariable &GV, const AbstractAttribute *AA,
3756 bool &UsedAssumedInformation) -> std::optional<Constant *> {
3757 if (!isAtFixpoint()) {
3758 if (!AA)
3759 return nullptr;
3760 UsedAssumedInformation = true;
3761 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3762 }
3763 return KernelEnvC;
3764 };
3765
3766 A.registerGlobalVariableSimplificationCallback(
3767 *KernelEnvGV, KernelConfigurationSimplifyCB);
3768
3769 // We cannot change to SPMD mode if the runtime functions aren't availible.
3770 bool CanChangeToSPMD = OMPInfoCache.runtimeFnsAvailable(
3771 {OMPRTL___kmpc_get_hardware_thread_id_in_block,
3772 OMPRTL___kmpc_barrier_simple_spmd});
3773
3774 // Check if we know we are in SPMD-mode already.
3775 ConstantInt *ExecModeC =
3776 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
3777 ConstantInt *AssumedExecModeC = ConstantInt::get(
3778 ExecModeC->getIntegerType(),
3780 if (ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)
3781 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3782 else if (DisableOpenMPOptSPMDization || !CanChangeToSPMD)
3783 // This is a generic region but SPMDization is disabled so stop
3784 // tracking.
3785 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3786 else
3787 setExecModeOfKernelEnvironment(AssumedExecModeC);
3788
3789 const Triple T(Fn->getParent()->getTargetTriple());
3790 auto *Int32Ty = Type::getInt32Ty(Fn->getContext());
3791 auto [MinThreads, MaxThreads] =
3793 if (MinThreads)
3794 setMinThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinThreads));
3795 if (MaxThreads)
3796 setMaxThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxThreads));
3797 auto [MinTeams, MaxTeams] =
3799 if (MinTeams)
3800 setMinTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinTeams));
3801 if (MaxTeams)
3802 setMaxTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxTeams));
3803
3804 ConstantInt *MayUseNestedParallelismC =
3805 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(KernelEnvC);
3806 ConstantInt *AssumedMayUseNestedParallelismC = ConstantInt::get(
3807 MayUseNestedParallelismC->getIntegerType(), NestedParallelism);
3808 setMayUseNestedParallelismOfKernelEnvironment(
3809 AssumedMayUseNestedParallelismC);
3810
3812 ConstantInt *UseGenericStateMachineC =
3813 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3814 KernelEnvC);
3815 ConstantInt *AssumedUseGenericStateMachineC =
3816 ConstantInt::get(UseGenericStateMachineC->getIntegerType(), false);
3817 setUseGenericStateMachineOfKernelEnvironment(
3818 AssumedUseGenericStateMachineC);
3819 }
3820
3821 // Register virtual uses of functions we might need to preserve.
3822 auto RegisterVirtualUse = [&](RuntimeFunction RFKind,
3824 if (!OMPInfoCache.RFIs[RFKind].Declaration)
3825 return;
3826 A.registerVirtualUseCallback(*OMPInfoCache.RFIs[RFKind].Declaration, CB);
3827 };
3828
3829 // Add a dependence to ensure updates if the state changes.
3830 auto AddDependence = [](Attributor &A, const AAKernelInfo *KI,
3831 const AbstractAttribute *QueryingAA) {
3832 if (QueryingAA) {
3833 A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);
3834 }
3835 return true;
3836 };
3837
3838 Attributor::VirtualUseCallbackTy CustomStateMachineUseCB =
3839 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3840 // Whenever we create a custom state machine we will insert calls to
3841 // __kmpc_get_hardware_num_threads_in_block,
3842 // __kmpc_get_warp_size,
3843 // __kmpc_barrier_simple_generic,
3844 // __kmpc_kernel_parallel, and
3845 // __kmpc_kernel_end_parallel.
3846 // Not needed if we are on track for SPMDzation.
3847 if (SPMDCompatibilityTracker.isValidState())
3848 return AddDependence(A, this, QueryingAA);
3849 // Not needed if we can't rewrite due to an invalid state.
3850 if (!ReachedKnownParallelRegions.isValidState())
3851 return AddDependence(A, this, QueryingAA);
3852 return false;
3853 };
3854
3855 // Not needed if we are pre-runtime merge.
3856 if (!KernelInitCB->getCalledFunction()->isDeclaration()) {
3857 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_num_threads_in_block,
3858 CustomStateMachineUseCB);
3859 RegisterVirtualUse(OMPRTL___kmpc_get_warp_size, CustomStateMachineUseCB);
3860 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_generic,
3861 CustomStateMachineUseCB);
3862 RegisterVirtualUse(OMPRTL___kmpc_kernel_parallel,
3863 CustomStateMachineUseCB);
3864 RegisterVirtualUse(OMPRTL___kmpc_kernel_end_parallel,
3865 CustomStateMachineUseCB);
3866 }
3867
3868 // If we do not perform SPMDzation we do not need the virtual uses below.
3869 if (SPMDCompatibilityTracker.isAtFixpoint())
3870 return;
3871
3872 Attributor::VirtualUseCallbackTy HWThreadIdUseCB =
3873 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3874 // Whenever we perform SPMDzation we will insert
3875 // __kmpc_get_hardware_thread_id_in_block calls.
3876 if (!SPMDCompatibilityTracker.isValidState())
3877 return AddDependence(A, this, QueryingAA);
3878 return false;
3879 };
3880 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_thread_id_in_block,
3881 HWThreadIdUseCB);
3882
3883 Attributor::VirtualUseCallbackTy SPMDBarrierUseCB =
3884 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3885 // Whenever we perform SPMDzation with guarding we will insert
3886 // __kmpc_simple_barrier_spmd calls. If SPMDzation failed, there is
3887 // nothing to guard, or there are no parallel regions, we don't need
3888 // the calls.
3889 if (!SPMDCompatibilityTracker.isValidState())
3890 return AddDependence(A, this, QueryingAA);
3891 if (SPMDCompatibilityTracker.empty())
3892 return AddDependence(A, this, QueryingAA);
3893 if (!mayContainParallelRegion())
3894 return AddDependence(A, this, QueryingAA);
3895 return false;
3896 };
3897 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_spmd, SPMDBarrierUseCB);
3898 }
3899
3900 /// Sanitize the string \p S such that it is a suitable global symbol name.
3901 static std::string sanitizeForGlobalName(std::string S) {
3902 std::replace_if(
3903 S.begin(), S.end(),
3904 [](const char C) {
3905 return !((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') ||
3906 (C >= '0' && C <= '9') || C == '_');
3907 },
3908 '.');
3909 return S;
3910 }
3911
3912 /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
3913 /// finished now.
3914 ChangeStatus manifest(Attributor &A) override {
3915 // If we are not looking at a kernel with __kmpc_target_init and
3916 // __kmpc_target_deinit call we cannot actually manifest the information.
3917 if (!KernelInitCB || !KernelDeinitCB)
3918 return ChangeStatus::UNCHANGED;
3919
3920 ChangeStatus Changed = ChangeStatus::UNCHANGED;
3921
3922 bool HasBuiltStateMachine = true;
3923 if (!changeToSPMDMode(A, Changed)) {
3924 if (!KernelInitCB->getCalledFunction()->isDeclaration())
3925 HasBuiltStateMachine = buildCustomStateMachine(A, Changed);
3926 else
3927 HasBuiltStateMachine = false;
3928 }
3929
3930 // We need to reset KernelEnvC if specific rewriting is not done.
3931 ConstantStruct *ExistingKernelEnvC =
3933 ConstantInt *OldUseGenericStateMachineVal =
3934 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3935 ExistingKernelEnvC);
3936 if (!HasBuiltStateMachine)
3937 setUseGenericStateMachineOfKernelEnvironment(
3938 OldUseGenericStateMachineVal);
3939
3940 // At last, update the KernelEnvc
3941 GlobalVariable *KernelEnvGV =
3943 if (KernelEnvGV->getInitializer() != KernelEnvC) {
3944 KernelEnvGV->setInitializer(KernelEnvC);
3945 Changed = ChangeStatus::CHANGED;
3946 }
3947
3948 return Changed;
3949 }
3950
3951 void insertInstructionGuardsHelper(Attributor &A) {
3952 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3953
3954 auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3955 Instruction *RegionEndI) {
3956 LoopInfo *LI = nullptr;
3957 DominatorTree *DT = nullptr;
3958 MemorySSAUpdater *MSU = nullptr;
3959 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3960
3961 BasicBlock *ParentBB = RegionStartI->getParent();
3962 Function *Fn = ParentBB->getParent();
3963 Module &M = *Fn->getParent();
3964
3965 // Create all the blocks and logic.
3966 // ParentBB:
3967 // goto RegionCheckTidBB
3968 // RegionCheckTidBB:
3969 // Tid = __kmpc_hardware_thread_id()
3970 // if (Tid != 0)
3971 // goto RegionBarrierBB
3972 // RegionStartBB:
3973 // <execute instructions guarded>
3974 // goto RegionEndBB
3975 // RegionEndBB:
3976 // <store escaping values to shared mem>
3977 // goto RegionBarrierBB
3978 // RegionBarrierBB:
3979 // __kmpc_simple_barrier_spmd()
3980 // // second barrier is omitted if lacking escaping values.
3981 // <load escaping values from shared mem>
3982 // __kmpc_simple_barrier_spmd()
3983 // goto RegionExitBB
3984 // RegionExitBB:
3985 // <execute rest of instructions>
3986
3987 BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3988 DT, LI, MSU, "region.guarded.end");
3989 BasicBlock *RegionBarrierBB =
3990 SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3991 MSU, "region.barrier");
3992 BasicBlock *RegionExitBB =
3993 SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3994 DT, LI, MSU, "region.exit");
3995 BasicBlock *RegionStartBB =
3996 SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
3997
3998 assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
3999 "Expected a different CFG");
4000
4001 BasicBlock *RegionCheckTidBB = SplitBlock(
4002 ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
4003
4004 // Register basic blocks with the Attributor.
4005 A.registerManifestAddedBasicBlock(*RegionEndBB);
4006 A.registerManifestAddedBasicBlock(*RegionBarrierBB);
4007 A.registerManifestAddedBasicBlock(*RegionExitBB);
4008 A.registerManifestAddedBasicBlock(*RegionStartBB);
4009 A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
4010
4011 bool HasBroadcastValues = false;
4012 // Find escaping outputs from the guarded region to outside users and
4013 // broadcast their values to them.
4014 for (Instruction &I : *RegionStartBB) {
4015 SmallVector<Use *, 4> OutsideUses;
4016 for (Use &U : I.uses()) {
4017 Instruction &UsrI = *cast<Instruction>(U.getUser());
4018 if (UsrI.getParent() != RegionStartBB)
4019 OutsideUses.push_back(&U);
4020 }
4021
4022 if (OutsideUses.empty())
4023 continue;
4024
4025 HasBroadcastValues = true;
4026
4027 // Emit a global variable in shared memory to store the broadcasted
4028 // value.
4029 auto *SharedMem = new GlobalVariable(
4030 M, I.getType(), /* IsConstant */ false,
4032 sanitizeForGlobalName(
4033 (I.getName() + ".guarded.output.alloc").str()),
4035 static_cast<unsigned>(AddressSpace::Shared));
4036
4037 // Emit a store instruction to update the value.
4038 new StoreInst(&I, SharedMem,
4039 RegionEndBB->getTerminator()->getIterator());
4040
4041 LoadInst *LoadI = new LoadInst(
4042 I.getType(), SharedMem, I.getName() + ".guarded.output.load",
4043 RegionBarrierBB->getTerminator()->getIterator());
4044
4045 // Emit a load instruction and replace uses of the output value.
4046 for (Use *U : OutsideUses)
4047 A.changeUseAfterManifest(*U, *LoadI);
4048 }
4049
4050 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4051
4052 // Go to tid check BB in ParentBB.
4053 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
4054 ParentBB->getTerminator()->eraseFromParent();
4055 OpenMPIRBuilder::LocationDescription Loc(
4056 InsertPointTy(ParentBB, ParentBB->end()), DL);
4057 OMPInfoCache.OMPBuilder.updateToLocation(Loc);
4058 uint32_t SrcLocStrSize;
4059 auto *SrcLocStr =
4060 OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4061 Value *Ident =
4062 OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4063 UncondBrInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
4064
4065 // Add check for Tid in RegionCheckTidBB
4066 RegionCheckTidBB->getTerminator()->eraseFromParent();
4067 OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
4068 InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
4069 OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
4070 FunctionCallee HardwareTidFn =
4071 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4072 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4073 CallInst *Tid =
4074 OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
4075 Tid->setDebugLoc(DL);
4076 OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);
4077 Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
4078 OMPInfoCache.OMPBuilder.Builder
4079 .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
4080 ->setDebugLoc(DL);
4081
4082 // First barrier for synchronization, ensures main thread has updated
4083 // values.
4084 FunctionCallee BarrierFn =
4085 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4086 M, OMPRTL___kmpc_barrier_simple_spmd);
4087 OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
4088 RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
4089 CallInst *Barrier =
4090 OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});
4091 Barrier->setDebugLoc(DL);
4092 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4093
4094 // Second barrier ensures workers have read broadcast values.
4095 if (HasBroadcastValues) {
4096 CallInst *Barrier =
4097 CallInst::Create(BarrierFn, {Ident, Tid}, "",
4098 RegionBarrierBB->getTerminator()->getIterator());
4099 Barrier->setDebugLoc(DL);
4100 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4101 }
4102 };
4103
4104 auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4105 SmallPtrSet<BasicBlock *, 8> Visited;
4106 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4107 BasicBlock *BB = GuardedI->getParent();
4108 if (!Visited.insert(BB).second)
4109 continue;
4110
4112 Instruction *LastEffect = nullptr;
4113 BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();
4114 while (++IP != IPEnd) {
4115 if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
4116 continue;
4117 Instruction *I = &*IP;
4118 if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI))
4119 continue;
4120 if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) {
4121 LastEffect = nullptr;
4122 continue;
4123 }
4124 if (LastEffect)
4125 Reorders.push_back({I, LastEffect});
4126 LastEffect = &*IP;
4127 }
4128 for (auto &Reorder : Reorders)
4129 Reorder.first->moveBefore(Reorder.second->getIterator());
4130 }
4131
4133
4134 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4135 BasicBlock *BB = GuardedI->getParent();
4136 auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
4137 IRPosition::function(*GuardedI->getFunction()), nullptr,
4138 DepClassTy::NONE);
4139 assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
4140 auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
4141 // Continue if instruction is already guarded.
4142 if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
4143 continue;
4144
4145 Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
4146 for (Instruction &I : *BB) {
4147 // If instruction I needs to be guarded update the guarded region
4148 // bounds.
4149 if (SPMDCompatibilityTracker.contains(&I)) {
4150 CalleeAAFunction.getGuardedInstructions().insert(&I);
4151 if (GuardedRegionStart)
4152 GuardedRegionEnd = &I;
4153 else
4154 GuardedRegionStart = GuardedRegionEnd = &I;
4155
4156 continue;
4157 }
4158
4159 // Instruction I does not need guarding, store
4160 // any region found and reset bounds.
4161 if (GuardedRegionStart) {
4162 GuardedRegions.push_back(
4163 std::make_pair(GuardedRegionStart, GuardedRegionEnd));
4164 GuardedRegionStart = nullptr;
4165 GuardedRegionEnd = nullptr;
4166 }
4167 }
4168 }
4169
4170 for (auto &GR : GuardedRegions)
4171 CreateGuardedRegion(GR.first, GR.second);
4172 }
4173
4174 void forceSingleThreadPerWorkgroupHelper(Attributor &A) {
4175 // Only allow 1 thread per workgroup to continue executing the user code.
4176 //
4177 // InitCB = __kmpc_target_init(...)
4178 // ThreadIdInBlock = __kmpc_get_hardware_thread_id_in_block();
4179 // if (ThreadIdInBlock != 0) return;
4180 // UserCode:
4181 // // user code
4182 //
4183 auto &Ctx = getAnchorValue().getContext();
4184 Function *Kernel = getAssociatedFunction();
4185 assert(Kernel && "Expected an associated function!");
4186
4187 // Create block for user code to branch to from initial block.
4188 BasicBlock *InitBB = KernelInitCB->getParent();
4189 BasicBlock *UserCodeBB = InitBB->splitBasicBlock(
4190 KernelInitCB->getNextNode(), "main.thread.user_code");
4191 BasicBlock *ReturnBB =
4192 BasicBlock::Create(Ctx, "exit.threads", Kernel, UserCodeBB);
4193
4194 // Register blocks with attributor:
4195 A.registerManifestAddedBasicBlock(*InitBB);
4196 A.registerManifestAddedBasicBlock(*UserCodeBB);
4197 A.registerManifestAddedBasicBlock(*ReturnBB);
4198
4199 // Debug location:
4200 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4201 ReturnInst::Create(Ctx, ReturnBB)->setDebugLoc(DLoc);
4202 InitBB->getTerminator()->eraseFromParent();
4203
4204 // Prepare call to OMPRTL___kmpc_get_hardware_thread_id_in_block.
4205 Module &M = *Kernel->getParent();
4206 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4207 FunctionCallee ThreadIdInBlockFn =
4208 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4209 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4210
4211 // Get thread ID in block.
4212 CallInst *ThreadIdInBlock =
4213 CallInst::Create(ThreadIdInBlockFn, "thread_id.in.block", InitBB);
4214 OMPInfoCache.setCallingConvention(ThreadIdInBlockFn, ThreadIdInBlock);
4215 ThreadIdInBlock->setDebugLoc(DLoc);
4216
4217 // Eliminate all threads in the block with ID not equal to 0:
4218 Instruction *IsMainThread =
4219 ICmpInst::Create(ICmpInst::ICmp, CmpInst::ICMP_NE, ThreadIdInBlock,
4220 ConstantInt::get(ThreadIdInBlock->getType(), 0),
4221 "thread.is_main", InitBB);
4222 IsMainThread->setDebugLoc(DLoc);
4223 CondBrInst::Create(IsMainThread, ReturnBB, UserCodeBB, InitBB);
4224 }
4225
4226 bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {
4227 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4228
4229 if (!SPMDCompatibilityTracker.isAssumed()) {
4230 for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
4231 if (!NonCompatibleI)
4232 continue;
4233
4234 // Skip diagnostics on calls to known OpenMP runtime functions for now.
4235 if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
4236 if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
4237 continue;
4238
4239 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4240 ORA << "Value has potential side effects preventing SPMD-mode "
4241 "execution";
4242 if (isa<CallBase>(NonCompatibleI)) {
4243 ORA << ". Add `[[omp::assume(\"ompx_spmd_amenable\")]]` to "
4244 "the called function to override";
4245 }
4246 return ORA << ".";
4247 };
4248 A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
4249 Remark);
4250
4251 LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
4252 << *NonCompatibleI << "\n");
4253 }
4254
4255 return false;
4256 }
4257
4258 // Get the actual kernel, could be the caller of the anchor scope if we have
4259 // a debug wrapper.
4260 Function *Kernel = getAnchorScope();
4261 if (Kernel->hasLocalLinkage()) {
4262 assert(Kernel->hasOneUse() && "Unexpected use of debug kernel wrapper.");
4263 auto *CB = cast<CallBase>(Kernel->user_back());
4264 Kernel = CB->getCaller();
4265 }
4266 assert(omp::isOpenMPKernel(*Kernel) && "Expected kernel function!");
4267
4268 // Check if the kernel is already in SPMD mode, if so, return success.
4269 ConstantStruct *ExistingKernelEnvC =
4271 auto *ExecModeC =
4272 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4273 const int8_t ExecModeVal = ExecModeC->getSExtValue();
4274 if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC)
4275 return true;
4276
4277 // We will now unconditionally modify the IR, indicate a change.
4278 Changed = ChangeStatus::CHANGED;
4279
4280 // Do not use instruction guards when no parallel is present inside
4281 // the target region.
4282 if (mayContainParallelRegion())
4283 insertInstructionGuardsHelper(A);
4284 else
4285 forceSingleThreadPerWorkgroupHelper(A);
4286
4287 // Adjust the global exec mode flag that tells the runtime what mode this
4288 // kernel is executed in.
4289 assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
4290 "Initially non-SPMD kernel has SPMD exec mode!");
4291 setExecModeOfKernelEnvironment(
4292 ConstantInt::get(ExecModeC->getIntegerType(),
4293 ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
4294
4295 ++NumOpenMPTargetRegionKernelsSPMD;
4296
4297 auto Remark = [&](OptimizationRemark OR) {
4298 return OR << "Transformed generic-mode kernel to SPMD-mode.";
4299 };
4300 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
4301 return true;
4302 };
4303
4304 bool buildCustomStateMachine(Attributor &A, ChangeStatus &Changed) {
4305 // If we have disabled state machine rewrites, don't make a custom one
4307 return false;
4308
4309 // Don't rewrite the state machine if we are not in a valid state.
4310 if (!ReachedKnownParallelRegions.isValidState())
4311 return false;
4312
4313 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4314 if (!OMPInfoCache.runtimeFnsAvailable(
4315 {OMPRTL___kmpc_get_hardware_num_threads_in_block,
4316 OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
4317 OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
4318 return false;
4319
4320 ConstantStruct *ExistingKernelEnvC =
4322
4323 // Check if the current configuration is non-SPMD and generic state machine.
4324 // If we already have SPMD mode or a custom state machine we do not need to
4325 // go any further. If it is anything but a constant something is weird and
4326 // we give up.
4327 ConstantInt *UseStateMachineC =
4328 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4329 ExistingKernelEnvC);
4330 ConstantInt *ModeC =
4331 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4332
4333 // If we are stuck with generic mode, try to create a custom device (=GPU)
4334 // state machine which is specialized for the parallel regions that are
4335 // reachable by the kernel.
4336 if (UseStateMachineC->isZero() ||
4338 return false;
4339
4340 Changed = ChangeStatus::CHANGED;
4341
4342 // If not SPMD mode, indicate we use a custom state machine now.
4343 setUseGenericStateMachineOfKernelEnvironment(
4344 ConstantInt::get(UseStateMachineC->getIntegerType(), false));
4345
4346 // If we don't actually need a state machine we are done here. This can
4347 // happen if there simply are no parallel regions. In the resulting kernel
4348 // all worker threads will simply exit right away, leaving the main thread
4349 // to do the work alone.
4350 if (!mayContainParallelRegion()) {
4351 ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
4352
4353 auto Remark = [&](OptimizationRemark OR) {
4354 return OR << "Removing unused state machine from generic-mode kernel.";
4355 };
4356 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
4357
4358 return true;
4359 }
4360
4361 // Keep track in the statistics of our new shiny custom state machine.
4362 if (ReachedUnknownParallelRegions.empty()) {
4363 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
4364
4365 auto Remark = [&](OptimizationRemark OR) {
4366 return OR << "Rewriting generic-mode kernel with a customized state "
4367 "machine.";
4368 };
4369 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
4370 } else {
4371 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
4372
4373 auto Remark = [&](OptimizationRemarkAnalysis OR) {
4374 return OR << "Generic-mode kernel is executed with a customized state "
4375 "machine that requires a fallback.";
4376 };
4377 A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
4378
4379 // Tell the user why we ended up with a fallback.
4380 for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
4381 if (!UnknownParallelRegionCB)
4382 continue;
4383 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4384 return ORA << "Call may contain unknown parallel regions. Use "
4385 << "`[[omp::assume(\"omp_no_parallelism\")]]` to "
4386 "override.";
4387 };
4388 A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
4389 "OMP133", Remark);
4390 }
4391 }
4392
4393 // Create all the blocks:
4394 //
4395 // InitCB = __kmpc_target_init(...)
4396 // BlockHwSize =
4397 // __kmpc_get_hardware_num_threads_in_block();
4398 // WarpSize = __kmpc_get_warp_size();
4399 // BlockSize = BlockHwSize - WarpSize;
4400 // IsWorkerCheckBB: bool IsWorker = InitCB != -1;
4401 // if (IsWorker) {
4402 // if (InitCB >= BlockSize) return;
4403 // SMBeginBB: __kmpc_barrier_simple_generic(...);
4404 // void *WorkFn;
4405 // bool Active = __kmpc_kernel_parallel(&WorkFn);
4406 // if (!WorkFn) return;
4407 // SMIsActiveCheckBB: if (Active) {
4408 // SMIfCascadeCurrentBB: if (WorkFn == <ParFn0>)
4409 // ParFn0(...);
4410 // SMIfCascadeCurrentBB: else if (WorkFn == <ParFn1>)
4411 // ParFn1(...);
4412 // ...
4413 // SMIfCascadeCurrentBB: else
4414 // ((WorkFnTy*)WorkFn)(...);
4415 // SMEndParallelBB: __kmpc_kernel_end_parallel(...);
4416 // }
4417 // SMDoneBB: __kmpc_barrier_simple_generic(...);
4418 // goto SMBeginBB;
4419 // }
4420 // UserCodeEntryBB: // user code
4421 // __kmpc_target_deinit(...)
4422 //
4423 auto &Ctx = getAnchorValue().getContext();
4424 Function *Kernel = getAssociatedFunction();
4425 assert(Kernel && "Expected an associated function!");
4426
4427 BasicBlock *InitBB = KernelInitCB->getParent();
4428 BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
4429 KernelInitCB->getNextNode(), "thread.user_code.check");
4430 BasicBlock *IsWorkerCheckBB =
4431 BasicBlock::Create(Ctx, "is_worker_check", Kernel, UserCodeEntryBB);
4432 BasicBlock *StateMachineBeginBB = BasicBlock::Create(
4433 Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
4434 BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
4435 Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
4436 BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
4437 Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
4438 BasicBlock *StateMachineIfCascadeCurrentBB =
4439 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4440 Kernel, UserCodeEntryBB);
4441 BasicBlock *StateMachineEndParallelBB =
4442 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
4443 Kernel, UserCodeEntryBB);
4444 BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
4445 Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
4446 A.registerManifestAddedBasicBlock(*InitBB);
4447 A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
4448 A.registerManifestAddedBasicBlock(*IsWorkerCheckBB);
4449 A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
4450 A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
4451 A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
4452 A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
4453 A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
4454 A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
4455
4456 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4457 ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
4458 InitBB->getTerminator()->eraseFromParent();
4459
4460 Instruction *IsWorker =
4461 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
4462 ConstantInt::getAllOnesValue(KernelInitCB->getType()),
4463 "thread.is_worker", InitBB);
4464 IsWorker->setDebugLoc(DLoc);
4465 CondBrInst::Create(IsWorker, IsWorkerCheckBB, UserCodeEntryBB, InitBB);
4466
4467 Module &M = *Kernel->getParent();
4468 FunctionCallee BlockHwSizeFn =
4469 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4470 M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
4471 FunctionCallee WarpSizeFn =
4472 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4473 M, OMPRTL___kmpc_get_warp_size);
4474 CallInst *BlockHwSize =
4475 CallInst::Create(BlockHwSizeFn, "block.hw_size", IsWorkerCheckBB);
4476 OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);
4477 BlockHwSize->setDebugLoc(DLoc);
4478 CallInst *WarpSize =
4479 CallInst::Create(WarpSizeFn, "warp.size", IsWorkerCheckBB);
4480 OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);
4481 WarpSize->setDebugLoc(DLoc);
4482 Instruction *BlockSize = BinaryOperator::CreateSub(
4483 BlockHwSize, WarpSize, "block.size", IsWorkerCheckBB);
4484 BlockSize->setDebugLoc(DLoc);
4485 Instruction *IsMainOrWorker = ICmpInst::Create(
4486 ICmpInst::ICmp, llvm::CmpInst::ICMP_SLT, KernelInitCB, BlockSize,
4487 "thread.is_main_or_worker", IsWorkerCheckBB);
4488 IsMainOrWorker->setDebugLoc(DLoc);
4489 CondBrInst::Create(IsMainOrWorker, StateMachineBeginBB,
4490 StateMachineFinishedBB, IsWorkerCheckBB);
4491
4492 // Create local storage for the work function pointer.
4493 const DataLayout &DL = M.getDataLayout();
4494 Type *VoidPtrTy = PointerType::getUnqual(Ctx);
4495 Instruction *WorkFnAI =
4496 new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
4497 "worker.work_fn.addr", Kernel->getEntryBlock().begin());
4498 WorkFnAI->setDebugLoc(DLoc);
4499
4500 OMPInfoCache.OMPBuilder.updateToLocation(
4501 OpenMPIRBuilder::LocationDescription(
4502 IRBuilder<>::InsertPoint(StateMachineBeginBB,
4503 StateMachineBeginBB->end()),
4504 DLoc));
4505
4506 Value *Ident = KernelInfo::getIdentFromKernelEnvironment(KernelEnvC);
4507 Value *GTid = KernelInitCB;
4508
4509 FunctionCallee BarrierFn =
4510 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4511 M, OMPRTL___kmpc_barrier_simple_generic);
4512 CallInst *Barrier =
4513 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB);
4514 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4515 Barrier->setDebugLoc(DLoc);
4516
4517 if (WorkFnAI->getType()->getPointerAddressSpace() !=
4518 (unsigned int)AddressSpace::Generic) {
4519 WorkFnAI = new AddrSpaceCastInst(
4520 WorkFnAI, PointerType::get(Ctx, (unsigned int)AddressSpace::Generic),
4521 WorkFnAI->getName() + ".generic", StateMachineBeginBB);
4522 WorkFnAI->setDebugLoc(DLoc);
4523 }
4524
4525 FunctionCallee KernelParallelFn =
4526 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4527 M, OMPRTL___kmpc_kernel_parallel);
4528 CallInst *IsActiveWorker = CallInst::Create(
4529 KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
4530 OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);
4531 IsActiveWorker->setDebugLoc(DLoc);
4532 Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
4533 StateMachineBeginBB);
4534 WorkFn->setDebugLoc(DLoc);
4535
4536 FunctionType *ParallelRegionFnTy = FunctionType::get(
4537 Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)},
4538 false);
4539
4540 Instruction *IsDone =
4541 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
4542 Constant::getNullValue(VoidPtrTy), "worker.is_done",
4543 StateMachineBeginBB);
4544 IsDone->setDebugLoc(DLoc);
4545 CondBrInst::Create(IsDone, StateMachineFinishedBB,
4546 StateMachineIsActiveCheckBB, StateMachineBeginBB)
4547 ->setDebugLoc(DLoc);
4548
4549 CondBrInst::Create(IsActiveWorker, StateMachineIfCascadeCurrentBB,
4550 StateMachineDoneBarrierBB, StateMachineIsActiveCheckBB)
4551 ->setDebugLoc(DLoc);
4552
4553 Value *ZeroArg =
4554 Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
4555
4556 const unsigned int WrapperFunctionArgNo = 6;
4557
4558 // Now that we have most of the CFG skeleton it is time for the if-cascade
4559 // that checks the function pointer we got from the runtime against the
4560 // parallel regions we expect, if there are any.
4561 for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) {
4562 auto *CB = ReachedKnownParallelRegions[I];
4563 auto *ParallelRegion = dyn_cast<Function>(
4564 CB->getArgOperand(WrapperFunctionArgNo)->stripPointerCasts());
4565 BasicBlock *PRExecuteBB = BasicBlock::Create(
4566 Ctx, "worker_state_machine.parallel_region.execute", Kernel,
4567 StateMachineEndParallelBB);
4568 CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
4569 ->setDebugLoc(DLoc);
4570 UncondBrInst::Create(StateMachineEndParallelBB, PRExecuteBB)
4571 ->setDebugLoc(DLoc);
4572
4573 BasicBlock *PRNextBB =
4574 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4575 Kernel, StateMachineEndParallelBB);
4576 A.registerManifestAddedBasicBlock(*PRExecuteBB);
4577 A.registerManifestAddedBasicBlock(*PRNextBB);
4578
4579 // Check if we need to compare the pointer at all or if we can just
4580 // call the parallel region function.
4581 Value *IsPR;
4582 if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) {
4583 Instruction *CmpI = ICmpInst::Create(
4584 ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn, ParallelRegion,
4585 "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
4586 CmpI->setDebugLoc(DLoc);
4587 IsPR = CmpI;
4588 } else {
4589 IsPR = ConstantInt::getTrue(Ctx);
4590 }
4591
4592 CondBrInst::Create(IsPR, PRExecuteBB, PRNextBB,
4593 StateMachineIfCascadeCurrentBB)
4594 ->setDebugLoc(DLoc);
4595 StateMachineIfCascadeCurrentBB = PRNextBB;
4596 }
4597
4598 // At the end of the if-cascade we place the indirect function pointer call
4599 // in case we might need it, that is if there can be parallel regions we
4600 // have not handled in the if-cascade above.
4601 if (!ReachedUnknownParallelRegions.empty()) {
4602 StateMachineIfCascadeCurrentBB->setName(
4603 "worker_state_machine.parallel_region.fallback.execute");
4604 CallInst::Create(ParallelRegionFnTy, WorkFn, {ZeroArg, GTid}, "",
4605 StateMachineIfCascadeCurrentBB)
4606 ->setDebugLoc(DLoc);
4607 }
4608 UncondBrInst::Create(StateMachineEndParallelBB,
4609 StateMachineIfCascadeCurrentBB)
4610 ->setDebugLoc(DLoc);
4611
4612 FunctionCallee EndParallelFn =
4613 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4614 M, OMPRTL___kmpc_kernel_end_parallel);
4615 CallInst *EndParallel =
4616 CallInst::Create(EndParallelFn, {}, "", StateMachineEndParallelBB);
4617 OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);
4618 EndParallel->setDebugLoc(DLoc);
4619 UncondBrInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
4620 ->setDebugLoc(DLoc);
4621
4622 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
4623 ->setDebugLoc(DLoc);
4624 UncondBrInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
4625 ->setDebugLoc(DLoc);
4626
4627 return true;
4628 }
4629
4630 /// Fixpoint iteration update function. Will be called every time a dependence
4631 /// changed its state (and in the beginning).
4632 ChangeStatus updateImpl(Attributor &A) override {
4633 KernelInfoState StateBefore = getState();
4634
4635 // When we leave this function this RAII will make sure the member
4636 // KernelEnvC is updated properly depending on the state. That member is
4637 // used for simplification of values and needs to be up to date at all
4638 // times.
4639 struct UpdateKernelEnvCRAII {
4640 AAKernelInfoFunction &AA;
4641
4642 UpdateKernelEnvCRAII(AAKernelInfoFunction &AA) : AA(AA) {}
4643
4644 ~UpdateKernelEnvCRAII() {
4645 if (!AA.KernelEnvC)
4646 return;
4647
4648 ConstantStruct *ExistingKernelEnvC =
4650
4651 if (!AA.isValidState()) {
4652 AA.KernelEnvC = ExistingKernelEnvC;
4653 return;
4654 }
4655
4656 if (!AA.ReachedKnownParallelRegions.isValidState())
4657 AA.setUseGenericStateMachineOfKernelEnvironment(
4658 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4659 ExistingKernelEnvC));
4660
4661 if (!AA.SPMDCompatibilityTracker.isValidState())
4662 AA.setExecModeOfKernelEnvironment(
4663 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC));
4664
4665 ConstantInt *MayUseNestedParallelismC =
4666 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(
4667 AA.KernelEnvC);
4668 ConstantInt *NewMayUseNestedParallelismC = ConstantInt::get(
4669 MayUseNestedParallelismC->getIntegerType(), AA.NestedParallelism);
4670 AA.setMayUseNestedParallelismOfKernelEnvironment(
4671 NewMayUseNestedParallelismC);
4672 }
4673 } RAII(*this);
4674
4675 // Callback to check a read/write instruction.
4676 auto CheckRWInst = [&](Instruction &I) {
4677 // We handle calls later.
4678 if (isa<CallBase>(I))
4679 return true;
4680 // We only care about write effects.
4681 if (!I.mayWriteToMemory())
4682 return true;
4683 if (auto *SI = dyn_cast<StoreInst>(&I)) {
4684 const auto *UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>(
4685 *this, IRPosition::value(*SI->getPointerOperand()),
4686 DepClassTy::OPTIONAL);
4687 auto *HS = A.getAAFor<AAHeapToStack>(
4688 *this, IRPosition::function(*I.getFunction()),
4689 DepClassTy::OPTIONAL);
4690 if (UnderlyingObjsAA &&
4691 UnderlyingObjsAA->forallUnderlyingObjects([&](Value &Obj) {
4692 if (AA::isAssumedThreadLocalObject(A, Obj, *this))
4693 return true;
4694 // Check for AAHeapToStack moved objects which must not be
4695 // guarded.
4696 auto *CB = dyn_cast<CallBase>(&Obj);
4697 return CB && HS && HS->isAssumedHeapToStack(*CB);
4698 }))
4699 return true;
4700 }
4701
4702 // Insert instruction that needs guarding.
4703 SPMDCompatibilityTracker.insert(&I);
4704 return true;
4705 };
4706
4707 bool UsedAssumedInformationInCheckRWInst = false;
4708 if (!SPMDCompatibilityTracker.isAtFixpoint())
4709 if (!A.checkForAllReadWriteInstructions(
4710 CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
4711 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4712
4713 bool UsedAssumedInformationFromReachingKernels = false;
4714 if (!IsKernelEntry) {
4715 updateParallelLevels(A);
4716
4717 bool AllReachingKernelsKnown = true;
4718 updateReachingKernelEntries(A, AllReachingKernelsKnown);
4719 UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
4720
4721 if (!SPMDCompatibilityTracker.empty()) {
4722 if (!ParallelLevels.isValidState())
4723 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4724 else if (!ReachingKernelEntries.isValidState())
4725 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4726 else {
4727 // Check if all reaching kernels agree on the mode as we can otherwise
4728 // not guard instructions. We might not be sure about the mode so we
4729 // we cannot fix the internal spmd-zation state either.
4730 int SPMD = 0, Generic = 0;
4731 for (auto *Kernel : ReachingKernelEntries) {
4732 auto *CBAA = A.getAAFor<AAKernelInfo>(
4733 *this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL);
4734 if (CBAA && CBAA->SPMDCompatibilityTracker.isValidState() &&
4735 CBAA->SPMDCompatibilityTracker.isAssumed())
4736 ++SPMD;
4737 else
4738 ++Generic;
4739 if (!CBAA || !CBAA->SPMDCompatibilityTracker.isAtFixpoint())
4740 UsedAssumedInformationFromReachingKernels = true;
4741 }
4742 if (SPMD != 0 && Generic != 0)
4743 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4744 }
4745 }
4746 }
4747
4748 // Callback to check a call instruction.
4749 bool AllParallelRegionStatesWereFixed = true;
4750 bool AllSPMDStatesWereFixed = true;
4751 auto CheckCallInst = [&](Instruction &I) {
4752 auto &CB = cast<CallBase>(I);
4753 auto *CBAA = A.getAAFor<AAKernelInfo>(
4754 *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4755 if (!CBAA)
4756 return false;
4757 getState() ^= CBAA->getState();
4758 AllSPMDStatesWereFixed &= CBAA->SPMDCompatibilityTracker.isAtFixpoint();
4759 AllParallelRegionStatesWereFixed &=
4760 CBAA->ReachedKnownParallelRegions.isAtFixpoint();
4761 AllParallelRegionStatesWereFixed &=
4762 CBAA->ReachedUnknownParallelRegions.isAtFixpoint();
4763 return true;
4764 };
4765
4766 bool UsedAssumedInformationInCheckCallInst = false;
4767 if (!A.checkForAllCallLikeInstructions(
4768 CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) {
4769 LLVM_DEBUG(dbgs() << TAG
4770 << "Failed to visit all call-like instructions!\n";);
4771 return indicatePessimisticFixpoint();
4772 }
4773
4774 // If we haven't used any assumed information for the reached parallel
4775 // region states we can fix it.
4776 if (!UsedAssumedInformationInCheckCallInst &&
4777 AllParallelRegionStatesWereFixed) {
4778 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
4779 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
4780 }
4781
4782 // If we haven't used any assumed information for the SPMD state we can fix
4783 // it.
4784 if (!UsedAssumedInformationInCheckRWInst &&
4785 !UsedAssumedInformationInCheckCallInst &&
4786 !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
4787 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4788
4789 return StateBefore == getState() ? ChangeStatus::UNCHANGED
4790 : ChangeStatus::CHANGED;
4791 }
4792
4793private:
4794 /// Update info regarding reaching kernels.
4795 void updateReachingKernelEntries(Attributor &A,
4796 bool &AllReachingKernelsKnown) {
4797 auto PredCallSite = [&](AbstractCallSite ACS) {
4798 Function *Caller = ACS.getInstruction()->getFunction();
4799
4800 assert(Caller && "Caller is nullptr");
4801
4802 auto *CAA = A.getOrCreateAAFor<AAKernelInfo>(
4803 IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
4804 if (CAA && CAA->ReachingKernelEntries.isValidState()) {
4805 ReachingKernelEntries ^= CAA->ReachingKernelEntries;
4806 return true;
4807 }
4808
4809 // We lost track of the caller of the associated function, any kernel
4810 // could reach now.
4811 ReachingKernelEntries.indicatePessimisticFixpoint();
4812
4813 return true;
4814 };
4815
4816 if (!A.checkForAllCallSites(PredCallSite, *this,
4817 true /* RequireAllCallSites */,
4818 AllReachingKernelsKnown))
4819 ReachingKernelEntries.indicatePessimisticFixpoint();
4820 }
4821
4822 /// Update info regarding parallel levels.
4823 void updateParallelLevels(Attributor &A) {
4824 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4825 OMPInformationCache::RuntimeFunctionInfo &Parallel60RFI =
4826 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_60];
4827
4828 auto PredCallSite = [&](AbstractCallSite ACS) {
4829 Function *Caller = ACS.getInstruction()->getFunction();
4830
4831 assert(Caller && "Caller is nullptr");
4832
4833 auto *CAA =
4834 A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
4835 if (CAA && CAA->ParallelLevels.isValidState()) {
4836 // Any function that is called by `__kmpc_parallel_60` will not be
4837 // folded as the parallel level in the function is updated. In order to
4838 // get it right, all the analysis would depend on the implentation. That
4839 // said, if in the future any change to the implementation, the analysis
4840 // could be wrong. As a consequence, we are just conservative here.
4841 if (Caller == Parallel60RFI.Declaration) {
4842 ParallelLevels.indicatePessimisticFixpoint();
4843 return true;
4844 }
4845
4846 ParallelLevels ^= CAA->ParallelLevels;
4847
4848 return true;
4849 }
4850
4851 // We lost track of the caller of the associated function, any kernel
4852 // could reach now.
4853 ParallelLevels.indicatePessimisticFixpoint();
4854
4855 return true;
4856 };
4857
4858 bool AllCallSitesKnown = true;
4859 if (!A.checkForAllCallSites(PredCallSite, *this,
4860 true /* RequireAllCallSites */,
4861 AllCallSitesKnown))
4862 ParallelLevels.indicatePessimisticFixpoint();
4863 }
4864};
4865
4866/// The call site kernel info abstract attribute, basically, what can we say
4867/// about a call site with regards to the KernelInfoState. For now this simply
4868/// forwards the information from the callee.
4869struct AAKernelInfoCallSite : AAKernelInfo {
4870 AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
4871 : AAKernelInfo(IRP, A) {}
4872
4873 /// See AbstractAttribute::initialize(...).
4874 void initialize(Attributor &A) override {
4875 AAKernelInfo::initialize(A);
4876
4877 CallBase &CB = cast<CallBase>(getAssociatedValue());
4878 auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>(
4879 *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4880
4881 // Check for SPMD-mode assumptions.
4882 if (AssumptionAA && AssumptionAA->hasAssumption("ompx_spmd_amenable")) {
4883 indicateOptimisticFixpoint();
4884 return;
4885 }
4886
4887 // First weed out calls we do not care about, that is readonly/readnone
4888 // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
4889 // parallel region or anything else we are looking for.
4890 if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
4891 indicateOptimisticFixpoint();
4892 return;
4893 }
4894
4895 // Next we check if we know the callee. If it is a known OpenMP function
4896 // we will handle them explicitly in the switch below. If it is not, we
4897 // will use an AAKernelInfo object on the callee to gather information and
4898 // merge that into the current state. The latter happens in the updateImpl.
4899 auto CheckCallee = [&](Function *Callee, unsigned NumCallees) {
4900 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4901 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4902 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4903 // Unknown caller or declarations are not analyzable, we give up.
4904 if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
4905
4906 // Unknown callees might contain parallel regions, except if they have
4907 // an appropriate assumption attached.
4908 if (!AssumptionAA ||
4909 !(AssumptionAA->hasAssumption("omp_no_openmp") ||
4910 AssumptionAA->hasAssumption("omp_no_parallelism")))
4911 ReachedUnknownParallelRegions.insert(&CB);
4912
4913 // If SPMDCompatibilityTracker is not fixed, we need to give up on the
4914 // idea we can run something unknown in SPMD-mode.
4915 if (!SPMDCompatibilityTracker.isAtFixpoint()) {
4916 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4917 SPMDCompatibilityTracker.insert(&CB);
4918 }
4919
4920 // We have updated the state for this unknown call properly, there
4921 // won't be any change so we indicate a fixpoint.
4922 indicateOptimisticFixpoint();
4923 }
4924 // If the callee is known and can be used in IPO, we will update the
4925 // state based on the callee state in updateImpl.
4926 return;
4927 }
4928 if (NumCallees > 1) {
4929 indicatePessimisticFixpoint();
4930 return;
4931 }
4932
4933 RuntimeFunction RF = It->getSecond();
4934 switch (RF) {
4935 // All the functions we know are compatible with SPMD mode.
4936 case OMPRTL___kmpc_is_spmd_exec_mode:
4937 case OMPRTL___kmpc_distribute_static_fini:
4938 case OMPRTL___kmpc_for_static_fini:
4939 case OMPRTL___kmpc_global_thread_num:
4940 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4941 case OMPRTL___kmpc_get_hardware_num_blocks:
4942 case OMPRTL___kmpc_single:
4943 case OMPRTL___kmpc_end_single:
4944 case OMPRTL___kmpc_master:
4945 case OMPRTL___kmpc_end_master:
4946 case OMPRTL___kmpc_barrier:
4947 case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
4948 case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
4949 case OMPRTL___kmpc_error:
4950 case OMPRTL___kmpc_flush:
4951 case OMPRTL___kmpc_get_hardware_thread_id_in_block:
4952 case OMPRTL___kmpc_get_warp_size:
4953 case OMPRTL_omp_get_thread_num:
4954 case OMPRTL_omp_get_num_threads:
4955 case OMPRTL_omp_get_max_threads:
4956 case OMPRTL_omp_in_parallel:
4957 case OMPRTL_omp_get_dynamic:
4958 case OMPRTL_omp_get_cancellation:
4959 case OMPRTL_omp_get_nested:
4960 case OMPRTL_omp_get_schedule:
4961 case OMPRTL_omp_get_thread_limit:
4962 case OMPRTL_omp_get_supported_active_levels:
4963 case OMPRTL_omp_get_max_active_levels:
4964 case OMPRTL_omp_get_level:
4965 case OMPRTL_omp_get_ancestor_thread_num:
4966 case OMPRTL_omp_get_team_size:
4967 case OMPRTL_omp_get_active_level:
4968 case OMPRTL_omp_in_final:
4969 case OMPRTL_omp_get_proc_bind:
4970 case OMPRTL_omp_get_num_places:
4971 case OMPRTL_omp_get_num_procs:
4972 case OMPRTL_omp_get_place_proc_ids:
4973 case OMPRTL_omp_get_place_num:
4974 case OMPRTL_omp_get_partition_num_places:
4975 case OMPRTL_omp_get_partition_place_nums:
4976 case OMPRTL_omp_get_wtime:
4977 break;
4978 case OMPRTL___kmpc_distribute_static_init_4:
4979 case OMPRTL___kmpc_distribute_static_init_4u:
4980 case OMPRTL___kmpc_distribute_static_init_8:
4981 case OMPRTL___kmpc_distribute_static_init_8u:
4982 case OMPRTL___kmpc_for_static_init_4:
4983 case OMPRTL___kmpc_for_static_init_4u:
4984 case OMPRTL___kmpc_for_static_init_8:
4985 case OMPRTL___kmpc_for_static_init_8u: {
4986 // Check the schedule and allow static schedule in SPMD mode.
4987 unsigned ScheduleArgOpNo = 2;
4988 auto *ScheduleTypeCI =
4989 dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
4990 unsigned ScheduleTypeVal =
4991 ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
4992 switch (OMPScheduleType(ScheduleTypeVal)) {
4993 case OMPScheduleType::UnorderedStatic:
4994 case OMPScheduleType::UnorderedStaticChunked:
4995 case OMPScheduleType::OrderedDistribute:
4996 case OMPScheduleType::OrderedDistributeChunked:
4997 break;
4998 default:
4999 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5000 SPMDCompatibilityTracker.insert(&CB);
5001 break;
5002 };
5003 } break;
5004 case OMPRTL___kmpc_target_init:
5005 KernelInitCB = &CB;
5006 break;
5007 case OMPRTL___kmpc_target_deinit:
5008 KernelDeinitCB = &CB;
5009 break;
5010 case OMPRTL___kmpc_parallel_60:
5011 if (!handleParallel60(A, CB))
5012 indicatePessimisticFixpoint();
5013 return;
5014 case OMPRTL___kmpc_omp_task:
5015 // We do not look into tasks right now, just give up.
5016 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5017 SPMDCompatibilityTracker.insert(&CB);
5018 ReachedUnknownParallelRegions.insert(&CB);
5019 break;
5020 case OMPRTL___kmpc_alloc_shared:
5021 case OMPRTL___kmpc_free_shared:
5022 // Return without setting a fixpoint, to be resolved in updateImpl.
5023 return;
5024 default:
5025 // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
5026 // generally. However, they do not hide parallel regions.
5027 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5028 SPMDCompatibilityTracker.insert(&CB);
5029 break;
5030 }
5031 // All other OpenMP runtime calls will not reach parallel regions so they
5032 // can be safely ignored for now. Since it is a known OpenMP runtime call
5033 // we have now modeled all effects and there is no need for any update.
5034 indicateOptimisticFixpoint();
5035 };
5036
5037 const auto *AACE =
5038 A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5039 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5040 CheckCallee(getAssociatedFunction(), 1);
5041 return;
5042 }
5043 const auto &OptimisticEdges = AACE->getOptimisticEdges();
5044 for (auto *Callee : OptimisticEdges) {
5045 CheckCallee(Callee, OptimisticEdges.size());
5046 if (isAtFixpoint())
5047 break;
5048 }
5049 }
5050
5051 ChangeStatus updateImpl(Attributor &A) override {
5052 // TODO: Once we have call site specific value information we can provide
5053 // call site specific liveness information and then it makes
5054 // sense to specialize attributes for call sites arguments instead of
5055 // redirecting requests to the callee argument.
5056 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5057 KernelInfoState StateBefore = getState();
5058
5059 auto CheckCallee = [&](Function *F, int NumCallees) {
5060 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
5061
5062 // If F is not a runtime function, propagate the AAKernelInfo of the
5063 // callee.
5064 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
5065 const IRPosition &FnPos = IRPosition::function(*F);
5066 auto *FnAA =
5067 A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
5068 if (!FnAA)
5069 return indicatePessimisticFixpoint();
5070 if (getState() == FnAA->getState())
5071 return ChangeStatus::UNCHANGED;
5072 getState() = FnAA->getState();
5073 return ChangeStatus::CHANGED;
5074 }
5075 if (NumCallees > 1)
5076 return indicatePessimisticFixpoint();
5077
5078 CallBase &CB = cast<CallBase>(getAssociatedValue());
5079 if (It->getSecond() == OMPRTL___kmpc_parallel_60) {
5080 if (!handleParallel60(A, CB))
5081 return indicatePessimisticFixpoint();
5082 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5083 : ChangeStatus::CHANGED;
5084 }
5085
5086 // F is a runtime function that allocates or frees memory, check
5087 // AAHeapToStack and AAHeapToShared.
5088 assert(
5089 (It->getSecond() == OMPRTL___kmpc_alloc_shared ||
5090 It->getSecond() == OMPRTL___kmpc_free_shared) &&
5091 "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
5092
5093 auto *HeapToStackAA = A.getAAFor<AAHeapToStack>(
5094 *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5095 auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>(
5096 *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5097
5098 RuntimeFunction RF = It->getSecond();
5099
5100 switch (RF) {
5101 // If neither HeapToStack nor HeapToShared assume the call is removed,
5102 // assume SPMD incompatibility.
5103 case OMPRTL___kmpc_alloc_shared:
5104 if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&
5105 (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))
5106 SPMDCompatibilityTracker.insert(&CB);
5107 break;
5108 case OMPRTL___kmpc_free_shared:
5109 if ((!HeapToStackAA ||
5110 !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&
5111 (!HeapToSharedAA ||
5112 !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))
5113 SPMDCompatibilityTracker.insert(&CB);
5114 break;
5115 default:
5116 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5117 SPMDCompatibilityTracker.insert(&CB);
5118 }
5119 return ChangeStatus::CHANGED;
5120 };
5121
5122 const auto *AACE =
5123 A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5124 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5125 if (Function *F = getAssociatedFunction())
5126 CheckCallee(F, /*NumCallees=*/1);
5127 } else {
5128 const auto &OptimisticEdges = AACE->getOptimisticEdges();
5129 for (auto *Callee : OptimisticEdges) {
5130 CheckCallee(Callee, OptimisticEdges.size());
5131 if (isAtFixpoint())
5132 break;
5133 }
5134 }
5135
5136 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5137 : ChangeStatus::CHANGED;
5138 }
5139
5140 /// Deal with a __kmpc_parallel_60 call (\p CB). Returns true if the call was
5141 /// handled, if a problem occurred, false is returned.
5142 bool handleParallel60(Attributor &A, CallBase &CB) {
5143 const unsigned int NonWrapperFunctionArgNo = 5;
5144 const unsigned int WrapperFunctionArgNo = 6;
5145 auto ParallelRegionOpArgNo = SPMDCompatibilityTracker.isAssumed()
5146 ? NonWrapperFunctionArgNo
5147 : WrapperFunctionArgNo;
5148
5149 auto *ParallelRegion = dyn_cast<Function>(
5150 CB.getArgOperand(ParallelRegionOpArgNo)->stripPointerCasts());
5151 if (!ParallelRegion)
5152 return false;
5153
5154 ReachedKnownParallelRegions.insert(&CB);
5155 /// Check nested parallelism
5156 auto *FnAA = A.getAAFor<AAKernelInfo>(
5157 *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL);
5158 NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
5159 !FnAA->ReachedKnownParallelRegions.empty() ||
5160 !FnAA->ReachedKnownParallelRegions.isValidState() ||
5161 !FnAA->ReachedUnknownParallelRegions.isValidState() ||
5162 !FnAA->ReachedUnknownParallelRegions.empty();
5163 return true;
5164 }
5165};
5166
5167struct AAFoldRuntimeCall
5168 : public StateWrapper<BooleanState, AbstractAttribute> {
5169 using Base = StateWrapper<BooleanState, AbstractAttribute>;
5170
5171 AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
5172
5173 /// Statistics are tracked as part of manifest for now.
5174 void trackStatistics() const override {}
5175
5176 /// Create an abstract attribute biew for the position \p IRP.
5177 static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
5178 Attributor &A);
5179
5180 /// See AbstractAttribute::getName()
5181 StringRef getName() const override { return "AAFoldRuntimeCall"; }
5182
5183 /// See AbstractAttribute::getIdAddr()
5184 const char *getIdAddr() const override { return &ID; }
5185
5186 /// This function should return true if the type of the \p AA is
5187 /// AAFoldRuntimeCall
5188 static bool classof(const AbstractAttribute *AA) {
5189 return (AA->getIdAddr() == &ID);
5190 }
5191
5192 static const char ID;
5193};
5194
5195struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
5196 AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
5197 : AAFoldRuntimeCall(IRP, A) {}
5198
5199 /// See AbstractAttribute::getAsStr()
5200 const std::string getAsStr(Attributor *) const override {
5201 if (!isValidState())
5202 return "<invalid>";
5203
5204 std::string Str("simplified value: ");
5205
5206 if (!SimplifiedValue)
5207 return Str + std::string("none");
5208
5209 if (!*SimplifiedValue)
5210 return Str + std::string("nullptr");
5211
5212 if (ConstantInt *CI = dyn_cast<ConstantInt>(*SimplifiedValue))
5213 return Str + std::to_string(CI->getSExtValue());
5214
5215 return Str + std::string("unknown");
5216 }
5217
5218 void initialize(Attributor &A) override {
5220 indicatePessimisticFixpoint();
5221
5222 Function *Callee = getAssociatedFunction();
5223
5224 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5225 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
5226 assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
5227 "Expected a known OpenMP runtime function");
5228
5229 RFKind = It->getSecond();
5230
5231 CallBase &CB = cast<CallBase>(getAssociatedValue());
5232 A.registerSimplificationCallback(
5234 [&](const IRPosition &IRP, const AbstractAttribute *AA,
5235 bool &UsedAssumedInformation) -> std::optional<Value *> {
5236 assert((isValidState() || SimplifiedValue == nullptr) &&
5237 "Unexpected invalid state!");
5238
5239 if (!isAtFixpoint()) {
5240 UsedAssumedInformation = true;
5241 if (AA)
5242 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
5243 }
5244 return SimplifiedValue;
5245 });
5246 }
5247
5248 ChangeStatus updateImpl(Attributor &A) override {
5249 ChangeStatus Changed = ChangeStatus::UNCHANGED;
5250 switch (RFKind) {
5251 case OMPRTL___kmpc_is_spmd_exec_mode:
5252 Changed |= foldIsSPMDExecMode(A);
5253 break;
5254 case OMPRTL___kmpc_parallel_level:
5255 Changed |= foldParallelLevel(A);
5256 break;
5257 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
5258 Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
5259 break;
5260 case OMPRTL___kmpc_get_hardware_num_blocks:
5261 Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
5262 break;
5263 default:
5264 llvm_unreachable("Unhandled OpenMP runtime function!");
5265 }
5266
5267 return Changed;
5268 }
5269
5270 ChangeStatus manifest(Attributor &A) override {
5271 ChangeStatus Changed = ChangeStatus::UNCHANGED;
5272
5273 if (SimplifiedValue && *SimplifiedValue) {
5274 Instruction &I = *getCtxI();
5275 A.changeAfterManifest(IRPosition::inst(I), **SimplifiedValue);
5276 A.deleteAfterManifest(I);
5277
5278 CallBase *CB = dyn_cast<CallBase>(&I);
5279 auto Remark = [&](OptimizationRemark OR) {
5280 if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue))
5281 return OR << "Replacing OpenMP runtime call "
5282 << CB->getCalledFunction()->getName() << " with "
5283 << ore::NV("FoldedValue", C->getZExtValue()) << ".";
5284 return OR << "Replacing OpenMP runtime call "
5285 << CB->getCalledFunction()->getName() << ".";
5286 };
5287
5288 if (CB && EnableVerboseRemarks)
5289 A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark);
5290
5291 LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "
5292 << **SimplifiedValue << "\n");
5293
5294 Changed = ChangeStatus::CHANGED;
5295 }
5296
5297 return Changed;
5298 }
5299
5300 ChangeStatus indicatePessimisticFixpoint() override {
5301 SimplifiedValue = nullptr;
5302 return AAFoldRuntimeCall::indicatePessimisticFixpoint();
5303 }
5304
5305private:
5306 /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
5307 ChangeStatus foldIsSPMDExecMode(Attributor &A) {
5308 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5309
5310 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5311 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5312 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5313 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5314
5315 if (!CallerKernelInfoAA ||
5316 !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5317 return indicatePessimisticFixpoint();
5318
5319 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5320 auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5321 DepClassTy::REQUIRED);
5322
5323 if (!AA || !AA->isValidState()) {
5324 SimplifiedValue = nullptr;
5325 return indicatePessimisticFixpoint();
5326 }
5327
5328 if (AA->SPMDCompatibilityTracker.isAssumed()) {
5329 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5330 ++KnownSPMDCount;
5331 else
5332 ++AssumedSPMDCount;
5333 } else {
5334 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5335 ++KnownNonSPMDCount;
5336 else
5337 ++AssumedNonSPMDCount;
5338 }
5339 }
5340
5341 if ((AssumedSPMDCount + KnownSPMDCount) &&
5342 (AssumedNonSPMDCount + KnownNonSPMDCount))
5343 return indicatePessimisticFixpoint();
5344
5345 auto &Ctx = getAnchorValue().getContext();
5346 if (KnownSPMDCount || AssumedSPMDCount) {
5347 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5348 "Expected only SPMD kernels!");
5349 // All reaching kernels are in SPMD mode. Update all function calls to
5350 // __kmpc_is_spmd_exec_mode to 1.
5351 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
5352 } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
5353 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5354 "Expected only non-SPMD kernels!");
5355 // All reaching kernels are in non-SPMD mode. Update all function
5356 // calls to __kmpc_is_spmd_exec_mode to 0.
5357 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
5358 } else {
5359 // We have empty reaching kernels, therefore we cannot tell if the
5360 // associated call site can be folded. At this moment, SimplifiedValue
5361 // must be none.
5362 assert(!SimplifiedValue && "SimplifiedValue should be none");
5363 }
5364
5365 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5366 : ChangeStatus::CHANGED;
5367 }
5368
5369 /// Fold __kmpc_parallel_level into a constant if possible.
5370 ChangeStatus foldParallelLevel(Attributor &A) {
5371 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5372
5373 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5374 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5375
5376 if (!CallerKernelInfoAA ||
5377 !CallerKernelInfoAA->ParallelLevels.isValidState())
5378 return indicatePessimisticFixpoint();
5379
5380 if (!CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5381 return indicatePessimisticFixpoint();
5382
5383 if (CallerKernelInfoAA->ReachingKernelEntries.empty()) {
5384 assert(!SimplifiedValue &&
5385 "SimplifiedValue should keep none at this point");
5386 return ChangeStatus::UNCHANGED;
5387 }
5388
5389 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5390 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5391 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5392 auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5393 DepClassTy::REQUIRED);
5394 if (!AA || !AA->SPMDCompatibilityTracker.isValidState())
5395 return indicatePessimisticFixpoint();
5396
5397 if (AA->SPMDCompatibilityTracker.isAssumed()) {
5398 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5399 ++KnownSPMDCount;
5400 else
5401 ++AssumedSPMDCount;
5402 } else {
5403 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5404 ++KnownNonSPMDCount;
5405 else
5406 ++AssumedNonSPMDCount;
5407 }
5408 }
5409
5410 if ((AssumedSPMDCount + KnownSPMDCount) &&
5411 (AssumedNonSPMDCount + KnownNonSPMDCount))
5412 return indicatePessimisticFixpoint();
5413
5414 auto &Ctx = getAnchorValue().getContext();
5415 // If the caller can only be reached by SPMD kernel entries, the parallel
5416 // level is 1. Similarly, if the caller can only be reached by non-SPMD
5417 // kernel entries, it is 0.
5418 if (AssumedSPMDCount || KnownSPMDCount) {
5419 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5420 "Expected only SPMD kernels!");
5421 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
5422 } else {
5423 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5424 "Expected only non-SPMD kernels!");
5425 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
5426 }
5427 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5428 : ChangeStatus::CHANGED;
5429 }
5430
5431 ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
5432 // Specialize only if all the calls agree with the attribute constant value
5433 int32_t CurrentAttrValue = -1;
5434 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5435
5436 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5437 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5438
5439 if (!CallerKernelInfoAA ||
5440 !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5441 return indicatePessimisticFixpoint();
5442
5443 // Iterate over the kernels that reach this function
5444 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5445 int32_t NextAttrVal = K->getFnAttributeAsParsedInteger(Attr, -1);
5446
5447 if (NextAttrVal == -1 ||
5448 (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
5449 return indicatePessimisticFixpoint();
5450 CurrentAttrValue = NextAttrVal;
5451 }
5452
5453 if (CurrentAttrValue != -1) {
5454 auto &Ctx = getAnchorValue().getContext();
5455 SimplifiedValue =
5456 ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
5457 }
5458 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5459 : ChangeStatus::CHANGED;
5460 }
5461
5462 /// An optional value the associated value is assumed to fold to. That is, we
5463 /// assume the associated value (which is a call) can be replaced by this
5464 /// simplified value.
5465 std::optional<Value *> SimplifiedValue;
5466
5467 /// The runtime function kind of the callee of the associated call site.
5468 RuntimeFunction RFKind;
5469};
5470
5471} // namespace
5472
5473/// Register folding callsite
5474void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
5475 auto &RFI = OMPInfoCache.RFIs[RF];
5476 RFI.foreachUse(SCC, [&](Use &U, Function &F) {
5477 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
5478 if (!CI)
5479 return false;
5480 A.getOrCreateAAFor<AAFoldRuntimeCall>(
5481 IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
5482 DepClassTy::NONE, /* ForceUpdate */ false,
5483 /* UpdateAfterInit */ false);
5484 return false;
5485 });
5486}
5487
5488void OpenMPOpt::registerAAs(bool IsModulePass) {
5489 if (SCC.empty())
5490 return;
5491
5492 if (IsModulePass) {
5493 // Ensure we create the AAKernelInfo AAs first and without triggering an
5494 // update. This will make sure we register all value simplification
5495 // callbacks before any other AA has the chance to create an AAValueSimplify
5496 // or similar.
5497 auto CreateKernelInfoCB = [&](Use &, Function &Kernel) {
5498 A.getOrCreateAAFor<AAKernelInfo>(
5499 IRPosition::function(Kernel), /* QueryingAA */ nullptr,
5500 DepClassTy::NONE, /* ForceUpdate */ false,
5501 /* UpdateAfterInit */ false);
5502 return false;
5503 };
5504 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
5505 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
5506 InitRFI.foreachUse(SCC, CreateKernelInfoCB);
5507
5508 registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
5509 registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
5510 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
5511 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
5512 }
5513
5514 // Create CallSite AA for all Getters.
5515 if (DeduceICVValues) {
5516 for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
5517 auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
5518
5519 auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
5520
5521 auto CreateAA = [&](Use &U, Function &Caller) {
5522 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
5523 if (!CI)
5524 return false;
5525
5526 auto &CB = cast<CallBase>(*CI);
5527
5528 IRPosition CBPos = IRPosition::callsite_function(CB);
5529 A.getOrCreateAAFor<AAICVTracker>(CBPos);
5530 return false;
5531 };
5532
5533 GetterRFI.foreachUse(SCC, CreateAA);
5534 }
5535 }
5536
5537 // Create an ExecutionDomain AA for every function and a HeapToStack AA for
5538 // every function if there is a device kernel.
5539 if (!isOpenMPDevice(M))
5540 return;
5541
5542 for (auto *F : SCC) {
5543 if (F->isDeclaration())
5544 continue;
5545
5546 // We look at internal functions only on-demand but if any use is not a
5547 // direct call or outside the current set of analyzed functions, we have
5548 // to do it eagerly.
5549 if (F->hasLocalLinkage()) {
5550 if (llvm::all_of(F->uses(), [this](const Use &U) {
5551 const auto *CB = dyn_cast<CallBase>(U.getUser());
5552 return CB && CB->isCallee(&U) &&
5553 A.isRunOn(const_cast<Function *>(CB->getCaller()));
5554 }))
5555 continue;
5556 }
5557 registerAAsForFunction(A, *F);
5558 }
5559}
5560
5561void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) {
5563 A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
5564 A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(F));
5566 A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(F));
5567 if (F.hasFnAttribute(Attribute::Convergent))
5568 A.getOrCreateAAFor<AANonConvergent>(IRPosition::function(F));
5569
5570 for (auto &I : instructions(F)) {
5571 if (auto *LI = dyn_cast<LoadInst>(&I)) {
5572 bool UsedAssumedInformation = false;
5573 A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
5574 UsedAssumedInformation, AA::Interprocedural);
5575 A.getOrCreateAAFor<AAAddressSpace>(
5576 IRPosition::value(*LI->getPointerOperand()));
5577 continue;
5578 }
5579 if (auto *CI = dyn_cast<CallBase>(&I)) {
5580 if (CI->isIndirectCall())
5581 A.getOrCreateAAFor<AAIndirectCallInfo>(
5583 }
5584 if (auto *SI = dyn_cast<StoreInst>(&I)) {
5585 A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI));
5586 A.getOrCreateAAFor<AAAddressSpace>(
5587 IRPosition::value(*SI->getPointerOperand()));
5588 continue;
5589 }
5590 if (auto *FI = dyn_cast<FenceInst>(&I)) {
5591 A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*FI));
5592 continue;
5593 }
5594 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
5595 if (II->getIntrinsicID() == Intrinsic::assume) {
5596 A.getOrCreateAAFor<AAPotentialValues>(
5597 IRPosition::value(*II->getArgOperand(0)));
5598 continue;
5599 }
5600 }
5601 }
5602}
5603
5604const char AAICVTracker::ID = 0;
5605const char AAKernelInfo::ID = 0;
5606const char AAExecutionDomain::ID = 0;
5607const char AAHeapToShared::ID = 0;
5608const char AAFoldRuntimeCall::ID = 0;
5609
5610AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
5611 Attributor &A) {
5612 AAICVTracker *AA = nullptr;
5613 switch (IRP.getPositionKind()) {
5618 llvm_unreachable("ICVTracker can only be created for function position!");
5620 AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
5621 break;
5623 AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
5624 break;
5626 AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
5627 break;
5629 AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
5630 break;
5631 }
5632
5633 return *AA;
5634}
5635
5637 Attributor &A) {
5638 AAExecutionDomainFunction *AA = nullptr;
5639 switch (IRP.getPositionKind()) {
5648 "AAExecutionDomain can only be created for function position!");
5650 AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
5651 break;
5652 }
5653
5654 return *AA;
5655}
5656
5657AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
5658 Attributor &A) {
5659 AAHeapToSharedFunction *AA = nullptr;
5660 switch (IRP.getPositionKind()) {
5669 "AAHeapToShared can only be created for function position!");
5671 AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
5672 break;
5673 }
5674
5675 return *AA;
5676}
5677
5678AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
5679 Attributor &A) {
5680 AAKernelInfo *AA = nullptr;
5681 switch (IRP.getPositionKind()) {
5688 llvm_unreachable("KernelInfo can only be created for function position!");
5690 AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
5691 break;
5693 AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
5694 break;
5695 }
5696
5697 return *AA;
5698}
5699
5700AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
5701 Attributor &A) {
5702 AAFoldRuntimeCall *AA = nullptr;
5703 switch (IRP.getPositionKind()) {
5711 llvm_unreachable("KernelInfo can only be created for call site position!");
5713 AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
5714 break;
5715 }
5716
5717 return *AA;
5718}
5719
5721 if (!containsOpenMP(M))
5722 return PreservedAnalyses::all();
5724 return PreservedAnalyses::all();
5725
5728 KernelSet Kernels = getDeviceKernels(M);
5729
5731 LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt Module Pass:\n" << M);
5732
5733 auto IsCalled = [&](Function &F) {
5734 if (Kernels.contains(&F))
5735 return true;
5736 return !F.use_empty();
5737 };
5738
5739 auto EmitRemark = [&](Function &F) {
5740 auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
5741 ORE.emit([&]() {
5742 OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
5743 return ORA << "Could not internalize function. "
5744 << "Some optimizations may not be possible. [OMP140]";
5745 });
5746 };
5747
5748 bool Changed = false;
5749
5750 // Create internal copies of each function if this is a kernel Module. This
5751 // allows iterprocedural passes to see every call edge.
5752 DenseMap<Function *, Function *> InternalizedMap;
5753 if (isOpenMPDevice(M)) {
5754 SmallPtrSet<Function *, 16> InternalizeFns;
5755 for (Function &F : M)
5756 if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
5759 InternalizeFns.insert(&F);
5760 } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
5761 EmitRemark(F);
5762 }
5763 }
5764
5765 Changed |=
5766 Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
5767 }
5768
5769 // Look at every function in the Module unless it was internalized.
5770 SetVector<Function *> Functions;
5772 for (Function &F : M)
5773 if (!F.isDeclaration() && !InternalizedMap.lookup(&F)) {
5774 SCC.push_back(&F);
5775 Functions.insert(&F);
5776 }
5777
5778 if (SCC.empty())
5780
5781 AnalysisGetter AG(FAM);
5782
5783 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5784 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
5785 };
5786
5787 BumpPtrAllocator Allocator;
5788 CallGraphUpdater CGUpdater;
5789
5790 bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5793 OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, PostLink);
5794
5795 unsigned MaxFixpointIterations =
5797
5798 AttributorConfig AC(CGUpdater);
5800 AC.IsModulePass = true;
5801 AC.RewriteSignatures = false;
5802 AC.MaxFixpointIterations = MaxFixpointIterations;
5803 AC.OREGetter = OREGetter;
5804 AC.PassName = DEBUG_TYPE;
5805 AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5806 AC.IPOAmendableCB = [](const Function &F) {
5807 return F.hasFnAttribute("kernel");
5808 };
5809
5810 Attributor A(Functions, InfoCache, AC);
5811
5812 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5813 Changed |= OMPOpt.run(true);
5814
5815 // Optionally inline device functions for potentially better performance.
5817 for (Function &F : M)
5818 if (!F.isDeclaration() && !Kernels.contains(&F) &&
5819 !F.hasFnAttribute(Attribute::NoInline))
5820 F.addFnAttr(Attribute::AlwaysInline);
5821
5823 LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M);
5824
5825 if (Changed)
5826 return PreservedAnalyses::none();
5827
5828 return PreservedAnalyses::all();
5829}
5830
5833 LazyCallGraph &CG,
5834 CGSCCUpdateResult &UR) {
5835 if (!containsOpenMP(*C.begin()->getFunction().getParent()))
5836 return PreservedAnalyses::all();
5838 return PreservedAnalyses::all();
5839
5841 // If there are kernels in the module, we have to run on all SCC's.
5842 for (LazyCallGraph::Node &N : C) {
5843 Function *Fn = &N.getFunction();
5844 SCC.push_back(Fn);
5845 }
5846
5847 if (SCC.empty())
5848 return PreservedAnalyses::all();
5849
5850 Module &M = *C.begin()->getFunction().getParent();
5851
5853 LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt CGSCC Pass:\n" << M);
5854
5856 AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
5857
5858 AnalysisGetter AG(FAM);
5859
5860 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5861 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
5862 };
5863
5864 BumpPtrAllocator Allocator;
5865 CallGraphUpdater CGUpdater;
5866 CGUpdater.initialize(CG, C, AM, UR);
5867
5868 bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5872 OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
5873 /*CGSCC*/ &Functions, PostLink);
5874
5875 unsigned MaxFixpointIterations =
5877
5878 AttributorConfig AC(CGUpdater);
5880 AC.IsModulePass = false;
5881 AC.RewriteSignatures = false;
5882 AC.MaxFixpointIterations = MaxFixpointIterations;
5883 AC.OREGetter = OREGetter;
5884 AC.PassName = DEBUG_TYPE;
5885 AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5886
5887 Attributor A(Functions, InfoCache, AC);
5888
5889 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5890 bool Changed = OMPOpt.run(false);
5891
5893 LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
5894
5895 if (Changed)
5896 return PreservedAnalyses::none();
5897
5898 return PreservedAnalyses::all();
5899}
5900
5902 return Fn.hasFnAttribute("kernel");
5903}
5904
5906 KernelSet Kernels;
5907
5908 for (Function &F : M)
5909 if (F.hasKernelCallingConv()) {
5910 // We are only interested in OpenMP target regions. Others, such as
5911 // kernels generated by CUDA but linked together, are not interesting to
5912 // this pass.
5913 if (isOpenMPKernel(F)) {
5914 ++NumOpenMPTargetRegionKernels;
5915 Kernels.insert(&F);
5916 } else
5917 ++NumNonOpenMPTargetRegionKernels;
5918 }
5919
5920 return Kernels;
5921}
5922
5924 Metadata *MD = M.getModuleFlag("openmp");
5925 if (!MD)
5926 return false;
5927
5928 return true;
5929}
5930
5932 Metadata *MD = M.getModuleFlag("openmp-device");
5933 if (!MD)
5934 return false;
5935
5936 return true;
5937}
@ Generic
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
static cl::opt< unsigned > SetFixpointIterations("attributor-max-iterations", cl::Hidden, cl::desc("Maximal number of fixpoint iterations."), cl::init(32))
static const Function * getParent(const Value *V)
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
This file provides interfaces used to manipulate a call graph, regardless if it is a "old style" Call...
This file provides interfaces used to build and manipulate a call graph, which is a very useful tool ...
This file contains the declarations for the subclasses of Constant, which represent the different fla...
dxil pretty DXIL Metadata Pretty Printer
This file defines the DenseSet and SmallDenseSet classes.
This file defines an array type that can be indexed using scoped enum values.
#define DEBUG_TYPE
static void emitRemark(const Function &F, OptimizationRemarkEmitter &ORE, bool Skip)
Loop::LoopBounds::Direction Direction
Definition LoopInfo.cpp:253
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
Machine Check Debug Module
This file provides utility analysis objects describing memory locations.
#define T
uint64_t IntrinsicInst * II
This file defines constans and helpers used when dealing with OpenMP.
This file defines constans that will be used by both host and device compilation.
static constexpr auto TAG
static cl::opt< bool > HideMemoryTransferLatency("openmp-hide-memory-transfer-latency", cl::desc("[WIP] Tries to hide the latency of host to device memory" " transfers"), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptStateMachineRewrite("openmp-opt-disable-state-machine-rewrite", cl::desc("Disable OpenMP optimizations that replace the state machine."), cl::Hidden, cl::init(false))
static cl::opt< bool > EnableParallelRegionMerging("openmp-opt-enable-merging", cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintModuleAfterOptimizations("openmp-opt-print-module-after", cl::desc("Print the current module after OpenMP optimizations."), cl::Hidden, cl::init(false))
#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER)
#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX)
#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER)
static cl::opt< bool > PrintOpenMPKernels("openmp-print-gpu-kernels", cl::init(false), cl::Hidden)
static cl::opt< bool > DisableOpenMPOptFolding("openmp-opt-disable-folding", cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintModuleBeforeOptimizations("openmp-opt-print-module-before", cl::desc("Print the current module before OpenMP optimizations."), cl::Hidden, cl::init(false))
static cl::opt< unsigned > SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden, cl::desc("Maximal number of attributor iterations."), cl::init(256))
static cl::opt< bool > DisableInternalization("openmp-opt-disable-internalization", cl::desc("Disable function internalization."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintICVValues("openmp-print-icv-values", cl::init(false), cl::Hidden)
static cl::opt< bool > DisableOpenMPOptimizations("openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."), cl::Hidden, cl::init(false))
static cl::opt< unsigned > SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden, cl::desc("Maximum amount of shared memory to use."), cl::init(std::numeric_limits< unsigned >::max()))
static cl::opt< bool > EnableVerboseRemarks("openmp-opt-verbose-remarks", cl::desc("Enables more verbose remarks."), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptDeglobalization("openmp-opt-disable-deglobalization", cl::desc("Disable OpenMP optimizations involving deglobalization."), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptBarrierElimination("openmp-opt-disable-barrier-elimination", cl::desc("Disable OpenMP optimizations that eliminate barriers."), cl::Hidden, cl::init(false))
#define DEBUG_TYPE
Definition OpenMPOpt.cpp:67
static cl::opt< bool > DeduceICVValues("openmp-deduce-icv-values", cl::init(false), cl::Hidden)
#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX)
#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE)
static cl::opt< bool > DisableOpenMPOptSPMDization("openmp-opt-disable-spmdization", cl::desc("Disable OpenMP optimizations involving SPMD-ization."), cl::Hidden, cl::init(false))
static cl::opt< bool > AlwaysInlineDeviceFunctions("openmp-opt-inline-device", cl::desc("Inline all applicable functions on the device."), cl::Hidden, cl::init(false))
#define P(N)
FunctionAnalysisManager FAM
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
static StringRef getName(Value *V)
R600 Clause Merge
Basic Register Allocator
std::pair< BasicBlock *, BasicBlock * > Edge
static bool contains(SmallPtrSetImpl< ConstantExpr * > &Cache, ConstantExpr *Expr, Constant *C)
Definition Value.cpp:487
This file implements a set that has insertion order iteration characteristics.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition Debug.h:114
static const int BlockSize
Definition TarWriter.cpp:33
static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, const llvm::StringTable &StandardNames, VectorLibrary VecLib)
Initialize the set of available library functions based on the specified target triple.
Value * RHS
static cl::opt< unsigned > MaxThreads("xcore-max-threads", cl::Optional, cl::desc("Maximum number of threads (for emulation thread-local storage)"), cl::Hidden, cl::value_desc("number"), cl::init(8))
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
size_t size() const
size - Get the array size.
Definition ArrayRef.h:142
iterator end()
Definition BasicBlock.h:462
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:449
LLVM_ABI const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
LLVM_ABI BasicBlock * splitBasicBlock(iterator I, const Twine &BBName="")
Split the basic block into two basic blocks at the specified instruction.
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
reverse_iterator rbegin()
Definition BasicBlock.h:465
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition BasicBlock.h:206
LLVM_ABI const BasicBlock * getUniqueSuccessor() const
Return the successor of this block if it has a unique successor.
InstListType::reverse_iterator reverse_iterator
Definition BasicBlock.h:172
reverse_iterator rend()
Definition BasicBlock.h:467
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition BasicBlock.h:233
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
void setCallingConv(CallingConv::ID CC)
bool arg_empty() const
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
bool doesNotAccessMemory(unsigned OpNo) const
LLVM_ABI bool isIndirectCall() const
Return true if the callsite is an indirect call.
bool isCallee(Value::const_user_iterator UI) const
Determine whether the passed iterator points to the callee operand's Use.
Value * getArgOperand(unsigned i) const
void setArgOperand(unsigned i, Value *v)
iterator_range< User::op_iterator > args()
Iteration adapter for range-for loops.
unsigned getArgOperandNo(const Use *U) const
Given a use for a arg operand, get the arg operand number that corresponds to it.
unsigned arg_size() const
AttributeList getAttributes() const
Return the attributes for this call.
void addParamAttr(unsigned ArgNo, Attribute::AttrKind Kind)
Adds the attribute to the indicated argument.
bool isArgOperand(const Use *U) const
bool hasOperandBundles() const
Return true if this User has any operand bundles.
LLVM_ABI Function * getCaller()
Helper to get the caller (the parent function).
Wrapper to unify "old style" CallGraph and "new style" LazyCallGraph.
void initialize(LazyCallGraph &LCG, LazyCallGraph::SCC &SCC, CGSCCAnalysisManager &AM, CGSCCUpdateResult &UR)
Initializers for usage outside of a CGSCC pass, inside a CGSCC pass in the old and new pass manager (...
static CallInst * Create(FunctionType *Ty, Value *F, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
@ ICMP_SLT
signed less than
Definition InstrTypes.h:705
@ ICMP_NE
not equal
Definition InstrTypes.h:698
static CondBrInst * Create(Value *Cond, BasicBlock *IfTrue, BasicBlock *IfFalse, InsertPosition InsertBefore=nullptr)
static LLVM_ABI Constant * getPointerCast(Constant *C, Type *Ty)
Create a BitCast, AddrSpaceCast, or a PtrToInt cast constant expression.
static LLVM_ABI Constant * getPointerBitCastOrAddrSpaceCast(Constant *C, Type *Ty)
Create a BitCast or AddrSpaceCast for a pointer type depending on the address space.
This is the shared class of boolean and integer constants.
Definition Constants.h:87
IntegerType * getIntegerType() const
Variant of the getType() method to always return an IntegerType, which reduces the amount of casting ...
Definition Constants.h:198
static LLVM_ABI ConstantInt * getTrue(LLVMContext &Context)
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition Constants.h:219
int64_t getSExtValue() const
Return the constant as a 64-bit integer value after it has been sign extended as appropriate for the ...
Definition Constants.h:174
This is an important base class in LLVM.
Definition Constant.h:43
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition DenseMap.h:205
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:241
LLVM_ABI Instruction * findNearestCommonDominator(Instruction *I1, Instruction *I2) const
Find the nearest instruction I that dominates both I1 and I2, in the sense that a result produced bef...
static ErrorSuccess success()
Create a success value.
Definition Error.h:336
AtomicOrdering getOrdering() const
Returns the ordering constraint of this fence instruction.
A proxy from a FunctionAnalysisManager to an SCC.
const BasicBlock & getEntryBlock() const
Definition Function.h:809
const BasicBlock & front() const
Definition Function.h:860
LLVMContext & getContext() const
getContext - Return a reference to the LLVMContext associated with this function.
Definition Function.cpp:358
Argument * getArg(unsigned i) const
Definition Function.h:886
bool hasFnAttribute(Attribute::AttrKind Kind) const
Return true if the function has the attribute.
Definition Function.cpp:728
LLVM_ABI bool isDeclaration() const
Return true if the primary definition of this global value is outside of the current translation unit...
Definition Globals.cpp:329
bool hasLocalLinkage() const
Module * getParent()
Get the module that this global value is contained inside of...
@ PrivateLinkage
Like Internal, but omit from symbol table.
Definition GlobalValue.h:61
@ InternalLinkage
Rename collisions when linking (static functions).
Definition GlobalValue.h:60
const Constant * getInitializer() const
getInitializer - Return the initializer for this global variable.
LLVM_ABI void setInitializer(Constant *InitVal)
setInitializer - Sets the initializer for this global variable, removing any existing initializer if ...
Definition Globals.cpp:534
BasicBlock * getBlock() const
Definition IRBuilder.h:306
CondBrInst * CreateCondBr(Value *Cond, BasicBlock *True, BasicBlock *False, MDNode *BranchWeights=nullptr, MDNode *Unpredictable=nullptr)
Create a conditional 'br Cond, TrueDest, FalseDest' instruction.
Definition IRBuilder.h:1223
CallInst * CreateCall(FunctionType *FTy, Value *Callee, ArrayRef< Value * > Args={}, const Twine &Name="", MDNode *FPMathTag=nullptr)
Definition IRBuilder.h:2510
Value * CreateIsNull(Value *Arg, const Twine &Name="")
Return a boolean value testing if Arg == 0.
Definition IRBuilder.h:2659
LLVM_ABI bool isLifetimeStartOrEnd() const LLVM_READONLY
Return true if the instruction is a llvm.lifetime.start or llvm.lifetime.end marker.
LLVM_ABI bool mayWriteToMemory() const LLVM_READONLY
Return true if this instruction may modify memory.
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
LLVM_ABI const Module * getModule() const
Return the module owning the function this instruction belongs to or nullptr it the function does not...
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
LLVM_ABI const Function * getFunction() const
Return the function this instruction belongs to.
LLVM_ABI bool mayHaveSideEffects() const LLVM_READONLY
Return true if the instruction may have side effects.
LLVM_ABI bool mayReadFromMemory() const LLVM_READONLY
Return true if this instruction may read memory.
void setDebugLoc(DebugLoc Loc)
Set the debug location information for this instruction.
LLVM_ABI void setSuccessor(unsigned Idx, BasicBlock *BB)
Update the specified successor to point at the provided block.
A node in the call graph.
An SCC of the call graph.
A lazily constructed view of the call graph of a module.
LLVM_ABI void eraseFromParent()
This method unlinks 'this' from the containing function and deletes it.
LLVM_ABI StringRef getName() const
Return the name of the corresponding LLVM basic block, or an empty string.
Root of the metadata hierarchy.
Definition Metadata.h:64
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
const Triple & getTargetTriple() const
Get the target triple which is a string describing the target host.
Definition Module.h:281
LLVM_ABI Constant * getOrCreateIdent(Constant *SrcLocStr, uint32_t SrcLocStrSize, omp::IdentFlag Flags=omp::IdentFlag(0), unsigned Reserve2Flags=0)
Return an ident_t* encoding the source location SrcLocStr and Flags.
LLVM_ABI FunctionCallee getOrCreateRuntimeFunction(Module &M, omp::RuntimeFunction FnID)
Return the function declaration for the runtime function with FnID.
static LLVM_ABI std::pair< int32_t, int32_t > readThreadBoundsForKernel(const Triple &T, Function &Kernel)
}
LLVM_ABI Constant * getOrCreateSrcLocStr(StringRef LocStr, uint32_t &SrcLocStrSize)
Return the (LLVM-IR) string describing the source location LocStr.
IRBuilder<>::InsertPoint InsertPointTy
Type used throughout for insertion points.
IRBuilder Builder
The LLVM-IR Builder used to create IR.
static LLVM_ABI std::pair< int32_t, int32_t > readTeamBoundsForKernel(const Triple &T, Function &Kernel)
Read/write a bounds on teams for Kernel.
bool updateToLocation(const LocationDescription &Loc)
Update the internal location to Loc.
PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR)
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM)
Diagnostic information for optimization analysis remarks.
The optimization diagnostic interface.
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Definition Analysis.h:115
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
static ReturnInst * Create(LLVMContext &C, Value *retVal=nullptr, InsertPosition InsertBefore=nullptr)
A vector that has set insertion semantics.
Definition SetVector.h:57
size_type size() const
Determine the number of elements in the SetVector.
Definition SetVector.h:103
size_type count(const_arg_type key) const
Count the number of elements of a given key in the SetVector.
Definition SetVector.h:262
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition SetVector.h:151
size_type size() const
Definition SmallPtrSet.h:99
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
iterator begin() const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
reference emplace_back(ArgTypes &&... Args)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
bool starts_with(StringRef Prefix) const
Check if this string starts with the given Prefix.
Definition StringRef.h:258
Triple - Helper class for working with autoconf configuration names.
Definition Triple.h:47
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
Definition Type.cpp:313
LLVM_ABI unsigned getPointerAddressSpace() const
Get the address space of this pointer or pointer vector type.
static UncondBrInst * Create(BasicBlock *IfTrue, InsertPosition InsertBefore=nullptr)
static LLVM_ABI UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
LLVM_ABI bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Definition User.cpp:25
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
LLVM_ABI void setName(const Twine &Name)
Change the name of the value.
Definition Value.cpp:397
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition Value.h:440
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition Value.cpp:553
iterator_range< user_iterator > users()
Definition Value.h:427
User * user_back()
Definition Value.h:413
LLVM_ABI const Value * stripPointerCasts() const
Strip off pointer casts, all-zero GEPs and address space casts.
Definition Value.cpp:713
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:322
const ParentTy * getParent() const
Definition ilist_node.h:34
self_iterator getIterator()
Definition ilist_node.h:123
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition ilist_node.h:348
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
GlobalVariable * getKernelEnvironementGVFromKernelInitCB(CallBase *KernelInitCB)
ConstantStruct * getKernelEnvironementFromKernelInitCB(CallBase *KernelInitCB)
Abstract Attribute helper functions.
Definition Attributor.h:165
LLVM_ABI bool isValidAtPosition(const ValueAndContext &VAC, InformationCache &InfoCache)
Return true if the value of VAC is a valid at the position of VAC, that is a constant,...
LLVM_ABI bool isPotentiallyAffectedByBarrier(Attributor &A, const Instruction &I, const AbstractAttribute &QueryingAA)
Return true if I is potentially affected by a barrier.
@ Interprocedural
Definition Attributor.h:184
LLVM_ABI bool isNoSyncInst(Attributor &A, const Instruction &I, const AbstractAttribute &QueryingAA)
Return true if I is a nosync instruction.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
E & operator^=(E &LHS, E RHS)
@ Entry
Definition COFF.h:862
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
initializer< Ty > init(const Ty &Val)
PointerTypeMap run(const Module &M)
Compute the PointerTypeMap for the module M.
bool isOpenMPDevice(Module &M)
Helper to determine if M is a OpenMP target offloading device module.
bool containsOpenMP(Module &M)
Helper to determine if M contains OpenMP.
InternalControlVar
IDs for all Internal Control Variables (ICVs).
RuntimeFunction
IDs for all omp runtime library (RTL) functions.
KernelSet getDeviceKernels(Module &M)
Get OpenMP device kernels in M.
@ OMP_TGT_EXEC_MODE_GENERIC_SPMD
SetVector< Kernel > KernelSet
Set of kernels in the module.
Definition OpenMPOpt.h:24
Function * Kernel
Summary of a kernel (=entry point for target offloading).
Definition OpenMPOpt.h:21
bool isOpenMPKernel(Function &Fn)
Return true iff Fn is an OpenMP GPU kernel; Fn has the "kernel" attribute.
DiagnosticInfoOptimizationBase::Argument NV
NodeAddr< UseNode * > Use
Definition RDFGraph.h:385
bool empty() const
Definition BasicBlock.h:101
iterator end() const
Definition BasicBlock.h:89
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
LLVM_ABI iterator begin() const
This is an optimization pass for GlobalISel generic memory operations.
Definition Types.h:26
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:316
@ Offset
Definition DWP.cpp:532
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1739
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition STLExtras.h:1669
bool succ_empty(const Instruction *I)
Definition CFG.h:153
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
FunctionAddr VTableAddr uintptr_t uintptr_t Int32Ty
Definition InstrProf.h:296
bool operator!=(uint64_t V1, const APInt &V2)
Definition APInt.h:2128
constexpr from_range_t from_range
Value * GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, const DataLayout &DL, bool AllowNonInbounds=true)
Analyze the specified pointer to see if it can be expressed as a base pointer plus a constant offset.
InnerAnalysisManagerProxy< FunctionAnalysisManager, Module > FunctionAnalysisManagerModuleProxy
Provide the FunctionAnalysisManager to Module proxy.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
AnalysisManager< LazyCallGraph::SCC, LazyCallGraph & > CGSCCAnalysisManager
The CGSCC analysis manager.
@ ThinLTOPostLink
ThinLTO postlink (backend compile) phase.
Definition Pass.h:83
@ FullLTOPostLink
Full LTO postlink (backend compile) phase.
Definition Pass.h:87
@ ThinLTOPreLink
ThinLTO prelink (summary) phase.
Definition Pass.h:81
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:753
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
MutableArrayRef(T &OneElt) -> MutableArrayRef< T >
void cantFail(Error Err, const char *Msg=nullptr)
Report a fatal error if Err is a failure value.
Definition Error.h:769
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
bool operator&=(SparseBitVector< ElementSize > *LHS, const SparseBitVector< ElementSize > &RHS)
LLVM_ABI BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="")
Split the specified block at the specified instruction.
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition STLExtras.h:2012
ArrayRef(const T &OneElt) -> ArrayRef< T >
std::string toString(const APInt &I, unsigned Radix, bool Signed, bool formatAsCLiteral=false, bool UpperCase=true, bool InsertSeparators=false)
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
auto predecessors(const MachineBasicBlock *BB)
ChangeStatus
{
Definition Attributor.h:496
LLVM_ABI Constant * ConstantFoldInsertValueInstruction(Constant *Agg, Constant *Val, ArrayRef< unsigned > Idxs)
Attempt to constant fold an insertvalue instruction with the specified operands and indices.
@ OPTIONAL
The target may be valid if the source is not.
Definition Attributor.h:508
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
BumpPtrAllocatorImpl<> BumpPtrAllocator
The standard BumpPtrAllocator which just uses the default template parameters.
Definition Allocator.h:383
LLVM_ABI const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=MaxLookupSearchDepth)
This method strips off any GEP address adjustments, pointer casts or llvm.threadlocal....
AnalysisManager< Module > ModuleAnalysisManager
Convenience typedef for the Module analysis manager.
Definition MIRParser.h:39
#define N
static LLVM_ABI AAExecutionDomain & createForPosition(const IRPosition &IRP, Attributor &A)
Create an abstract attribute view for the position IRP.
AAExecutionDomain(const IRPosition &IRP, Attributor &A)
static LLVM_ABI const char ID
Unique ID (due to the unique address)
AccessKind
Simple enum to distinguish read/write/read-write accesses.
StateType::base_t MemoryLocationsKind
static LLVM_ABI bool isAlignedBarrier(const CallBase &CB, bool ExecutedAligned)
Helper function to determine if CB is an aligned (GPU) barrier.
Base struct for all "concrete attribute" deductions.
virtual const char * getIdAddr() const =0
This function should return the address of the ID of the AbstractAttribute.
An interface to query the internal state of an abstract attribute.
Wrapper for FunctionAnalysisManager.
Configuration for the Attributor.
std::function< void(Attributor &A, const Function &F)> InitializationCallback
Callback function to be invoked on internal functions marked live.
std::optional< unsigned > MaxFixpointIterations
Maximum number of iterations to run until fixpoint.
bool RewriteSignatures
Flag to determine if we rewrite function signatures.
const char * PassName
}
OptimizationRemarkGetter OREGetter
IPOAmendableCBTy IPOAmendableCB
bool IsModulePass
Is the user of the Attributor a module pass or not.
bool DefaultInitializeLiveInternals
Flag to determine if we want to initialize all default AAs for an internal function marked live.
The fixpoint analysis framework that orchestrates the attribute deduction.
static LLVM_ABI bool isInternalizable(Function &F)
Returns true if the function F can be internalized.
std::function< std::optional< Value * >( const IRPosition &, const AbstractAttribute *, bool &)> SimplifictionCallbackTy
Register CB as a simplification callback.
std::function< std::optional< Constant * >( const GlobalVariable &, const AbstractAttribute *, bool &)> GlobalVariableSimplifictionCallbackTy
Register CB as a simplification callback.
std::function< bool(Attributor &, const AbstractAttribute *)> VirtualUseCallbackTy
static LLVM_ABI bool internalizeFunctions(SmallPtrSetImpl< Function * > &FnSet, DenseMap< Function *, Function * > &FnMap)
Make copies of each function in the set FnSet such that the copied version has internal linkage after...
Simple wrapper for a single bit (boolean) state.
Support structure for SCC passes to communicate updates the call graph back to the CGSCC pass manager...
Helper to describe and deal with positions in the LLVM-IR.
Definition Attributor.h:593
static const IRPosition callsite_returned(const CallBase &CB)
Create a position describing the returned value of CB.
Definition Attributor.h:661
static const IRPosition returned(const Function &F, const CallBaseContext *CBContext=nullptr)
Create a position describing the returned value of F.
Definition Attributor.h:643
static const IRPosition value(const Value &V, const CallBaseContext *CBContext=nullptr)
Create a position describing the value of V.
Definition Attributor.h:617
static const IRPosition inst(const Instruction &I, const CallBaseContext *CBContext=nullptr)
Create a position describing the instruction I.
Definition Attributor.h:629
@ IRP_ARGUMENT
An attribute for a function argument.
Definition Attributor.h:607
@ IRP_RETURNED
An attribute for the function return value.
Definition Attributor.h:603
@ IRP_CALL_SITE
An attribute for a call site (function scope).
Definition Attributor.h:606
@ IRP_CALL_SITE_RETURNED
An attribute for a call site return value.
Definition Attributor.h:604
@ IRP_FUNCTION
An attribute for a function (scope).
Definition Attributor.h:605
@ IRP_FLOAT
A position that is not associated with a spot suitable for attributes.
Definition Attributor.h:601
@ IRP_CALL_SITE_ARGUMENT
An attribute for a call site argument.
Definition Attributor.h:608
@ IRP_INVALID
An invalid position.
Definition Attributor.h:600
static const IRPosition function(const Function &F, const CallBaseContext *CBContext=nullptr)
Create a position describing the function scope of F.
Definition Attributor.h:636
Kind getPositionKind() const
Return the associated position kind.
Definition Attributor.h:889
static const IRPosition callsite_function(const CallBase &CB)
Create a position describing the function scope of CB.
Definition Attributor.h:656
Data structure to hold cached (LLVM-IR) information.
Defines various target-specific GPU grid values that must be consistent between host RTL (plugin),...