LLVM 20.0.0git
OpenMPOpt.cpp
Go to the documentation of this file.
1//===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// OpenMP specific optimizations:
10//
11// - Deduplication of runtime calls, e.g., omp_get_thread_num.
12// - Replacing globalized device memory with stack memory.
13// - Replacing globalized device memory with shared memory.
14// - Parallel region merging.
15// - Transforming generic-mode device kernels to SPMD mode.
16// - Specializing the state machine for generic-mode device kernels.
17//
18//===----------------------------------------------------------------------===//
19
21
24#include "llvm/ADT/SetVector.h"
27#include "llvm/ADT/Statistic.h"
29#include "llvm/ADT/StringRef.h"
37#include "llvm/IR/Assumptions.h"
38#include "llvm/IR/BasicBlock.h"
39#include "llvm/IR/Constants.h"
41#include "llvm/IR/Dominators.h"
42#include "llvm/IR/Function.h"
43#include "llvm/IR/GlobalValue.h"
45#include "llvm/IR/InstrTypes.h"
46#include "llvm/IR/Instruction.h"
49#include "llvm/IR/IntrinsicsAMDGPU.h"
50#include "llvm/IR/IntrinsicsNVPTX.h"
51#include "llvm/IR/LLVMContext.h"
54#include "llvm/Support/Debug.h"
58
59#include <algorithm>
60#include <optional>
61#include <string>
62
63using namespace llvm;
64using namespace omp;
65
66#define DEBUG_TYPE "openmp-opt"
67
69 "openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."),
70 cl::Hidden, cl::init(false));
71
73 "openmp-opt-enable-merging",
74 cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
75 cl::init(false));
76
77static cl::opt<bool>
78 DisableInternalization("openmp-opt-disable-internalization",
79 cl::desc("Disable function internalization."),
80 cl::Hidden, cl::init(false));
81
82static cl::opt<bool> DeduceICVValues("openmp-deduce-icv-values",
83 cl::init(false), cl::Hidden);
84static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
86static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
87 cl::init(false), cl::Hidden);
88
90 "openmp-hide-memory-transfer-latency",
91 cl::desc("[WIP] Tries to hide the latency of host to device memory"
92 " transfers"),
93 cl::Hidden, cl::init(false));
94
96 "openmp-opt-disable-deglobalization",
97 cl::desc("Disable OpenMP optimizations involving deglobalization."),
98 cl::Hidden, cl::init(false));
99
101 "openmp-opt-disable-spmdization",
102 cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
103 cl::Hidden, cl::init(false));
104
106 "openmp-opt-disable-folding",
107 cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
108 cl::init(false));
109
111 "openmp-opt-disable-state-machine-rewrite",
112 cl::desc("Disable OpenMP optimizations that replace the state machine."),
113 cl::Hidden, cl::init(false));
114
116 "openmp-opt-disable-barrier-elimination",
117 cl::desc("Disable OpenMP optimizations that eliminate barriers."),
118 cl::Hidden, cl::init(false));
119
121 "openmp-opt-print-module-after",
122 cl::desc("Print the current module after OpenMP optimizations."),
123 cl::Hidden, cl::init(false));
124
126 "openmp-opt-print-module-before",
127 cl::desc("Print the current module before OpenMP optimizations."),
128 cl::Hidden, cl::init(false));
129
131 "openmp-opt-inline-device",
132 cl::desc("Inline all applicible functions on the device."), cl::Hidden,
133 cl::init(false));
134
135static cl::opt<bool>
136 EnableVerboseRemarks("openmp-opt-verbose-remarks",
137 cl::desc("Enables more verbose remarks."), cl::Hidden,
138 cl::init(false));
139
141 SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden,
142 cl::desc("Maximal number of attributor iterations."),
143 cl::init(256));
144
146 SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden,
147 cl::desc("Maximum amount of shared memory to use."),
148 cl::init(std::numeric_limits<unsigned>::max()));
149
150STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
151 "Number of OpenMP runtime calls deduplicated");
152STATISTIC(NumOpenMPParallelRegionsDeleted,
153 "Number of OpenMP parallel regions deleted");
154STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
155 "Number of OpenMP runtime functions identified");
156STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
157 "Number of OpenMP runtime function uses identified");
158STATISTIC(NumOpenMPTargetRegionKernels,
159 "Number of OpenMP target region entry points (=kernels) identified");
160STATISTIC(NumNonOpenMPTargetRegionKernels,
161 "Number of non-OpenMP target region kernels identified");
162STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
163 "Number of OpenMP target region entry points (=kernels) executed in "
164 "SPMD-mode instead of generic-mode");
165STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
166 "Number of OpenMP target region entry points (=kernels) executed in "
167 "generic-mode without a state machines");
168STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
169 "Number of OpenMP target region entry points (=kernels) executed in "
170 "generic-mode with customized state machines with fallback");
171STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
172 "Number of OpenMP target region entry points (=kernels) executed in "
173 "generic-mode with customized state machines without fallback");
175 NumOpenMPParallelRegionsReplacedInGPUStateMachine,
176 "Number of OpenMP parallel regions replaced with ID in GPU state machines");
177STATISTIC(NumOpenMPParallelRegionsMerged,
178 "Number of OpenMP parallel regions merged");
179STATISTIC(NumBytesMovedToSharedMemory,
180 "Amount of memory pushed to shared memory");
181STATISTIC(NumBarriersEliminated, "Number of redundant barriers eliminated");
182
183#if !defined(NDEBUG)
184static constexpr auto TAG = "[" DEBUG_TYPE "]";
185#endif
186
187namespace KernelInfo {
188
189// struct ConfigurationEnvironmentTy {
190// uint8_t UseGenericStateMachine;
191// uint8_t MayUseNestedParallelism;
192// llvm::omp::OMPTgtExecModeFlags ExecMode;
193// int32_t MinThreads;
194// int32_t MaxThreads;
195// int32_t MinTeams;
196// int32_t MaxTeams;
197// };
198
199// struct DynamicEnvironmentTy {
200// uint16_t DebugIndentionLevel;
201// };
202
203// struct KernelEnvironmentTy {
204// ConfigurationEnvironmentTy Configuration;
205// IdentTy *Ident;
206// DynamicEnvironmentTy *DynamicEnv;
207// };
208
209#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX) \
210 constexpr const unsigned MEMBER##Idx = IDX;
211
212KERNEL_ENVIRONMENT_IDX(Configuration, 0)
214
215#undef KERNEL_ENVIRONMENT_IDX
216
217#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX) \
218 constexpr const unsigned MEMBER##Idx = IDX;
219
220KERNEL_ENVIRONMENT_CONFIGURATION_IDX(UseGenericStateMachine, 0)
221KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MayUseNestedParallelism, 1)
227
228#undef KERNEL_ENVIRONMENT_CONFIGURATION_IDX
229
230#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE) \
231 RETURNTYPE *get##MEMBER##FromKernelEnvironment(ConstantStruct *KernelEnvC) { \
232 return cast<RETURNTYPE>(KernelEnvC->getAggregateElement(MEMBER##Idx)); \
233 }
234
237
238#undef KERNEL_ENVIRONMENT_GETTER
239
240#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER) \
241 ConstantInt *get##MEMBER##FromKernelEnvironment( \
242 ConstantStruct *KernelEnvC) { \
243 ConstantStruct *ConfigC = \
244 getConfigurationFromKernelEnvironment(KernelEnvC); \
245 return dyn_cast<ConstantInt>(ConfigC->getAggregateElement(MEMBER##Idx)); \
246 }
247
248KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(UseGenericStateMachine)
249KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MayUseNestedParallelism)
255
256#undef KERNEL_ENVIRONMENT_CONFIGURATION_GETTER
257
260 constexpr const int InitKernelEnvironmentArgNo = 0;
261 return cast<GlobalVariable>(
262 KernelInitCB->getArgOperand(InitKernelEnvironmentArgNo)
264}
265
267 GlobalVariable *KernelEnvGV =
269 return cast<ConstantStruct>(KernelEnvGV->getInitializer());
270}
271} // namespace KernelInfo
272
273namespace {
274
275struct AAHeapToShared;
276
277struct AAICVTracker;
278
279/// OpenMP specific information. For now, stores RFIs and ICVs also needed for
280/// Attributor runs.
281struct OMPInformationCache : public InformationCache {
282 OMPInformationCache(Module &M, AnalysisGetter &AG,
284 bool OpenMPPostLink)
285 : InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),
286 OpenMPPostLink(OpenMPPostLink) {
287
288 OMPBuilder.Config.IsTargetDevice = isOpenMPDevice(OMPBuilder.M);
289 OMPBuilder.initialize();
290 initializeRuntimeFunctions(M);
291 initializeInternalControlVars();
292 }
293
294 /// Generic information that describes an internal control variable.
295 struct InternalControlVarInfo {
296 /// The kind, as described by InternalControlVar enum.
298
299 /// The name of the ICV.
301
302 /// Environment variable associated with this ICV.
303 StringRef EnvVarName;
304
305 /// Initial value kind.
306 ICVInitValue InitKind;
307
308 /// Initial value.
309 ConstantInt *InitValue;
310
311 /// Setter RTL function associated with this ICV.
312 RuntimeFunction Setter;
313
314 /// Getter RTL function associated with this ICV.
315 RuntimeFunction Getter;
316
317 /// RTL Function corresponding to the override clause of this ICV
319 };
320
321 /// Generic information that describes a runtime function
322 struct RuntimeFunctionInfo {
323
324 /// The kind, as described by the RuntimeFunction enum.
326
327 /// The name of the function.
329
330 /// Flag to indicate a variadic function.
331 bool IsVarArg;
332
333 /// The return type of the function.
335
336 /// The argument types of the function.
337 SmallVector<Type *, 8> ArgumentTypes;
338
339 /// The declaration if available.
340 Function *Declaration = nullptr;
341
342 /// Uses of this runtime function per function containing the use.
343 using UseVector = SmallVector<Use *, 16>;
344
345 /// Clear UsesMap for runtime function.
346 void clearUsesMap() { UsesMap.clear(); }
347
348 /// Boolean conversion that is true if the runtime function was found.
349 operator bool() const { return Declaration; }
350
351 /// Return the vector of uses in function \p F.
352 UseVector &getOrCreateUseVector(Function *F) {
353 std::shared_ptr<UseVector> &UV = UsesMap[F];
354 if (!UV)
355 UV = std::make_shared<UseVector>();
356 return *UV;
357 }
358
359 /// Return the vector of uses in function \p F or `nullptr` if there are
360 /// none.
361 const UseVector *getUseVector(Function &F) const {
362 auto I = UsesMap.find(&F);
363 if (I != UsesMap.end())
364 return I->second.get();
365 return nullptr;
366 }
367
368 /// Return how many functions contain uses of this runtime function.
369 size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
370
371 /// Return the number of arguments (or the minimal number for variadic
372 /// functions).
373 size_t getNumArgs() const { return ArgumentTypes.size(); }
374
375 /// Run the callback \p CB on each use and forget the use if the result is
376 /// true. The callback will be fed the function in which the use was
377 /// encountered as second argument.
378 void foreachUse(SmallVectorImpl<Function *> &SCC,
379 function_ref<bool(Use &, Function &)> CB) {
380 for (Function *F : SCC)
381 foreachUse(CB, F);
382 }
383
384 /// Run the callback \p CB on each use within the function \p F and forget
385 /// the use if the result is true.
386 void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
387 SmallVector<unsigned, 8> ToBeDeleted;
388 ToBeDeleted.clear();
389
390 unsigned Idx = 0;
391 UseVector &UV = getOrCreateUseVector(F);
392
393 for (Use *U : UV) {
394 if (CB(*U, *F))
395 ToBeDeleted.push_back(Idx);
396 ++Idx;
397 }
398
399 // Remove the to-be-deleted indices in reverse order as prior
400 // modifications will not modify the smaller indices.
401 while (!ToBeDeleted.empty()) {
402 unsigned Idx = ToBeDeleted.pop_back_val();
403 UV[Idx] = UV.back();
404 UV.pop_back();
405 }
406 }
407
408 private:
409 /// Map from functions to all uses of this runtime function contained in
410 /// them.
412
413 public:
414 /// Iterators for the uses of this runtime function.
415 decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
416 decltype(UsesMap)::iterator end() { return UsesMap.end(); }
417 };
418
419 /// An OpenMP-IR-Builder instance
420 OpenMPIRBuilder OMPBuilder;
421
422 /// Map from runtime function kind to the runtime function description.
423 EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
424 RuntimeFunction::OMPRTL___last>
425 RFIs;
426
427 /// Map from function declarations/definitions to their runtime enum type.
428 DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
429
430 /// Map from ICV kind to the ICV description.
431 EnumeratedArray<InternalControlVarInfo, InternalControlVar,
432 InternalControlVar::ICV___last>
433 ICVs;
434
435 /// Helper to initialize all internal control variable information for those
436 /// defined in OMPKinds.def.
437 void initializeInternalControlVars() {
438#define ICV_RT_SET(_Name, RTL) \
439 { \
440 auto &ICV = ICVs[_Name]; \
441 ICV.Setter = RTL; \
442 }
443#define ICV_RT_GET(Name, RTL) \
444 { \
445 auto &ICV = ICVs[Name]; \
446 ICV.Getter = RTL; \
447 }
448#define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \
449 { \
450 auto &ICV = ICVs[Enum]; \
451 ICV.Name = _Name; \
452 ICV.Kind = Enum; \
453 ICV.InitKind = Init; \
454 ICV.EnvVarName = _EnvVarName; \
455 switch (ICV.InitKind) { \
456 case ICV_IMPLEMENTATION_DEFINED: \
457 ICV.InitValue = nullptr; \
458 break; \
459 case ICV_ZERO: \
460 ICV.InitValue = ConstantInt::get( \
461 Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \
462 break; \
463 case ICV_FALSE: \
464 ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \
465 break; \
466 case ICV_LAST: \
467 break; \
468 } \
469 }
470#include "llvm/Frontend/OpenMP/OMPKinds.def"
471 }
472
473 /// Returns true if the function declaration \p F matches the runtime
474 /// function types, that is, return type \p RTFRetType, and argument types
475 /// \p RTFArgTypes.
476 static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
477 SmallVector<Type *, 8> &RTFArgTypes) {
478 // TODO: We should output information to the user (under debug output
479 // and via remarks).
480
481 if (!F)
482 return false;
483 if (F->getReturnType() != RTFRetType)
484 return false;
485 if (F->arg_size() != RTFArgTypes.size())
486 return false;
487
488 auto *RTFTyIt = RTFArgTypes.begin();
489 for (Argument &Arg : F->args()) {
490 if (Arg.getType() != *RTFTyIt)
491 return false;
492
493 ++RTFTyIt;
494 }
495
496 return true;
497 }
498
499 // Helper to collect all uses of the declaration in the UsesMap.
500 unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
501 unsigned NumUses = 0;
502 if (!RFI.Declaration)
503 return NumUses;
504 OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
505
506 if (CollectStats) {
507 NumOpenMPRuntimeFunctionsIdentified += 1;
508 NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
509 }
510
511 // TODO: We directly convert uses into proper calls and unknown uses.
512 for (Use &U : RFI.Declaration->uses()) {
513 if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
514 if (!CGSCC || CGSCC->empty() || CGSCC->contains(UserI->getFunction())) {
515 RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
516 ++NumUses;
517 }
518 } else {
519 RFI.getOrCreateUseVector(nullptr).push_back(&U);
520 ++NumUses;
521 }
522 }
523 return NumUses;
524 }
525
526 // Helper function to recollect uses of a runtime function.
527 void recollectUsesForFunction(RuntimeFunction RTF) {
528 auto &RFI = RFIs[RTF];
529 RFI.clearUsesMap();
530 collectUses(RFI, /*CollectStats*/ false);
531 }
532
533 // Helper function to recollect uses of all runtime functions.
534 void recollectUses() {
535 for (int Idx = 0; Idx < RFIs.size(); ++Idx)
536 recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
537 }
538
539 // Helper function to inherit the calling convention of the function callee.
540 void setCallingConvention(FunctionCallee Callee, CallInst *CI) {
541 if (Function *Fn = dyn_cast<Function>(Callee.getCallee()))
542 CI->setCallingConv(Fn->getCallingConv());
543 }
544
545 // Helper function to determine if it's legal to create a call to the runtime
546 // functions.
547 bool runtimeFnsAvailable(ArrayRef<RuntimeFunction> Fns) {
548 // We can always emit calls if we haven't yet linked in the runtime.
549 if (!OpenMPPostLink)
550 return true;
551
552 // Once the runtime has been already been linked in we cannot emit calls to
553 // any undefined functions.
554 for (RuntimeFunction Fn : Fns) {
555 RuntimeFunctionInfo &RFI = RFIs[Fn];
556
557 if (RFI.Declaration && RFI.Declaration->isDeclaration())
558 return false;
559 }
560 return true;
561 }
562
563 /// Helper to initialize all runtime function information for those defined
564 /// in OpenMPKinds.def.
565 void initializeRuntimeFunctions(Module &M) {
566
567 // Helper macros for handling __VA_ARGS__ in OMP_RTL
568#define OMP_TYPE(VarName, ...) \
569 Type *VarName = OMPBuilder.VarName; \
570 (void)VarName;
571
572#define OMP_ARRAY_TYPE(VarName, ...) \
573 ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \
574 (void)VarName##Ty; \
575 PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \
576 (void)VarName##PtrTy;
577
578#define OMP_FUNCTION_TYPE(VarName, ...) \
579 FunctionType *VarName = OMPBuilder.VarName; \
580 (void)VarName; \
581 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
582 (void)VarName##Ptr;
583
584#define OMP_STRUCT_TYPE(VarName, ...) \
585 StructType *VarName = OMPBuilder.VarName; \
586 (void)VarName; \
587 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
588 (void)VarName##Ptr;
589
590#define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \
591 { \
592 SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \
593 Function *F = M.getFunction(_Name); \
594 RTLFunctions.insert(F); \
595 if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \
596 RuntimeFunctionIDMap[F] = _Enum; \
597 auto &RFI = RFIs[_Enum]; \
598 RFI.Kind = _Enum; \
599 RFI.Name = _Name; \
600 RFI.IsVarArg = _IsVarArg; \
601 RFI.ReturnType = OMPBuilder._ReturnType; \
602 RFI.ArgumentTypes = std::move(ArgsTypes); \
603 RFI.Declaration = F; \
604 unsigned NumUses = collectUses(RFI); \
605 (void)NumUses; \
606 LLVM_DEBUG({ \
607 dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \
608 << " found\n"; \
609 if (RFI.Declaration) \
610 dbgs() << TAG << "-> got " << NumUses << " uses in " \
611 << RFI.getNumFunctionsWithUses() \
612 << " different functions.\n"; \
613 }); \
614 } \
615 }
616#include "llvm/Frontend/OpenMP/OMPKinds.def"
617
618 // Remove the `noinline` attribute from `__kmpc`, `ompx::` and `omp_`
619 // functions, except if `optnone` is present.
620 if (isOpenMPDevice(M)) {
621 for (Function &F : M) {
622 for (StringRef Prefix : {"__kmpc", "_ZN4ompx", "omp_"})
623 if (F.hasFnAttribute(Attribute::NoInline) &&
624 F.getName().starts_with(Prefix) &&
625 !F.hasFnAttribute(Attribute::OptimizeNone))
626 F.removeFnAttr(Attribute::NoInline);
627 }
628 }
629
630 // TODO: We should attach the attributes defined in OMPKinds.def.
631 }
632
633 /// Collection of known OpenMP runtime functions..
634 DenseSet<const Function *> RTLFunctions;
635
636 /// Indicates if we have already linked in the OpenMP device library.
637 bool OpenMPPostLink = false;
638};
639
640template <typename Ty, bool InsertInvalidates = true>
641struct BooleanStateWithSetVector : public BooleanState {
642 bool contains(const Ty &Elem) const { return Set.contains(Elem); }
643 bool insert(const Ty &Elem) {
644 if (InsertInvalidates)
646 return Set.insert(Elem);
647 }
648
649 const Ty &operator[](int Idx) const { return Set[Idx]; }
650 bool operator==(const BooleanStateWithSetVector &RHS) const {
651 return BooleanState::operator==(RHS) && Set == RHS.Set;
652 }
653 bool operator!=(const BooleanStateWithSetVector &RHS) const {
654 return !(*this == RHS);
655 }
656
657 bool empty() const { return Set.empty(); }
658 size_t size() const { return Set.size(); }
659
660 /// "Clamp" this state with \p RHS.
661 BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
662 BooleanState::operator^=(RHS);
663 Set.insert(RHS.Set.begin(), RHS.Set.end());
664 return *this;
665 }
666
667private:
668 /// A set to keep track of elements.
670
671public:
672 typename decltype(Set)::iterator begin() { return Set.begin(); }
673 typename decltype(Set)::iterator end() { return Set.end(); }
674 typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
675 typename decltype(Set)::const_iterator end() const { return Set.end(); }
676};
677
678template <typename Ty, bool InsertInvalidates = true>
679using BooleanStateWithPtrSetVector =
680 BooleanStateWithSetVector<Ty *, InsertInvalidates>;
681
682struct KernelInfoState : AbstractState {
683 /// Flag to track if we reached a fixpoint.
684 bool IsAtFixpoint = false;
685
686 /// The parallel regions (identified by the outlined parallel functions) that
687 /// can be reached from the associated function.
688 BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
689 ReachedKnownParallelRegions;
690
691 /// State to track what parallel region we might reach.
692 BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
693
694 /// State to track if we are in SPMD-mode, assumed or know, and why we decided
695 /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
696 /// false.
697 BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
698
699 /// The __kmpc_target_init call in this kernel, if any. If we find more than
700 /// one we abort as the kernel is malformed.
701 CallBase *KernelInitCB = nullptr;
702
703 /// The constant kernel environement as taken from and passed to
704 /// __kmpc_target_init.
705 ConstantStruct *KernelEnvC = nullptr;
706
707 /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
708 /// one we abort as the kernel is malformed.
709 CallBase *KernelDeinitCB = nullptr;
710
711 /// Flag to indicate if the associated function is a kernel entry.
712 bool IsKernelEntry = false;
713
714 /// State to track what kernel entries can reach the associated function.
715 BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
716
717 /// State to indicate if we can track parallel level of the associated
718 /// function. We will give up tracking if we encounter unknown caller or the
719 /// caller is __kmpc_parallel_51.
720 BooleanStateWithSetVector<uint8_t> ParallelLevels;
721
722 /// Flag that indicates if the kernel has nested Parallelism
723 bool NestedParallelism = false;
724
725 /// Abstract State interface
726 ///{
727
728 KernelInfoState() = default;
729 KernelInfoState(bool BestState) {
730 if (!BestState)
732 }
733
734 /// See AbstractState::isValidState(...)
735 bool isValidState() const override { return true; }
736
737 /// See AbstractState::isAtFixpoint(...)
738 bool isAtFixpoint() const override { return IsAtFixpoint; }
739
740 /// See AbstractState::indicatePessimisticFixpoint(...)
742 IsAtFixpoint = true;
743 ParallelLevels.indicatePessimisticFixpoint();
744 ReachingKernelEntries.indicatePessimisticFixpoint();
745 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
746 ReachedKnownParallelRegions.indicatePessimisticFixpoint();
747 ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
748 NestedParallelism = true;
750 }
751
752 /// See AbstractState::indicateOptimisticFixpoint(...)
754 IsAtFixpoint = true;
755 ParallelLevels.indicateOptimisticFixpoint();
756 ReachingKernelEntries.indicateOptimisticFixpoint();
757 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
758 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
759 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
761 }
762
763 /// Return the assumed state
764 KernelInfoState &getAssumed() { return *this; }
765 const KernelInfoState &getAssumed() const { return *this; }
766
767 bool operator==(const KernelInfoState &RHS) const {
768 if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
769 return false;
770 if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
771 return false;
772 if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
773 return false;
774 if (ReachingKernelEntries != RHS.ReachingKernelEntries)
775 return false;
776 if (ParallelLevels != RHS.ParallelLevels)
777 return false;
778 if (NestedParallelism != RHS.NestedParallelism)
779 return false;
780 return true;
781 }
782
783 /// Returns true if this kernel contains any OpenMP parallel regions.
784 bool mayContainParallelRegion() {
785 return !ReachedKnownParallelRegions.empty() ||
786 !ReachedUnknownParallelRegions.empty();
787 }
788
789 /// Return empty set as the best state of potential values.
790 static KernelInfoState getBestState() { return KernelInfoState(true); }
791
792 static KernelInfoState getBestState(KernelInfoState &KIS) {
793 return getBestState();
794 }
795
796 /// Return full set as the worst state of potential values.
797 static KernelInfoState getWorstState() { return KernelInfoState(false); }
798
799 /// "Clamp" this state with \p KIS.
800 KernelInfoState operator^=(const KernelInfoState &KIS) {
801 // Do not merge two different _init and _deinit call sites.
802 if (KIS.KernelInitCB) {
803 if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
804 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
805 "assumptions.");
806 KernelInitCB = KIS.KernelInitCB;
807 }
808 if (KIS.KernelDeinitCB) {
809 if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
810 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
811 "assumptions.");
812 KernelDeinitCB = KIS.KernelDeinitCB;
813 }
814 if (KIS.KernelEnvC) {
815 if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
816 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
817 "assumptions.");
818 KernelEnvC = KIS.KernelEnvC;
819 }
820 SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
821 ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
822 ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
823 NestedParallelism |= KIS.NestedParallelism;
824 return *this;
825 }
826
827 KernelInfoState operator&=(const KernelInfoState &KIS) {
828 return (*this ^= KIS);
829 }
830
831 ///}
832};
833
834/// Used to map the values physically (in the IR) stored in an offload
835/// array, to a vector in memory.
836struct OffloadArray {
837 /// Physical array (in the IR).
838 AllocaInst *Array = nullptr;
839 /// Mapped values.
840 SmallVector<Value *, 8> StoredValues;
841 /// Last stores made in the offload array.
842 SmallVector<StoreInst *, 8> LastAccesses;
843
844 OffloadArray() = default;
845
846 /// Initializes the OffloadArray with the values stored in \p Array before
847 /// instruction \p Before is reached. Returns false if the initialization
848 /// fails.
849 /// This MUST be used immediately after the construction of the object.
850 bool initialize(AllocaInst &Array, Instruction &Before) {
851 if (!Array.getAllocatedType()->isArrayTy())
852 return false;
853
854 if (!getValues(Array, Before))
855 return false;
856
857 this->Array = &Array;
858 return true;
859 }
860
861 static const unsigned DeviceIDArgNum = 1;
862 static const unsigned BasePtrsArgNum = 3;
863 static const unsigned PtrsArgNum = 4;
864 static const unsigned SizesArgNum = 5;
865
866private:
867 /// Traverses the BasicBlock where \p Array is, collecting the stores made to
868 /// \p Array, leaving StoredValues with the values stored before the
869 /// instruction \p Before is reached.
870 bool getValues(AllocaInst &Array, Instruction &Before) {
871 // Initialize container.
872 const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
873 StoredValues.assign(NumValues, nullptr);
874 LastAccesses.assign(NumValues, nullptr);
875
876 // TODO: This assumes the instruction \p Before is in the same
877 // BasicBlock as Array. Make it general, for any control flow graph.
878 BasicBlock *BB = Array.getParent();
879 if (BB != Before.getParent())
880 return false;
881
882 const DataLayout &DL = Array.getDataLayout();
883 const unsigned int PointerSize = DL.getPointerSize();
884
885 for (Instruction &I : *BB) {
886 if (&I == &Before)
887 break;
888
889 if (!isa<StoreInst>(&I))
890 continue;
891
892 auto *S = cast<StoreInst>(&I);
893 int64_t Offset = -1;
894 auto *Dst =
895 GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
896 if (Dst == &Array) {
897 int64_t Idx = Offset / PointerSize;
898 StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
899 LastAccesses[Idx] = S;
900 }
901 }
902
903 return isFilled();
904 }
905
906 /// Returns true if all values in StoredValues and
907 /// LastAccesses are not nullptrs.
908 bool isFilled() {
909 const unsigned NumValues = StoredValues.size();
910 for (unsigned I = 0; I < NumValues; ++I) {
911 if (!StoredValues[I] || !LastAccesses[I])
912 return false;
913 }
914
915 return true;
916 }
917};
918
919struct OpenMPOpt {
920
921 using OptimizationRemarkGetter =
923
924 OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
925 OptimizationRemarkGetter OREGetter,
926 OMPInformationCache &OMPInfoCache, Attributor &A)
927 : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
928 OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
929
930 /// Check if any remarks are enabled for openmp-opt
931 bool remarksEnabled() {
932 auto &Ctx = M.getContext();
933 return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
934 }
935
936 /// Run all OpenMP optimizations on the underlying SCC.
937 bool run(bool IsModulePass) {
938 if (SCC.empty())
939 return false;
940
941 bool Changed = false;
942
943 LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
944 << " functions\n");
945
946 if (IsModulePass) {
947 Changed |= runAttributor(IsModulePass);
948
949 // Recollect uses, in case Attributor deleted any.
950 OMPInfoCache.recollectUses();
951
952 // TODO: This should be folded into buildCustomStateMachine.
953 Changed |= rewriteDeviceCodeStateMachine();
954
955 if (remarksEnabled())
956 analysisGlobalization();
957 } else {
958 if (PrintICVValues)
959 printICVs();
961 printKernels();
962
963 Changed |= runAttributor(IsModulePass);
964
965 // Recollect uses, in case Attributor deleted any.
966 OMPInfoCache.recollectUses();
967
968 Changed |= deleteParallelRegions();
969
971 Changed |= hideMemTransfersLatency();
972 Changed |= deduplicateRuntimeCalls();
974 if (mergeParallelRegions()) {
975 deduplicateRuntimeCalls();
976 Changed = true;
977 }
978 }
979 }
980
981 if (OMPInfoCache.OpenMPPostLink)
982 Changed |= removeRuntimeSymbols();
983
984 return Changed;
985 }
986
987 /// Print initial ICV values for testing.
988 /// FIXME: This should be done from the Attributor once it is added.
989 void printICVs() const {
990 InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
991 ICV_proc_bind};
992
993 for (Function *F : SCC) {
994 for (auto ICV : ICVs) {
995 auto ICVInfo = OMPInfoCache.ICVs[ICV];
996 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
997 return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
998 << " Value: "
999 << (ICVInfo.InitValue
1000 ? toString(ICVInfo.InitValue->getValue(), 10, true)
1001 : "IMPLEMENTATION_DEFINED");
1002 };
1003
1004 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
1005 }
1006 }
1007 }
1008
1009 /// Print OpenMP GPU kernels for testing.
1010 void printKernels() const {
1011 for (Function *F : SCC) {
1012 if (!omp::isOpenMPKernel(*F))
1013 continue;
1014
1015 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1016 return ORA << "OpenMP GPU kernel "
1017 << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
1018 };
1019
1020 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
1021 }
1022 }
1023
1024 /// Return the call if \p U is a callee use in a regular call. If \p RFI is
1025 /// given it has to be the callee or a nullptr is returned.
1026 static CallInst *getCallIfRegularCall(
1027 Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1028 CallInst *CI = dyn_cast<CallInst>(U.getUser());
1029 if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
1030 (!RFI ||
1031 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1032 return CI;
1033 return nullptr;
1034 }
1035
1036 /// Return the call if \p V is a regular call. If \p RFI is given it has to be
1037 /// the callee or a nullptr is returned.
1038 static CallInst *getCallIfRegularCall(
1039 Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1040 CallInst *CI = dyn_cast<CallInst>(&V);
1041 if (CI && !CI->hasOperandBundles() &&
1042 (!RFI ||
1043 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1044 return CI;
1045 return nullptr;
1046 }
1047
1048private:
1049 /// Merge parallel regions when it is safe.
1050 bool mergeParallelRegions() {
1051 const unsigned CallbackCalleeOperand = 2;
1052 const unsigned CallbackFirstArgOperand = 3;
1053 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
1054
1055 // Check if there are any __kmpc_fork_call calls to merge.
1056 OMPInformationCache::RuntimeFunctionInfo &RFI =
1057 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1058
1059 if (!RFI.Declaration)
1060 return false;
1061
1062 // Unmergable calls that prevent merging a parallel region.
1063 OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
1064 OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
1065 OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
1066 };
1067
1068 bool Changed = false;
1069 LoopInfo *LI = nullptr;
1070 DominatorTree *DT = nullptr;
1071
1073
1074 BasicBlock *StartBB = nullptr, *EndBB = nullptr;
1075 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1076 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1077 BasicBlock *CGEndBB =
1078 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1079 assert(StartBB != nullptr && "StartBB should not be null");
1080 CGStartBB->getTerminator()->setSuccessor(0, StartBB);
1081 assert(EndBB != nullptr && "EndBB should not be null");
1082 EndBB->getTerminator()->setSuccessor(0, CGEndBB);
1083 };
1084
1085 auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
1086 Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
1087 ReplacementValue = &Inner;
1088 return CodeGenIP;
1089 };
1090
1091 auto FiniCB = [&](InsertPointTy CodeGenIP) {};
1092
1093 /// Create a sequential execution region within a merged parallel region,
1094 /// encapsulated in a master construct with a barrier for synchronization.
1095 auto CreateSequentialRegion = [&](Function *OuterFn,
1096 BasicBlock *OuterPredBB,
1097 Instruction *SeqStartI,
1098 Instruction *SeqEndI) {
1099 // Isolate the instructions of the sequential region to a separate
1100 // block.
1101 BasicBlock *ParentBB = SeqStartI->getParent();
1102 BasicBlock *SeqEndBB =
1103 SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
1104 BasicBlock *SeqAfterBB =
1105 SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
1106 BasicBlock *SeqStartBB =
1107 SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
1108
1109 assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
1110 "Expected a different CFG");
1111 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
1112 ParentBB->getTerminator()->eraseFromParent();
1113
1114 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1115 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1116 BasicBlock *CGEndBB =
1117 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1118 assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
1119 CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
1120 assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
1121 SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
1122 };
1123 auto FiniCB = [&](InsertPointTy CodeGenIP) {};
1124
1125 // Find outputs from the sequential region to outside users and
1126 // broadcast their values to them.
1127 for (Instruction &I : *SeqStartBB) {
1128 SmallPtrSet<Instruction *, 4> OutsideUsers;
1129 for (User *Usr : I.users()) {
1130 Instruction &UsrI = *cast<Instruction>(Usr);
1131 // Ignore outputs to LT intrinsics, code extraction for the merged
1132 // parallel region will fix them.
1133 if (UsrI.isLifetimeStartOrEnd())
1134 continue;
1135
1136 if (UsrI.getParent() != SeqStartBB)
1137 OutsideUsers.insert(&UsrI);
1138 }
1139
1140 if (OutsideUsers.empty())
1141 continue;
1142
1143 // Emit an alloca in the outer region to store the broadcasted
1144 // value.
1145 const DataLayout &DL = M.getDataLayout();
1146 AllocaInst *AllocaI = new AllocaInst(
1147 I.getType(), DL.getAllocaAddrSpace(), nullptr,
1148 I.getName() + ".seq.output.alloc", OuterFn->front().begin());
1149
1150 // Emit a store instruction in the sequential BB to update the
1151 // value.
1152 new StoreInst(&I, AllocaI, SeqStartBB->getTerminator()->getIterator());
1153
1154 // Emit a load instruction and replace the use of the output value
1155 // with it.
1156 for (Instruction *UsrI : OutsideUsers) {
1157 LoadInst *LoadI = new LoadInst(I.getType(), AllocaI,
1158 I.getName() + ".seq.output.load",
1159 UsrI->getIterator());
1160 UsrI->replaceUsesOfWith(&I, LoadI);
1161 }
1162 }
1163
1165 InsertPointTy(ParentBB, ParentBB->end()), DL);
1166 InsertPointTy SeqAfterIP =
1167 OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
1168
1169 OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
1170
1171 BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
1172
1173 LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
1174 << "\n");
1175 };
1176
1177 // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
1178 // contained in BB and only separated by instructions that can be
1179 // redundantly executed in parallel. The block BB is split before the first
1180 // call (in MergableCIs) and after the last so the entire region we merge
1181 // into a single parallel region is contained in a single basic block
1182 // without any other instructions. We use the OpenMPIRBuilder to outline
1183 // that block and call the resulting function via __kmpc_fork_call.
1184 auto Merge = [&](const SmallVectorImpl<CallInst *> &MergableCIs,
1185 BasicBlock *BB) {
1186 // TODO: Change the interface to allow single CIs expanded, e.g, to
1187 // include an outer loop.
1188 assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
1189
1190 auto Remark = [&](OptimizationRemark OR) {
1191 OR << "Parallel region merged with parallel region"
1192 << (MergableCIs.size() > 2 ? "s" : "") << " at ";
1193 for (auto *CI : llvm::drop_begin(MergableCIs)) {
1194 OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
1195 if (CI != MergableCIs.back())
1196 OR << ", ";
1197 }
1198 return OR << ".";
1199 };
1200
1201 emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
1202
1203 Function *OriginalFn = BB->getParent();
1204 LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
1205 << " parallel regions in " << OriginalFn->getName()
1206 << "\n");
1207
1208 // Isolate the calls to merge in a separate block.
1209 EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
1210 BasicBlock *AfterBB =
1211 SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1212 StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
1213 "omp.par.merged");
1214
1215 assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
1216 const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1217 BB->getTerminator()->eraseFromParent();
1218
1219 // Create sequential regions for sequential instructions that are
1220 // in-between mergable parallel regions.
1221 for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1222 It != End; ++It) {
1223 Instruction *ForkCI = *It;
1224 Instruction *NextForkCI = *(It + 1);
1225
1226 // Continue if there are not in-between instructions.
1227 if (ForkCI->getNextNode() == NextForkCI)
1228 continue;
1229
1230 CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1231 NextForkCI->getPrevNode());
1232 }
1233
1234 OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1235 DL);
1236 IRBuilder<>::InsertPoint AllocaIP(
1237 &OriginalFn->getEntryBlock(),
1238 OriginalFn->getEntryBlock().getFirstInsertionPt());
1239 // Create the merged parallel region with default proc binding, to
1240 // avoid overriding binding settings, and without explicit cancellation.
1241 InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
1242 Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
1243 OMP_PROC_BIND_default, /* IsCancellable */ false);
1244 BranchInst::Create(AfterBB, AfterIP.getBlock());
1245
1246 // Perform the actual outlining.
1247 OMPInfoCache.OMPBuilder.finalize(OriginalFn);
1248
1249 Function *OutlinedFn = MergableCIs.front()->getCaller();
1250
1251 // Replace the __kmpc_fork_call calls with direct calls to the outlined
1252 // callbacks.
1254 for (auto *CI : MergableCIs) {
1255 Value *Callee = CI->getArgOperand(CallbackCalleeOperand);
1256 FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask;
1257 Args.clear();
1258 Args.push_back(OutlinedFn->getArg(0));
1259 Args.push_back(OutlinedFn->getArg(1));
1260 for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1261 ++U)
1262 Args.push_back(CI->getArgOperand(U));
1263
1264 CallInst *NewCI =
1265 CallInst::Create(FT, Callee, Args, "", CI->getIterator());
1266 if (CI->getDebugLoc())
1267 NewCI->setDebugLoc(CI->getDebugLoc());
1268
1269 // Forward parameter attributes from the callback to the callee.
1270 for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1271 ++U)
1272 for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
1273 NewCI->addParamAttr(
1274 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1275
1276 // Emit an explicit barrier to replace the implicit fork-join barrier.
1277 if (CI != MergableCIs.back()) {
1278 // TODO: Remove barrier if the merged parallel region includes the
1279 // 'nowait' clause.
1280 OMPInfoCache.OMPBuilder.createBarrier(
1281 InsertPointTy(NewCI->getParent(),
1282 NewCI->getNextNode()->getIterator()),
1283 OMPD_parallel);
1284 }
1285
1286 CI->eraseFromParent();
1287 }
1288
1289 assert(OutlinedFn != OriginalFn && "Outlining failed");
1290 CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1291 CGUpdater.reanalyzeFunction(*OriginalFn);
1292
1293 NumOpenMPParallelRegionsMerged += MergableCIs.size();
1294
1295 return true;
1296 };
1297
1298 // Helper function that identifes sequences of
1299 // __kmpc_fork_call uses in a basic block.
1300 auto DetectPRsCB = [&](Use &U, Function &F) {
1301 CallInst *CI = getCallIfRegularCall(U, &RFI);
1302 BB2PRMap[CI->getParent()].insert(CI);
1303
1304 return false;
1305 };
1306
1307 BB2PRMap.clear();
1308 RFI.foreachUse(SCC, DetectPRsCB);
1309 SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1310 // Find mergable parallel regions within a basic block that are
1311 // safe to merge, that is any in-between instructions can safely
1312 // execute in parallel after merging.
1313 // TODO: support merging across basic-blocks.
1314 for (auto &It : BB2PRMap) {
1315 auto &CIs = It.getSecond();
1316 if (CIs.size() < 2)
1317 continue;
1318
1319 BasicBlock *BB = It.getFirst();
1320 SmallVector<CallInst *, 4> MergableCIs;
1321
1322 /// Returns true if the instruction is mergable, false otherwise.
1323 /// A terminator instruction is unmergable by definition since merging
1324 /// works within a BB. Instructions before the mergable region are
1325 /// mergable if they are not calls to OpenMP runtime functions that may
1326 /// set different execution parameters for subsequent parallel regions.
1327 /// Instructions in-between parallel regions are mergable if they are not
1328 /// calls to any non-intrinsic function since that may call a non-mergable
1329 /// OpenMP runtime function.
1330 auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1331 // We do not merge across BBs, hence return false (unmergable) if the
1332 // instruction is a terminator.
1333 if (I.isTerminator())
1334 return false;
1335
1336 if (!isa<CallInst>(&I))
1337 return true;
1338
1339 CallInst *CI = cast<CallInst>(&I);
1340 if (IsBeforeMergableRegion) {
1341 Function *CalledFunction = CI->getCalledFunction();
1342 if (!CalledFunction)
1343 return false;
1344 // Return false (unmergable) if the call before the parallel
1345 // region calls an explicit affinity (proc_bind) or number of
1346 // threads (num_threads) compiler-generated function. Those settings
1347 // may be incompatible with following parallel regions.
1348 // TODO: ICV tracking to detect compatibility.
1349 for (const auto &RFI : UnmergableCallsInfo) {
1350 if (CalledFunction == RFI.Declaration)
1351 return false;
1352 }
1353 } else {
1354 // Return false (unmergable) if there is a call instruction
1355 // in-between parallel regions when it is not an intrinsic. It
1356 // may call an unmergable OpenMP runtime function in its callpath.
1357 // TODO: Keep track of possible OpenMP calls in the callpath.
1358 if (!isa<IntrinsicInst>(CI))
1359 return false;
1360 }
1361
1362 return true;
1363 };
1364 // Find maximal number of parallel region CIs that are safe to merge.
1365 for (auto It = BB->begin(), End = BB->end(); It != End;) {
1366 Instruction &I = *It;
1367 ++It;
1368
1369 if (CIs.count(&I)) {
1370 MergableCIs.push_back(cast<CallInst>(&I));
1371 continue;
1372 }
1373
1374 // Continue expanding if the instruction is mergable.
1375 if (IsMergable(I, MergableCIs.empty()))
1376 continue;
1377
1378 // Forward the instruction iterator to skip the next parallel region
1379 // since there is an unmergable instruction which can affect it.
1380 for (; It != End; ++It) {
1381 Instruction &SkipI = *It;
1382 if (CIs.count(&SkipI)) {
1383 LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1384 << " due to " << I << "\n");
1385 ++It;
1386 break;
1387 }
1388 }
1389
1390 // Store mergable regions found.
1391 if (MergableCIs.size() > 1) {
1392 MergableCIsVector.push_back(MergableCIs);
1393 LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1394 << " parallel regions in block " << BB->getName()
1395 << " of function " << BB->getParent()->getName()
1396 << "\n";);
1397 }
1398
1399 MergableCIs.clear();
1400 }
1401
1402 if (!MergableCIsVector.empty()) {
1403 Changed = true;
1404
1405 for (auto &MergableCIs : MergableCIsVector)
1406 Merge(MergableCIs, BB);
1407 MergableCIsVector.clear();
1408 }
1409 }
1410
1411 if (Changed) {
1412 /// Re-collect use for fork calls, emitted barrier calls, and
1413 /// any emitted master/end_master calls.
1414 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1415 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1416 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1417 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1418 }
1419
1420 return Changed;
1421 }
1422
1423 /// Try to delete parallel regions if possible.
1424 bool deleteParallelRegions() {
1425 const unsigned CallbackCalleeOperand = 2;
1426
1427 OMPInformationCache::RuntimeFunctionInfo &RFI =
1428 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1429
1430 if (!RFI.Declaration)
1431 return false;
1432
1433 bool Changed = false;
1434 auto DeleteCallCB = [&](Use &U, Function &) {
1435 CallInst *CI = getCallIfRegularCall(U);
1436 if (!CI)
1437 return false;
1438 auto *Fn = dyn_cast<Function>(
1439 CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1440 if (!Fn)
1441 return false;
1442 if (!Fn->onlyReadsMemory())
1443 return false;
1444 if (!Fn->hasFnAttribute(Attribute::WillReturn))
1445 return false;
1446
1447 LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1448 << CI->getCaller()->getName() << "\n");
1449
1450 auto Remark = [&](OptimizationRemark OR) {
1451 return OR << "Removing parallel region with no side-effects.";
1452 };
1453 emitRemark<OptimizationRemark>(CI, "OMP160", Remark);
1454
1455 CI->eraseFromParent();
1456 Changed = true;
1457 ++NumOpenMPParallelRegionsDeleted;
1458 return true;
1459 };
1460
1461 RFI.foreachUse(SCC, DeleteCallCB);
1462
1463 return Changed;
1464 }
1465
1466 /// Try to eliminate runtime calls by reusing existing ones.
1467 bool deduplicateRuntimeCalls() {
1468 bool Changed = false;
1469
1470 RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1471 OMPRTL_omp_get_num_threads,
1472 OMPRTL_omp_in_parallel,
1473 OMPRTL_omp_get_cancellation,
1474 OMPRTL_omp_get_supported_active_levels,
1475 OMPRTL_omp_get_level,
1476 OMPRTL_omp_get_ancestor_thread_num,
1477 OMPRTL_omp_get_team_size,
1478 OMPRTL_omp_get_active_level,
1479 OMPRTL_omp_in_final,
1480 OMPRTL_omp_get_proc_bind,
1481 OMPRTL_omp_get_num_places,
1482 OMPRTL_omp_get_num_procs,
1483 OMPRTL_omp_get_place_num,
1484 OMPRTL_omp_get_partition_num_places,
1485 OMPRTL_omp_get_partition_place_nums};
1486
1487 // Global-tid is handled separately.
1489 collectGlobalThreadIdArguments(GTIdArgs);
1490 LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1491 << " global thread ID arguments\n");
1492
1493 for (Function *F : SCC) {
1494 for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1495 Changed |= deduplicateRuntimeCalls(
1496 *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1497
1498 // __kmpc_global_thread_num is special as we can replace it with an
1499 // argument in enough cases to make it worth trying.
1500 Value *GTIdArg = nullptr;
1501 for (Argument &Arg : F->args())
1502 if (GTIdArgs.count(&Arg)) {
1503 GTIdArg = &Arg;
1504 break;
1505 }
1506 Changed |= deduplicateRuntimeCalls(
1507 *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1508 }
1509
1510 return Changed;
1511 }
1512
1513 /// Tries to remove known runtime symbols that are optional from the module.
1514 bool removeRuntimeSymbols() {
1515 // The RPC client symbol is defined in `libc` and indicates that something
1516 // required an RPC server. If its users were all optimized out then we can
1517 // safely remove it.
1518 // TODO: This should be somewhere more common in the future.
1519 if (GlobalVariable *GV = M.getNamedGlobal("__llvm_libc_rpc_client")) {
1520 if (!GV->getType()->isPointerTy())
1521 return false;
1522
1523 Constant *C = GV->getInitializer();
1524 if (!C)
1525 return false;
1526
1527 // Check to see if the only user of the RPC client is the external handle.
1528 GlobalVariable *Client = dyn_cast<GlobalVariable>(C->stripPointerCasts());
1529 if (!Client || Client->getNumUses() > 1 ||
1530 Client->user_back() != GV->getInitializer())
1531 return false;
1532
1533 Client->replaceAllUsesWith(PoisonValue::get(Client->getType()));
1534 Client->eraseFromParent();
1535
1536 GV->replaceAllUsesWith(PoisonValue::get(GV->getType()));
1537 GV->eraseFromParent();
1538
1539 return true;
1540 }
1541 return false;
1542 }
1543
1544 /// Tries to hide the latency of runtime calls that involve host to
1545 /// device memory transfers by splitting them into their "issue" and "wait"
1546 /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1547 /// moved downards as much as possible. The "issue" issues the memory transfer
1548 /// asynchronously, returning a handle. The "wait" waits in the returned
1549 /// handle for the memory transfer to finish.
1550 bool hideMemTransfersLatency() {
1551 auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1552 bool Changed = false;
1553 auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1554 auto *RTCall = getCallIfRegularCall(U, &RFI);
1555 if (!RTCall)
1556 return false;
1557
1558 OffloadArray OffloadArrays[3];
1559 if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1560 return false;
1561
1562 LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1563
1564 // TODO: Check if can be moved upwards.
1565 bool WasSplit = false;
1566 Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1567 if (WaitMovementPoint)
1568 WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1569
1570 Changed |= WasSplit;
1571 return WasSplit;
1572 };
1573 if (OMPInfoCache.runtimeFnsAvailable(
1574 {OMPRTL___tgt_target_data_begin_mapper_issue,
1575 OMPRTL___tgt_target_data_begin_mapper_wait}))
1576 RFI.foreachUse(SCC, SplitMemTransfers);
1577
1578 return Changed;
1579 }
1580
1581 void analysisGlobalization() {
1582 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1583
1584 auto CheckGlobalization = [&](Use &U, Function &Decl) {
1585 if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1586 auto Remark = [&](OptimizationRemarkMissed ORM) {
1587 return ORM
1588 << "Found thread data sharing on the GPU. "
1589 << "Expect degraded performance due to data globalization.";
1590 };
1591 emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark);
1592 }
1593
1594 return false;
1595 };
1596
1597 RFI.foreachUse(SCC, CheckGlobalization);
1598 }
1599
1600 /// Maps the values stored in the offload arrays passed as arguments to
1601 /// \p RuntimeCall into the offload arrays in \p OAs.
1602 bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1604 assert(OAs.size() == 3 && "Need space for three offload arrays!");
1605
1606 // A runtime call that involves memory offloading looks something like:
1607 // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1608 // i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1609 // ...)
1610 // So, the idea is to access the allocas that allocate space for these
1611 // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1612 // Therefore:
1613 // i8** %offload_baseptrs.
1614 Value *BasePtrsArg =
1615 RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1616 // i8** %offload_ptrs.
1617 Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1618 // i8** %offload_sizes.
1619 Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1620
1621 // Get values stored in **offload_baseptrs.
1622 auto *V = getUnderlyingObject(BasePtrsArg);
1623 if (!isa<AllocaInst>(V))
1624 return false;
1625 auto *BasePtrsArray = cast<AllocaInst>(V);
1626 if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1627 return false;
1628
1629 // Get values stored in **offload_baseptrs.
1630 V = getUnderlyingObject(PtrsArg);
1631 if (!isa<AllocaInst>(V))
1632 return false;
1633 auto *PtrsArray = cast<AllocaInst>(V);
1634 if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1635 return false;
1636
1637 // Get values stored in **offload_sizes.
1638 V = getUnderlyingObject(SizesArg);
1639 // If it's a [constant] global array don't analyze it.
1640 if (isa<GlobalValue>(V))
1641 return isa<Constant>(V);
1642 if (!isa<AllocaInst>(V))
1643 return false;
1644
1645 auto *SizesArray = cast<AllocaInst>(V);
1646 if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1647 return false;
1648
1649 return true;
1650 }
1651
1652 /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1653 /// For now this is a way to test that the function getValuesInOffloadArrays
1654 /// is working properly.
1655 /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1656 void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1657 assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1658
1659 LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1660 std::string ValuesStr;
1661 raw_string_ostream Printer(ValuesStr);
1662 std::string Separator = " --- ";
1663
1664 for (auto *BP : OAs[0].StoredValues) {
1665 BP->print(Printer);
1666 Printer << Separator;
1667 }
1668 LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << ValuesStr << "\n");
1669 ValuesStr.clear();
1670
1671 for (auto *P : OAs[1].StoredValues) {
1672 P->print(Printer);
1673 Printer << Separator;
1674 }
1675 LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << ValuesStr << "\n");
1676 ValuesStr.clear();
1677
1678 for (auto *S : OAs[2].StoredValues) {
1679 S->print(Printer);
1680 Printer << Separator;
1681 }
1682 LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << ValuesStr << "\n");
1683 }
1684
1685 /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1686 /// moved. Returns nullptr if the movement is not possible, or not worth it.
1687 Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1688 // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1689 // Make it traverse the CFG.
1690
1691 Instruction *CurrentI = &RuntimeCall;
1692 bool IsWorthIt = false;
1693 while ((CurrentI = CurrentI->getNextNode())) {
1694
1695 // TODO: Once we detect the regions to be offloaded we should use the
1696 // alias analysis manager to check if CurrentI may modify one of
1697 // the offloaded regions.
1698 if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1699 if (IsWorthIt)
1700 return CurrentI;
1701
1702 return nullptr;
1703 }
1704
1705 // FIXME: For now if we move it over anything without side effect
1706 // is worth it.
1707 IsWorthIt = true;
1708 }
1709
1710 // Return end of BasicBlock.
1711 return RuntimeCall.getParent()->getTerminator();
1712 }
1713
1714 /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1715 bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1716 Instruction &WaitMovementPoint) {
1717 // Create stack allocated handle (__tgt_async_info) at the beginning of the
1718 // function. Used for storing information of the async transfer, allowing to
1719 // wait on it later.
1720 auto &IRBuilder = OMPInfoCache.OMPBuilder;
1721 Function *F = RuntimeCall.getCaller();
1722 BasicBlock &Entry = F->getEntryBlock();
1723 IRBuilder.Builder.SetInsertPoint(&Entry,
1724 Entry.getFirstNonPHIOrDbgOrAlloca());
1725 Value *Handle = IRBuilder.Builder.CreateAlloca(
1726 IRBuilder.AsyncInfo, /*ArraySize=*/nullptr, "handle");
1727 Handle =
1728 IRBuilder.Builder.CreateAddrSpaceCast(Handle, IRBuilder.AsyncInfoPtr);
1729
1730 // Add "issue" runtime call declaration:
1731 // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1732 // i8**, i8**, i64*, i64*)
1733 FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1734 M, OMPRTL___tgt_target_data_begin_mapper_issue);
1735
1736 // Change RuntimeCall call site for its asynchronous version.
1738 for (auto &Arg : RuntimeCall.args())
1739 Args.push_back(Arg.get());
1740 Args.push_back(Handle);
1741
1742 CallInst *IssueCallsite = CallInst::Create(IssueDecl, Args, /*NameStr=*/"",
1743 RuntimeCall.getIterator());
1744 OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);
1745 RuntimeCall.eraseFromParent();
1746
1747 // Add "wait" runtime call declaration:
1748 // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1749 FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1750 M, OMPRTL___tgt_target_data_begin_mapper_wait);
1751
1752 Value *WaitParams[2] = {
1753 IssueCallsite->getArgOperand(
1754 OffloadArray::DeviceIDArgNum), // device_id.
1755 Handle // handle to wait on.
1756 };
1757 CallInst *WaitCallsite = CallInst::Create(
1758 WaitDecl, WaitParams, /*NameStr=*/"", WaitMovementPoint.getIterator());
1759 OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);
1760
1761 return true;
1762 }
1763
1764 static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1765 bool GlobalOnly, bool &SingleChoice) {
1766 if (CurrentIdent == NextIdent)
1767 return CurrentIdent;
1768
1769 // TODO: Figure out how to actually combine multiple debug locations. For
1770 // now we just keep an existing one if there is a single choice.
1771 if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1772 SingleChoice = !CurrentIdent;
1773 return NextIdent;
1774 }
1775 return nullptr;
1776 }
1777
1778 /// Return an `struct ident_t*` value that represents the ones used in the
1779 /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1780 /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1781 /// return value we create one from scratch. We also do not yet combine
1782 /// information, e.g., the source locations, see combinedIdentStruct.
1783 Value *
1784 getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1785 Function &F, bool GlobalOnly) {
1786 bool SingleChoice = true;
1787 Value *Ident = nullptr;
1788 auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1789 CallInst *CI = getCallIfRegularCall(U, &RFI);
1790 if (!CI || &F != &Caller)
1791 return false;
1792 Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1793 /* GlobalOnly */ true, SingleChoice);
1794 return false;
1795 };
1796 RFI.foreachUse(SCC, CombineIdentStruct);
1797
1798 if (!Ident || !SingleChoice) {
1799 // The IRBuilder uses the insertion block to get to the module, this is
1800 // unfortunate but we work around it for now.
1801 if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1802 OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1803 &F.getEntryBlock(), F.getEntryBlock().begin()));
1804 // Create a fallback location if non was found.
1805 // TODO: Use the debug locations of the calls instead.
1806 uint32_t SrcLocStrSize;
1807 Constant *Loc =
1808 OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1809 Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize);
1810 }
1811 return Ident;
1812 }
1813
1814 /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1815 /// \p ReplVal if given.
1816 bool deduplicateRuntimeCalls(Function &F,
1817 OMPInformationCache::RuntimeFunctionInfo &RFI,
1818 Value *ReplVal = nullptr) {
1819 auto *UV = RFI.getUseVector(F);
1820 if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1821 return false;
1822
1823 LLVM_DEBUG(
1824 dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1825 << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1826
1827 assert((!ReplVal || (isa<Argument>(ReplVal) &&
1828 cast<Argument>(ReplVal)->getParent() == &F)) &&
1829 "Unexpected replacement value!");
1830
1831 // TODO: Use dominance to find a good position instead.
1832 auto CanBeMoved = [this](CallBase &CB) {
1833 unsigned NumArgs = CB.arg_size();
1834 if (NumArgs == 0)
1835 return true;
1836 if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1837 return false;
1838 for (unsigned U = 1; U < NumArgs; ++U)
1839 if (isa<Instruction>(CB.getArgOperand(U)))
1840 return false;
1841 return true;
1842 };
1843
1844 if (!ReplVal) {
1845 auto *DT =
1846 OMPInfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(F);
1847 if (!DT)
1848 return false;
1849 Instruction *IP = nullptr;
1850 for (Use *U : *UV) {
1851 if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1852 if (IP)
1853 IP = DT->findNearestCommonDominator(IP, CI);
1854 else
1855 IP = CI;
1856 if (!CanBeMoved(*CI))
1857 continue;
1858 if (!ReplVal)
1859 ReplVal = CI;
1860 }
1861 }
1862 if (!ReplVal)
1863 return false;
1864 assert(IP && "Expected insertion point!");
1865 cast<Instruction>(ReplVal)->moveBefore(IP);
1866 }
1867
1868 // If we use a call as a replacement value we need to make sure the ident is
1869 // valid at the new location. For now we just pick a global one, either
1870 // existing and used by one of the calls, or created from scratch.
1871 if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1872 if (!CI->arg_empty() &&
1873 CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1874 Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1875 /* GlobalOnly */ true);
1876 CI->setArgOperand(0, Ident);
1877 }
1878 }
1879
1880 bool Changed = false;
1881 auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1882 CallInst *CI = getCallIfRegularCall(U, &RFI);
1883 if (!CI || CI == ReplVal || &F != &Caller)
1884 return false;
1885 assert(CI->getCaller() == &F && "Unexpected call!");
1886
1887 auto Remark = [&](OptimizationRemark OR) {
1888 return OR << "OpenMP runtime call "
1889 << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1890 };
1891 if (CI->getDebugLoc())
1892 emitRemark<OptimizationRemark>(CI, "OMP170", Remark);
1893 else
1894 emitRemark<OptimizationRemark>(&F, "OMP170", Remark);
1895
1896 CI->replaceAllUsesWith(ReplVal);
1897 CI->eraseFromParent();
1898 ++NumOpenMPRuntimeCallsDeduplicated;
1899 Changed = true;
1900 return true;
1901 };
1902 RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1903
1904 return Changed;
1905 }
1906
1907 /// Collect arguments that represent the global thread id in \p GTIdArgs.
1908 void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1909 // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1910 // initialization. We could define an AbstractAttribute instead and
1911 // run the Attributor here once it can be run as an SCC pass.
1912
1913 // Helper to check the argument \p ArgNo at all call sites of \p F for
1914 // a GTId.
1915 auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1916 if (!F.hasLocalLinkage())
1917 return false;
1918 for (Use &U : F.uses()) {
1919 if (CallInst *CI = getCallIfRegularCall(U)) {
1920 Value *ArgOp = CI->getArgOperand(ArgNo);
1921 if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1922 getCallIfRegularCall(
1923 *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1924 continue;
1925 }
1926 return false;
1927 }
1928 return true;
1929 };
1930
1931 // Helper to identify uses of a GTId as GTId arguments.
1932 auto AddUserArgs = [&](Value &GTId) {
1933 for (Use &U : GTId.uses())
1934 if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1935 if (CI->isArgOperand(&U))
1936 if (Function *Callee = CI->getCalledFunction())
1937 if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1938 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1939 };
1940
1941 // The argument users of __kmpc_global_thread_num calls are GTIds.
1942 OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1943 OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1944
1945 GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1946 if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1947 AddUserArgs(*CI);
1948 return false;
1949 });
1950
1951 // Transitively search for more arguments by looking at the users of the
1952 // ones we know already. During the search the GTIdArgs vector is extended
1953 // so we cannot cache the size nor can we use a range based for.
1954 for (unsigned U = 0; U < GTIdArgs.size(); ++U)
1955 AddUserArgs(*GTIdArgs[U]);
1956 }
1957
1958 /// Kernel (=GPU) optimizations and utility functions
1959 ///
1960 ///{{
1961
1962 /// Cache to remember the unique kernel for a function.
1964
1965 /// Find the unique kernel that will execute \p F, if any.
1966 Kernel getUniqueKernelFor(Function &F);
1967
1968 /// Find the unique kernel that will execute \p I, if any.
1969 Kernel getUniqueKernelFor(Instruction &I) {
1970 return getUniqueKernelFor(*I.getFunction());
1971 }
1972
1973 /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1974 /// the cases we can avoid taking the address of a function.
1975 bool rewriteDeviceCodeStateMachine();
1976
1977 ///
1978 ///}}
1979
1980 /// Emit a remark generically
1981 ///
1982 /// This template function can be used to generically emit a remark. The
1983 /// RemarkKind should be one of the following:
1984 /// - OptimizationRemark to indicate a successful optimization attempt
1985 /// - OptimizationRemarkMissed to report a failed optimization attempt
1986 /// - OptimizationRemarkAnalysis to provide additional information about an
1987 /// optimization attempt
1988 ///
1989 /// The remark is built using a callback function provided by the caller that
1990 /// takes a RemarkKind as input and returns a RemarkKind.
1991 template <typename RemarkKind, typename RemarkCallBack>
1992 void emitRemark(Instruction *I, StringRef RemarkName,
1993 RemarkCallBack &&RemarkCB) const {
1994 Function *F = I->getParent()->getParent();
1995 auto &ORE = OREGetter(F);
1996
1997 if (RemarkName.starts_with("OMP"))
1998 ORE.emit([&]() {
1999 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
2000 << " [" << RemarkName << "]";
2001 });
2002 else
2003 ORE.emit(
2004 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
2005 }
2006
2007 /// Emit a remark on a function.
2008 template <typename RemarkKind, typename RemarkCallBack>
2009 void emitRemark(Function *F, StringRef RemarkName,
2010 RemarkCallBack &&RemarkCB) const {
2011 auto &ORE = OREGetter(F);
2012
2013 if (RemarkName.starts_with("OMP"))
2014 ORE.emit([&]() {
2015 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
2016 << " [" << RemarkName << "]";
2017 });
2018 else
2019 ORE.emit(
2020 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
2021 }
2022
2023 /// The underlying module.
2024 Module &M;
2025
2026 /// The SCC we are operating on.
2028
2029 /// Callback to update the call graph, the first argument is a removed call,
2030 /// the second an optional replacement call.
2031 CallGraphUpdater &CGUpdater;
2032
2033 /// Callback to get an OptimizationRemarkEmitter from a Function *
2034 OptimizationRemarkGetter OREGetter;
2035
2036 /// OpenMP-specific information cache. Also Used for Attributor runs.
2037 OMPInformationCache &OMPInfoCache;
2038
2039 /// Attributor instance.
2040 Attributor &A;
2041
2042 /// Helper function to run Attributor on SCC.
2043 bool runAttributor(bool IsModulePass) {
2044 if (SCC.empty())
2045 return false;
2046
2047 registerAAs(IsModulePass);
2048
2049 ChangeStatus Changed = A.run();
2050
2051 LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
2052 << " functions, result: " << Changed << ".\n");
2053
2054 if (Changed == ChangeStatus::CHANGED)
2055 OMPInfoCache.invalidateAnalyses();
2056
2057 return Changed == ChangeStatus::CHANGED;
2058 }
2059
2060 void registerFoldRuntimeCall(RuntimeFunction RF);
2061
2062 /// Populate the Attributor with abstract attribute opportunities in the
2063 /// functions.
2064 void registerAAs(bool IsModulePass);
2065
2066public:
2067 /// Callback to register AAs for live functions, including internal functions
2068 /// marked live during the traversal.
2069 static void registerAAsForFunction(Attributor &A, const Function &F);
2070};
2071
2072Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
2073 if (OMPInfoCache.CGSCC && !OMPInfoCache.CGSCC->empty() &&
2074 !OMPInfoCache.CGSCC->contains(&F))
2075 return nullptr;
2076
2077 // Use a scope to keep the lifetime of the CachedKernel short.
2078 {
2079 std::optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
2080 if (CachedKernel)
2081 return *CachedKernel;
2082
2083 // TODO: We should use an AA to create an (optimistic and callback
2084 // call-aware) call graph. For now we stick to simple patterns that
2085 // are less powerful, basically the worst fixpoint.
2086 if (isOpenMPKernel(F)) {
2087 CachedKernel = Kernel(&F);
2088 return *CachedKernel;
2089 }
2090
2091 CachedKernel = nullptr;
2092 if (!F.hasLocalLinkage()) {
2093
2094 // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
2095 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2096 return ORA << "Potentially unknown OpenMP target region caller.";
2097 };
2098 emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
2099
2100 return nullptr;
2101 }
2102 }
2103
2104 auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
2105 if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
2106 // Allow use in equality comparisons.
2107 if (Cmp->isEquality())
2108 return getUniqueKernelFor(*Cmp);
2109 return nullptr;
2110 }
2111 if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
2112 // Allow direct calls.
2113 if (CB->isCallee(&U))
2114 return getUniqueKernelFor(*CB);
2115
2116 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2117 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2118 // Allow the use in __kmpc_parallel_51 calls.
2119 if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
2120 return getUniqueKernelFor(*CB);
2121 return nullptr;
2122 }
2123 // Disallow every other use.
2124 return nullptr;
2125 };
2126
2127 // TODO: In the future we want to track more than just a unique kernel.
2128 SmallPtrSet<Kernel, 2> PotentialKernels;
2129 OMPInformationCache::foreachUse(F, [&](const Use &U) {
2130 PotentialKernels.insert(GetUniqueKernelForUse(U));
2131 });
2132
2133 Kernel K = nullptr;
2134 if (PotentialKernels.size() == 1)
2135 K = *PotentialKernels.begin();
2136
2137 // Cache the result.
2138 UniqueKernelMap[&F] = K;
2139
2140 return K;
2141}
2142
2143bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
2144 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2145 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2146
2147 bool Changed = false;
2148 if (!KernelParallelRFI)
2149 return Changed;
2150
2151 // If we have disabled state machine changes, exit
2153 return Changed;
2154
2155 for (Function *F : SCC) {
2156
2157 // Check if the function is a use in a __kmpc_parallel_51 call at
2158 // all.
2159 bool UnknownUse = false;
2160 bool KernelParallelUse = false;
2161 unsigned NumDirectCalls = 0;
2162
2163 SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
2164 OMPInformationCache::foreachUse(*F, [&](Use &U) {
2165 if (auto *CB = dyn_cast<CallBase>(U.getUser()))
2166 if (CB->isCallee(&U)) {
2167 ++NumDirectCalls;
2168 return;
2169 }
2170
2171 if (isa<ICmpInst>(U.getUser())) {
2172 ToBeReplacedStateMachineUses.push_back(&U);
2173 return;
2174 }
2175
2176 // Find wrapper functions that represent parallel kernels.
2177 CallInst *CI =
2178 OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
2179 const unsigned int WrapperFunctionArgNo = 6;
2180 if (!KernelParallelUse && CI &&
2181 CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
2182 KernelParallelUse = true;
2183 ToBeReplacedStateMachineUses.push_back(&U);
2184 return;
2185 }
2186 UnknownUse = true;
2187 });
2188
2189 // Do not emit a remark if we haven't seen a __kmpc_parallel_51
2190 // use.
2191 if (!KernelParallelUse)
2192 continue;
2193
2194 // If this ever hits, we should investigate.
2195 // TODO: Checking the number of uses is not a necessary restriction and
2196 // should be lifted.
2197 if (UnknownUse || NumDirectCalls != 1 ||
2198 ToBeReplacedStateMachineUses.size() > 2) {
2199 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2200 return ORA << "Parallel region is used in "
2201 << (UnknownUse ? "unknown" : "unexpected")
2202 << " ways. Will not attempt to rewrite the state machine.";
2203 };
2204 emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark);
2205 continue;
2206 }
2207
2208 // Even if we have __kmpc_parallel_51 calls, we (for now) give
2209 // up if the function is not called from a unique kernel.
2210 Kernel K = getUniqueKernelFor(*F);
2211 if (!K) {
2212 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2213 return ORA << "Parallel region is not called from a unique kernel. "
2214 "Will not attempt to rewrite the state machine.";
2215 };
2216 emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark);
2217 continue;
2218 }
2219
2220 // We now know F is a parallel body function called only from the kernel K.
2221 // We also identified the state machine uses in which we replace the
2222 // function pointer by a new global symbol for identification purposes. This
2223 // ensures only direct calls to the function are left.
2224
2225 Module &M = *F->getParent();
2226 Type *Int8Ty = Type::getInt8Ty(M.getContext());
2227
2228 auto *ID = new GlobalVariable(
2229 M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2230 UndefValue::get(Int8Ty), F->getName() + ".ID");
2231
2232 for (Use *U : ToBeReplacedStateMachineUses)
2234 ID, U->get()->getType()));
2235
2236 ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2237
2238 Changed = true;
2239 }
2240
2241 return Changed;
2242}
2243
2244/// Abstract Attribute for tracking ICV values.
2245struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2247 AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2248
2249 /// Returns true if value is assumed to be tracked.
2250 bool isAssumedTracked() const { return getAssumed(); }
2251
2252 /// Returns true if value is known to be tracked.
2253 bool isKnownTracked() const { return getAssumed(); }
2254
2255 /// Create an abstract attribute biew for the position \p IRP.
2256 static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2257
2258 /// Return the value with which \p I can be replaced for specific \p ICV.
2259 virtual std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2260 const Instruction *I,
2261 Attributor &A) const {
2262 return std::nullopt;
2263 }
2264
2265 /// Return an assumed unique ICV value if a single candidate is found. If
2266 /// there cannot be one, return a nullptr. If it is not clear yet, return
2267 /// std::nullopt.
2268 virtual std::optional<Value *>
2269 getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2270
2271 // Currently only nthreads is being tracked.
2272 // this array will only grow with time.
2273 InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2274
2275 /// See AbstractAttribute::getName()
2276 const std::string getName() const override { return "AAICVTracker"; }
2277
2278 /// See AbstractAttribute::getIdAddr()
2279 const char *getIdAddr() const override { return &ID; }
2280
2281 /// This function should return true if the type of the \p AA is AAICVTracker
2282 static bool classof(const AbstractAttribute *AA) {
2283 return (AA->getIdAddr() == &ID);
2284 }
2285
2286 static const char ID;
2287};
2288
2289struct AAICVTrackerFunction : public AAICVTracker {
2290 AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2291 : AAICVTracker(IRP, A) {}
2292
2293 // FIXME: come up with better string.
2294 const std::string getAsStr(Attributor *) const override {
2295 return "ICVTrackerFunction";
2296 }
2297
2298 // FIXME: come up with some stats.
2299 void trackStatistics() const override {}
2300
2301 /// We don't manifest anything for this AA.
2302 ChangeStatus manifest(Attributor &A) override {
2303 return ChangeStatus::UNCHANGED;
2304 }
2305
2306 // Map of ICV to their values at specific program point.
2308 InternalControlVar::ICV___last>
2309 ICVReplacementValuesMap;
2310
2311 ChangeStatus updateImpl(Attributor &A) override {
2312 ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
2313
2314 Function *F = getAnchorScope();
2315
2316 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2317
2318 for (InternalControlVar ICV : TrackableICVs) {
2319 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2320
2321 auto &ValuesMap = ICVReplacementValuesMap[ICV];
2322 auto TrackValues = [&](Use &U, Function &) {
2323 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2324 if (!CI)
2325 return false;
2326
2327 // FIXME: handle setters with more that 1 arguments.
2328 /// Track new value.
2329 if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2330 HasChanged = ChangeStatus::CHANGED;
2331
2332 return false;
2333 };
2334
2335 auto CallCheck = [&](Instruction &I) {
2336 std::optional<Value *> ReplVal = getValueForCall(A, I, ICV);
2337 if (ReplVal && ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2338 HasChanged = ChangeStatus::CHANGED;
2339
2340 return true;
2341 };
2342
2343 // Track all changes of an ICV.
2344 SetterRFI.foreachUse(TrackValues, F);
2345
2346 bool UsedAssumedInformation = false;
2347 A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2348 UsedAssumedInformation,
2349 /* CheckBBLivenessOnly */ true);
2350
2351 /// TODO: Figure out a way to avoid adding entry in
2352 /// ICVReplacementValuesMap
2353 Instruction *Entry = &F->getEntryBlock().front();
2354 if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
2355 ValuesMap.insert(std::make_pair(Entry, nullptr));
2356 }
2357
2358 return HasChanged;
2359 }
2360
2361 /// Helper to check if \p I is a call and get the value for it if it is
2362 /// unique.
2363 std::optional<Value *> getValueForCall(Attributor &A, const Instruction &I,
2364 InternalControlVar &ICV) const {
2365
2366 const auto *CB = dyn_cast<CallBase>(&I);
2367 if (!CB || CB->hasFnAttr("no_openmp") ||
2368 CB->hasFnAttr("no_openmp_routines"))
2369 return std::nullopt;
2370
2371 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2372 auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2373 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2374 Function *CalledFunction = CB->getCalledFunction();
2375
2376 // Indirect call, assume ICV changes.
2377 if (CalledFunction == nullptr)
2378 return nullptr;
2379 if (CalledFunction == GetterRFI.Declaration)
2380 return std::nullopt;
2381 if (CalledFunction == SetterRFI.Declaration) {
2382 if (ICVReplacementValuesMap[ICV].count(&I))
2383 return ICVReplacementValuesMap[ICV].lookup(&I);
2384
2385 return nullptr;
2386 }
2387
2388 // Since we don't know, assume it changes the ICV.
2389 if (CalledFunction->isDeclaration())
2390 return nullptr;
2391
2392 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2393 *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
2394
2395 if (ICVTrackingAA->isAssumedTracked()) {
2396 std::optional<Value *> URV =
2397 ICVTrackingAA->getUniqueReplacementValue(ICV);
2398 if (!URV || (*URV && AA::isValidAtPosition(AA::ValueAndContext(**URV, I),
2399 OMPInfoCache)))
2400 return URV;
2401 }
2402
2403 // If we don't know, assume it changes.
2404 return nullptr;
2405 }
2406
2407 // We don't check unique value for a function, so return std::nullopt.
2408 std::optional<Value *>
2409 getUniqueReplacementValue(InternalControlVar ICV) const override {
2410 return std::nullopt;
2411 }
2412
2413 /// Return the value with which \p I can be replaced for specific \p ICV.
2414 std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2415 const Instruction *I,
2416 Attributor &A) const override {
2417 const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2418 if (ValuesMap.count(I))
2419 return ValuesMap.lookup(I);
2420
2423 Worklist.push_back(I);
2424
2425 std::optional<Value *> ReplVal;
2426
2427 while (!Worklist.empty()) {
2428 const Instruction *CurrInst = Worklist.pop_back_val();
2429 if (!Visited.insert(CurrInst).second)
2430 continue;
2431
2432 const BasicBlock *CurrBB = CurrInst->getParent();
2433
2434 // Go up and look for all potential setters/calls that might change the
2435 // ICV.
2436 while ((CurrInst = CurrInst->getPrevNode())) {
2437 if (ValuesMap.count(CurrInst)) {
2438 std::optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2439 // Unknown value, track new.
2440 if (!ReplVal) {
2441 ReplVal = NewReplVal;
2442 break;
2443 }
2444
2445 // If we found a new value, we can't know the icv value anymore.
2446 if (NewReplVal)
2447 if (ReplVal != NewReplVal)
2448 return nullptr;
2449
2450 break;
2451 }
2452
2453 std::optional<Value *> NewReplVal = getValueForCall(A, *CurrInst, ICV);
2454 if (!NewReplVal)
2455 continue;
2456
2457 // Unknown value, track new.
2458 if (!ReplVal) {
2459 ReplVal = NewReplVal;
2460 break;
2461 }
2462
2463 // if (NewReplVal.hasValue())
2464 // We found a new value, we can't know the icv value anymore.
2465 if (ReplVal != NewReplVal)
2466 return nullptr;
2467 }
2468
2469 // If we are in the same BB and we have a value, we are done.
2470 if (CurrBB == I->getParent() && ReplVal)
2471 return ReplVal;
2472
2473 // Go through all predecessors and add terminators for analysis.
2474 for (const BasicBlock *Pred : predecessors(CurrBB))
2475 if (const Instruction *Terminator = Pred->getTerminator())
2476 Worklist.push_back(Terminator);
2477 }
2478
2479 return ReplVal;
2480 }
2481};
2482
2483struct AAICVTrackerFunctionReturned : AAICVTracker {
2484 AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2485 : AAICVTracker(IRP, A) {}
2486
2487 // FIXME: come up with better string.
2488 const std::string getAsStr(Attributor *) const override {
2489 return "ICVTrackerFunctionReturned";
2490 }
2491
2492 // FIXME: come up with some stats.
2493 void trackStatistics() const override {}
2494
2495 /// We don't manifest anything for this AA.
2496 ChangeStatus manifest(Attributor &A) override {
2497 return ChangeStatus::UNCHANGED;
2498 }
2499
2500 // Map of ICV to their values at specific program point.
2502 InternalControlVar::ICV___last>
2503 ICVReplacementValuesMap;
2504
2505 /// Return the value with which \p I can be replaced for specific \p ICV.
2506 std::optional<Value *>
2507 getUniqueReplacementValue(InternalControlVar ICV) const override {
2508 return ICVReplacementValuesMap[ICV];
2509 }
2510
2511 ChangeStatus updateImpl(Attributor &A) override {
2512 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2513 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2514 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2515
2516 if (!ICVTrackingAA->isAssumedTracked())
2517 return indicatePessimisticFixpoint();
2518
2519 for (InternalControlVar ICV : TrackableICVs) {
2520 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2521 std::optional<Value *> UniqueICVValue;
2522
2523 auto CheckReturnInst = [&](Instruction &I) {
2524 std::optional<Value *> NewReplVal =
2525 ICVTrackingAA->getReplacementValue(ICV, &I, A);
2526
2527 // If we found a second ICV value there is no unique returned value.
2528 if (UniqueICVValue && UniqueICVValue != NewReplVal)
2529 return false;
2530
2531 UniqueICVValue = NewReplVal;
2532
2533 return true;
2534 };
2535
2536 bool UsedAssumedInformation = false;
2537 if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2538 UsedAssumedInformation,
2539 /* CheckBBLivenessOnly */ true))
2540 UniqueICVValue = nullptr;
2541
2542 if (UniqueICVValue == ReplVal)
2543 continue;
2544
2545 ReplVal = UniqueICVValue;
2546 Changed = ChangeStatus::CHANGED;
2547 }
2548
2549 return Changed;
2550 }
2551};
2552
2553struct AAICVTrackerCallSite : AAICVTracker {
2554 AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2555 : AAICVTracker(IRP, A) {}
2556
2557 void initialize(Attributor &A) override {
2558 assert(getAnchorScope() && "Expected anchor function");
2559
2560 // We only initialize this AA for getters, so we need to know which ICV it
2561 // gets.
2562 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2563 for (InternalControlVar ICV : TrackableICVs) {
2564 auto ICVInfo = OMPInfoCache.ICVs[ICV];
2565 auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2566 if (Getter.Declaration == getAssociatedFunction()) {
2567 AssociatedICV = ICVInfo.Kind;
2568 return;
2569 }
2570 }
2571
2572 /// Unknown ICV.
2573 indicatePessimisticFixpoint();
2574 }
2575
2576 ChangeStatus manifest(Attributor &A) override {
2577 if (!ReplVal || !*ReplVal)
2578 return ChangeStatus::UNCHANGED;
2579
2580 A.changeAfterManifest(IRPosition::inst(*getCtxI()), **ReplVal);
2581 A.deleteAfterManifest(*getCtxI());
2582
2583 return ChangeStatus::CHANGED;
2584 }
2585
2586 // FIXME: come up with better string.
2587 const std::string getAsStr(Attributor *) const override {
2588 return "ICVTrackerCallSite";
2589 }
2590
2591 // FIXME: come up with some stats.
2592 void trackStatistics() const override {}
2593
2594 InternalControlVar AssociatedICV;
2595 std::optional<Value *> ReplVal;
2596
2597 ChangeStatus updateImpl(Attributor &A) override {
2598 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2599 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2600
2601 // We don't have any information, so we assume it changes the ICV.
2602 if (!ICVTrackingAA->isAssumedTracked())
2603 return indicatePessimisticFixpoint();
2604
2605 std::optional<Value *> NewReplVal =
2606 ICVTrackingAA->getReplacementValue(AssociatedICV, getCtxI(), A);
2607
2608 if (ReplVal == NewReplVal)
2609 return ChangeStatus::UNCHANGED;
2610
2611 ReplVal = NewReplVal;
2612 return ChangeStatus::CHANGED;
2613 }
2614
2615 // Return the value with which associated value can be replaced for specific
2616 // \p ICV.
2617 std::optional<Value *>
2618 getUniqueReplacementValue(InternalControlVar ICV) const override {
2619 return ReplVal;
2620 }
2621};
2622
2623struct AAICVTrackerCallSiteReturned : AAICVTracker {
2624 AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2625 : AAICVTracker(IRP, A) {}
2626
2627 // FIXME: come up with better string.
2628 const std::string getAsStr(Attributor *) const override {
2629 return "ICVTrackerCallSiteReturned";
2630 }
2631
2632 // FIXME: come up with some stats.
2633 void trackStatistics() const override {}
2634
2635 /// We don't manifest anything for this AA.
2636 ChangeStatus manifest(Attributor &A) override {
2637 return ChangeStatus::UNCHANGED;
2638 }
2639
2640 // Map of ICV to their values at specific program point.
2642 InternalControlVar::ICV___last>
2643 ICVReplacementValuesMap;
2644
2645 /// Return the value with which associated value can be replaced for specific
2646 /// \p ICV.
2647 std::optional<Value *>
2648 getUniqueReplacementValue(InternalControlVar ICV) const override {
2649 return ICVReplacementValuesMap[ICV];
2650 }
2651
2652 ChangeStatus updateImpl(Attributor &A) override {
2653 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2654 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2655 *this, IRPosition::returned(*getAssociatedFunction()),
2656 DepClassTy::REQUIRED);
2657
2658 // We don't have any information, so we assume it changes the ICV.
2659 if (!ICVTrackingAA->isAssumedTracked())
2660 return indicatePessimisticFixpoint();
2661
2662 for (InternalControlVar ICV : TrackableICVs) {
2663 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2664 std::optional<Value *> NewReplVal =
2665 ICVTrackingAA->getUniqueReplacementValue(ICV);
2666
2667 if (ReplVal == NewReplVal)
2668 continue;
2669
2670 ReplVal = NewReplVal;
2671 Changed = ChangeStatus::CHANGED;
2672 }
2673 return Changed;
2674 }
2675};
2676
2677/// Determines if \p BB exits the function unconditionally itself or reaches a
2678/// block that does through only unique successors.
2679static bool hasFunctionEndAsUniqueSuccessor(const BasicBlock *BB) {
2680 if (succ_empty(BB))
2681 return true;
2682 const BasicBlock *const Successor = BB->getUniqueSuccessor();
2683 if (!Successor)
2684 return false;
2685 return hasFunctionEndAsUniqueSuccessor(Successor);
2686}
2687
2688struct AAExecutionDomainFunction : public AAExecutionDomain {
2689 AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2690 : AAExecutionDomain(IRP, A) {}
2691
2692 ~AAExecutionDomainFunction() { delete RPOT; }
2693
2694 void initialize(Attributor &A) override {
2696 assert(F && "Expected anchor function");
2698 }
2699
2700 const std::string getAsStr(Attributor *) const override {
2701 unsigned TotalBlocks = 0, InitialThreadBlocks = 0, AlignedBlocks = 0;
2702 for (auto &It : BEDMap) {
2703 if (!It.getFirst())
2704 continue;
2705 TotalBlocks++;
2706 InitialThreadBlocks += It.getSecond().IsExecutedByInitialThreadOnly;
2707 AlignedBlocks += It.getSecond().IsReachedFromAlignedBarrierOnly &&
2708 It.getSecond().IsReachingAlignedBarrierOnly;
2709 }
2710 return "[AAExecutionDomain] " + std::to_string(InitialThreadBlocks) + "/" +
2711 std::to_string(AlignedBlocks) + " of " +
2712 std::to_string(TotalBlocks) +
2713 " executed by initial thread / aligned";
2714 }
2715
2716 /// See AbstractAttribute::trackStatistics().
2717 void trackStatistics() const override {}
2718
2719 ChangeStatus manifest(Attributor &A) override {
2720 LLVM_DEBUG({
2721 for (const BasicBlock &BB : *getAnchorScope()) {
2723 continue;
2724 dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2725 << BB.getName() << " is executed by a single thread.\n";
2726 }
2727 });
2728
2730
2732 return Changed;
2733
2734 SmallPtrSet<CallBase *, 16> DeletedBarriers;
2735 auto HandleAlignedBarrier = [&](CallBase *CB) {
2736 const ExecutionDomainTy &ED = CB ? CEDMap[{CB, PRE}] : BEDMap[nullptr];
2737 if (!ED.IsReachedFromAlignedBarrierOnly ||
2738 ED.EncounteredNonLocalSideEffect)
2739 return;
2740 if (!ED.EncounteredAssumes.empty() && !A.isModulePass())
2741 return;
2742
2743 // We can remove this barrier, if it is one, or aligned barriers reaching
2744 // the kernel end (if CB is nullptr). Aligned barriers reaching the kernel
2745 // end should only be removed if the kernel end is their unique successor;
2746 // otherwise, they may have side-effects that aren't accounted for in the
2747 // kernel end in their other successors. If those barriers have other
2748 // barriers reaching them, those can be transitively removed as well as
2749 // long as the kernel end is also their unique successor.
2750 if (CB) {
2751 DeletedBarriers.insert(CB);
2752 A.deleteAfterManifest(*CB);
2753 ++NumBarriersEliminated;
2754 Changed = ChangeStatus::CHANGED;
2755 } else if (!ED.AlignedBarriers.empty()) {
2756 Changed = ChangeStatus::CHANGED;
2757 SmallVector<CallBase *> Worklist(ED.AlignedBarriers.begin(),
2758 ED.AlignedBarriers.end());
2760 while (!Worklist.empty()) {
2761 CallBase *LastCB = Worklist.pop_back_val();
2762 if (!Visited.insert(LastCB))
2763 continue;
2764 if (LastCB->getFunction() != getAnchorScope())
2765 continue;
2766 if (!hasFunctionEndAsUniqueSuccessor(LastCB->getParent()))
2767 continue;
2768 if (!DeletedBarriers.count(LastCB)) {
2769 ++NumBarriersEliminated;
2770 A.deleteAfterManifest(*LastCB);
2771 continue;
2772 }
2773 // The final aligned barrier (LastCB) reaching the kernel end was
2774 // removed already. This means we can go one step further and remove
2775 // the barriers encoutered last before (LastCB).
2776 const ExecutionDomainTy &LastED = CEDMap[{LastCB, PRE}];
2777 Worklist.append(LastED.AlignedBarriers.begin(),
2778 LastED.AlignedBarriers.end());
2779 }
2780 }
2781
2782 // If we actually eliminated a barrier we need to eliminate the associated
2783 // llvm.assumes as well to avoid creating UB.
2784 if (!ED.EncounteredAssumes.empty() && (CB || !ED.AlignedBarriers.empty()))
2785 for (auto *AssumeCB : ED.EncounteredAssumes)
2786 A.deleteAfterManifest(*AssumeCB);
2787 };
2788
2789 for (auto *CB : AlignedBarriers)
2790 HandleAlignedBarrier(CB);
2791
2792 // Handle the "kernel end barrier" for kernels too.
2794 HandleAlignedBarrier(nullptr);
2795
2796 return Changed;
2797 }
2798
2799 bool isNoOpFence(const FenceInst &FI) const override {
2800 return getState().isValidState() && !NonNoOpFences.count(&FI);
2801 }
2802
2803 /// Merge barrier and assumption information from \p PredED into the successor
2804 /// \p ED.
2805 void
2806 mergeInPredecessorBarriersAndAssumptions(Attributor &A, ExecutionDomainTy &ED,
2807 const ExecutionDomainTy &PredED);
2808
2809 /// Merge all information from \p PredED into the successor \p ED. If
2810 /// \p InitialEdgeOnly is set, only the initial edge will enter the block
2811 /// represented by \p ED from this predecessor.
2812 bool mergeInPredecessor(Attributor &A, ExecutionDomainTy &ED,
2813 const ExecutionDomainTy &PredED,
2814 bool InitialEdgeOnly = false);
2815
2816 /// Accumulate information for the entry block in \p EntryBBED.
2817 bool handleCallees(Attributor &A, ExecutionDomainTy &EntryBBED);
2818
2819 /// See AbstractAttribute::updateImpl.
2821
2822 /// Query interface, see AAExecutionDomain
2823 ///{
2824 bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2825 if (!isValidState())
2826 return false;
2827 assert(BB.getParent() == getAnchorScope() && "Block is out of scope!");
2828 return BEDMap.lookup(&BB).IsExecutedByInitialThreadOnly;
2829 }
2830
2832 const Instruction &I) const override {
2833 assert(I.getFunction() == getAnchorScope() &&
2834 "Instruction is out of scope!");
2835 if (!isValidState())
2836 return false;
2837
2838 bool ForwardIsOk = true;
2839 const Instruction *CurI;
2840
2841 // Check forward until a call or the block end is reached.
2842 CurI = &I;
2843 do {
2844 auto *CB = dyn_cast<CallBase>(CurI);
2845 if (!CB)
2846 continue;
2847 if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2848 return true;
2849 const auto &It = CEDMap.find({CB, PRE});
2850 if (It == CEDMap.end())
2851 continue;
2852 if (!It->getSecond().IsReachingAlignedBarrierOnly)
2853 ForwardIsOk = false;
2854 break;
2855 } while ((CurI = CurI->getNextNonDebugInstruction()));
2856
2857 if (!CurI && !BEDMap.lookup(I.getParent()).IsReachingAlignedBarrierOnly)
2858 ForwardIsOk = false;
2859
2860 // Check backward until a call or the block beginning is reached.
2861 CurI = &I;
2862 do {
2863 auto *CB = dyn_cast<CallBase>(CurI);
2864 if (!CB)
2865 continue;
2866 if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2867 return true;
2868 const auto &It = CEDMap.find({CB, POST});
2869 if (It == CEDMap.end())
2870 continue;
2871 if (It->getSecond().IsReachedFromAlignedBarrierOnly)
2872 break;
2873 return false;
2874 } while ((CurI = CurI->getPrevNonDebugInstruction()));
2875
2876 // Delayed decision on the forward pass to allow aligned barrier detection
2877 // in the backwards traversal.
2878 if (!ForwardIsOk)
2879 return false;
2880
2881 if (!CurI) {
2882 const BasicBlock *BB = I.getParent();
2883 if (BB == &BB->getParent()->getEntryBlock())
2884 return BEDMap.lookup(nullptr).IsReachedFromAlignedBarrierOnly;
2885 if (!llvm::all_of(predecessors(BB), [&](const BasicBlock *PredBB) {
2886 return BEDMap.lookup(PredBB).IsReachedFromAlignedBarrierOnly;
2887 })) {
2888 return false;
2889 }
2890 }
2891
2892 // On neither traversal we found a anything but aligned barriers.
2893 return true;
2894 }
2895
2896 ExecutionDomainTy getExecutionDomain(const BasicBlock &BB) const override {
2897 assert(isValidState() &&
2898 "No request should be made against an invalid state!");
2899 return BEDMap.lookup(&BB);
2900 }
2901 std::pair<ExecutionDomainTy, ExecutionDomainTy>
2902 getExecutionDomain(const CallBase &CB) const override {
2903 assert(isValidState() &&
2904 "No request should be made against an invalid state!");
2905 return {CEDMap.lookup({&CB, PRE}), CEDMap.lookup({&CB, POST})};
2906 }
2907 ExecutionDomainTy getFunctionExecutionDomain() const override {
2908 assert(isValidState() &&
2909 "No request should be made against an invalid state!");
2910 return InterProceduralED;
2911 }
2912 ///}
2913
2914 // Check if the edge into the successor block contains a condition that only
2915 // lets the main thread execute it.
2916 static bool isInitialThreadOnlyEdge(Attributor &A, BranchInst *Edge,
2917 BasicBlock &SuccessorBB) {
2918 if (!Edge || !Edge->isConditional())
2919 return false;
2920 if (Edge->getSuccessor(0) != &SuccessorBB)
2921 return false;
2922
2923 auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2924 if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2925 return false;
2926
2927 ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2928 if (!C)
2929 return false;
2930
2931 // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
2932 if (C->isAllOnesValue()) {
2933 auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2934 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2935 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2936 CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2937 if (!CB)
2938 return false;
2939 ConstantStruct *KernelEnvC =
2941 ConstantInt *ExecModeC =
2942 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
2943 return ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC;
2944 }
2945
2946 if (C->isZero()) {
2947 // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x()
2948 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2949 if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2950 return true;
2951
2952 // Match: 0 == llvm.amdgcn.workitem.id.x()
2953 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2954 if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2955 return true;
2956 }
2957
2958 return false;
2959 };
2960
2961 /// Mapping containing information about the function for other AAs.
2962 ExecutionDomainTy InterProceduralED;
2963
2964 enum Direction { PRE = 0, POST = 1 };
2965 /// Mapping containing information per block.
2968 CEDMap;
2969 SmallSetVector<CallBase *, 16> AlignedBarriers;
2970
2972
2973 /// Set \p R to \V and report true if that changed \p R.
2974 static bool setAndRecord(bool &R, bool V) {
2975 bool Eq = (R == V);
2976 R = V;
2977 return !Eq;
2978 }
2979
2980 /// Collection of fences known to be non-no-opt. All fences not in this set
2981 /// can be assumed no-opt.
2983};
2984
2985void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions(
2986 Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED) {
2987 for (auto *EA : PredED.EncounteredAssumes)
2988 ED.addAssumeInst(A, *EA);
2989
2990 for (auto *AB : PredED.AlignedBarriers)
2991 ED.addAlignedBarrier(A, *AB);
2992}
2993
2994bool AAExecutionDomainFunction::mergeInPredecessor(
2995 Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED,
2996 bool InitialEdgeOnly) {
2997
2998 bool Changed = false;
2999 Changed |=
3000 setAndRecord(ED.IsExecutedByInitialThreadOnly,
3001 InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly &&
3002 ED.IsExecutedByInitialThreadOnly));
3003
3004 Changed |= setAndRecord(ED.IsReachedFromAlignedBarrierOnly,
3005 ED.IsReachedFromAlignedBarrierOnly &&
3006 PredED.IsReachedFromAlignedBarrierOnly);
3007 Changed |= setAndRecord(ED.EncounteredNonLocalSideEffect,
3008 ED.EncounteredNonLocalSideEffect |
3009 PredED.EncounteredNonLocalSideEffect);
3010 // Do not track assumptions and barriers as part of Changed.
3011 if (ED.IsReachedFromAlignedBarrierOnly)
3012 mergeInPredecessorBarriersAndAssumptions(A, ED, PredED);
3013 else
3014 ED.clearAssumeInstAndAlignedBarriers();
3015 return Changed;
3016}
3017
3018bool AAExecutionDomainFunction::handleCallees(Attributor &A,
3019 ExecutionDomainTy &EntryBBED) {
3021 auto PredForCallSite = [&](AbstractCallSite ACS) {
3022 const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3023 *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
3024 DepClassTy::OPTIONAL);
3025 if (!EDAA || !EDAA->getState().isValidState())
3026 return false;
3027 CallSiteEDs.emplace_back(
3028 EDAA->getExecutionDomain(*cast<CallBase>(ACS.getInstruction())));
3029 return true;
3030 };
3031
3032 ExecutionDomainTy ExitED;
3033 bool AllCallSitesKnown;
3034 if (A.checkForAllCallSites(PredForCallSite, *this,
3035 /* RequiresAllCallSites */ true,
3036 AllCallSitesKnown)) {
3037 for (const auto &[CSInED, CSOutED] : CallSiteEDs) {
3038 mergeInPredecessor(A, EntryBBED, CSInED);
3039 ExitED.IsReachingAlignedBarrierOnly &=
3040 CSOutED.IsReachingAlignedBarrierOnly;
3041 }
3042
3043 } else {
3044 // We could not find all predecessors, so this is either a kernel or a
3045 // function with external linkage (or with some other weird uses).
3046 if (omp::isOpenMPKernel(*getAnchorScope())) {
3047 EntryBBED.IsExecutedByInitialThreadOnly = false;
3048 EntryBBED.IsReachedFromAlignedBarrierOnly = true;
3049 EntryBBED.EncounteredNonLocalSideEffect = false;
3050 ExitED.IsReachingAlignedBarrierOnly = false;
3051 } else {
3052 EntryBBED.IsExecutedByInitialThreadOnly = false;
3053 EntryBBED.IsReachedFromAlignedBarrierOnly = false;
3054 EntryBBED.EncounteredNonLocalSideEffect = true;
3055 ExitED.IsReachingAlignedBarrierOnly = false;
3056 }
3057 }
3058
3059 bool Changed = false;
3060 auto &FnED = BEDMap[nullptr];
3061 Changed |= setAndRecord(FnED.IsReachedFromAlignedBarrierOnly,
3062 FnED.IsReachedFromAlignedBarrierOnly &
3063 EntryBBED.IsReachedFromAlignedBarrierOnly);
3064 Changed |= setAndRecord(FnED.IsReachingAlignedBarrierOnly,
3065 FnED.IsReachingAlignedBarrierOnly &
3066 ExitED.IsReachingAlignedBarrierOnly);
3067 Changed |= setAndRecord(FnED.IsExecutedByInitialThreadOnly,
3068 EntryBBED.IsExecutedByInitialThreadOnly);
3069 return Changed;
3070}
3071
3072ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
3073
3074 bool Changed = false;
3075
3076 // Helper to deal with an aligned barrier encountered during the forward
3077 // traversal. \p CB is the aligned barrier, \p ED is the execution domain when
3078 // it was encountered.
3079 auto HandleAlignedBarrier = [&](CallBase &CB, ExecutionDomainTy &ED) {
3080 Changed |= AlignedBarriers.insert(&CB);
3081 // First, update the barrier ED kept in the separate CEDMap.
3082 auto &CallInED = CEDMap[{&CB, PRE}];
3083 Changed |= mergeInPredecessor(A, CallInED, ED);
3084 CallInED.IsReachingAlignedBarrierOnly = true;
3085 // Next adjust the ED we use for the traversal.
3086 ED.EncounteredNonLocalSideEffect = false;
3087 ED.IsReachedFromAlignedBarrierOnly = true;
3088 // Aligned barrier collection has to come last.
3089 ED.clearAssumeInstAndAlignedBarriers();
3090 ED.addAlignedBarrier(A, CB);
3091 auto &CallOutED = CEDMap[{&CB, POST}];
3092 Changed |= mergeInPredecessor(A, CallOutED, ED);
3093 };
3094
3095 auto *LivenessAA =
3096 A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL);
3097
3098 Function *F = getAnchorScope();
3099 BasicBlock &EntryBB = F->getEntryBlock();
3100 bool IsKernel = omp::isOpenMPKernel(*F);
3101
3102 SmallVector<Instruction *> SyncInstWorklist;
3103 for (auto &RIt : *RPOT) {
3104 BasicBlock &BB = *RIt;
3105
3106 bool IsEntryBB = &BB == &EntryBB;
3107 // TODO: We use local reasoning since we don't have a divergence analysis
3108 // running as well. We could basically allow uniform branches here.
3109 bool AlignedBarrierLastInBlock = IsEntryBB && IsKernel;
3110 bool IsExplicitlyAligned = IsEntryBB && IsKernel;
3111 ExecutionDomainTy ED;
3112 // Propagate "incoming edges" into information about this block.
3113 if (IsEntryBB) {
3114 Changed |= handleCallees(A, ED);
3115 } else {
3116 // For live non-entry blocks we only propagate
3117 // information via live edges.
3118 if (LivenessAA && LivenessAA->isAssumedDead(&BB))
3119 continue;
3120
3121 for (auto *PredBB : predecessors(&BB)) {
3122 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, &BB))
3123 continue;
3124 bool InitialEdgeOnly = isInitialThreadOnlyEdge(
3125 A, dyn_cast<BranchInst>(PredBB->getTerminator()), BB);
3126 mergeInPredecessor(A, ED, BEDMap[PredBB], InitialEdgeOnly);
3127 }
3128 }
3129
3130 // Now we traverse the block, accumulate effects in ED and attach
3131 // information to calls.
3132 for (Instruction &I : BB) {
3133 bool UsedAssumedInformation;
3134 if (A.isAssumedDead(I, *this, LivenessAA, UsedAssumedInformation,
3135 /* CheckBBLivenessOnly */ false, DepClassTy::OPTIONAL,
3136 /* CheckForDeadStore */ true))
3137 continue;
3138
3139 // Asummes and "assume-like" (dbg, lifetime, ...) are handled first, the
3140 // former is collected the latter is ignored.
3141 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
3142 if (auto *AI = dyn_cast_or_null<AssumeInst>(II)) {
3143 ED.addAssumeInst(A, *AI);
3144 continue;
3145 }
3146 // TODO: Should we also collect and delete lifetime markers?
3147 if (II->isAssumeLikeIntrinsic())
3148 continue;
3149 }
3150
3151 if (auto *FI = dyn_cast<FenceInst>(&I)) {
3152 if (!ED.EncounteredNonLocalSideEffect) {
3153 // An aligned fence without non-local side-effects is a no-op.
3154 if (ED.IsReachedFromAlignedBarrierOnly)
3155 continue;
3156 // A non-aligned fence without non-local side-effects is a no-op
3157 // if the ordering only publishes non-local side-effects (or less).
3158 switch (FI->getOrdering()) {
3159 case AtomicOrdering::NotAtomic:
3160 continue;
3161 case AtomicOrdering::Unordered:
3162 continue;
3163 case AtomicOrdering::Monotonic:
3164 continue;
3165 case AtomicOrdering::Acquire:
3166 break;
3167 case AtomicOrdering::Release:
3168 continue;
3169 case AtomicOrdering::AcquireRelease:
3170 break;
3171 case AtomicOrdering::SequentiallyConsistent:
3172 break;
3173 };
3174 }
3175 NonNoOpFences.insert(FI);
3176 }
3177
3178 auto *CB = dyn_cast<CallBase>(&I);
3179 bool IsNoSync = AA::isNoSyncInst(A, I, *this);
3180 bool IsAlignedBarrier =
3181 !IsNoSync && CB &&
3182 AANoSync::isAlignedBarrier(*CB, AlignedBarrierLastInBlock);
3183
3184 AlignedBarrierLastInBlock &= IsNoSync;
3185 IsExplicitlyAligned &= IsNoSync;
3186
3187 // Next we check for calls. Aligned barriers are handled
3188 // explicitly, everything else is kept for the backward traversal and will
3189 // also affect our state.
3190 if (CB) {
3191 if (IsAlignedBarrier) {
3192 HandleAlignedBarrier(*CB, ED);
3193 AlignedBarrierLastInBlock = true;
3194 IsExplicitlyAligned = true;
3195 continue;
3196 }
3197
3198 // Check the pointer(s) of a memory intrinsic explicitly.
3199 if (isa<MemIntrinsic>(&I)) {
3200 if (!ED.EncounteredNonLocalSideEffect &&
3202 ED.EncounteredNonLocalSideEffect = true;
3203 if (!IsNoSync) {
3204 ED.IsReachedFromAlignedBarrierOnly = false;
3205 SyncInstWorklist.push_back(&I);
3206 }
3207 continue;
3208 }
3209
3210 // Record how we entered the call, then accumulate the effect of the
3211 // call in ED for potential use by the callee.
3212 auto &CallInED = CEDMap[{CB, PRE}];
3213 Changed |= mergeInPredecessor(A, CallInED, ED);
3214
3215 // If we have a sync-definition we can check if it starts/ends in an
3216 // aligned barrier. If we are unsure we assume any sync breaks
3217 // alignment.
3219 if (!IsNoSync && Callee && !Callee->isDeclaration()) {
3220 const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3221 *this, IRPosition::function(*Callee), DepClassTy::OPTIONAL);
3222 if (EDAA && EDAA->getState().isValidState()) {
3223 const auto &CalleeED = EDAA->getFunctionExecutionDomain();
3225 CalleeED.IsReachedFromAlignedBarrierOnly;
3226 AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly;
3227 if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly)
3228 ED.EncounteredNonLocalSideEffect |=
3229 CalleeED.EncounteredNonLocalSideEffect;
3230 else
3231 ED.EncounteredNonLocalSideEffect =
3232 CalleeED.EncounteredNonLocalSideEffect;
3233 if (!CalleeED.IsReachingAlignedBarrierOnly) {
3234 Changed |=
3235 setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3236 SyncInstWorklist.push_back(&I);
3237 }
3238 if (CalleeED.IsReachedFromAlignedBarrierOnly)
3239 mergeInPredecessorBarriersAndAssumptions(A, ED, CalleeED);
3240 auto &CallOutED = CEDMap[{CB, POST}];
3241 Changed |= mergeInPredecessor(A, CallOutED, ED);
3242 continue;
3243 }
3244 }
3245 if (!IsNoSync) {
3246 ED.IsReachedFromAlignedBarrierOnly = false;
3247 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3248 SyncInstWorklist.push_back(&I);
3249 }
3250 AlignedBarrierLastInBlock &= ED.IsReachedFromAlignedBarrierOnly;
3251 ED.EncounteredNonLocalSideEffect |= !CB->doesNotAccessMemory();
3252 auto &CallOutED = CEDMap[{CB, POST}];
3253 Changed |= mergeInPredecessor(A, CallOutED, ED);
3254 }
3255
3256 if (!I.mayHaveSideEffects() && !I.mayReadFromMemory())
3257 continue;
3258
3259 // If we have a callee we try to use fine-grained information to
3260 // determine local side-effects.
3261 if (CB) {
3262 const auto *MemAA = A.getAAFor<AAMemoryLocation>(
3263 *this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL);
3264
3265 auto AccessPred = [&](const Instruction *I, const Value *Ptr,
3268 return !AA::isPotentiallyAffectedByBarrier(A, {Ptr}, *this, I);
3269 };
3270 if (MemAA && MemAA->getState().isValidState() &&
3271 MemAA->checkForAllAccessesToMemoryKind(
3273 continue;
3274 }
3275
3276 auto &InfoCache = A.getInfoCache();
3277 if (!I.mayHaveSideEffects() && InfoCache.isOnlyUsedByAssume(I))
3278 continue;
3279
3280 if (auto *LI = dyn_cast<LoadInst>(&I))
3281 if (LI->hasMetadata(LLVMContext::MD_invariant_load))
3282 continue;
3283
3284 if (!ED.EncounteredNonLocalSideEffect &&
3286 ED.EncounteredNonLocalSideEffect = true;
3287 }
3288
3289 bool IsEndAndNotReachingAlignedBarriersOnly = false;
3290 if (!isa<UnreachableInst>(BB.getTerminator()) &&
3291 !BB.getTerminator()->getNumSuccessors()) {
3292
3293 Changed |= mergeInPredecessor(A, InterProceduralED, ED);
3294
3295 auto &FnED = BEDMap[nullptr];
3296 if (IsKernel && !IsExplicitlyAligned)
3297 FnED.IsReachingAlignedBarrierOnly = false;
3298 Changed |= mergeInPredecessor(A, FnED, ED);
3299
3300 if (!FnED.IsReachingAlignedBarrierOnly) {
3301 IsEndAndNotReachingAlignedBarriersOnly = true;
3302 SyncInstWorklist.push_back(BB.getTerminator());
3303 auto &BBED = BEDMap[&BB];
3304 Changed |= setAndRecord(BBED.IsReachingAlignedBarrierOnly, false);
3305 }
3306 }
3307
3308 ExecutionDomainTy &StoredED = BEDMap[&BB];
3309 ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly &
3310 !IsEndAndNotReachingAlignedBarriersOnly;
3311
3312 // Check if we computed anything different as part of the forward
3313 // traversal. We do not take assumptions and aligned barriers into account
3314 // as they do not influence the state we iterate. Backward traversal values
3315 // are handled later on.
3316 if (ED.IsExecutedByInitialThreadOnly !=
3317 StoredED.IsExecutedByInitialThreadOnly ||
3318 ED.IsReachedFromAlignedBarrierOnly !=
3319 StoredED.IsReachedFromAlignedBarrierOnly ||
3320 ED.EncounteredNonLocalSideEffect !=
3321 StoredED.EncounteredNonLocalSideEffect)
3322 Changed = true;
3323
3324 // Update the state with the new value.
3325 StoredED = std::move(ED);
3326 }
3327
3328 // Propagate (non-aligned) sync instruction effects backwards until the
3329 // entry is hit or an aligned barrier.
3331 while (!SyncInstWorklist.empty()) {
3332 Instruction *SyncInst = SyncInstWorklist.pop_back_val();
3333 Instruction *CurInst = SyncInst;
3334 bool HitAlignedBarrierOrKnownEnd = false;
3335 while ((CurInst = CurInst->getPrevNode())) {
3336 auto *CB = dyn_cast<CallBase>(CurInst);
3337 if (!CB)
3338 continue;
3339 auto &CallOutED = CEDMap[{CB, POST}];
3340 Changed |= setAndRecord(CallOutED.IsReachingAlignedBarrierOnly, false);
3341 auto &CallInED = CEDMap[{CB, PRE}];
3342 HitAlignedBarrierOrKnownEnd =
3343 AlignedBarriers.count(CB) || !CallInED.IsReachingAlignedBarrierOnly;
3344 if (HitAlignedBarrierOrKnownEnd)
3345 break;
3346 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3347 }
3348 if (HitAlignedBarrierOrKnownEnd)
3349 continue;
3350 BasicBlock *SyncBB = SyncInst->getParent();
3351 for (auto *PredBB : predecessors(SyncBB)) {
3352 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, SyncBB))
3353 continue;
3354 if (!Visited.insert(PredBB))
3355 continue;
3356 auto &PredED = BEDMap[PredBB];
3357 if (setAndRecord(PredED.IsReachingAlignedBarrierOnly, false)) {
3358 Changed = true;
3359 SyncInstWorklist.push_back(PredBB->getTerminator());
3360 }
3361 }
3362 if (SyncBB != &EntryBB)
3363 continue;
3364 Changed |=
3365 setAndRecord(InterProceduralED.IsReachingAlignedBarrierOnly, false);
3366 }
3367
3368 return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
3369}
3370
3371/// Try to replace memory allocation calls called by a single thread with a
3372/// static buffer of shared memory.
3373struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
3375 AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3376
3377 /// Create an abstract attribute view for the position \p IRP.
3378 static AAHeapToShared &createForPosition(const IRPosition &IRP,
3379 Attributor &A);
3380
3381 /// Returns true if HeapToShared conversion is assumed to be possible.
3382 virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
3383
3384 /// Returns true if HeapToShared conversion is assumed and the CB is a
3385 /// callsite to a free operation to be removed.
3386 virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
3387
3388 /// See AbstractAttribute::getName().
3389 const std::string getName() const override { return "AAHeapToShared"; }
3390
3391 /// See AbstractAttribute::getIdAddr().
3392 const char *getIdAddr() const override { return &ID; }
3393
3394 /// This function should return true if the type of the \p AA is
3395 /// AAHeapToShared.
3396 static bool classof(const AbstractAttribute *AA) {
3397 return (AA->getIdAddr() == &ID);
3398 }
3399
3400 /// Unique ID (due to the unique address)
3401 static const char ID;
3402};
3403
3404struct AAHeapToSharedFunction : public AAHeapToShared {
3405 AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
3406 : AAHeapToShared(IRP, A) {}
3407
3408 const std::string getAsStr(Attributor *) const override {
3409 return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
3410 " malloc calls eligible.";
3411 }
3412
3413 /// See AbstractAttribute::trackStatistics().
3414 void trackStatistics() const override {}
3415
3416 /// This functions finds free calls that will be removed by the
3417 /// HeapToShared transformation.
3418 void findPotentialRemovedFreeCalls(Attributor &A) {
3419 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3420 auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3421
3422 PotentialRemovedFreeCalls.clear();
3423 // Update free call users of found malloc calls.
3424 for (CallBase *CB : MallocCalls) {
3426 for (auto *U : CB->users()) {
3427 CallBase *C = dyn_cast<CallBase>(U);
3428 if (C && C->getCalledFunction() == FreeRFI.Declaration)
3429 FreeCalls.push_back(C);
3430 }
3431
3432 if (FreeCalls.size() != 1)
3433 continue;
3434
3435 PotentialRemovedFreeCalls.insert(FreeCalls.front());
3436 }
3437 }
3438
3439 void initialize(Attributor &A) override {
3441 indicatePessimisticFixpoint();
3442 return;
3443 }
3444
3445 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3446 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3447 if (!RFI.Declaration)
3448 return;
3449
3451 [](const IRPosition &, const AbstractAttribute *,
3452 bool &) -> std::optional<Value *> { return nullptr; };
3453
3454 Function *F = getAnchorScope();
3455 for (User *U : RFI.Declaration->users())
3456 if (CallBase *CB = dyn_cast<CallBase>(U)) {
3457 if (CB->getFunction() != F)
3458 continue;
3459 MallocCalls.insert(CB);
3460 A.registerSimplificationCallback(IRPosition::callsite_returned(*CB),
3461 SCB);
3462 }
3463
3464 findPotentialRemovedFreeCalls(A);
3465 }
3466
3467 bool isAssumedHeapToShared(CallBase &CB) const override {
3468 return isValidState() && MallocCalls.count(&CB);
3469 }
3470
3471 bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
3472 return isValidState() && PotentialRemovedFreeCalls.count(&CB);
3473 }
3474
3475 ChangeStatus manifest(Attributor &A) override {
3476 if (MallocCalls.empty())
3477 return ChangeStatus::UNCHANGED;
3478
3479 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3480 auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3481
3482 Function *F = getAnchorScope();
3483 auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
3484 DepClassTy::OPTIONAL);
3485
3486 ChangeStatus Changed = ChangeStatus::UNCHANGED;
3487 for (CallBase *CB : MallocCalls) {
3488 // Skip replacing this if HeapToStack has already claimed it.
3489 if (HS && HS->isAssumedHeapToStack(*CB))
3490 continue;
3491
3492 // Find the unique free call to remove it.
3494 for (auto *U : CB->users()) {
3495 CallBase *C = dyn_cast<CallBase>(U);
3496 if (C && C->getCalledFunction() == FreeCall.Declaration)
3497 FreeCalls.push_back(C);
3498 }
3499 if (FreeCalls.size() != 1)
3500 continue;
3501
3502 auto *AllocSize = cast<ConstantInt>(CB->getArgOperand(0));
3503
3504 if (AllocSize->getZExtValue() + SharedMemoryUsed > SharedMemoryLimit) {
3505 LLVM_DEBUG(dbgs() << TAG << "Cannot replace call " << *CB
3506 << " with shared memory."
3507 << " Shared memory usage is limited to "
3508 << SharedMemoryLimit << " bytes\n");
3509 continue;
3510 }
3511
3512 LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
3513 << " with " << AllocSize->getZExtValue()
3514 << " bytes of shared memory\n");
3515
3516 // Create a new shared memory buffer of the same size as the allocation
3517 // and replace all the uses of the original allocation with it.
3518 Module *M = CB->getModule();
3519 Type *Int8Ty = Type::getInt8Ty(M->getContext());
3520 Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
3521 auto *SharedMem = new GlobalVariable(
3522 *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
3523 PoisonValue::get(Int8ArrTy), CB->getName() + "_shared", nullptr,
3525 static_cast<unsigned>(AddressSpace::Shared));
3526 auto *NewBuffer =
3527 ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo());
3528
3529 auto Remark = [&](OptimizationRemark OR) {
3530 return OR << "Replaced globalized variable with "
3531 << ore::NV("SharedMemory", AllocSize->getZExtValue())
3532 << (AllocSize->isOne() ? " byte " : " bytes ")
3533 << "of shared memory.";
3534 };
3535 A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
3536
3537 MaybeAlign Alignment = CB->getRetAlign();
3538 assert(Alignment &&
3539 "HeapToShared on allocation without alignment attribute");
3540 SharedMem->setAlignment(*Alignment);
3541
3542 A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewBuffer);
3543 A.deleteAfterManifest(*CB);
3544 A.deleteAfterManifest(*FreeCalls.front());
3545
3546 SharedMemoryUsed += AllocSize->getZExtValue();
3547 NumBytesMovedToSharedMemory = SharedMemoryUsed;
3548 Changed = ChangeStatus::CHANGED;
3549 }
3550
3551 return Changed;
3552 }
3553
3554 ChangeStatus updateImpl(Attributor &A) override {
3555 if (MallocCalls.empty())
3556 return indicatePessimisticFixpoint();
3557 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3558 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3559 if (!RFI.Declaration)
3560 return ChangeStatus::UNCHANGED;
3561
3562 Function *F = getAnchorScope();
3563
3564 auto NumMallocCalls = MallocCalls.size();
3565
3566 // Only consider malloc calls executed by a single thread with a constant.
3567 for (User *U : RFI.Declaration->users()) {
3568 if (CallBase *CB = dyn_cast<CallBase>(U)) {
3569 if (CB->getCaller() != F)
3570 continue;
3571 if (!MallocCalls.count(CB))
3572 continue;
3573 if (!isa<ConstantInt>(CB->getArgOperand(0))) {
3574 MallocCalls.remove(CB);
3575 continue;
3576 }
3577 const auto *ED = A.getAAFor<AAExecutionDomain>(
3578 *this, IRPosition::function(*F), DepClassTy::REQUIRED);
3579 if (!ED || !ED->isExecutedByInitialThreadOnly(*CB))
3580 MallocCalls.remove(CB);
3581 }
3582 }
3583
3584 findPotentialRemovedFreeCalls(A);
3585
3586 if (NumMallocCalls != MallocCalls.size())
3587 return ChangeStatus::CHANGED;
3588
3589 return ChangeStatus::UNCHANGED;
3590 }
3591
3592 /// Collection of all malloc calls in a function.
3594 /// Collection of potentially removed free calls in a function.
3595 SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
3596 /// The total amount of shared memory that has been used for HeapToShared.
3597 unsigned SharedMemoryUsed = 0;
3598};
3599
3600struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
3602 AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3603
3604 /// The callee value is tracked beyond a simple stripPointerCasts, so we allow
3605 /// unknown callees.
3606 static bool requiresCalleeForCallBase() { return false; }
3607
3608 /// Statistics are tracked as part of manifest for now.
3609 void trackStatistics() const override {}
3610
3611 /// See AbstractAttribute::getAsStr()
3612 const std::string getAsStr(Attributor *) const override {
3613 if (!isValidState())
3614 return "<invalid>";
3615 return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
3616 : "generic") +
3617 std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
3618 : "") +
3619 std::string(" #PRs: ") +
3620 (ReachedKnownParallelRegions.isValidState()
3621 ? std::to_string(ReachedKnownParallelRegions.size())
3622 : "<invalid>") +
3623 ", #Unknown PRs: " +
3624 (ReachedUnknownParallelRegions.isValidState()
3625 ? std::to_string(ReachedUnknownParallelRegions.size())
3626 : "<invalid>") +
3627 ", #Reaching Kernels: " +
3628 (ReachingKernelEntries.isValidState()
3629 ? std::to_string(ReachingKernelEntries.size())
3630 : "<invalid>") +
3631 ", #ParLevels: " +
3632 (ParallelLevels.isValidState()
3633 ? std::to_string(ParallelLevels.size())
3634 : "<invalid>") +
3635 ", NestedPar: " + (NestedParallelism ? "yes" : "no");
3636 }
3637
3638 /// Create an abstract attribute biew for the position \p IRP.
3639 static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
3640
3641 /// See AbstractAttribute::getName()
3642 const std::string getName() const override { return "AAKernelInfo"; }
3643
3644 /// See AbstractAttribute::getIdAddr()
3645 const char *getIdAddr() const override { return &ID; }
3646
3647 /// This function should return true if the type of the \p AA is AAKernelInfo
3648 static bool classof(const AbstractAttribute *AA) {
3649 return (AA->getIdAddr() == &ID);
3650 }
3651
3652 static const char ID;
3653};
3654
3655/// The function kernel info abstract attribute, basically, what can we say
3656/// about a function with regards to the KernelInfoState.
3657struct AAKernelInfoFunction : AAKernelInfo {
3658 AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
3659 : AAKernelInfo(IRP, A) {}
3660
3661 SmallPtrSet<Instruction *, 4> GuardedInstructions;
3662
3663 SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
3664 return GuardedInstructions;
3665 }
3666
3667 void setConfigurationOfKernelEnvironment(ConstantStruct *ConfigC) {
3669 KernelEnvC, ConfigC, {KernelInfo::ConfigurationIdx});
3670 assert(NewKernelEnvC && "Failed to create new kernel environment");
3671 KernelEnvC = cast<ConstantStruct>(NewKernelEnvC);
3672 }
3673
3674#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER) \
3675 void set##MEMBER##OfKernelEnvironment(ConstantInt *NewVal) { \
3676 ConstantStruct *ConfigC = \
3677 KernelInfo::getConfigurationFromKernelEnvironment(KernelEnvC); \
3678 Constant *NewConfigC = ConstantFoldInsertValueInstruction( \
3679 ConfigC, NewVal, {KernelInfo::MEMBER##Idx}); \
3680 assert(NewConfigC && "Failed to create new configuration environment"); \
3681 setConfigurationOfKernelEnvironment(cast<ConstantStruct>(NewConfigC)); \
3682 }
3683
3684 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(UseGenericStateMachine)
3685 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MayUseNestedParallelism)
3691
3692#undef KERNEL_ENVIRONMENT_CONFIGURATION_SETTER
3693
3694 /// See AbstractAttribute::initialize(...).
3695 void initialize(Attributor &A) override {
3696 // This is a high-level transform that might change the constant arguments
3697 // of the init and dinit calls. We need to tell the Attributor about this
3698 // to avoid other parts using the current constant value for simpliication.
3699 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3700
3701 Function *Fn = getAnchorScope();
3702
3703 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
3704 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
3705 OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
3706 OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
3707
3708 // For kernels we perform more initialization work, first we find the init
3709 // and deinit calls.
3710 auto StoreCallBase = [](Use &U,
3711 OMPInformationCache::RuntimeFunctionInfo &RFI,
3712 CallBase *&Storage) {
3713 CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
3714 assert(CB &&
3715 "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
3716 assert(!Storage &&
3717 "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
3718 Storage = CB;
3719 return false;
3720 };
3721 InitRFI.foreachUse(
3722 [&](Use &U, Function &) {
3723 StoreCallBase(U, InitRFI, KernelInitCB);
3724 return false;
3725 },
3726 Fn);
3727 DeinitRFI.foreachUse(
3728 [&](Use &U, Function &) {
3729 StoreCallBase(U, DeinitRFI, KernelDeinitCB);
3730 return false;
3731 },
3732 Fn);
3733
3734 // Ignore kernels without initializers such as global constructors.
3735 if (!KernelInitCB || !KernelDeinitCB)
3736 return;
3737
3738 // Add itself to the reaching kernel and set IsKernelEntry.
3739 ReachingKernelEntries.insert(Fn);
3740 IsKernelEntry = true;
3741
3742 KernelEnvC =
3744 GlobalVariable *KernelEnvGV =
3746
3748 KernelConfigurationSimplifyCB =
3749 [&](const GlobalVariable &GV, const AbstractAttribute *AA,
3750 bool &UsedAssumedInformation) -> std::optional<Constant *> {
3751 if (!isAtFixpoint()) {
3752 if (!AA)
3753 return nullptr;
3754 UsedAssumedInformation = true;
3755 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3756 }
3757 return KernelEnvC;
3758 };
3759
3760 A.registerGlobalVariableSimplificationCallback(
3761 *KernelEnvGV, KernelConfigurationSimplifyCB);
3762
3763 // We cannot change to SPMD mode if the runtime functions aren't availible.
3764 bool CanChangeToSPMD = OMPInfoCache.runtimeFnsAvailable(
3765 {OMPRTL___kmpc_get_hardware_thread_id_in_block,
3766 OMPRTL___kmpc_barrier_simple_spmd});
3767
3768 // Check if we know we are in SPMD-mode already.
3769 ConstantInt *ExecModeC =
3770 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
3771 ConstantInt *AssumedExecModeC = ConstantInt::get(
3772 ExecModeC->getIntegerType(),
3774 if (ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)
3775 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3776 else if (DisableOpenMPOptSPMDization || !CanChangeToSPMD)
3777 // This is a generic region but SPMDization is disabled so stop
3778 // tracking.
3779 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3780 else
3781 setExecModeOfKernelEnvironment(AssumedExecModeC);
3782
3783 const Triple T(Fn->getParent()->getTargetTriple());
3784 auto *Int32Ty = Type::getInt32Ty(Fn->getContext());
3785 auto [MinThreads, MaxThreads] =
3787 if (MinThreads)
3788 setMinThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinThreads));
3789 if (MaxThreads)
3790 setMaxThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxThreads));
3791 auto [MinTeams, MaxTeams] =
3793 if (MinTeams)
3794 setMinTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinTeams));
3795 if (MaxTeams)
3796 setMaxTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxTeams));
3797
3798 ConstantInt *MayUseNestedParallelismC =
3799 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(KernelEnvC);
3800 ConstantInt *AssumedMayUseNestedParallelismC = ConstantInt::get(
3801 MayUseNestedParallelismC->getIntegerType(), NestedParallelism);
3802 setMayUseNestedParallelismOfKernelEnvironment(
3803 AssumedMayUseNestedParallelismC);
3804
3806 ConstantInt *UseGenericStateMachineC =
3807 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3808 KernelEnvC);
3809 ConstantInt *AssumedUseGenericStateMachineC =
3810 ConstantInt::get(UseGenericStateMachineC->getIntegerType(), false);
3811 setUseGenericStateMachineOfKernelEnvironment(
3812 AssumedUseGenericStateMachineC);
3813 }
3814
3815 // Register virtual uses of functions we might need to preserve.
3816 auto RegisterVirtualUse = [&](RuntimeFunction RFKind,
3818 if (!OMPInfoCache.RFIs[RFKind].Declaration)
3819 return;
3820 A.registerVirtualUseCallback(*OMPInfoCache.RFIs[RFKind].Declaration, CB);
3821 };
3822
3823 // Add a dependence to ensure updates if the state changes.
3824 auto AddDependence = [](Attributor &A, const AAKernelInfo *KI,
3825 const AbstractAttribute *QueryingAA) {
3826 if (QueryingAA) {
3827 A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);
3828 }
3829 return true;
3830 };
3831
3832 Attributor::VirtualUseCallbackTy CustomStateMachineUseCB =
3833 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3834 // Whenever we create a custom state machine we will insert calls to
3835 // __kmpc_get_hardware_num_threads_in_block,
3836 // __kmpc_get_warp_size,
3837 // __kmpc_barrier_simple_generic,
3838 // __kmpc_kernel_parallel, and
3839 // __kmpc_kernel_end_parallel.
3840 // Not needed if we are on track for SPMDzation.
3841 if (SPMDCompatibilityTracker.isValidState())
3842 return AddDependence(A, this, QueryingAA);
3843 // Not needed if we can't rewrite due to an invalid state.
3844 if (!ReachedKnownParallelRegions.isValidState())
3845 return AddDependence(A, this, QueryingAA);
3846 return false;
3847 };
3848
3849 // Not needed if we are pre-runtime merge.
3850 if (!KernelInitCB->getCalledFunction()->isDeclaration()) {
3851 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_num_threads_in_block,
3852 CustomStateMachineUseCB);
3853 RegisterVirtualUse(OMPRTL___kmpc_get_warp_size, CustomStateMachineUseCB);
3854 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_generic,
3855 CustomStateMachineUseCB);
3856 RegisterVirtualUse(OMPRTL___kmpc_kernel_parallel,
3857 CustomStateMachineUseCB);
3858 RegisterVirtualUse(OMPRTL___kmpc_kernel_end_parallel,
3859 CustomStateMachineUseCB);
3860 }
3861
3862 // If we do not perform SPMDzation we do not need the virtual uses below.
3863 if (SPMDCompatibilityTracker.isAtFixpoint())
3864 return;
3865
3866 Attributor::VirtualUseCallbackTy HWThreadIdUseCB =
3867 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3868 // Whenever we perform SPMDzation we will insert
3869 // __kmpc_get_hardware_thread_id_in_block calls.
3870 if (!SPMDCompatibilityTracker.isValidState())
3871 return AddDependence(A, this, QueryingAA);
3872 return false;
3873 };
3874 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_thread_id_in_block,
3875 HWThreadIdUseCB);
3876
3877 Attributor::VirtualUseCallbackTy SPMDBarrierUseCB =
3878 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3879 // Whenever we perform SPMDzation with guarding we will insert
3880 // __kmpc_simple_barrier_spmd calls. If SPMDzation failed, there is
3881 // nothing to guard, or there are no parallel regions, we don't need
3882 // the calls.
3883 if (!SPMDCompatibilityTracker.isValidState())
3884 return AddDependence(A, this, QueryingAA);
3885 if (SPMDCompatibilityTracker.empty())
3886 return AddDependence(A, this, QueryingAA);
3887 if (!mayContainParallelRegion())
3888 return AddDependence(A, this, QueryingAA);
3889 return false;
3890 };
3891 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_spmd, SPMDBarrierUseCB);
3892 }
3893
3894 /// Sanitize the string \p S such that it is a suitable global symbol name.
3895 static std::string sanitizeForGlobalName(std::string S) {
3896 std::replace_if(
3897 S.begin(), S.end(),
3898 [](const char C) {
3899 return !((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') ||
3900 (C >= '0' && C <= '9') || C == '_');
3901 },
3902 '.');
3903 return S;
3904 }
3905
3906 /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
3907 /// finished now.
3908 ChangeStatus manifest(Attributor &A) override {
3909 // If we are not looking at a kernel with __kmpc_target_init and
3910 // __kmpc_target_deinit call we cannot actually manifest the information.
3911 if (!KernelInitCB || !KernelDeinitCB)
3912 return ChangeStatus::UNCHANGED;
3913
3914 ChangeStatus Changed = ChangeStatus::UNCHANGED;
3915
3916 bool HasBuiltStateMachine = true;
3917 if (!changeToSPMDMode(A, Changed)) {
3918 if (!KernelInitCB->getCalledFunction()->isDeclaration())
3919 HasBuiltStateMachine = buildCustomStateMachine(A, Changed);
3920 else
3921 HasBuiltStateMachine = false;
3922 }
3923
3924 // We need to reset KernelEnvC if specific rewriting is not done.
3925 ConstantStruct *ExistingKernelEnvC =
3927 ConstantInt *OldUseGenericStateMachineVal =
3928 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3929 ExistingKernelEnvC);
3930 if (!HasBuiltStateMachine)
3931 setUseGenericStateMachineOfKernelEnvironment(
3932 OldUseGenericStateMachineVal);
3933
3934 // At last, update the KernelEnvc
3935 GlobalVariable *KernelEnvGV =
3937 if (KernelEnvGV->getInitializer() != KernelEnvC) {
3938 KernelEnvGV->setInitializer(KernelEnvC);
3939 Changed = ChangeStatus::CHANGED;
3940 }
3941
3942 return Changed;
3943 }
3944
3945 void insertInstructionGuardsHelper(Attributor &A) {
3946 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3947
3948 auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3949 Instruction *RegionEndI) {
3950 LoopInfo *LI = nullptr;
3951 DominatorTree *DT = nullptr;
3952 MemorySSAUpdater *MSU = nullptr;
3953 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3954
3955 BasicBlock *ParentBB = RegionStartI->getParent();
3956 Function *Fn = ParentBB->getParent();
3957 Module &M = *Fn->getParent();
3958
3959 // Create all the blocks and logic.
3960 // ParentBB:
3961 // goto RegionCheckTidBB
3962 // RegionCheckTidBB:
3963 // Tid = __kmpc_hardware_thread_id()
3964 // if (Tid != 0)
3965 // goto RegionBarrierBB
3966 // RegionStartBB:
3967 // <execute instructions guarded>
3968 // goto RegionEndBB
3969 // RegionEndBB:
3970 // <store escaping values to shared mem>
3971 // goto RegionBarrierBB
3972 // RegionBarrierBB:
3973 // __kmpc_simple_barrier_spmd()
3974 // // second barrier is omitted if lacking escaping values.
3975 // <load escaping values from shared mem>
3976 // __kmpc_simple_barrier_spmd()
3977 // goto RegionExitBB
3978 // RegionExitBB:
3979 // <execute rest of instructions>
3980
3981 BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3982 DT, LI, MSU, "region.guarded.end");
3983 BasicBlock *RegionBarrierBB =
3984 SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3985 MSU, "region.barrier");
3986 BasicBlock *RegionExitBB =
3987 SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3988 DT, LI, MSU, "region.exit");
3989 BasicBlock *RegionStartBB =
3990 SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
3991
3992 assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
3993 "Expected a different CFG");
3994
3995 BasicBlock *RegionCheckTidBB = SplitBlock(
3996 ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
3997
3998 // Register basic blocks with the Attributor.
3999 A.registerManifestAddedBasicBlock(*RegionEndBB);
4000 A.registerManifestAddedBasicBlock(*RegionBarrierBB);
4001 A.registerManifestAddedBasicBlock(*RegionExitBB);
4002 A.registerManifestAddedBasicBlock(*RegionStartBB);
4003 A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
4004
4005 bool HasBroadcastValues = false;
4006 // Find escaping outputs from the guarded region to outside users and
4007 // broadcast their values to them.
4008 for (Instruction &I : *RegionStartBB) {
4009 SmallVector<Use *, 4> OutsideUses;
4010 for (Use &U : I.uses()) {
4011 Instruction &UsrI = *cast<Instruction>(U.getUser());
4012 if (UsrI.getParent() != RegionStartBB)
4013 OutsideUses.push_back(&U);
4014 }
4015
4016 if (OutsideUses.empty())
4017 continue;
4018
4019 HasBroadcastValues = true;
4020
4021 // Emit a global variable in shared memory to store the broadcasted
4022 // value.
4023 auto *SharedMem = new GlobalVariable(
4024 M, I.getType(), /* IsConstant */ false,
4026 sanitizeForGlobalName(
4027 (I.getName() + ".guarded.output.alloc").str()),
4029 static_cast<unsigned>(AddressSpace::Shared));
4030
4031 // Emit a store instruction to update the value.
4032 new StoreInst(&I, SharedMem,
4033 RegionEndBB->getTerminator()->getIterator());
4034
4035 LoadInst *LoadI = new LoadInst(
4036 I.getType(), SharedMem, I.getName() + ".guarded.output.load",
4037 RegionBarrierBB->getTerminator()->getIterator());
4038
4039 // Emit a load instruction and replace uses of the output value.
4040 for (Use *U : OutsideUses)
4041 A.changeUseAfterManifest(*U, *LoadI);
4042 }
4043
4044 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4045
4046 // Go to tid check BB in ParentBB.
4047 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
4048 ParentBB->getTerminator()->eraseFromParent();
4050 InsertPointTy(ParentBB, ParentBB->end()), DL);
4051 OMPInfoCache.OMPBuilder.updateToLocation(Loc);
4052 uint32_t SrcLocStrSize;
4053 auto *SrcLocStr =
4054 OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4055 Value *Ident =
4056 OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4057 BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
4058
4059 // Add check for Tid in RegionCheckTidBB
4060 RegionCheckTidBB->getTerminator()->eraseFromParent();
4061 OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
4062 InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
4063 OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
4064 FunctionCallee HardwareTidFn =
4065 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4066 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4067 CallInst *Tid =
4068 OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
4069 Tid->setDebugLoc(DL);
4070 OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);
4071 Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
4072 OMPInfoCache.OMPBuilder.Builder
4073 .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
4074 ->setDebugLoc(DL);
4075
4076 // First barrier for synchronization, ensures main thread has updated
4077 // values.
4078 FunctionCallee BarrierFn =
4079 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4080 M, OMPRTL___kmpc_barrier_simple_spmd);
4081 OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
4082 RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
4083 CallInst *Barrier =
4084 OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});
4085 Barrier->setDebugLoc(DL);
4086 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4087
4088 // Second barrier ensures workers have read broadcast values.
4089 if (HasBroadcastValues) {
4090 CallInst *Barrier =
4091 CallInst::Create(BarrierFn, {Ident, Tid}, "",
4092 RegionBarrierBB->getTerminator()->getIterator());
4093 Barrier->setDebugLoc(DL);
4094 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4095 }
4096 };
4097
4098 auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4100 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4101 BasicBlock *BB = GuardedI->getParent();
4102 if (!Visited.insert(BB).second)
4103 continue;
4104
4106 Instruction *LastEffect = nullptr;
4107 BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();
4108 while (++IP != IPEnd) {
4109 if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
4110 continue;
4111 Instruction *I = &*IP;
4112 if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI))
4113 continue;
4114 if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) {
4115 LastEffect = nullptr;
4116 continue;
4117 }
4118 if (LastEffect)
4119 Reorders.push_back({I, LastEffect});
4120 LastEffect = &*IP;
4121 }
4122 for (auto &Reorder : Reorders)
4123 Reorder.first->moveBefore(Reorder.second);
4124 }
4125
4127
4128 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4129 BasicBlock *BB = GuardedI->getParent();
4130 auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
4131 IRPosition::function(*GuardedI->getFunction()), nullptr,
4132 DepClassTy::NONE);
4133 assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
4134 auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
4135 // Continue if instruction is already guarded.
4136 if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
4137 continue;
4138
4139 Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
4140 for (Instruction &I : *BB) {
4141 // If instruction I needs to be guarded update the guarded region
4142 // bounds.
4143 if (SPMDCompatibilityTracker.contains(&I)) {
4144 CalleeAAFunction.getGuardedInstructions().insert(&I);
4145 if (GuardedRegionStart)
4146 GuardedRegionEnd = &I;
4147 else
4148 GuardedRegionStart = GuardedRegionEnd = &I;
4149
4150 continue;
4151 }
4152
4153 // Instruction I does not need guarding, store
4154 // any region found and reset bounds.
4155 if (GuardedRegionStart) {
4156 GuardedRegions.push_back(
4157 std::make_pair(GuardedRegionStart, GuardedRegionEnd));
4158 GuardedRegionStart = nullptr;
4159 GuardedRegionEnd = nullptr;
4160 }
4161 }
4162 }
4163
4164 for (auto &GR : GuardedRegions)
4165 CreateGuardedRegion(GR.first, GR.second);
4166 }
4167
4168 void forceSingleThreadPerWorkgroupHelper(Attributor &A) {
4169 // Only allow 1 thread per workgroup to continue executing the user code.
4170 //
4171 // InitCB = __kmpc_target_init(...)
4172 // ThreadIdInBlock = __kmpc_get_hardware_thread_id_in_block();
4173 // if (ThreadIdInBlock != 0) return;
4174 // UserCode:
4175 // // user code
4176 //
4177 auto &Ctx = getAnchorValue().getContext();
4178 Function *Kernel = getAssociatedFunction();
4179 assert(Kernel && "Expected an associated function!");
4180
4181 // Create block for user code to branch to from initial block.
4182 BasicBlock *InitBB = KernelInitCB->getParent();
4183 BasicBlock *UserCodeBB = InitBB->splitBasicBlock(
4184 KernelInitCB->getNextNode(), "main.thread.user_code");
4185 BasicBlock *ReturnBB =
4186 BasicBlock::Create(Ctx, "exit.threads", Kernel, UserCodeBB);
4187
4188 // Register blocks with attributor:
4189 A.registerManifestAddedBasicBlock(*InitBB);
4190 A.registerManifestAddedBasicBlock(*UserCodeBB);
4191 A.registerManifestAddedBasicBlock(*ReturnBB);
4192
4193 // Debug location:
4194 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4195 ReturnInst::Create(Ctx, ReturnBB)->setDebugLoc(DLoc);
4196 InitBB->getTerminator()->eraseFromParent();
4197
4198 // Prepare call to OMPRTL___kmpc_get_hardware_thread_id_in_block.
4199 Module &M = *Kernel->getParent();
4200 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4201 FunctionCallee ThreadIdInBlockFn =
4202 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4203 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4204
4205 // Get thread ID in block.
4206 CallInst *ThreadIdInBlock =
4207 CallInst::Create(ThreadIdInBlockFn, "thread_id.in.block", InitBB);
4208 OMPInfoCache.setCallingConvention(ThreadIdInBlockFn, ThreadIdInBlock);
4209 ThreadIdInBlock->setDebugLoc(DLoc);
4210
4211 // Eliminate all threads in the block with ID not equal to 0:
4212 Instruction *IsMainThread =
4213 ICmpInst::Create(ICmpInst::ICmp, CmpInst::ICMP_NE, ThreadIdInBlock,
4214 ConstantInt::get(ThreadIdInBlock->getType(), 0),
4215 "thread.is_main", InitBB);
4216 IsMainThread->setDebugLoc(DLoc);
4217 BranchInst::Create(ReturnBB, UserCodeBB, IsMainThread, InitBB);
4218 }
4219
4220 bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {
4221 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4222
4223 if (!SPMDCompatibilityTracker.isAssumed()) {
4224 for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
4225 if (!NonCompatibleI)
4226 continue;
4227
4228 // Skip diagnostics on calls to known OpenMP runtime functions for now.
4229 if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
4230 if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
4231 continue;
4232
4233 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4234 ORA << "Value has potential side effects preventing SPMD-mode "
4235 "execution";
4236 if (isa<CallBase>(NonCompatibleI)) {
4237 ORA << ". Add `[[omp::assume(\"ompx_spmd_amenable\")]]` to "
4238 "the called function to override";
4239 }
4240 return ORA << ".";
4241 };
4242 A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
4243 Remark);
4244
4245 LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
4246 << *NonCompatibleI << "\n");
4247 }
4248
4249 return false;
4250 }
4251
4252 // Get the actual kernel, could be the caller of the anchor scope if we have
4253 // a debug wrapper.
4254 Function *Kernel = getAnchorScope();
4255 if (Kernel->hasLocalLinkage()) {
4256 assert(Kernel->hasOneUse() && "Unexpected use of debug kernel wrapper.");
4257 auto *CB = cast<CallBase>(Kernel->user_back());
4258 Kernel = CB->getCaller();
4259 }
4260 assert(omp::isOpenMPKernel(*Kernel) && "Expected kernel function!");
4261
4262 // Check if the kernel is already in SPMD mode, if so, return success.
4263 ConstantStruct *ExistingKernelEnvC =
4265 auto *ExecModeC =
4266 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4267 const int8_t ExecModeVal = ExecModeC->getSExtValue();
4268 if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC)
4269 return true;
4270
4271 // We will now unconditionally modify the IR, indicate a change.
4272 Changed = ChangeStatus::CHANGED;
4273
4274 // Do not use instruction guards when no parallel is present inside
4275 // the target region.
4276 if (mayContainParallelRegion())
4277 insertInstructionGuardsHelper(A);
4278 else
4279 forceSingleThreadPerWorkgroupHelper(A);
4280
4281 // Adjust the global exec mode flag that tells the runtime what mode this
4282 // kernel is executed in.
4283 assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
4284 "Initially non-SPMD kernel has SPMD exec mode!");
4285 setExecModeOfKernelEnvironment(
4286 ConstantInt::get(ExecModeC->getIntegerType(),
4287 ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
4288
4289 ++NumOpenMPTargetRegionKernelsSPMD;
4290
4291 auto Remark = [&](OptimizationRemark OR) {
4292 return OR << "Transformed generic-mode kernel to SPMD-mode.";
4293 };
4294 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
4295 return true;
4296 };
4297
4298 bool buildCustomStateMachine(Attributor &A, ChangeStatus &Changed) {
4299 // If we have disabled state machine rewrites, don't make a custom one
4301 return false;
4302
4303 // Don't rewrite the state machine if we are not in a valid state.
4304 if (!ReachedKnownParallelRegions.isValidState())
4305 return false;
4306
4307 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4308 if (!OMPInfoCache.runtimeFnsAvailable(
4309 {OMPRTL___kmpc_get_hardware_num_threads_in_block,
4310 OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
4311 OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
4312 return false;
4313
4314 ConstantStruct *ExistingKernelEnvC =
4316
4317 // Check if the current configuration is non-SPMD and generic state machine.
4318 // If we already have SPMD mode or a custom state machine we do not need to
4319 // go any further. If it is anything but a constant something is weird and
4320 // we give up.
4321 ConstantInt *UseStateMachineC =
4322 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4323 ExistingKernelEnvC);
4324 ConstantInt *ModeC =
4325 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4326
4327 // If we are stuck with generic mode, try to create a custom device (=GPU)
4328 // state machine which is specialized for the parallel regions that are
4329 // reachable by the kernel.
4330 if (UseStateMachineC->isZero() ||
4332 return false;
4333
4334 Changed = ChangeStatus::CHANGED;
4335
4336 // If not SPMD mode, indicate we use a custom state machine now.
4337 setUseGenericStateMachineOfKernelEnvironment(
4338 ConstantInt::get(UseStateMachineC->getIntegerType(), false));
4339
4340 // If we don't actually need a state machine we are done here. This can
4341 // happen if there simply are no parallel regions. In the resulting kernel
4342 // all worker threads will simply exit right away, leaving the main thread
4343 // to do the work alone.
4344 if (!mayContainParallelRegion()) {
4345 ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
4346
4347 auto Remark = [&](OptimizationRemark OR) {
4348 return OR << "Removing unused state machine from generic-mode kernel.";
4349 };
4350 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
4351
4352 return true;
4353 }
4354
4355 // Keep track in the statistics of our new shiny custom state machine.
4356 if (ReachedUnknownParallelRegions.empty()) {
4357 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
4358
4359 auto Remark = [&](OptimizationRemark OR) {
4360 return OR << "Rewriting generic-mode kernel with a customized state "
4361 "machine.";
4362 };
4363 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
4364 } else {
4365 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
4366
4368 return OR << "Generic-mode kernel is executed with a customized state "
4369 "machine that requires a fallback.";
4370 };
4371 A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
4372
4373 // Tell the user why we ended up with a fallback.
4374 for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
4375 if (!UnknownParallelRegionCB)
4376 continue;
4377 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4378 return ORA << "Call may contain unknown parallel regions. Use "
4379 << "`[[omp::assume(\"omp_no_parallelism\")]]` to "
4380 "override.";
4381 };
4382 A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
4383 "OMP133", Remark);
4384 }
4385 }
4386
4387 // Create all the blocks:
4388 //
4389 // InitCB = __kmpc_target_init(...)
4390 // BlockHwSize =
4391 // __kmpc_get_hardware_num_threads_in_block();
4392 // WarpSize = __kmpc_get_warp_size();
4393 // BlockSize = BlockHwSize - WarpSize;
4394 // IsWorkerCheckBB: bool IsWorker = InitCB != -1;
4395 // if (IsWorker) {
4396 // if (InitCB >= BlockSize) return;
4397 // SMBeginBB: __kmpc_barrier_simple_generic(...);
4398 // void *WorkFn;
4399 // bool Active = __kmpc_kernel_parallel(&WorkFn);
4400 // if (!WorkFn) return;
4401 // SMIsActiveCheckBB: if (Active) {
4402 // SMIfCascadeCurrentBB: if (WorkFn == <ParFn0>)
4403 // ParFn0(...);
4404 // SMIfCascadeCurrentBB: else if (WorkFn == <ParFn1>)
4405 // ParFn1(...);
4406 // ...
4407 // SMIfCascadeCurrentBB: else
4408 // ((WorkFnTy*)WorkFn)(...);
4409 // SMEndParallelBB: __kmpc_kernel_end_parallel(...);
4410 // }
4411 // SMDoneBB: __kmpc_barrier_simple_generic(...);
4412 // goto SMBeginBB;
4413 // }
4414 // UserCodeEntryBB: // user code
4415 // __kmpc_target_deinit(...)
4416 //
4417 auto &Ctx = getAnchorValue().getContext();
4418 Function *Kernel = getAssociatedFunction();
4419 assert(Kernel && "Expected an associated function!");
4420
4421 BasicBlock *InitBB = KernelInitCB->getParent();
4422 BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
4423 KernelInitCB->getNextNode(), "thread.user_code.check");
4424 BasicBlock *IsWorkerCheckBB =
4425 BasicBlock::Create(Ctx, "is_worker_check", Kernel, UserCodeEntryBB);
4426 BasicBlock *StateMachineBeginBB = BasicBlock::Create(
4427 Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
4428 BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
4429 Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
4430 BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
4431 Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
4432 BasicBlock *StateMachineIfCascadeCurrentBB =
4433 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4434 Kernel, UserCodeEntryBB);
4435 BasicBlock *StateMachineEndParallelBB =
4436 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
4437 Kernel, UserCodeEntryBB);
4438 BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
4439 Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
4440 A.registerManifestAddedBasicBlock(*InitBB);
4441 A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
4442 A.registerManifestAddedBasicBlock(*IsWorkerCheckBB);
4443 A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
4444 A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
4445 A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
4446 A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
4447 A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
4448 A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
4449
4450 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4451 ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
4452 InitBB->getTerminator()->eraseFromParent();
4453
4454 Instruction *IsWorker =
4455 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
4456 ConstantInt::get(KernelInitCB->getType(), -1),
4457 "thread.is_worker", InitBB);
4458 IsWorker->setDebugLoc(DLoc);
4459 BranchInst::Create(IsWorkerCheckBB, UserCodeEntryBB, IsWorker, InitBB);
4460
4461 Module &M = *Kernel->getParent();
4462 FunctionCallee BlockHwSizeFn =
4463 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4464 M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
4465 FunctionCallee WarpSizeFn =
4466 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4467 M, OMPRTL___kmpc_get_warp_size);
4468 CallInst *BlockHwSize =
4469 CallInst::Create(BlockHwSizeFn, "block.hw_size", IsWorkerCheckBB);
4470 OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);
4471 BlockHwSize->setDebugLoc(DLoc);
4472 CallInst *WarpSize =
4473 CallInst::Create(WarpSizeFn, "warp.size", IsWorkerCheckBB);
4474 OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);
4475 WarpSize->setDebugLoc(DLoc);
4476 Instruction *BlockSize = BinaryOperator::CreateSub(
4477 BlockHwSize, WarpSize, "block.size", IsWorkerCheckBB);
4478 BlockSize->setDebugLoc(DLoc);
4479 Instruction *IsMainOrWorker = ICmpInst::Create(
4480 ICmpInst::ICmp, llvm::CmpInst::ICMP_SLT, KernelInitCB, BlockSize,
4481 "thread.is_main_or_worker", IsWorkerCheckBB);
4482 IsMainOrWorker->setDebugLoc(DLoc);
4483 BranchInst::Create(StateMachineBeginBB, StateMachineFinishedBB,
4484 IsMainOrWorker, IsWorkerCheckBB);
4485
4486 // Create local storage for the work function pointer.
4487 const DataLayout &DL = M.getDataLayout();
4488 Type *VoidPtrTy = PointerType::getUnqual(Ctx);
4489 Instruction *WorkFnAI =
4490 new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
4491 "worker.work_fn.addr", Kernel->getEntryBlock().begin());
4492 WorkFnAI->setDebugLoc(DLoc);
4493
4494 OMPInfoCache.OMPBuilder.updateToLocation(
4496 IRBuilder<>::InsertPoint(StateMachineBeginBB,
4497 StateMachineBeginBB->end()),
4498 DLoc));
4499
4500 Value *Ident = KernelInfo::getIdentFromKernelEnvironment(KernelEnvC);
4501 Value *GTid = KernelInitCB;
4502
4503 FunctionCallee BarrierFn =
4504 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4505 M, OMPRTL___kmpc_barrier_simple_generic);
4506 CallInst *Barrier =
4507 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB);
4508 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4509 Barrier->setDebugLoc(DLoc);
4510
4511 if (WorkFnAI->getType()->getPointerAddressSpace() !=
4512 (unsigned int)AddressSpace::Generic) {
4513 WorkFnAI = new AddrSpaceCastInst(
4514 WorkFnAI, PointerType::get(Ctx, (unsigned int)AddressSpace::Generic),
4515 WorkFnAI->getName() + ".generic", StateMachineBeginBB);
4516 WorkFnAI->setDebugLoc(DLoc);
4517 }
4518
4519 FunctionCallee KernelParallelFn =
4520 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4521 M, OMPRTL___kmpc_kernel_parallel);
4522 CallInst *IsActiveWorker = CallInst::Create(
4523 KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
4524 OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);
4525 IsActiveWorker->setDebugLoc(DLoc);
4526 Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
4527 StateMachineBeginBB);
4528 WorkFn->setDebugLoc(DLoc);
4529
4530 FunctionType *ParallelRegionFnTy = FunctionType::get(
4532 false);
4533
4534 Instruction *IsDone =
4535 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
4536 Constant::getNullValue(VoidPtrTy), "worker.is_done",
4537 StateMachineBeginBB);
4538 IsDone->setDebugLoc(DLoc);
4539 BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
4540 IsDone, StateMachineBeginBB)
4541 ->setDebugLoc(DLoc);
4542
4543 BranchInst::Create(StateMachineIfCascadeCurrentBB,
4544 StateMachineDoneBarrierBB, IsActiveWorker,
4545 StateMachineIsActiveCheckBB)
4546 ->setDebugLoc(DLoc);
4547
4548 Value *ZeroArg =
4549 Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
4550
4551 const unsigned int WrapperFunctionArgNo = 6;
4552
4553 // Now that we have most of the CFG skeleton it is time for the if-cascade
4554 // that checks the function pointer we got from the runtime against the
4555 // parallel regions we expect, if there are any.
4556 for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) {
4557 auto *CB = ReachedKnownParallelRegions[I];
4558 auto *ParallelRegion = dyn_cast<Function>(
4559 CB->getArgOperand(WrapperFunctionArgNo)->stripPointerCasts());
4560 BasicBlock *PRExecuteBB = BasicBlock::Create(
4561 Ctx, "worker_state_machine.parallel_region.execute", Kernel,
4562 StateMachineEndParallelBB);
4563 CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
4564 ->setDebugLoc(DLoc);
4565 BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
4566 ->setDebugLoc(DLoc);
4567
4568 BasicBlock *PRNextBB =
4569 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4570 Kernel, StateMachineEndParallelBB);
4571 A.registerManifestAddedBasicBlock(*PRExecuteBB);
4572 A.registerManifestAddedBasicBlock(*PRNextBB);
4573
4574 // Check if we need to compare the pointer at all or if we can just
4575 // call the parallel region function.
4576 Value *IsPR;
4577 if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) {
4578 Instruction *CmpI = ICmpInst::Create(
4579 ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn, ParallelRegion,
4580 "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
4581 CmpI->setDebugLoc(DLoc);
4582 IsPR = CmpI;
4583 } else {
4584 IsPR = ConstantInt::getTrue(Ctx);
4585 }
4586
4587 BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
4588 StateMachineIfCascadeCurrentBB)
4589 ->setDebugLoc(DLoc);
4590 StateMachineIfCascadeCurrentBB = PRNextBB;
4591 }
4592
4593 // At the end of the if-cascade we place the indirect function pointer call
4594 // in case we might need it, that is if there can be parallel regions we
4595 // have not handled in the if-cascade above.
4596 if (!ReachedUnknownParallelRegions.empty()) {
4597 StateMachineIfCascadeCurrentBB->setName(
4598 "worker_state_machine.parallel_region.fallback.execute");
4599 CallInst::Create(ParallelRegionFnTy, WorkFn, {ZeroArg, GTid}, "",
4600 StateMachineIfCascadeCurrentBB)
4601 ->setDebugLoc(DLoc);
4602 }
4603 BranchInst::Create(StateMachineEndParallelBB,
4604 StateMachineIfCascadeCurrentBB)
4605 ->setDebugLoc(DLoc);
4606
4607 FunctionCallee EndParallelFn =
4608 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4609 M, OMPRTL___kmpc_kernel_end_parallel);
4610 CallInst *EndParallel =
4611 CallInst::Create(EndParallelFn, {}, "", StateMachineEndParallelBB);
4612 OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);
4613 EndParallel->setDebugLoc(DLoc);
4614 BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
4615 ->setDebugLoc(DLoc);
4616
4617 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
4618 ->setDebugLoc(DLoc);
4619 BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
4620 ->setDebugLoc(DLoc);
4621
4622 return true;
4623 }
4624
4625 /// Fixpoint iteration update function. Will be called every time a dependence
4626 /// changed its state (and in the beginning).
4627 ChangeStatus updateImpl(Attributor &A) override {
4628 KernelInfoState StateBefore = getState();
4629
4630 // When we leave this function this RAII will make sure the member
4631 // KernelEnvC is updated properly depending on the state. That member is
4632 // used for simplification of values and needs to be up to date at all
4633 // times.
4634 struct UpdateKernelEnvCRAII {
4635 AAKernelInfoFunction &AA;
4636
4637 UpdateKernelEnvCRAII(AAKernelInfoFunction &AA) : AA(AA) {}
4638
4639 ~UpdateKernelEnvCRAII() {
4640 if (!AA.KernelEnvC)
4641 return;
4642
4643 ConstantStruct *ExistingKernelEnvC =
4645
4646 if (!AA.isValidState()) {
4647 AA.KernelEnvC = ExistingKernelEnvC;
4648 return;
4649 }
4650
4651 if (!AA.ReachedKnownParallelRegions.isValidState())
4652 AA.setUseGenericStateMachineOfKernelEnvironment(
4653 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4654 ExistingKernelEnvC));
4655
4656 if (!AA.SPMDCompatibilityTracker.isValidState())
4657 AA.setExecModeOfKernelEnvironment(
4658 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC));
4659
4660 ConstantInt *MayUseNestedParallelismC =
4661 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(
4662 AA.KernelEnvC);
4663 ConstantInt *NewMayUseNestedParallelismC = ConstantInt::get(
4664 MayUseNestedParallelismC->getIntegerType(), AA.NestedParallelism);
4665 AA.setMayUseNestedParallelismOfKernelEnvironment(
4666 NewMayUseNestedParallelismC);
4667 }
4668 } RAII(*this);
4669
4670 // Callback to check a read/write instruction.
4671 auto CheckRWInst = [&](Instruction &I) {
4672 // We handle calls later.
4673 if (isa<CallBase>(I))
4674 return true;
4675 // We only care about write effects.
4676 if (!I.mayWriteToMemory())
4677 return true;
4678 if (auto *SI = dyn_cast<StoreInst>(&I)) {
4679 const auto *UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>(
4680 *this, IRPosition::value(*SI->getPointerOperand()),
4681 DepClassTy::OPTIONAL);
4682 auto *HS = A.getAAFor<AAHeapToStack>(
4683 *this, IRPosition::function(*I.getFunction()),
4684 DepClassTy::OPTIONAL);
4685 if (UnderlyingObjsAA &&
4686 UnderlyingObjsAA->forallUnderlyingObjects([&](Value &Obj) {
4687 if (AA::isAssumedThreadLocalObject(A, Obj, *this))
4688 return true;
4689 // Check for AAHeapToStack moved objects which must not be
4690 // guarded.
4691 auto *CB = dyn_cast<CallBase>(&Obj);
4692 return CB && HS && HS->isAssumedHeapToStack(*CB);
4693 }))
4694 return true;
4695 }
4696
4697 // Insert instruction that needs guarding.
4698 SPMDCompatibilityTracker.insert(&I);
4699 return true;
4700 };
4701
4702 bool UsedAssumedInformationInCheckRWInst = false;
4703 if (!SPMDCompatibilityTracker.isAtFixpoint())
4704 if (!A.checkForAllReadWriteInstructions(
4705 CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
4706 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4707
4708 bool UsedAssumedInformationFromReachingKernels = false;
4709 if (!IsKernelEntry) {
4710 updateParallelLevels(A);
4711
4712 bool AllReachingKernelsKnown = true;
4713 updateReachingKernelEntries(A, AllReachingKernelsKnown);
4714 UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
4715
4716 if (!SPMDCompatibilityTracker.empty()) {
4717 if (!ParallelLevels.isValidState())
4718 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4719 else if (!ReachingKernelEntries.isValidState())
4720 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4721 else {
4722 // Check if all reaching kernels agree on the mode as we can otherwise
4723 // not guard instructions. We might not be sure about the mode so we
4724 // we cannot fix the internal spmd-zation state either.
4725 int SPMD = 0, Generic = 0;
4726 for (auto *Kernel : ReachingKernelEntries) {
4727 auto *CBAA = A.getAAFor<AAKernelInfo>(
4728 *this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL);
4729 if (CBAA && CBAA->SPMDCompatibilityTracker.isValidState() &&
4730 CBAA->SPMDCompatibilityTracker.isAssumed())
4731 ++SPMD;
4732 else
4733 ++Generic;
4734 if (!CBAA || !CBAA->SPMDCompatibilityTracker.isAtFixpoint())
4735 UsedAssumedInformationFromReachingKernels = true;
4736 }
4737 if (SPMD != 0 && Generic != 0)
4738 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4739 }
4740 }
4741 }
4742
4743 // Callback to check a call instruction.
4744 bool AllParallelRegionStatesWereFixed = true;
4745 bool AllSPMDStatesWereFixed = true;
4746 auto CheckCallInst = [&](Instruction &I) {
4747 auto &CB = cast<CallBase>(I);
4748 auto *CBAA = A.getAAFor<AAKernelInfo>(
4749 *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4750 if (!CBAA)
4751 return false;
4752 getState() ^= CBAA->getState();
4753 AllSPMDStatesWereFixed &= CBAA->SPMDCompatibilityTracker.isAtFixpoint();
4754 AllParallelRegionStatesWereFixed &=
4755 CBAA->ReachedKnownParallelRegions.isAtFixpoint();
4756 AllParallelRegionStatesWereFixed &=
4757 CBAA->ReachedUnknownParallelRegions.isAtFixpoint();
4758 return true;
4759 };
4760
4761 bool UsedAssumedInformationInCheckCallInst = false;
4762 if (!A.checkForAllCallLikeInstructions(
4763 CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) {
4764 LLVM_DEBUG(dbgs() << TAG
4765 << "Failed to visit all call-like instructions!\n";);
4766 return indicatePessimisticFixpoint();
4767 }
4768
4769 // If we haven't used any assumed information for the reached parallel
4770 // region states we can fix it.
4771 if (!UsedAssumedInformationInCheckCallInst &&
4772 AllParallelRegionStatesWereFixed) {
4773 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
4774 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
4775 }
4776
4777 // If we haven't used any assumed information for the SPMD state we can fix
4778 // it.
4779 if (!UsedAssumedInformationInCheckRWInst &&
4780 !UsedAssumedInformationInCheckCallInst &&
4781 !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
4782 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4783
4784 return StateBefore == getState() ? ChangeStatus::UNCHANGED
4785 : ChangeStatus::CHANGED;
4786 }
4787
4788private:
4789 /// Update info regarding reaching kernels.
4790 void updateReachingKernelEntries(Attributor &A,
4791 bool &AllReachingKernelsKnown) {
4792 auto PredCallSite = [&](AbstractCallSite ACS) {
4793 Function *Caller = ACS.getInstruction()->getFunction();
4794
4795 assert(Caller && "Caller is nullptr");
4796
4797 auto *CAA = A.getOrCreateAAFor<AAKernelInfo>(
4798 IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
4799 if (CAA && CAA->ReachingKernelEntries.isValidState()) {
4800 ReachingKernelEntries ^= CAA->ReachingKernelEntries;
4801 return true;
4802 }
4803
4804 // We lost track of the caller of the associated function, any kernel
4805 // could reach now.
4806 ReachingKernelEntries.indicatePessimisticFixpoint();
4807
4808 return true;
4809 };
4810
4811 if (!A.checkForAllCallSites(PredCallSite, *this,
4812 true /* RequireAllCallSites */,
4813 AllReachingKernelsKnown))
4814 ReachingKernelEntries.indicatePessimisticFixpoint();
4815 }
4816
4817 /// Update info regarding parallel levels.
4818 void updateParallelLevels(Attributor &A) {
4819 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4820 OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
4821 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
4822
4823 auto PredCallSite = [&](AbstractCallSite ACS) {
4824 Function *Caller = ACS.getInstruction()->getFunction();
4825
4826 assert(Caller && "Caller is nullptr");
4827
4828 auto *CAA =
4829 A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
4830 if (CAA && CAA->ParallelLevels.isValidState()) {
4831 // Any function that is called by `__kmpc_parallel_51` will not be
4832 // folded as the parallel level in the function is updated. In order to
4833 // get it right, all the analysis would depend on the implentation. That
4834 // said, if in the future any change to the implementation, the analysis
4835 // could be wrong. As a consequence, we are just conservative here.
4836 if (Caller == Parallel51RFI.Declaration) {
4837 ParallelLevels.indicatePessimisticFixpoint();
4838 return true;
4839 }
4840
4841 ParallelLevels ^= CAA->ParallelLevels;
4842
4843 return true;
4844 }
4845
4846 // We lost track of the caller of the associated function, any kernel
4847 // could reach now.
4848 ParallelLevels.indicatePessimisticFixpoint();
4849
4850 return true;
4851 };
4852
4853 bool AllCallSitesKnown = true;
4854 if (!A.checkForAllCallSites(PredCallSite, *this,
4855 true /* RequireAllCallSites */,
4856 AllCallSitesKnown))
4857 ParallelLevels.indicatePessimisticFixpoint();
4858 }
4859};
4860
4861/// The call site kernel info abstract attribute, basically, what can we say
4862/// about a call site with regards to the KernelInfoState. For now this simply
4863/// forwards the information from the callee.
4864struct AAKernelInfoCallSite : AAKernelInfo {
4865 AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
4866 : AAKernelInfo(IRP, A) {}
4867
4868 /// See AbstractAttribute::initialize(...).
4869 void initialize(Attributor &A) override {
4870 AAKernelInfo::initialize(A);
4871
4872 CallBase &CB = cast<CallBase>(getAssociatedValue());
4873 auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>(
4874 *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4875
4876 // Check for SPMD-mode assumptions.
4877 if (AssumptionAA && AssumptionAA->hasAssumption("ompx_spmd_amenable")) {
4878 indicateOptimisticFixpoint();
4879 return;
4880 }
4881
4882 // First weed out calls we do not care about, that is readonly/readnone
4883 // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
4884 // parallel region or anything else we are looking for.
4885 if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
4886 indicateOptimisticFixpoint();
4887 return;
4888 }
4889
4890 // Next we check if we know the callee. If it is a known OpenMP function
4891 // we will handle them explicitly in the switch below. If it is not, we
4892 // will use an AAKernelInfo object on the callee to gather information and
4893 // merge that into the current state. The latter happens in the updateImpl.
4894 auto CheckCallee = [&](Function *Callee, unsigned NumCallees) {
4895 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4896 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4897 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4898 // Unknown caller or declarations are not analyzable, we give up.
4899 if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
4900
4901 // Unknown callees might contain parallel regions, except if they have
4902 // an appropriate assumption attached.
4903 if (!AssumptionAA ||
4904 !(AssumptionAA->hasAssumption("omp_no_openmp") ||
4905 AssumptionAA->hasAssumption("omp_no_parallelism")))
4906 ReachedUnknownParallelRegions.insert(&CB);
4907
4908 // If SPMDCompatibilityTracker is not fixed, we need to give up on the
4909 // idea we can run something unknown in SPMD-mode.
4910 if (!SPMDCompatibilityTracker.isAtFixpoint()) {
4911 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4912 SPMDCompatibilityTracker.insert(&CB);
4913 }
4914
4915 // We have updated the state for this unknown call properly, there
4916 // won't be any change so we indicate a fixpoint.
4917 indicateOptimisticFixpoint();
4918 }
4919 // If the callee is known and can be used in IPO, we will update the
4920 // state based on the callee state in updateImpl.
4921 return;
4922 }
4923 if (NumCallees > 1) {
4924 indicatePessimisticFixpoint();
4925 return;
4926 }
4927
4928 RuntimeFunction RF = It->getSecond();
4929 switch (RF) {
4930 // All the functions we know are compatible with SPMD mode.
4931 case OMPRTL___kmpc_is_spmd_exec_mode:
4932 case OMPRTL___kmpc_distribute_static_fini:
4933 case OMPRTL___kmpc_for_static_fini:
4934 case OMPRTL___kmpc_global_thread_num:
4935 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4936 case OMPRTL___kmpc_get_hardware_num_blocks:
4937 case OMPRTL___kmpc_single:
4938 case OMPRTL___kmpc_end_single:
4939 case OMPRTL___kmpc_master:
4940 case OMPRTL___kmpc_end_master:
4941 case OMPRTL___kmpc_barrier:
4942 case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
4943 case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
4944 case OMPRTL___kmpc_error:
4945 case OMPRTL___kmpc_flush:
4946 case OMPRTL___kmpc_get_hardware_thread_id_in_block:
4947 case OMPRTL___kmpc_get_warp_size:
4948 case OMPRTL_omp_get_thread_num:
4949 case OMPRTL_omp_get_num_threads:
4950 case OMPRTL_omp_get_max_threads:
4951 case OMPRTL_omp_in_parallel:
4952 case OMPRTL_omp_get_dynamic:
4953 case OMPRTL_omp_get_cancellation:
4954 case OMPRTL_omp_get_nested:
4955 case OMPRTL_omp_get_schedule:
4956 case OMPRTL_omp_get_thread_limit:
4957 case OMPRTL_omp_get_supported_active_levels:
4958 case OMPRTL_omp_get_max_active_levels:
4959 case OMPRTL_omp_get_level:
4960 case OMPRTL_omp_get_ancestor_thread_num:
4961 case OMPRTL_omp_get_team_size:
4962 case OMPRTL_omp_get_active_level:
4963 case OMPRTL_omp_in_final:
4964 case OMPRTL_omp_get_proc_bind:
4965 case OMPRTL_omp_get_num_places:
4966 case OMPRTL_omp_get_num_procs:
4967 case OMPRTL_omp_get_place_proc_ids:
4968 case OMPRTL_omp_get_place_num:
4969 case OMPRTL_omp_get_partition_num_places:
4970 case OMPRTL_omp_get_partition_place_nums:
4971 case OMPRTL_omp_get_wtime:
4972 break;
4973 case OMPRTL___kmpc_distribute_static_init_4:
4974 case OMPRTL___kmpc_distribute_static_init_4u:
4975 case OMPRTL___kmpc_distribute_static_init_8:
4976 case OMPRTL___kmpc_distribute_static_init_8u:
4977 case OMPRTL___kmpc_for_static_init_4:
4978 case OMPRTL___kmpc_for_static_init_4u:
4979 case OMPRTL___kmpc_for_static_init_8:
4980 case OMPRTL___kmpc_for_static_init_8u: {
4981 // Check the schedule and allow static schedule in SPMD mode.
4982 unsigned ScheduleArgOpNo = 2;
4983 auto *ScheduleTypeCI =
4984 dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
4985 unsigned ScheduleTypeVal =
4986 ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
4987 switch (OMPScheduleType(ScheduleTypeVal)) {
4988 case OMPScheduleType::UnorderedStatic:
4989 case OMPScheduleType::UnorderedStaticChunked:
4990 case OMPScheduleType::OrderedDistribute:
4991 case OMPScheduleType::OrderedDistributeChunked:
4992 break;
4993 default:
4994 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4995 SPMDCompatibilityTracker.insert(&CB);
4996 break;
4997 };
4998 } break;
4999 case OMPRTL___kmpc_target_init:
5000 KernelInitCB = &CB;
5001 break;
5002 case OMPRTL___kmpc_target_deinit:
5003 KernelDeinitCB = &CB;
5004 break;
5005 case OMPRTL___kmpc_parallel_51:
5006 if (!handleParallel51(A, CB))
5007 indicatePessimisticFixpoint();
5008 return;
5009 case OMPRTL___kmpc_omp_task:
5010 // We do not look into tasks right now, just give up.
5011 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5012 SPMDCompatibilityTracker.insert(&CB);
5013 ReachedUnknownParallelRegions.insert(&CB);
5014 break;
5015 case OMPRTL___kmpc_alloc_shared:
5016 case OMPRTL___kmpc_free_shared:
5017 // Return without setting a fixpoint, to be resolved in updateImpl.
5018 return;
5019 default:
5020 // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
5021 // generally. However, they do not hide parallel regions.
5022 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5023 SPMDCompatibilityTracker.insert(&CB);
5024 break;
5025 }
5026 // All other OpenMP runtime calls will not reach parallel regions so they
5027 // can be safely ignored for now. Since it is a known OpenMP runtime call
5028 // we have now modeled all effects and there is no need for any update.
5029 indicateOptimisticFixpoint();
5030 };
5031
5032 const auto *AACE =
5033 A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5034 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5035 CheckCallee(getAssociatedFunction(), 1);
5036 return;
5037 }
5038 const auto &OptimisticEdges = AACE->getOptimisticEdges();
5039 for (auto *Callee : OptimisticEdges) {
5040 CheckCallee(Callee, OptimisticEdges.size());
5041 if (isAtFixpoint())
5042 break;
5043 }
5044 }
5045
5046 ChangeStatus updateImpl(Attributor &A) override {
5047 // TODO: Once we have call site specific value information we can provide
5048 // call site specific liveness information and then it makes
5049 // sense to specialize attributes for call sites arguments instead of
5050 // redirecting requests to the callee argument.
5051 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5052 KernelInfoState StateBefore = getState();
5053
5054 auto CheckCallee = [&](Function *F, int NumCallees) {
5055 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
5056
5057 // If F is not a runtime function, propagate the AAKernelInfo of the
5058 // callee.
5059 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
5060 const IRPosition &FnPos = IRPosition::function(*F);
5061 auto *FnAA =
5062 A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
5063 if (!FnAA)
5064 return indicatePessimisticFixpoint();
5065 if (getState() == FnAA->getState())
5066 return ChangeStatus::UNCHANGED;
5067 getState() = FnAA->getState();
5068 return ChangeStatus::CHANGED;
5069 }
5070 if (NumCallees > 1)
5071 return indicatePessimisticFixpoint();
5072
5073 CallBase &CB = cast<CallBase>(getAssociatedValue());
5074 if (It->getSecond() == OMPRTL___kmpc_parallel_51) {
5075 if (!handleParallel51(A, CB))
5076 return indicatePessimisticFixpoint();
5077 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5078 : ChangeStatus::CHANGED;
5079 }
5080
5081 // F is a runtime function that allocates or frees memory, check
5082 // AAHeapToStack and AAHeapToShared.
5083 assert(
5084 (It->getSecond() == OMPRTL___kmpc_alloc_shared ||
5085 It->getSecond() == OMPRTL___kmpc_free_shared) &&
5086 "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
5087
5088 auto *HeapToStackAA = A.getAAFor<AAHeapToStack>(
5089 *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5090 auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>(
5091 *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5092
5093 RuntimeFunction RF = It->getSecond();
5094
5095 switch (RF) {
5096 // If neither HeapToStack nor HeapToShared assume the call is removed,
5097 // assume SPMD incompatibility.
5098 case OMPRTL___kmpc_alloc_shared:
5099 if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&
5100 (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))
5101 SPMDCompatibilityTracker.insert(&CB);
5102 break;
5103 case OMPRTL___kmpc_free_shared:
5104 if ((!HeapToStackAA ||
5105 !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&
5106 (!HeapToSharedAA ||
5107 !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))
5108 SPMDCompatibilityTracker.insert(&CB);
5109 break;
5110 default:
5111 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5112 SPMDCompatibilityTracker.insert(&CB);
5113 }
5114 return ChangeStatus::CHANGED;
5115 };
5116
5117 const auto *AACE =
5118 A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5119 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5120 if (Function *F = getAssociatedFunction())
5121 CheckCallee(F, /*NumCallees=*/1);
5122 } else {
5123 const auto &OptimisticEdges = AACE->getOptimisticEdges();
5124 for (auto *Callee : OptimisticEdges) {
5125 CheckCallee(Callee, OptimisticEdges.size());
5126 if (isAtFixpoint())
5127 break;
5128 }
5129 }
5130
5131 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5132 : ChangeStatus::CHANGED;
5133 }
5134
5135 /// Deal with a __kmpc_parallel_51 call (\p CB). Returns true if the call was
5136 /// handled, if a problem occurred, false is returned.
5137 bool handleParallel51(Attributor &A, CallBase &CB) {
5138 const unsigned int NonWrapperFunctionArgNo = 5;
5139 const unsigned int WrapperFunctionArgNo = 6;
5140 auto ParallelRegionOpArgNo = SPMDCompatibilityTracker.isAssumed()
5141 ? NonWrapperFunctionArgNo
5142 : WrapperFunctionArgNo;
5143
5144 auto *ParallelRegion = dyn_cast<Function>(
5145 CB.getArgOperand(ParallelRegionOpArgNo)->stripPointerCasts());
5146 if (!ParallelRegion)
5147 return false;
5148
5149 ReachedKnownParallelRegions.insert(&CB);
5150 /// Check nested parallelism
5151 auto *FnAA = A.getAAFor<AAKernelInfo>(
5152 *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL);
5153 NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
5154 !FnAA->ReachedKnownParallelRegions.empty() ||
5155 !FnAA->ReachedKnownParallelRegions.isValidState() ||
5156 !FnAA->ReachedUnknownParallelRegions.isValidState() ||
5157 !FnAA->ReachedUnknownParallelRegions.empty();
5158 return true;
5159 }
5160};
5161
5162struct AAFoldRuntimeCall
5163 : public StateWrapper<BooleanState, AbstractAttribute> {
5165
5166 AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
5167
5168 /// Statistics are tracked as part of manifest for now.
5169 void trackStatistics() const override {}
5170
5171 /// Create an abstract attribute biew for the position \p IRP.
5172 static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
5173 Attributor &A);
5174
5175 /// See AbstractAttribute::getName()
5176 const std::string getName() const override { return "AAFoldRuntimeCall"; }
5177
5178 /// See AbstractAttribute::getIdAddr()
5179 const char *getIdAddr() const override { return &ID; }
5180
5181 /// This function should return true if the type of the \p AA is
5182 /// AAFoldRuntimeCall
5183 static bool classof(const AbstractAttribute *AA) {
5184 return (AA->getIdAddr() == &ID);
5185 }
5186
5187 static const char ID;
5188};
5189
5190struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
5191 AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
5192 : AAFoldRuntimeCall(IRP, A) {}
5193
5194 /// See AbstractAttribute::getAsStr()
5195 const std::string getAsStr(Attributor *) const override {
5196 if (!isValidState())
5197 return "<invalid>";
5198
5199 std::string Str("simplified value: ");
5200
5201 if (!SimplifiedValue)
5202 return Str + std::string("none");
5203
5204 if (!*SimplifiedValue)
5205 return Str + std::string("nullptr");
5206
5207 if (ConstantInt *CI = dyn_cast<ConstantInt>(*SimplifiedValue))
5208 return Str + std::to_string(CI->getSExtValue());
5209
5210 return Str + std::string("unknown");
5211 }
5212
5213 void initialize(Attributor &A) override {
5215 indicatePessimisticFixpoint();
5216
5217 Function *Callee = getAssociatedFunction();
5218
5219 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5220 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
5221 assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
5222 "Expected a known OpenMP runtime function");
5223
5224 RFKind = It->getSecond();
5225
5226 CallBase &CB = cast<CallBase>(getAssociatedValue());
5227 A.registerSimplificationCallback(
5229 [&](const IRPosition &IRP, const AbstractAttribute *AA,
5230 bool &UsedAssumedInformation) -> std::optional<Value *> {
5231 assert((isValidState() ||
5232 (SimplifiedValue && *SimplifiedValue == nullptr)) &&
5233 "Unexpected invalid state!");
5234
5235 if (!isAtFixpoint()) {
5236 UsedAssumedInformation = true;
5237 if (AA)
5238 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
5239 }
5240 return SimplifiedValue;
5241 });
5242 }
5243
5244 ChangeStatus updateImpl(Attributor &A) override {
5245 ChangeStatus Changed = ChangeStatus::UNCHANGED;
5246 switch (RFKind) {
5247 case OMPRTL___kmpc_is_spmd_exec_mode:
5248 Changed |= foldIsSPMDExecMode(A);
5249 break;
5250 case OMPRTL___kmpc_parallel_level:
5251 Changed |= foldParallelLevel(A);
5252 break;
5253 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
5254 Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
5255 break;
5256 case OMPRTL___kmpc_get_hardware_num_blocks:
5257 Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
5258 break;
5259 default:
5260 llvm_unreachable("Unhandled OpenMP runtime function!");
5261 }
5262
5263 return Changed;
5264 }
5265
5266 ChangeStatus manifest(Attributor &A) override {
5267 ChangeStatus Changed = ChangeStatus::UNCHANGED;
5268
5269 if (SimplifiedValue && *SimplifiedValue) {
5270 Instruction &I = *getCtxI();
5271 A.changeAfterManifest(IRPosition::inst(I), **SimplifiedValue);
5272 A.deleteAfterManifest(I);
5273
5274 CallBase *CB = dyn_cast<CallBase>(&I);
5275 auto Remark = [&](OptimizationRemark OR) {
5276 if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue))
5277 return OR << "Replacing OpenMP runtime call "
5278 << CB->getCalledFunction()->getName() << " with "
5279 << ore::NV("FoldedValue", C->getZExtValue()) << ".";
5280 return OR << "Replacing OpenMP runtime call "
5281 << CB->getCalledFunction()->getName() << ".";
5282 };
5283
5284 if (CB && EnableVerboseRemarks)
5285 A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark);
5286
5287 LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "
5288 << **SimplifiedValue << "\n");
5289
5290 Changed = ChangeStatus::CHANGED;
5291 }
5292
5293 return Changed;
5294 }
5295
5296 ChangeStatus indicatePessimisticFixpoint() override {
5297 SimplifiedValue = nullptr;
5298 return AAFoldRuntimeCall::indicatePessimisticFixpoint();
5299 }
5300
5301private:
5302 /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
5303 ChangeStatus foldIsSPMDExecMode(Attributor &A) {
5304 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5305
5306 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5307 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5308 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5309 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5310
5311 if (!CallerKernelInfoAA ||
5312 !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5313 return indicatePessimisticFixpoint();
5314
5315 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5316 auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5317 DepClassTy::REQUIRED);
5318
5319 if (!AA || !AA->isValidState()) {
5320 SimplifiedValue = nullptr;
5321 return indicatePessimisticFixpoint();
5322 }
5323
5324 if (AA->SPMDCompatibilityTracker.isAssumed()) {
5325 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5326 ++KnownSPMDCount;
5327 else
5328 ++AssumedSPMDCount;
5329 } else {
5330 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5331 ++KnownNonSPMDCount;
5332 else
5333 ++AssumedNonSPMDCount;
5334 }
5335 }
5336
5337 if ((AssumedSPMDCount + KnownSPMDCount) &&
5338 (AssumedNonSPMDCount + KnownNonSPMDCount))
5339 return indicatePessimisticFixpoint();
5340
5341 auto &Ctx = getAnchorValue().getContext();
5342 if (KnownSPMDCount || AssumedSPMDCount) {
5343 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5344 "Expected only SPMD kernels!");
5345 // All reaching kernels are in SPMD mode. Update all function calls to
5346 // __kmpc_is_spmd_exec_mode to 1.
5347 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
5348 } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
5349 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5350 "Expected only non-SPMD kernels!");
5351 // All reaching kernels are in non-SPMD mode. Update all function
5352 // calls to __kmpc_is_spmd_exec_mode to 0.
5353 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
5354 } else {
5355 // We have empty reaching kernels, therefore we cannot tell if the
5356 // associated call site can be folded. At this moment, SimplifiedValue
5357 // must be none.
5358 assert(!SimplifiedValue && "SimplifiedValue should be none");
5359 }
5360
5361 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5362 : ChangeStatus::CHANGED;
5363 }
5364
5365 /// Fold __kmpc_parallel_level into a constant if possible.
5366 ChangeStatus foldParallelLevel(Attributor &A) {
5367 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5368
5369 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5370 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5371
5372 if (!CallerKernelInfoAA ||
5373 !CallerKernelInfoAA->ParallelLevels.isValidState())
5374 return indicatePessimisticFixpoint();
5375
5376 if (!CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5377 return indicatePessimisticFixpoint();
5378
5379 if (CallerKernelInfoAA->ReachingKernelEntries.empty()) {
5380 assert(!SimplifiedValue &&
5381 "SimplifiedValue should keep none at this point");
5382 return ChangeStatus::UNCHANGED;
5383 }
5384
5385 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5386 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5387 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5388 auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5389 DepClassTy::REQUIRED);
5390 if (!AA || !AA->SPMDCompatibilityTracker.isValidState())
5391 return indicatePessimisticFixpoint();
5392
5393 if (AA->SPMDCompatibilityTracker.isAssumed()) {
5394 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5395 ++KnownSPMDCount;
5396 else
5397 ++AssumedSPMDCount;
5398 } else {
5399 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5400 ++KnownNonSPMDCount;
5401 else
5402 ++AssumedNonSPMDCount;
5403 }
5404 }
5405
5406 if ((AssumedSPMDCount + KnownSPMDCount) &&
5407 (AssumedNonSPMDCount + KnownNonSPMDCount))
5408 return indicatePessimisticFixpoint();
5409
5410 auto &Ctx = getAnchorValue().getContext();
5411 // If the caller can only be reached by SPMD kernel entries, the parallel
5412 // level is 1. Similarly, if the caller can only be reached by non-SPMD
5413 // kernel entries, it is 0.
5414 if (AssumedSPMDCount || KnownSPMDCount) {
5415 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5416 "Expected only SPMD kernels!");
5417 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
5418 } else {
5419 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5420 "Expected only non-SPMD kernels!");
5421 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
5422 }
5423 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5424 : ChangeStatus::CHANGED;
5425 }
5426
5427 ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
5428 // Specialize only if all the calls agree with the attribute constant value
5429 int32_t CurrentAttrValue = -1;
5430 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5431
5432 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5433 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5434
5435 if (!CallerKernelInfoAA ||
5436 !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5437 return indicatePessimisticFixpoint();
5438
5439 // Iterate over the kernels that reach this function
5440 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5441 int32_t NextAttrVal = K->getFnAttributeAsParsedInteger(Attr, -1);
5442
5443 if (NextAttrVal == -1 ||
5444 (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
5445 return indicatePessimisticFixpoint();
5446 CurrentAttrValue = NextAttrVal;
5447 }
5448
5449 if (CurrentAttrValue != -1) {
5450 auto &Ctx = getAnchorValue().getContext();
5451 SimplifiedValue =
5452 ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
5453 }
5454 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5455 : ChangeStatus::CHANGED;
5456 }
5457
5458 /// An optional value the associated value is assumed to fold to. That is, we
5459 /// assume the associated value (which is a call) can be replaced by this
5460 /// simplified value.
5461 std::optional<Value *> SimplifiedValue;
5462
5463 /// The runtime function kind of the callee of the associated call site.
5464 RuntimeFunction RFKind;
5465};
5466
5467} // namespace
5468
5469/// Register folding callsite
5470void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
5471 auto &RFI = OMPInfoCache.RFIs[RF];
5472 RFI.foreachUse(SCC, [&](Use &U, Function &F) {
5473 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
5474 if (!CI)
5475 return false;
5476 A.getOrCreateAAFor<AAFoldRuntimeCall>(
5477 IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
5478 DepClassTy::NONE, /* ForceUpdate */ false,
5479 /* UpdateAfterInit */ false);
5480 return false;
5481 });
5482}
5483
5484void OpenMPOpt::registerAAs(bool IsModulePass) {
5485 if (SCC.empty())
5486 return;
5487
5488 if (IsModulePass) {
5489 // Ensure we create the AAKernelInfo AAs first and without triggering an
5490 // update. This will make sure we register all value simplification
5491 // callbacks before any other AA has the chance to create an AAValueSimplify
5492 // or similar.
5493 auto CreateKernelInfoCB = [&](Use &, Function &Kernel) {
5494 A.getOrCreateAAFor<AAKernelInfo>(
5495 IRPosition::function(Kernel), /* QueryingAA */ nullptr,
5496 DepClassTy::NONE, /* ForceUpdate */ false,
5497 /* UpdateAfterInit */ false);
5498 return false;
5499 };
5500 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
5501 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
5502 InitRFI.foreachUse(SCC, CreateKernelInfoCB);
5503
5504 registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
5505 registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
5506 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
5507 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
5508 }
5509
5510 // Create CallSite AA for all Getters.
5511 if (DeduceICVValues) {
5512 for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
5513 auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
5514
5515 auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
5516
5517 auto CreateAA = [&](Use &U, Function &Caller) {
5518 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
5519 if (!CI)
5520 return false;
5521
5522 auto &CB = cast<CallBase>(*CI);
5523
5525 A.getOrCreateAAFor<AAICVTracker>(CBPos);
5526 return false;
5527 };
5528
5529 GetterRFI.foreachUse(SCC, CreateAA);
5530 }
5531 }
5532
5533 // Create an ExecutionDomain AA for every function and a HeapToStack AA for
5534 // every function if there is a device kernel.
5535 if (!isOpenMPDevice(M))
5536 return;
5537
5538 for (auto *F : SCC) {
5539 if (F->isDeclaration())
5540 continue;
5541
5542 // We look at internal functions only on-demand but if any use is not a
5543 // direct call or outside the current set of analyzed functions, we have
5544 // to do it eagerly.
5545 if (F->hasLocalLinkage()) {
5546 if (llvm::all_of(F->uses(), [this](const Use &U) {
5547 const auto *CB = dyn_cast<CallBase>(U.getUser());
5548 return CB && CB->isCallee(&U) &&
5549 A.isRunOn(const_cast<Function *>(CB->getCaller()));
5550 }))
5551 continue;
5552 }
5553 registerAAsForFunction(A, *F);
5554 }
5555}
5556
5557void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) {
5559 A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
5560 A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(F));
5562 A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(F));
5563 if (F.hasFnAttribute(Attribute::Convergent))
5564 A.getOrCreateAAFor<AANonConvergent>(IRPosition::function(F));
5565
5566 for (auto &I : instructions(F)) {
5567 if (auto *LI = dyn_cast<LoadInst>(&I)) {
5568 bool UsedAssumedInformation = false;
5569 A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
5570 UsedAssumedInformation, AA::Interprocedural);
5571 A.getOrCreateAAFor<AAAddressSpace>(
5572 IRPosition::value(*LI->getPointerOperand()));
5573 continue;
5574 }
5575 if (auto *CI = dyn_cast<CallBase>(&I)) {
5576 if (CI->isIndirectCall())
5577 A.getOrCreateAAFor<AAIndirectCallInfo>(
5579 }
5580 if (auto *SI = dyn_cast<StoreInst>(&I)) {
5581 A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI));
5582 A.getOrCreateAAFor<AAAddressSpace>(
5583 IRPosition::value(*SI->getPointerOperand()));
5584 continue;
5585 }
5586 if (auto *FI = dyn_cast<FenceInst>(&I)) {
5587 A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*FI));
5588 continue;
5589 }
5590 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
5591 if (II->getIntrinsicID() == Intrinsic::assume) {
5592 A.getOrCreateAAFor<AAPotentialValues>(
5593 IRPosition::value(*II->getArgOperand(0)));
5594 continue;
5595 }
5596 }
5597 }
5598}
5599
5600const char AAICVTracker::ID = 0;
5601const char AAKernelInfo::ID = 0;
5602const char AAExecutionDomain::ID = 0;
5603const char AAHeapToShared::ID = 0;
5604const char AAFoldRuntimeCall::ID = 0;
5605
5606AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
5607 Attributor &A) {
5608 AAICVTracker *AA = nullptr;
5609 switch (IRP.getPositionKind()) {
5614 llvm_unreachable("ICVTracker can only be created for function position!");
5616 AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
5617 break;
5619 AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
5620 break;
5622 AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
5623 break;
5625 AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
5626 break;
5627 }
5628
5629 return *AA;
5630}
5631
5633 Attributor &A) {
5634 AAExecutionDomainFunction *AA = nullptr;
5635 switch (IRP.getPositionKind()) {
5644 "AAExecutionDomain can only be created for function position!");
5646 AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
5647 break;
5648 }
5649
5650 return *AA;
5651}
5652
5653AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
5654 Attributor &A) {
5655 AAHeapToSharedFunction *AA = nullptr;
5656 switch (IRP.getPositionKind()) {
5665 "AAHeapToShared can only be created for function position!");
5667 AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
5668 break;
5669 }
5670
5671 return *AA;
5672}
5673
5674AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
5675 Attributor &A) {
5676 AAKernelInfo *AA = nullptr;
5677 switch (IRP.getPositionKind()) {
5684 llvm_unreachable("KernelInfo can only be created for function position!");
5686 AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
5687 break;
5689 AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
5690 break;
5691 }
5692
5693 return *AA;
5694}
5695
5696AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
5697 Attributor &A) {
5698 AAFoldRuntimeCall *AA = nullptr;
5699 switch (IRP.getPositionKind()) {
5707 llvm_unreachable("KernelInfo can only be created for call site position!");
5709 AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
5710 break;
5711 }
5712
5713 return *AA;
5714}
5715
5717 if (!containsOpenMP(M))
5718 return PreservedAnalyses::all();
5720 return PreservedAnalyses::all();
5721
5724 KernelSet Kernels = getDeviceKernels(M);
5725
5727 LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt Module Pass:\n" << M);
5728
5729 auto IsCalled = [&](Function &F) {
5730 if (Kernels.contains(&F))
5731 return true;
5732 for (const User *U : F.users())
5733 if (!isa<BlockAddress>(U))
5734 return true;
5735 return false;
5736 };
5737
5738 auto EmitRemark = [&](Function &F) {
5740 ORE.emit([&]() {
5741 OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
5742 return ORA << "Could not internalize function. "
5743 << "Some optimizations may not be possible. [OMP140]";
5744 });
5745 };
5746
5747 bool Changed = false;
5748
5749 // Create internal copies of each function if this is a kernel Module. This
5750 // allows iterprocedural passes to see every call edge.
5751 DenseMap<Function *, Function *> InternalizedMap;
5752 if (isOpenMPDevice(M)) {
5753 SmallPtrSet<Function *, 16> InternalizeFns;
5754 for (Function &F : M)
5755 if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
5758 InternalizeFns.insert(&F);
5759 } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
5760 EmitRemark(F);
5761 }
5762 }
5763
5764 Changed |=
5765 Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
5766 }
5767
5768 // Look at every function in the Module unless it was internalized.
5769 SetVector<Function *> Functions;
5771 for (Function &F : M)
5772 if (!F.isDeclaration() && !InternalizedMap.lookup(&F)) {
5773 SCC.push_back(&F);
5774 Functions.insert(&F);
5775 }
5776
5777 if (SCC.empty())
5778 return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
5779
5780 AnalysisGetter AG(FAM);
5781
5782 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5784 };
5785
5787 CallGraphUpdater CGUpdater;
5788
5789 bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5791 OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, PostLink);
5792
5793 unsigned MaxFixpointIterations =
5795
5796 AttributorConfig AC(CGUpdater);
5798 AC.IsModulePass = true;
5799 AC.RewriteSignatures = false;
5800 AC.MaxFixpointIterations = MaxFixpointIterations;
5801 AC.OREGetter = OREGetter;
5802 AC.PassName = DEBUG_TYPE;
5803 AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5804 AC.IPOAmendableCB = [](const Function &F) {
5805 return F.hasFnAttribute("kernel");
5806 };
5807
5808 Attributor A(Functions, InfoCache, AC);
5809
5810 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5811 Changed |= OMPOpt.run(true);
5812
5813 // Optionally inline device functions for potentially better performance.
5815 for (Function &F : M)
5816 if (!F.isDeclaration() && !Kernels.contains(&F) &&
5817 !F.hasFnAttribute(Attribute::NoInline))
5818 F.addFnAttr(Attribute::AlwaysInline);
5819
5821 LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M);
5822
5823 if (Changed)
5824 return PreservedAnalyses::none();
5825
5826 return PreservedAnalyses::all();
5827}
5828
5831 LazyCallGraph &CG,
5832 CGSCCUpdateResult &UR) {
5833 if (!containsOpenMP(*C.begin()->getFunction().getParent()))
5834 return PreservedAnalyses::all();
5836 return PreservedAnalyses::all();
5837
5839 // If there are kernels in the module, we have to run on all SCC's.
5840 for (LazyCallGraph::Node &N : C) {
5841 Function *Fn = &N.getFunction();
5842 SCC.push_back(Fn);
5843 }
5844
5845 if (SCC.empty())
5846 return PreservedAnalyses::all();
5847
5848 Module &M = *C.begin()->getFunction().getParent();
5849
5851 LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt CGSCC Pass:\n" << M);
5852
5853 KernelSet Kernels = getDeviceKernels(M);
5854
5856 AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
5857
5858 AnalysisGetter AG(FAM);
5859
5860 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5862 };
5863
5865 CallGraphUpdater CGUpdater;
5866 CGUpdater.initialize(CG, C, AM, UR);
5867
5868 bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5870 SetVector<Function *> Functions(SCC.begin(), SCC.end());
5871 OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
5872 /*CGSCC*/ &Functions, PostLink);
5873
5874 unsigned MaxFixpointIterations =
5876
5877 AttributorConfig AC(CGUpdater);
5879 AC.IsModulePass = false;
5880 AC.RewriteSignatures = false;
5881 AC.MaxFixpointIterations = MaxFixpointIterations;
5882 AC.OREGetter = OREGetter;
5883 AC.PassName = DEBUG_TYPE;
5884 AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5885
5886 Attributor A(Functions, InfoCache, AC);
5887
5888 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5889 bool Changed = OMPOpt.run(false);
5890
5892 LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
5893
5894 if (Changed)
5895 return PreservedAnalyses::none();
5896
5897 return PreservedAnalyses::all();
5898}
5899
5901 return Fn.hasFnAttribute("kernel");
5902}
5903
5905 // TODO: Create a more cross-platform way of determining device kernels.
5906 NamedMDNode *MD = M.getNamedMetadata("nvvm.annotations");
5907 KernelSet Kernels;
5908
5909 if (!MD)
5910 return Kernels;
5911
5912 for (auto *Op : MD->operands()) {
5913 if (Op->getNumOperands() < 2)
5914 continue;
5915 MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
5916 if (!KindID || KindID->getString() != "kernel")
5917 continue;
5918
5919 Function *KernelFn =
5920 mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
5921 if (!KernelFn)
5922 continue;
5923
5924 // We are only interested in OpenMP target regions. Others, such as kernels
5925 // generated by CUDA but linked together, are not interesting to this pass.
5926 if (isOpenMPKernel(*KernelFn)) {
5927 ++NumOpenMPTargetRegionKernels;
5928 Kernels.insert(KernelFn);
5929 } else
5930 ++NumNonOpenMPTargetRegionKernels;
5931 }
5932
5933 return Kernels;
5934}
5935
5937 Metadata *MD = M.getModuleFlag("openmp");
5938 if (!MD)
5939 return false;
5940
5941 return true;
5942}
5943
5945 Metadata *MD = M.getModuleFlag("openmp-device");
5946 if (!MD)
5947 return false;
5948
5949 return true;
5950}
@ Generic
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
static cl::opt< unsigned > SetFixpointIterations("attributor-max-iterations", cl::Hidden, cl::desc("Maximal number of fixpoint iterations."), cl::init(32))
static const Function * getParent(const Value *V)
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
This file provides interfaces used to manipulate a call graph, regardless if it is a "old style" Call...
This file provides interfaces used to build and manipulate a call graph, which is a very useful tool ...
This file contains the declarations for the subclasses of Constant, which represent the different fla...
dxil pretty DXIL Metadata Pretty Printer
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(X)
Definition: Debug.h:101
bool End
Definition: ELF_riscv.cpp:480
This file defines an array type that can be indexed using scoped enum values.
#define DEBUG_TYPE
static void emitRemark(const Function &F, OptimizationRemarkEmitter &ORE, bool Skip)
static bool lookup(const GsymReader &GR, DataExtractor &Data, uint64_t &Offset, uint64_t BaseAddr, uint64_t Addr, SourceLocations &SrcLocs, llvm::Error &Err)
A Lookup helper functions.
Definition: InlineInfo.cpp:109
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
This file provides utility analysis objects describing memory locations.
uint64_t IntrinsicInst * II
This file defines constans and helpers used when dealing with OpenMP.
This file defines constans that will be used by both host and device compilation.
static constexpr auto TAG
Definition: OpenMPOpt.cpp:184
static cl::opt< bool > HideMemoryTransferLatency("openmp-hide-memory-transfer-latency", cl::desc("[WIP] Tries to hide the latency of host to device memory" " transfers"), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptStateMachineRewrite("openmp-opt-disable-state-machine-rewrite", cl::desc("Disable OpenMP optimizations that replace the state machine."), cl::Hidden, cl::init(false))
static cl::opt< bool > EnableParallelRegionMerging("openmp-opt-enable-merging", cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden, cl::init(false))
static cl::opt< bool > AlwaysInlineDeviceFunctions("openmp-opt-inline-device", cl::desc("Inline all applicible functions on the device."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintModuleAfterOptimizations("openmp-opt-print-module-after", cl::desc("Print the current module after OpenMP optimizations."), cl::Hidden, cl::init(false))
#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER)
Definition: OpenMPOpt.cpp:240
#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX)
Definition: OpenMPOpt.cpp:217
#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER)
Definition: OpenMPOpt.cpp:3674
static cl::opt< bool > PrintOpenMPKernels("openmp-print-gpu-kernels", cl::init(false), cl::Hidden)
static cl::opt< bool > DisableOpenMPOptFolding("openmp-opt-disable-folding", cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintModuleBeforeOptimizations("openmp-opt-print-module-before", cl::desc("Print the current module before OpenMP optimizations."), cl::Hidden, cl::init(false))
static cl::opt< unsigned > SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden, cl::desc("Maximal number of attributor iterations."), cl::init(256))
static cl::opt< bool > DisableInternalization("openmp-opt-disable-internalization", cl::desc("Disable function internalization."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintICVValues("openmp-print-icv-values", cl::init(false), cl::Hidden)
static cl::opt< bool > DisableOpenMPOptimizations("openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."), cl::Hidden, cl::init(false))
static cl::opt< unsigned > SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden, cl::desc("Maximum amount of shared memory to use."), cl::init(std::numeric_limits< unsigned >::max()))
static cl::opt< bool > EnableVerboseRemarks("openmp-opt-verbose-remarks", cl::desc("Enables more verbose remarks."), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptDeglobalization("openmp-opt-disable-deglobalization", cl::desc("Disable OpenMP optimizations involving deglobalization."), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptBarrierElimination("openmp-opt-disable-barrier-elimination", cl::desc("Disable OpenMP optimizations that eliminate barriers."), cl::Hidden, cl::init(false))
#define DEBUG_TYPE
Definition: OpenMPOpt.cpp:66
static cl::opt< bool > DeduceICVValues("openmp-deduce-icv-values", cl::init(false), cl::Hidden)
#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX)
Definition: OpenMPOpt.cpp:209
#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE)
Definition: OpenMPOpt.cpp:230
static cl::opt< bool > DisableOpenMPOptSPMDization("openmp-opt-disable-spmdization", cl::desc("Disable OpenMP optimizations involving SPMD-ization."), cl::Hidden, cl::init(false))
#define P(N)
FunctionAnalysisManager FAM
if(PassOpts->AAPipeline)
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
static StringRef getName(Value *V)
R600 Clause Merge
Basic Register Allocator
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file implements a set that has insertion order iteration characteristics.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:166
This file contains some functions that are useful when dealing with strings.
static const int BlockSize
Definition: TarWriter.cpp:33
static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, ArrayRef< StringLiteral > StandardNames)
Initialize the set of available library functions based on the specified target triple.
static bool contains(SmallPtrSetImpl< ConstantExpr * > &Cache, ConstantExpr *Expr, Constant *C)
Definition: Value.cpp:469
Value * RHS
static cl::opt< unsigned > MaxThreads("xcore-max-threads", cl::Optional, cl::desc("Maximum number of threads (for emulation thread-local storage)"), cl::Hidden, cl::value_desc("number"), cl::init(8))
AbstractCallSite.
This class represents a conversion between pointers from one address space to another.
an instruction to allocate memory on the stack
Definition: Instructions.h:61
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:405
This class represents an incoming formal argument to a Function.
Definition: Argument.h:31
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
size_t size() const
size - Get the array size.
Definition: ArrayRef.h:165
AttributeSet getParamAttrs(unsigned ArgNo) const
The attributes for the argument or parameter at the given index are returned.
LLVM Basic Block Representation.
Definition: BasicBlock.h:61
iterator end()
Definition: BasicBlock.h:461
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:448
const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
Definition: BasicBlock.cpp:416
reverse_iterator rbegin()
Definition: BasicBlock.h:464
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition: BasicBlock.h:212
BasicBlock * splitBasicBlock(iterator I, const Twine &BBName="", bool Before=false)
Split the basic block into two basic blocks at the specified instruction.
Definition: BasicBlock.cpp:577
const BasicBlock * getUniqueSuccessor() const
Return the successor of this block if it has a unique successor.
Definition: BasicBlock.cpp:497
InstListType::reverse_iterator reverse_iterator
Definition: BasicBlock.h:179
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:219
reverse_iterator rend()
Definition: BasicBlock.h:466
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:239
Conditional or Unconditional Branch instruction.
bool isConditional() const
static BranchInst * Create(BasicBlock *IfTrue, InsertPosition InsertBefore=nullptr)
BasicBlock * getSuccessor(unsigned i) const
Value * getCondition() const
Allocate memory in an ever growing pool, as if by bump-pointer.
Definition: Allocator.h:66
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
Definition: InstrTypes.h:1236
void setCallingConv(CallingConv::ID CC)
Definition: InstrTypes.h:1527
bool arg_empty() const
Definition: InstrTypes.h:1407
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1465
bool doesNotAccessMemory(unsigned OpNo) const
Definition: InstrTypes.h:1810
bool isIndirectCall() const
Return true if the callsite is an indirect call.
bool isCallee(Value::const_user_iterator UI) const
Determine whether the passed iterator points to the callee operand's Use.
Definition: InstrTypes.h:1476
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1410
void setArgOperand(unsigned i, Value *v)
Definition: InstrTypes.h:1415
iterator_range< User::op_iterator > args()
Iteration adapter for range-for loops.
Definition: InstrTypes.h:1401
unsigned getArgOperandNo(const Use *U) const
Given a use for a arg operand, get the arg operand number that corresponds to it.
Definition: InstrTypes.h:1441
unsigned arg_size() const
Definition: InstrTypes.h:1408
AttributeList getAttributes() const
Return the parameter attributes for this call.
Definition: InstrTypes.h:1542
void addParamAttr(unsigned ArgNo, Attribute::AttrKind Kind)
Adds the attribute to the indicated argument.
Definition: InstrTypes.h:1594
bool isArgOperand(const Use *U) const
Definition: InstrTypes.h:1430
bool hasOperandBundles() const
Return true if this User has any operand bundles.
Definition: InstrTypes.h:2061
Function * getCaller()
Helper to get the caller (the parent function).
Wrapper to unify "old style" CallGraph and "new style" LazyCallGraph.
void initialize(LazyCallGraph &LCG, LazyCallGraph::SCC &SCC, CGSCCAnalysisManager &AM, CGSCCUpdateResult &UR)
Initializers for usage outside of a CGSCC pass, inside a CGSCC pass in the old and new pass manager (...
This class represents a function call, abstracting a target machine's calling convention.
static CallInst * Create(FunctionType *Ty, Value *F, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
@ ICMP_SLT
signed less than
Definition: InstrTypes.h:786
@ ICMP_EQ
equal
Definition: InstrTypes.h:778
@ ICMP_NE
not equal
Definition: InstrTypes.h:779
static Constant * getPointerCast(Constant *C, Type *Ty)
Create a BitCast, AddrSpaceCast, or a PtrToInt cast constant expression.
Definition: Constants.cpp:2227
static Constant * getPointerBitCastOrAddrSpaceCast(Constant *C, Type *Ty)
Create a BitCast or AddrSpaceCast for a pointer type depending on the address space.
Definition: Constants.cpp:2242
This is the shared class of boolean and integer constants.
Definition: Constants.h:81
IntegerType * getIntegerType() const
Variant of the getType() method to always return an IntegerType, which reduces the amount of casting ...
Definition: Constants.h:185
static ConstantInt * getTrue(LLVMContext &Context)
Definition: Constants.cpp:850
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition: Constants.h:206
int64_t getSExtValue() const
Return the constant as a 64-bit integer value after it has been sign extended as appropriate for the ...
Definition: Constants.h:161
This is an important base class in LLVM.
Definition: Constant.h:42
static Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
Definition: Constants.cpp:370
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:63
A debug info location.
Definition: DebugLoc.h:33
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition: DenseMap.h:194
iterator begin()
Definition: DenseMap.h:75
iterator end()
Definition: DenseMap.h:84
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:211
Implements a dense probed hash-table based set.
Definition: DenseSet.h:271
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:279
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:162
Instruction * findNearestCommonDominator(Instruction *I1, Instruction *I2) const
Find the nearest instruction I that dominates both I1 and I2, in the sense that a result produced bef...
Definition: Dominators.cpp:344
An instruction for ordering other memory operations.
Definition: Instructions.h:420
AtomicOrdering getOrdering() const
Returns the ordering constraint of this fence instruction.
Definition: Instructions.h:443
A proxy from a FunctionAnalysisManager to an SCC.
A handy container for a FunctionType+Callee-pointer pair, which can be passed around as a single enti...
Definition: DerivedTypes.h:168
const BasicBlock & getEntryBlock() const
Definition: Function.h:807
const BasicBlock & front() const
Definition: Function.h:858
LLVMContext & getContext() const
getContext - Return a reference to the LLVMContext associated with this function.
Definition: Function.cpp:380
Argument * getArg(unsigned i) const
Definition: Function.h:884
bool hasFnAttribute(Attribute::AttrKind Kind) const
Return true if the function has the attribute.
Definition: Function.cpp:743
bool isDeclaration() const
Return true if the primary definition of this global value is outside of the current translation unit...
Definition: Globals.cpp:290
bool hasLocalLinkage() const
Definition: GlobalValue.h:528
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:656
PointerType * getType() const
Global values are always pointers.
Definition: GlobalValue.h:294
@ PrivateLinkage
Like Internal, but omit from symbol table.
Definition: GlobalValue.h:60
@ InternalLinkage
Rename collisions when linking (static functions).
Definition: GlobalValue.h:59
const Constant * getInitializer() const
getInitializer - Return the initializer for this global variable.
void setInitializer(Constant *InitVal)
setInitializer - Sets the initializer for this global variable, removing any existing initializer if ...
Definition: Globals.cpp:485
void eraseFromParent()
eraseFromParent - This method unlinks 'this' from the containing module and deletes it.
Definition: Globals.cpp:481
InsertPoint - A saved insertion point.
Definition: IRBuilder.h:254
AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")
Definition: IRBuilder.h:1790
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition: IRBuilder.h:177
Value * CreateAddrSpaceCast(Value *V, Type *DestTy, const Twine &Name="")
Definition: IRBuilder.h:2152
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2686
An analysis over an "outer" IR unit that provides access to an analysis manager over an "inner" IR un...
Definition: PassManager.h:563
bool isLifetimeStartOrEnd() const LLVM_READONLY
Return true if the instruction is a llvm.lifetime.start or llvm.lifetime.end marker.
bool mayWriteToMemory() const LLVM_READONLY
Return true if this instruction may modify memory.
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
Definition: Instruction.h:466
const Module * getModule() const
Return the module owning the function this instruction belongs to or nullptr it the function does not...
Definition: Instruction.cpp:66
const Instruction * getPrevNonDebugInstruction(bool SkipPseudoOp=false) const
Return a pointer to the previous non-debug instruction in the same basic block as 'this',...
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:92
const Function * getFunction() const
Return the function this instruction belongs to.
Definition: Instruction.cpp:70
bool mayHaveSideEffects() const LLVM_READONLY
Return true if the instruction may have side effects.
bool mayReadFromMemory() const LLVM_READONLY
Return true if this instruction may read memory.
const Instruction * getNextNonDebugInstruction(bool SkipPseudoOp=false) const
Return a pointer to the next non-debug instruction in the same basic block as 'this',...
void setDebugLoc(DebugLoc Loc)
Set the debug location information for this instruction.
Definition: Instruction.h:463
void setSuccessor(unsigned Idx, BasicBlock *BB)
Update the specified successor to point at the provided block.
void moveBefore(Instruction *MovePos)
Unlink this instruction from its current basic block and insert it into the basic block that MovePos ...
A node in the call graph.
An SCC of the call graph.
A lazily constructed view of the call graph of a module.
An instruction for reading from memory.
Definition: Instructions.h:174
A single uniqued string.
Definition: Metadata.h:720
StringRef getString() const
Definition: Metadata.cpp:616
void eraseFromParent()
This method unlinks 'this' from the containing function and deletes it.
Root of the metadata hierarchy.
Definition: Metadata.h:62
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
const std::string & getTargetTriple() const
Get the target triple which is a string describing the target host.
Definition: Module.h:295
MutableArrayRef - Represent a mutable reference to an array (0 or more elements consecutively in memo...
Definition: ArrayRef.h:307
A tuple of MDNodes.
Definition: Metadata.h:1730
iterator_range< op_iterator > operands()
Definition: Metadata.h:1826
An interface to create LLVM-IR for OpenMP directives.
Definition: OMPIRBuilder.h:473
static std::pair< int32_t, int32_t > readThreadBoundsForKernel(const Triple &T, Function &Kernel)
}
void addAttributes(omp::RuntimeFunction FnID, Function &Fn)
Add attributes known for FnID to Fn.
IRBuilder<>::InsertPoint InsertPointTy
Type used throughout for insertion points.
Definition: OMPIRBuilder.h:499
static std::pair< int32_t, int32_t > readTeamBoundsForKernel(const Triple &T, Function &Kernel)
Read/write a bounds on teams for Kernel.
PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR)
Definition: OpenMPOpt.cpp:5829
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM)
Definition: OpenMPOpt.cpp:5716
Diagnostic information for optimization analysis remarks.
The optimization diagnostic interface.
Diagnostic information for missed-optimization remarks.
Diagnostic information for applied optimization remarks.
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1852
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Definition: Analysis.h:114
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:117
static ReturnInst * Create(LLVMContext &C, Value *retVal=nullptr, InsertPosition InsertBefore=nullptr)
A vector that has set insertion semantics.
Definition: SetVector.h:57
size_type size() const
Determine the number of elements in the SetVector.
Definition: SetVector.h:98
const value_type & back() const
Return the last element of the SetVector.
Definition: SetVector.h:149
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
Definition: SetVector.h:264
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:162
size_type size() const
Definition: SmallPtrSet.h:95
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:346
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
Definition: SmallPtrSet.h:435
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:367
iterator begin() const
Definition: SmallPtrSet.h:455
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:502
A SetVector that performs no allocations if smaller than a certain size.
Definition: SetVector.h:370
bool empty() const
Definition: SmallVector.h:94
size_t size() const
Definition: SmallVector.h:91
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:586
void assign(size_type NumElts, ValueParamT Elt)
Definition: SmallVector.h:717
reference emplace_back(ArgTypes &&... Args)
Definition: SmallVector.h:950
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:696
void push_back(const T &Elt)
Definition: SmallVector.h:426
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1209
An instruction for storing to memory.
Definition: Instructions.h:290
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
bool starts_with(StringRef Prefix) const
Check if this string starts with the given Prefix.
Definition: StringRef.h:250
Triple - Helper class for working with autoconf configuration names.
Definition: Triple.h:44
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
PointerType * getPointerTo(unsigned AddrSpace=0) const
Return a pointer to the current type.
unsigned getPointerAddressSpace() const
Get the address space of this pointer or pointer vector type.
static Type * getVoidTy(LLVMContext &C)
static IntegerType * getInt16Ty(LLVMContext &C)
static IntegerType * getInt8Ty(LLVMContext &C)
static IntegerType * getInt32Ty(LLVMContext &C)
static UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
Definition: Constants.cpp:1833
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Definition: User.cpp:21
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
void setName(const Twine &Name)
Change the name of the value.
Definition: Value.cpp:377
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:534
iterator_range< user_iterator > users()
Definition: Value.h:421
User * user_back()
Definition: Value.h:407
const Value * stripPointerCasts() const
Strip off pointer casts, all-zero GEPs and address space casts.
Definition: Value.cpp:694
unsigned getNumUses() const
This method computes the number of uses of this Value.
Definition: Value.cpp:255
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
Definition: ilist_node.h:32
self_iterator getIterator()
Definition: ilist_node.h:132
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition: ilist_node.h:353
A raw_ostream that writes to an std::string.
Definition: raw_ostream.h:661
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
GlobalVariable * getKernelEnvironementGVFromKernelInitCB(CallBase *KernelInitCB)
Definition: OpenMPOpt.cpp:259
ConstantStruct * getKernelEnvironementFromKernelInitCB(CallBase *KernelInitCB)
Definition: OpenMPOpt.cpp:266
bool isValidAtPosition(const ValueAndContext &VAC, InformationCache &InfoCache)
Return true if the value of VAC is a valid at the position of VAC, that is a constant,...
Definition: Attributor.cpp:291
bool isPotentiallyAffectedByBarrier(Attributor &A, const Instruction &I, const AbstractAttribute &QueryingAA)
Return true if I is potentially affected by a barrier.
Definition: Attributor.cpp:890
@ Interprocedural
Definition: Attributor.h:182
bool isNoSyncInst(Attributor &A, const Instruction &I, const AbstractAttribute &QueryingAA)
Return true if I is a nosync instruction.
Definition: Attributor.cpp:206
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
E & operator^=(E &LHS, E RHS)
Definition: BitmaskEnum.h:181
@ Entry
Definition: COFF.h:826
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:443
std::optional< const char * > toString(const std::optional< DWARFFormValue > &V)
Take an optional DWARFFormValue and try to extract a string value from it.
PointerTypeMap run(const Module &M)
Compute the PointerTypeMap for the module M.
bool isOpenMPDevice(Module &M)
Helper to determine if M is a OpenMP target offloading device module.
Definition: OpenMPOpt.cpp:5944
bool containsOpenMP(Module &M)
Helper to determine if M contains OpenMP.
Definition: OpenMPOpt.cpp:5936
InternalControlVar
IDs for all Internal Control Variables (ICVs).
Definition: OMPConstants.h:26
RuntimeFunction
IDs for all omp runtime library (RTL) functions.
Definition: OMPConstants.h:45
KernelSet getDeviceKernels(Module &M)
Get OpenMP device kernels in M.
Definition: OpenMPOpt.cpp:5904
Function * Kernel
Summary of a kernel (=entry point for target offloading).
Definition: OpenMPOpt.h:21
@ OMP_TGT_EXEC_MODE_GENERIC_SPMD
bool isOpenMPKernel(Function &Fn)
Return true iff Fn is an OpenMP GPU kernel; Fn has the "kernel" attribute.
Definition: OpenMPOpt.cpp:5900
DiagnosticInfoOptimizationBase::Argument NV
const_iterator begin(StringRef path, Style style=Style::native)
Get begin iterator over path.
Definition: Path.cpp:227
const_iterator end(StringRef path)
Get end iterator over path.
Definition: Path.cpp:236
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition: STLExtras.h:329
@ Offset
Definition: DWP.cpp:480
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1722
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition: STLExtras.h:1680
bool succ_empty(const Instruction *I)
Definition: CFG.h:255
std::string to_string(const T &Value)
Definition: ScopedPrinter.h:85
bool operator!=(uint64_t V1, const APInt &V2)
Definition: APInt.h:2060
Value * GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, const DataLayout &DL, bool AllowNonInbounds=true)
Analyze the specified pointer to see if it can be expressed as a base pointer plus a constant offset.
const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)
This method strips off any GEP address adjustments, pointer casts or llvm.threadlocal....
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
@ CGSCC
Definition: Attributor.h:6420
@ FullLTOPostLink
Full LTO postlink (backend compile) phase.
@ ThinLTOPreLink
ThinLTO prelink (summary) phase.
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
bool operator&=(SparseBitVector< ElementSize > *LHS, const SparseBitVector< ElementSize > &RHS)
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition: STLExtras.h:1921
BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
auto predecessors(const MachineBasicBlock *BB)
ChangeStatus
{
Definition: Attributor.h:484
Constant * ConstantFoldInsertValueInstruction(Constant *Agg, Constant *Val, ArrayRef< unsigned > Idxs)
ConstantFoldInsertValueInstruction - Attempt to constant fold an insertvalue instruction with the spe...
Implement std::hash so that hash_code can be used in STL containers.
Definition: BitVector.h:858
#define N
An abstract interface for address space information.
Definition: Attributor.h:6229
An abstract attribute for getting assumption information.
Definition: Attributor.h:6155
An abstract state for querying live call edges.
Definition: Attributor.h:5479
virtual const SetVector< Function * > & getOptimisticEdges() const =0
Get the optimistic edges.
bool isExecutedByInitialThreadOnly(const Instruction &I) const
Check if an instruction is executed only by the initial thread.
Definition: Attributor.h:5639
static AAExecutionDomain & createForPosition(const IRPosition &IRP, Attributor &A)
Create an abstract attribute view for the position IRP.
Definition: OpenMPOpt.cpp:5632
virtual ExecutionDomainTy getFunctionExecutionDomain() const =0
virtual ExecutionDomainTy getExecutionDomain(const BasicBlock &) const =0
virtual bool isExecutedInAlignedRegion(Attributor &A, const Instruction &I) const =0
Check if the instruction I is executed in an aligned region, that is, the synchronizing effects befor...
virtual bool isNoOpFence(const FenceInst &FI) const =0
Helper function to determine if FI is a no-op given the information about its execution from ExecDoma...
static const char ID
Unique ID (due to the unique address)
Definition: Attributor.h:5670
An abstract interface for indirect call information interference.
Definition: Attributor.h:6347
An abstract interface for liveness abstract attribute.
Definition: Attributor.h:3968
An abstract interface for all memory location attributes (readnone/argmemonly/inaccessiblememonly/ina...
Definition: Attributor.h:4697
AccessKind
Simple enum to distinguish read/write/read-write accesses.
Definition: Attributor.h:4833
StateType::base_t MemoryLocationsKind
Definition: Attributor.h:4698
static bool isAlignedBarrier(const CallBase &CB, bool ExecutedAligned)
Helper function to determine if CB is an aligned (GPU) barrier.
An abstract Attribute for determining the necessity of the convergent attribute.
Definition: Attributor.h:5715
An abstract attribute for getting all assumption underlying objects.
Definition: Attributor.h:6187
Base struct for all "concrete attribute" deductions.
Definition: Attributor.h:3276
virtual ChangeStatus manifest(Attributor &A)
Hook for the Attributor to trigger the manifestation of the information represented by the abstract a...
Definition: Attributor.h:3391
virtual void initialize(Attributor &A)
Initialize the state with the information in the Attributor A.
Definition: Attributor.h:3340
virtual const std::string getAsStr(Attributor *A) const =0
This function should return the "summarized" assumed state as string.
virtual ChangeStatus updateImpl(Attributor &A)=0
The actual update/transfer function which has to be implemented by the derived classes.
virtual void trackStatistics() const =0
Hook to enable custom statistic tracking, called after manifest that resulted in a change if statisti...
virtual const char * getIdAddr() const =0
This function should return the address of the ID of the AbstractAttribute.
An interface to query the internal state of an abstract attribute.
Definition: Attributor.h:2596
virtual ChangeStatus indicatePessimisticFixpoint()=0
Indicate that the abstract state should converge to the pessimistic state.
virtual bool isAtFixpoint() const =0
Return if this abstract state is fixed, thus does not need to be updated if information changes as it...
virtual bool isValidState() const =0
Return if this abstract state is in a valid state.
virtual ChangeStatus indicateOptimisticFixpoint()=0
Indicate that the abstract state should converge to the optimistic state.
Wrapper for FunctionAnalysisManager.
Definition: Attributor.h:1122
Configuration for the Attributor.
Definition: Attributor.h:1414
std::function< void(Attributor &A, const Function &F)> InitializationCallback
Callback function to be invoked on internal functions marked live.
Definition: Attributor.h:1445
std::optional< unsigned > MaxFixpointIterations
Maximum number of iterations to run until fixpoint.
Definition: Attributor.h:1461
bool RewriteSignatures
Flag to determine if we rewrite function signatures.
Definition: Attributor.h:1431
const char * PassName
}
Definition: Attributor.h:1471
OptimizationRemarkGetter OREGetter
Definition: Attributor.h:1467
IPOAmendableCBTy IPOAmendableCB
Definition: Attributor.h:1474
bool IsModulePass
Is the user of the Attributor a module pass or not.
Definition: Attributor.h:1425
bool DefaultInitializeLiveInternals
Flag to determine if we want to initialize all default AAs for an internal function marked live.
Definition: Attributor.h:1435
The fixpoint analysis framework that orchestrates the attribute deduction.
Definition: Attributor.h:1508
static bool isInternalizable(Function &F)
Returns true if the function F can be internalized.
std::function< std::optional< Constant * >(const GlobalVariable &, const AbstractAttribute *, bool &)> GlobalVariableSimplifictionCallbackTy
Register CB as a simplification callback.
Definition: Attributor.h:2025
std::function< bool(Attributor &, const AbstractAttribute *)> VirtualUseCallbackTy
Definition: Attributor.h:2055
static bool internalizeFunctions(SmallPtrSetImpl< Function * > &FnSet, DenseMap< Function *, Function * > &FnMap)
Make copies of each function in the set FnSet such that the copied version has internal linkage after...
std::function< std::optional< Value * >(const IRPosition &, const AbstractAttribute *, bool &)> SimplifictionCallbackTy
Register CB as a simplification callback.
Definition: Attributor.h:2009
Simple wrapper for a single bit (boolean) state.
Definition: Attributor.h:2880
Support structure for SCC passes to communicate updates the call graph back to the CGSCC pass manager...
Helper to describe and deal with positions in the LLVM-IR.
Definition: Attributor.h:581
static const IRPosition callsite_returned(const CallBase &CB)
Create a position describing the returned value of CB.
Definition: Attributor.h:649
static const IRPosition returned(const Function &F, const CallBaseContext *CBContext=nullptr)
Create a position describing the returned value of F.
Definition: Attributor.h:631
static const IRPosition value(const Value &V, const CallBaseContext *CBContext=nullptr)
Create a position describing the value of V.
Definition: Attributor.h:605
static const IRPosition inst(const Instruction &I, const CallBaseContext *CBContext=nullptr)
Create a position describing the instruction I.
Definition: Attributor.h:617
@ IRP_ARGUMENT
An attribute for a function argument.
Definition: Attributor.h:595
@ IRP_RETURNED
An attribute for the function return value.
Definition: Attributor.h:591
@ IRP_CALL_SITE
An attribute for a call site (function scope).
Definition: Attributor.h:594
@ IRP_CALL_SITE_RETURNED
An attribute for a call site return value.
Definition: Attributor.h:592
@ IRP_FUNCTION
An attribute for a function (scope).
Definition: Attributor.h:593
@ IRP_FLOAT
A position that is not associated with a spot suitable for attributes.
Definition: Attributor.h:589
@ IRP_CALL_SITE_ARGUMENT
An attribute for a call site argument.
Definition: Attributor.h:596
@ IRP_INVALID
An invalid position.
Definition: Attributor.h:588
static const IRPosition function(const Function &F, const CallBaseContext *CBContext=nullptr)
Create a position describing the function scope of F.
Definition: Attributor.h:624
Kind getPositionKind() const
Return the associated position kind.
Definition: Attributor.h:877
static const IRPosition callsite_function(const CallBase &CB)
Create a position describing the function scope of CB.
Definition: Attributor.h:644
Function * getAnchorScope() const
Return the Function surrounding the anchor value.
Definition: Attributor.h:753
Data structure to hold cached (LLVM-IR) information.
Definition: Attributor.h:1198
bool isValidState() const override
See AbstractState::isValidState() NOTE: For now we simply pretend that the worst possible state is in...
Definition: Attributor.h:2655
ChangeStatus indicatePessimisticFixpoint() override
See AbstractState::indicatePessimisticFixpoint(...)
Definition: Attributor.h:2667
Direction
An enum for the direction of the loop.
Definition: LoopInfo.h:215
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Definition: Alignment.h:117
Description of a LLVM-IR insertion point (IP) and a debug/source location (filename,...
Definition: OMPIRBuilder.h:615
Helper to tie a abstract state implementation to an abstract attribute.
Definition: Attributor.h:3165
StateType & getState() override
See AbstractAttribute::getState(...).
Definition: Attributor.h:3173
Defines various target-specific GPU grid values that must be consistent between host RTL (plugin),...
Definition: OMPGridValues.h:57