LLVM 20.0.0git
OpenMPOpt.cpp
Go to the documentation of this file.
1//===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// OpenMP specific optimizations:
10//
11// - Deduplication of runtime calls, e.g., omp_get_thread_num.
12// - Replacing globalized device memory with stack memory.
13// - Replacing globalized device memory with shared memory.
14// - Parallel region merging.
15// - Transforming generic-mode device kernels to SPMD mode.
16// - Specializing the state machine for generic-mode device kernels.
17//
18//===----------------------------------------------------------------------===//
19
21
24#include "llvm/ADT/SetVector.h"
27#include "llvm/ADT/Statistic.h"
29#include "llvm/ADT/StringRef.h"
37#include "llvm/IR/Assumptions.h"
38#include "llvm/IR/BasicBlock.h"
39#include "llvm/IR/Constants.h"
41#include "llvm/IR/Dominators.h"
42#include "llvm/IR/Function.h"
43#include "llvm/IR/GlobalValue.h"
45#include "llvm/IR/InstrTypes.h"
46#include "llvm/IR/Instruction.h"
49#include "llvm/IR/IntrinsicsAMDGPU.h"
50#include "llvm/IR/IntrinsicsNVPTX.h"
51#include "llvm/IR/LLVMContext.h"
54#include "llvm/Support/Debug.h"
58
59#include <algorithm>
60#include <optional>
61#include <string>
62
63using namespace llvm;
64using namespace omp;
65
66#define DEBUG_TYPE "openmp-opt"
67
69 "openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."),
70 cl::Hidden, cl::init(false));
71
73 "openmp-opt-enable-merging",
74 cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
75 cl::init(false));
76
77static cl::opt<bool>
78 DisableInternalization("openmp-opt-disable-internalization",
79 cl::desc("Disable function internalization."),
80 cl::Hidden, cl::init(false));
81
82static cl::opt<bool> DeduceICVValues("openmp-deduce-icv-values",
83 cl::init(false), cl::Hidden);
84static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
86static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
87 cl::init(false), cl::Hidden);
88
90 "openmp-hide-memory-transfer-latency",
91 cl::desc("[WIP] Tries to hide the latency of host to device memory"
92 " transfers"),
93 cl::Hidden, cl::init(false));
94
96 "openmp-opt-disable-deglobalization",
97 cl::desc("Disable OpenMP optimizations involving deglobalization."),
98 cl::Hidden, cl::init(false));
99
101 "openmp-opt-disable-spmdization",
102 cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
103 cl::Hidden, cl::init(false));
104
106 "openmp-opt-disable-folding",
107 cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
108 cl::init(false));
109
111 "openmp-opt-disable-state-machine-rewrite",
112 cl::desc("Disable OpenMP optimizations that replace the state machine."),
113 cl::Hidden, cl::init(false));
114
116 "openmp-opt-disable-barrier-elimination",
117 cl::desc("Disable OpenMP optimizations that eliminate barriers."),
118 cl::Hidden, cl::init(false));
119
121 "openmp-opt-print-module-after",
122 cl::desc("Print the current module after OpenMP optimizations."),
123 cl::Hidden, cl::init(false));
124
126 "openmp-opt-print-module-before",
127 cl::desc("Print the current module before OpenMP optimizations."),
128 cl::Hidden, cl::init(false));
129
131 "openmp-opt-inline-device",
132 cl::desc("Inline all applicible functions on the device."), cl::Hidden,
133 cl::init(false));
134
135static cl::opt<bool>
136 EnableVerboseRemarks("openmp-opt-verbose-remarks",
137 cl::desc("Enables more verbose remarks."), cl::Hidden,
138 cl::init(false));
139
141 SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden,
142 cl::desc("Maximal number of attributor iterations."),
143 cl::init(256));
144
146 SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden,
147 cl::desc("Maximum amount of shared memory to use."),
148 cl::init(std::numeric_limits<unsigned>::max()));
149
150STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
151 "Number of OpenMP runtime calls deduplicated");
152STATISTIC(NumOpenMPParallelRegionsDeleted,
153 "Number of OpenMP parallel regions deleted");
154STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
155 "Number of OpenMP runtime functions identified");
156STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
157 "Number of OpenMP runtime function uses identified");
158STATISTIC(NumOpenMPTargetRegionKernels,
159 "Number of OpenMP target region entry points (=kernels) identified");
160STATISTIC(NumNonOpenMPTargetRegionKernels,
161 "Number of non-OpenMP target region kernels identified");
162STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
163 "Number of OpenMP target region entry points (=kernels) executed in "
164 "SPMD-mode instead of generic-mode");
165STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
166 "Number of OpenMP target region entry points (=kernels) executed in "
167 "generic-mode without a state machines");
168STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
169 "Number of OpenMP target region entry points (=kernels) executed in "
170 "generic-mode with customized state machines with fallback");
171STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
172 "Number of OpenMP target region entry points (=kernels) executed in "
173 "generic-mode with customized state machines without fallback");
175 NumOpenMPParallelRegionsReplacedInGPUStateMachine,
176 "Number of OpenMP parallel regions replaced with ID in GPU state machines");
177STATISTIC(NumOpenMPParallelRegionsMerged,
178 "Number of OpenMP parallel regions merged");
179STATISTIC(NumBytesMovedToSharedMemory,
180 "Amount of memory pushed to shared memory");
181STATISTIC(NumBarriersEliminated, "Number of redundant barriers eliminated");
182
183#if !defined(NDEBUG)
184static constexpr auto TAG = "[" DEBUG_TYPE "]";
185#endif
186
187namespace KernelInfo {
188
189// struct ConfigurationEnvironmentTy {
190// uint8_t UseGenericStateMachine;
191// uint8_t MayUseNestedParallelism;
192// llvm::omp::OMPTgtExecModeFlags ExecMode;
193// int32_t MinThreads;
194// int32_t MaxThreads;
195// int32_t MinTeams;
196// int32_t MaxTeams;
197// };
198
199// struct DynamicEnvironmentTy {
200// uint16_t DebugIndentionLevel;
201// };
202
203// struct KernelEnvironmentTy {
204// ConfigurationEnvironmentTy Configuration;
205// IdentTy *Ident;
206// DynamicEnvironmentTy *DynamicEnv;
207// };
208
209#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX) \
210 constexpr const unsigned MEMBER##Idx = IDX;
211
212KERNEL_ENVIRONMENT_IDX(Configuration, 0)
214
215#undef KERNEL_ENVIRONMENT_IDX
216
217#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX) \
218 constexpr const unsigned MEMBER##Idx = IDX;
219
220KERNEL_ENVIRONMENT_CONFIGURATION_IDX(UseGenericStateMachine, 0)
221KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MayUseNestedParallelism, 1)
227
228#undef KERNEL_ENVIRONMENT_CONFIGURATION_IDX
229
230#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE) \
231 RETURNTYPE *get##MEMBER##FromKernelEnvironment(ConstantStruct *KernelEnvC) { \
232 return cast<RETURNTYPE>(KernelEnvC->getAggregateElement(MEMBER##Idx)); \
233 }
234
237
238#undef KERNEL_ENVIRONMENT_GETTER
239
240#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER) \
241 ConstantInt *get##MEMBER##FromKernelEnvironment( \
242 ConstantStruct *KernelEnvC) { \
243 ConstantStruct *ConfigC = \
244 getConfigurationFromKernelEnvironment(KernelEnvC); \
245 return dyn_cast<ConstantInt>(ConfigC->getAggregateElement(MEMBER##Idx)); \
246 }
247
248KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(UseGenericStateMachine)
249KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MayUseNestedParallelism)
255
256#undef KERNEL_ENVIRONMENT_CONFIGURATION_GETTER
257
260 constexpr const int InitKernelEnvironmentArgNo = 0;
261 return cast<GlobalVariable>(
262 KernelInitCB->getArgOperand(InitKernelEnvironmentArgNo)
264}
265
267 GlobalVariable *KernelEnvGV =
269 return cast<ConstantStruct>(KernelEnvGV->getInitializer());
270}
271} // namespace KernelInfo
272
273namespace {
274
275struct AAHeapToShared;
276
277struct AAICVTracker;
278
279/// OpenMP specific information. For now, stores RFIs and ICVs also needed for
280/// Attributor runs.
281struct OMPInformationCache : public InformationCache {
282 OMPInformationCache(Module &M, AnalysisGetter &AG,
284 bool OpenMPPostLink)
285 : InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),
286 OpenMPPostLink(OpenMPPostLink) {
287
288 OMPBuilder.Config.IsTargetDevice = isOpenMPDevice(OMPBuilder.M);
289 const Triple T(OMPBuilder.M.getTargetTriple());
290 switch (T.getArch()) {
294 assert(OMPBuilder.Config.IsTargetDevice &&
295 "OpenMP AMDGPU/NVPTX is only prepared to deal with device code.");
296 OMPBuilder.Config.IsGPU = true;
297 break;
298 default:
299 OMPBuilder.Config.IsGPU = false;
300 break;
301 }
302 OMPBuilder.initialize();
303 initializeRuntimeFunctions(M);
304 initializeInternalControlVars();
305 }
306
307 /// Generic information that describes an internal control variable.
308 struct InternalControlVarInfo {
309 /// The kind, as described by InternalControlVar enum.
311
312 /// The name of the ICV.
314
315 /// Environment variable associated with this ICV.
316 StringRef EnvVarName;
317
318 /// Initial value kind.
319 ICVInitValue InitKind;
320
321 /// Initial value.
322 ConstantInt *InitValue;
323
324 /// Setter RTL function associated with this ICV.
325 RuntimeFunction Setter;
326
327 /// Getter RTL function associated with this ICV.
328 RuntimeFunction Getter;
329
330 /// RTL Function corresponding to the override clause of this ICV
332 };
333
334 /// Generic information that describes a runtime function
335 struct RuntimeFunctionInfo {
336
337 /// The kind, as described by the RuntimeFunction enum.
339
340 /// The name of the function.
342
343 /// Flag to indicate a variadic function.
344 bool IsVarArg;
345
346 /// The return type of the function.
348
349 /// The argument types of the function.
350 SmallVector<Type *, 8> ArgumentTypes;
351
352 /// The declaration if available.
353 Function *Declaration = nullptr;
354
355 /// Uses of this runtime function per function containing the use.
356 using UseVector = SmallVector<Use *, 16>;
357
358 /// Clear UsesMap for runtime function.
359 void clearUsesMap() { UsesMap.clear(); }
360
361 /// Boolean conversion that is true if the runtime function was found.
362 operator bool() const { return Declaration; }
363
364 /// Return the vector of uses in function \p F.
365 UseVector &getOrCreateUseVector(Function *F) {
366 std::shared_ptr<UseVector> &UV = UsesMap[F];
367 if (!UV)
368 UV = std::make_shared<UseVector>();
369 return *UV;
370 }
371
372 /// Return the vector of uses in function \p F or `nullptr` if there are
373 /// none.
374 const UseVector *getUseVector(Function &F) const {
375 auto I = UsesMap.find(&F);
376 if (I != UsesMap.end())
377 return I->second.get();
378 return nullptr;
379 }
380
381 /// Return how many functions contain uses of this runtime function.
382 size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
383
384 /// Return the number of arguments (or the minimal number for variadic
385 /// functions).
386 size_t getNumArgs() const { return ArgumentTypes.size(); }
387
388 /// Run the callback \p CB on each use and forget the use if the result is
389 /// true. The callback will be fed the function in which the use was
390 /// encountered as second argument.
391 void foreachUse(SmallVectorImpl<Function *> &SCC,
392 function_ref<bool(Use &, Function &)> CB) {
393 for (Function *F : SCC)
394 foreachUse(CB, F);
395 }
396
397 /// Run the callback \p CB on each use within the function \p F and forget
398 /// the use if the result is true.
399 void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
400 SmallVector<unsigned, 8> ToBeDeleted;
401 ToBeDeleted.clear();
402
403 unsigned Idx = 0;
404 UseVector &UV = getOrCreateUseVector(F);
405
406 for (Use *U : UV) {
407 if (CB(*U, *F))
408 ToBeDeleted.push_back(Idx);
409 ++Idx;
410 }
411
412 // Remove the to-be-deleted indices in reverse order as prior
413 // modifications will not modify the smaller indices.
414 while (!ToBeDeleted.empty()) {
415 unsigned Idx = ToBeDeleted.pop_back_val();
416 UV[Idx] = UV.back();
417 UV.pop_back();
418 }
419 }
420
421 private:
422 /// Map from functions to all uses of this runtime function contained in
423 /// them.
425
426 public:
427 /// Iterators for the uses of this runtime function.
428 decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
429 decltype(UsesMap)::iterator end() { return UsesMap.end(); }
430 };
431
432 /// An OpenMP-IR-Builder instance
433 OpenMPIRBuilder OMPBuilder;
434
435 /// Map from runtime function kind to the runtime function description.
436 EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
437 RuntimeFunction::OMPRTL___last>
438 RFIs;
439
440 /// Map from function declarations/definitions to their runtime enum type.
441 DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
442
443 /// Map from ICV kind to the ICV description.
444 EnumeratedArray<InternalControlVarInfo, InternalControlVar,
445 InternalControlVar::ICV___last>
446 ICVs;
447
448 /// Helper to initialize all internal control variable information for those
449 /// defined in OMPKinds.def.
450 void initializeInternalControlVars() {
451#define ICV_RT_SET(_Name, RTL) \
452 { \
453 auto &ICV = ICVs[_Name]; \
454 ICV.Setter = RTL; \
455 }
456#define ICV_RT_GET(Name, RTL) \
457 { \
458 auto &ICV = ICVs[Name]; \
459 ICV.Getter = RTL; \
460 }
461#define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \
462 { \
463 auto &ICV = ICVs[Enum]; \
464 ICV.Name = _Name; \
465 ICV.Kind = Enum; \
466 ICV.InitKind = Init; \
467 ICV.EnvVarName = _EnvVarName; \
468 switch (ICV.InitKind) { \
469 case ICV_IMPLEMENTATION_DEFINED: \
470 ICV.InitValue = nullptr; \
471 break; \
472 case ICV_ZERO: \
473 ICV.InitValue = ConstantInt::get( \
474 Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \
475 break; \
476 case ICV_FALSE: \
477 ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \
478 break; \
479 case ICV_LAST: \
480 break; \
481 } \
482 }
483#include "llvm/Frontend/OpenMP/OMPKinds.def"
484 }
485
486 /// Returns true if the function declaration \p F matches the runtime
487 /// function types, that is, return type \p RTFRetType, and argument types
488 /// \p RTFArgTypes.
489 static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
490 SmallVector<Type *, 8> &RTFArgTypes) {
491 // TODO: We should output information to the user (under debug output
492 // and via remarks).
493
494 if (!F)
495 return false;
496 if (F->getReturnType() != RTFRetType)
497 return false;
498 if (F->arg_size() != RTFArgTypes.size())
499 return false;
500
501 auto *RTFTyIt = RTFArgTypes.begin();
502 for (Argument &Arg : F->args()) {
503 if (Arg.getType() != *RTFTyIt)
504 return false;
505
506 ++RTFTyIt;
507 }
508
509 return true;
510 }
511
512 // Helper to collect all uses of the declaration in the UsesMap.
513 unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
514 unsigned NumUses = 0;
515 if (!RFI.Declaration)
516 return NumUses;
517 OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
518
519 if (CollectStats) {
520 NumOpenMPRuntimeFunctionsIdentified += 1;
521 NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
522 }
523
524 // TODO: We directly convert uses into proper calls and unknown uses.
525 for (Use &U : RFI.Declaration->uses()) {
526 if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
527 if (!CGSCC || CGSCC->empty() || CGSCC->contains(UserI->getFunction())) {
528 RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
529 ++NumUses;
530 }
531 } else {
532 RFI.getOrCreateUseVector(nullptr).push_back(&U);
533 ++NumUses;
534 }
535 }
536 return NumUses;
537 }
538
539 // Helper function to recollect uses of a runtime function.
540 void recollectUsesForFunction(RuntimeFunction RTF) {
541 auto &RFI = RFIs[RTF];
542 RFI.clearUsesMap();
543 collectUses(RFI, /*CollectStats*/ false);
544 }
545
546 // Helper function to recollect uses of all runtime functions.
547 void recollectUses() {
548 for (int Idx = 0; Idx < RFIs.size(); ++Idx)
549 recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
550 }
551
552 // Helper function to inherit the calling convention of the function callee.
553 void setCallingConvention(FunctionCallee Callee, CallInst *CI) {
554 if (Function *Fn = dyn_cast<Function>(Callee.getCallee()))
555 CI->setCallingConv(Fn->getCallingConv());
556 }
557
558 // Helper function to determine if it's legal to create a call to the runtime
559 // functions.
560 bool runtimeFnsAvailable(ArrayRef<RuntimeFunction> Fns) {
561 // We can always emit calls if we haven't yet linked in the runtime.
562 if (!OpenMPPostLink)
563 return true;
564
565 // Once the runtime has been already been linked in we cannot emit calls to
566 // any undefined functions.
567 for (RuntimeFunction Fn : Fns) {
568 RuntimeFunctionInfo &RFI = RFIs[Fn];
569
570 if (RFI.Declaration && RFI.Declaration->isDeclaration())
571 return false;
572 }
573 return true;
574 }
575
576 /// Helper to initialize all runtime function information for those defined
577 /// in OpenMPKinds.def.
578 void initializeRuntimeFunctions(Module &M) {
579
580 // Helper macros for handling __VA_ARGS__ in OMP_RTL
581#define OMP_TYPE(VarName, ...) \
582 Type *VarName = OMPBuilder.VarName; \
583 (void)VarName;
584
585#define OMP_ARRAY_TYPE(VarName, ...) \
586 ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \
587 (void)VarName##Ty; \
588 PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \
589 (void)VarName##PtrTy;
590
591#define OMP_FUNCTION_TYPE(VarName, ...) \
592 FunctionType *VarName = OMPBuilder.VarName; \
593 (void)VarName; \
594 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
595 (void)VarName##Ptr;
596
597#define OMP_STRUCT_TYPE(VarName, ...) \
598 StructType *VarName = OMPBuilder.VarName; \
599 (void)VarName; \
600 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
601 (void)VarName##Ptr;
602
603#define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \
604 { \
605 SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \
606 Function *F = M.getFunction(_Name); \
607 RTLFunctions.insert(F); \
608 if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \
609 RuntimeFunctionIDMap[F] = _Enum; \
610 auto &RFI = RFIs[_Enum]; \
611 RFI.Kind = _Enum; \
612 RFI.Name = _Name; \
613 RFI.IsVarArg = _IsVarArg; \
614 RFI.ReturnType = OMPBuilder._ReturnType; \
615 RFI.ArgumentTypes = std::move(ArgsTypes); \
616 RFI.Declaration = F; \
617 unsigned NumUses = collectUses(RFI); \
618 (void)NumUses; \
619 LLVM_DEBUG({ \
620 dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \
621 << " found\n"; \
622 if (RFI.Declaration) \
623 dbgs() << TAG << "-> got " << NumUses << " uses in " \
624 << RFI.getNumFunctionsWithUses() \
625 << " different functions.\n"; \
626 }); \
627 } \
628 }
629#include "llvm/Frontend/OpenMP/OMPKinds.def"
630
631 // Remove the `noinline` attribute from `__kmpc`, `ompx::` and `omp_`
632 // functions, except if `optnone` is present.
633 if (isOpenMPDevice(M)) {
634 for (Function &F : M) {
635 for (StringRef Prefix : {"__kmpc", "_ZN4ompx", "omp_"})
636 if (F.hasFnAttribute(Attribute::NoInline) &&
637 F.getName().starts_with(Prefix) &&
638 !F.hasFnAttribute(Attribute::OptimizeNone))
639 F.removeFnAttr(Attribute::NoInline);
640 }
641 }
642
643 // TODO: We should attach the attributes defined in OMPKinds.def.
644 }
645
646 /// Collection of known OpenMP runtime functions..
647 DenseSet<const Function *> RTLFunctions;
648
649 /// Indicates if we have already linked in the OpenMP device library.
650 bool OpenMPPostLink = false;
651};
652
653template <typename Ty, bool InsertInvalidates = true>
654struct BooleanStateWithSetVector : public BooleanState {
655 bool contains(const Ty &Elem) const { return Set.contains(Elem); }
656 bool insert(const Ty &Elem) {
657 if (InsertInvalidates)
659 return Set.insert(Elem);
660 }
661
662 const Ty &operator[](int Idx) const { return Set[Idx]; }
663 bool operator==(const BooleanStateWithSetVector &RHS) const {
664 return BooleanState::operator==(RHS) && Set == RHS.Set;
665 }
666 bool operator!=(const BooleanStateWithSetVector &RHS) const {
667 return !(*this == RHS);
668 }
669
670 bool empty() const { return Set.empty(); }
671 size_t size() const { return Set.size(); }
672
673 /// "Clamp" this state with \p RHS.
674 BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
675 BooleanState::operator^=(RHS);
676 Set.insert(RHS.Set.begin(), RHS.Set.end());
677 return *this;
678 }
679
680private:
681 /// A set to keep track of elements.
683
684public:
685 typename decltype(Set)::iterator begin() { return Set.begin(); }
686 typename decltype(Set)::iterator end() { return Set.end(); }
687 typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
688 typename decltype(Set)::const_iterator end() const { return Set.end(); }
689};
690
691template <typename Ty, bool InsertInvalidates = true>
692using BooleanStateWithPtrSetVector =
693 BooleanStateWithSetVector<Ty *, InsertInvalidates>;
694
695struct KernelInfoState : AbstractState {
696 /// Flag to track if we reached a fixpoint.
697 bool IsAtFixpoint = false;
698
699 /// The parallel regions (identified by the outlined parallel functions) that
700 /// can be reached from the associated function.
701 BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
702 ReachedKnownParallelRegions;
703
704 /// State to track what parallel region we might reach.
705 BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
706
707 /// State to track if we are in SPMD-mode, assumed or know, and why we decided
708 /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
709 /// false.
710 BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
711
712 /// The __kmpc_target_init call in this kernel, if any. If we find more than
713 /// one we abort as the kernel is malformed.
714 CallBase *KernelInitCB = nullptr;
715
716 /// The constant kernel environement as taken from and passed to
717 /// __kmpc_target_init.
718 ConstantStruct *KernelEnvC = nullptr;
719
720 /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
721 /// one we abort as the kernel is malformed.
722 CallBase *KernelDeinitCB = nullptr;
723
724 /// Flag to indicate if the associated function is a kernel entry.
725 bool IsKernelEntry = false;
726
727 /// State to track what kernel entries can reach the associated function.
728 BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
729
730 /// State to indicate if we can track parallel level of the associated
731 /// function. We will give up tracking if we encounter unknown caller or the
732 /// caller is __kmpc_parallel_51.
733 BooleanStateWithSetVector<uint8_t> ParallelLevels;
734
735 /// Flag that indicates if the kernel has nested Parallelism
736 bool NestedParallelism = false;
737
738 /// Abstract State interface
739 ///{
740
741 KernelInfoState() = default;
742 KernelInfoState(bool BestState) {
743 if (!BestState)
745 }
746
747 /// See AbstractState::isValidState(...)
748 bool isValidState() const override { return true; }
749
750 /// See AbstractState::isAtFixpoint(...)
751 bool isAtFixpoint() const override { return IsAtFixpoint; }
752
753 /// See AbstractState::indicatePessimisticFixpoint(...)
755 IsAtFixpoint = true;
756 ParallelLevels.indicatePessimisticFixpoint();
757 ReachingKernelEntries.indicatePessimisticFixpoint();
758 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
759 ReachedKnownParallelRegions.indicatePessimisticFixpoint();
760 ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
761 NestedParallelism = true;
763 }
764
765 /// See AbstractState::indicateOptimisticFixpoint(...)
767 IsAtFixpoint = true;
768 ParallelLevels.indicateOptimisticFixpoint();
769 ReachingKernelEntries.indicateOptimisticFixpoint();
770 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
771 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
772 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
774 }
775
776 /// Return the assumed state
777 KernelInfoState &getAssumed() { return *this; }
778 const KernelInfoState &getAssumed() const { return *this; }
779
780 bool operator==(const KernelInfoState &RHS) const {
781 if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
782 return false;
783 if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
784 return false;
785 if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
786 return false;
787 if (ReachingKernelEntries != RHS.ReachingKernelEntries)
788 return false;
789 if (ParallelLevels != RHS.ParallelLevels)
790 return false;
791 if (NestedParallelism != RHS.NestedParallelism)
792 return false;
793 return true;
794 }
795
796 /// Returns true if this kernel contains any OpenMP parallel regions.
797 bool mayContainParallelRegion() {
798 return !ReachedKnownParallelRegions.empty() ||
799 !ReachedUnknownParallelRegions.empty();
800 }
801
802 /// Return empty set as the best state of potential values.
803 static KernelInfoState getBestState() { return KernelInfoState(true); }
804
805 static KernelInfoState getBestState(KernelInfoState &KIS) {
806 return getBestState();
807 }
808
809 /// Return full set as the worst state of potential values.
810 static KernelInfoState getWorstState() { return KernelInfoState(false); }
811
812 /// "Clamp" this state with \p KIS.
813 KernelInfoState operator^=(const KernelInfoState &KIS) {
814 // Do not merge two different _init and _deinit call sites.
815 if (KIS.KernelInitCB) {
816 if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
817 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
818 "assumptions.");
819 KernelInitCB = KIS.KernelInitCB;
820 }
821 if (KIS.KernelDeinitCB) {
822 if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
823 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
824 "assumptions.");
825 KernelDeinitCB = KIS.KernelDeinitCB;
826 }
827 if (KIS.KernelEnvC) {
828 if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
829 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
830 "assumptions.");
831 KernelEnvC = KIS.KernelEnvC;
832 }
833 SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
834 ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
835 ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
836 NestedParallelism |= KIS.NestedParallelism;
837 return *this;
838 }
839
840 KernelInfoState operator&=(const KernelInfoState &KIS) {
841 return (*this ^= KIS);
842 }
843
844 ///}
845};
846
847/// Used to map the values physically (in the IR) stored in an offload
848/// array, to a vector in memory.
849struct OffloadArray {
850 /// Physical array (in the IR).
851 AllocaInst *Array = nullptr;
852 /// Mapped values.
853 SmallVector<Value *, 8> StoredValues;
854 /// Last stores made in the offload array.
855 SmallVector<StoreInst *, 8> LastAccesses;
856
857 OffloadArray() = default;
858
859 /// Initializes the OffloadArray with the values stored in \p Array before
860 /// instruction \p Before is reached. Returns false if the initialization
861 /// fails.
862 /// This MUST be used immediately after the construction of the object.
863 bool initialize(AllocaInst &Array, Instruction &Before) {
864 if (!Array.getAllocatedType()->isArrayTy())
865 return false;
866
867 if (!getValues(Array, Before))
868 return false;
869
870 this->Array = &Array;
871 return true;
872 }
873
874 static const unsigned DeviceIDArgNum = 1;
875 static const unsigned BasePtrsArgNum = 3;
876 static const unsigned PtrsArgNum = 4;
877 static const unsigned SizesArgNum = 5;
878
879private:
880 /// Traverses the BasicBlock where \p Array is, collecting the stores made to
881 /// \p Array, leaving StoredValues with the values stored before the
882 /// instruction \p Before is reached.
883 bool getValues(AllocaInst &Array, Instruction &Before) {
884 // Initialize container.
885 const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
886 StoredValues.assign(NumValues, nullptr);
887 LastAccesses.assign(NumValues, nullptr);
888
889 // TODO: This assumes the instruction \p Before is in the same
890 // BasicBlock as Array. Make it general, for any control flow graph.
891 BasicBlock *BB = Array.getParent();
892 if (BB != Before.getParent())
893 return false;
894
895 const DataLayout &DL = Array.getDataLayout();
896 const unsigned int PointerSize = DL.getPointerSize();
897
898 for (Instruction &I : *BB) {
899 if (&I == &Before)
900 break;
901
902 if (!isa<StoreInst>(&I))
903 continue;
904
905 auto *S = cast<StoreInst>(&I);
906 int64_t Offset = -1;
907 auto *Dst =
908 GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
909 if (Dst == &Array) {
910 int64_t Idx = Offset / PointerSize;
911 StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
912 LastAccesses[Idx] = S;
913 }
914 }
915
916 return isFilled();
917 }
918
919 /// Returns true if all values in StoredValues and
920 /// LastAccesses are not nullptrs.
921 bool isFilled() {
922 const unsigned NumValues = StoredValues.size();
923 for (unsigned I = 0; I < NumValues; ++I) {
924 if (!StoredValues[I] || !LastAccesses[I])
925 return false;
926 }
927
928 return true;
929 }
930};
931
932struct OpenMPOpt {
933
934 using OptimizationRemarkGetter =
936
937 OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
938 OptimizationRemarkGetter OREGetter,
939 OMPInformationCache &OMPInfoCache, Attributor &A)
940 : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
941 OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
942
943 /// Check if any remarks are enabled for openmp-opt
944 bool remarksEnabled() {
945 auto &Ctx = M.getContext();
946 return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
947 }
948
949 /// Run all OpenMP optimizations on the underlying SCC.
950 bool run(bool IsModulePass) {
951 if (SCC.empty())
952 return false;
953
954 bool Changed = false;
955
956 LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
957 << " functions\n");
958
959 if (IsModulePass) {
960 Changed |= runAttributor(IsModulePass);
961
962 // Recollect uses, in case Attributor deleted any.
963 OMPInfoCache.recollectUses();
964
965 // TODO: This should be folded into buildCustomStateMachine.
966 Changed |= rewriteDeviceCodeStateMachine();
967
968 if (remarksEnabled())
969 analysisGlobalization();
970 } else {
971 if (PrintICVValues)
972 printICVs();
974 printKernels();
975
976 Changed |= runAttributor(IsModulePass);
977
978 // Recollect uses, in case Attributor deleted any.
979 OMPInfoCache.recollectUses();
980
981 Changed |= deleteParallelRegions();
982
984 Changed |= hideMemTransfersLatency();
985 Changed |= deduplicateRuntimeCalls();
987 if (mergeParallelRegions()) {
988 deduplicateRuntimeCalls();
989 Changed = true;
990 }
991 }
992 }
993
994 if (OMPInfoCache.OpenMPPostLink)
995 Changed |= removeRuntimeSymbols();
996
997 return Changed;
998 }
999
1000 /// Print initial ICV values for testing.
1001 /// FIXME: This should be done from the Attributor once it is added.
1002 void printICVs() const {
1003 InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
1004 ICV_proc_bind};
1005
1006 for (Function *F : SCC) {
1007 for (auto ICV : ICVs) {
1008 auto ICVInfo = OMPInfoCache.ICVs[ICV];
1009 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1010 return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
1011 << " Value: "
1012 << (ICVInfo.InitValue
1013 ? toString(ICVInfo.InitValue->getValue(), 10, true)
1014 : "IMPLEMENTATION_DEFINED");
1015 };
1016
1017 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
1018 }
1019 }
1020 }
1021
1022 /// Print OpenMP GPU kernels for testing.
1023 void printKernels() const {
1024 for (Function *F : SCC) {
1025 if (!omp::isOpenMPKernel(*F))
1026 continue;
1027
1028 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1029 return ORA << "OpenMP GPU kernel "
1030 << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
1031 };
1032
1033 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
1034 }
1035 }
1036
1037 /// Return the call if \p U is a callee use in a regular call. If \p RFI is
1038 /// given it has to be the callee or a nullptr is returned.
1039 static CallInst *getCallIfRegularCall(
1040 Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1041 CallInst *CI = dyn_cast<CallInst>(U.getUser());
1042 if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
1043 (!RFI ||
1044 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1045 return CI;
1046 return nullptr;
1047 }
1048
1049 /// Return the call if \p V is a regular call. If \p RFI is given it has to be
1050 /// the callee or a nullptr is returned.
1051 static CallInst *getCallIfRegularCall(
1052 Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1053 CallInst *CI = dyn_cast<CallInst>(&V);
1054 if (CI && !CI->hasOperandBundles() &&
1055 (!RFI ||
1056 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1057 return CI;
1058 return nullptr;
1059 }
1060
1061private:
1062 /// Merge parallel regions when it is safe.
1063 bool mergeParallelRegions() {
1064 const unsigned CallbackCalleeOperand = 2;
1065 const unsigned CallbackFirstArgOperand = 3;
1066 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
1067
1068 // Check if there are any __kmpc_fork_call calls to merge.
1069 OMPInformationCache::RuntimeFunctionInfo &RFI =
1070 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1071
1072 if (!RFI.Declaration)
1073 return false;
1074
1075 // Unmergable calls that prevent merging a parallel region.
1076 OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
1077 OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
1078 OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
1079 };
1080
1081 bool Changed = false;
1082 LoopInfo *LI = nullptr;
1083 DominatorTree *DT = nullptr;
1084
1086
1087 BasicBlock *StartBB = nullptr, *EndBB = nullptr;
1088 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1089 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1090 BasicBlock *CGEndBB =
1091 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1092 assert(StartBB != nullptr && "StartBB should not be null");
1093 CGStartBB->getTerminator()->setSuccessor(0, StartBB);
1094 assert(EndBB != nullptr && "EndBB should not be null");
1095 EndBB->getTerminator()->setSuccessor(0, CGEndBB);
1096 return Error::success();
1097 };
1098
1099 auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
1100 Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
1101 ReplacementValue = &Inner;
1102 return CodeGenIP;
1103 };
1104
1105 auto FiniCB = [&](InsertPointTy CodeGenIP) { return Error::success(); };
1106
1107 /// Create a sequential execution region within a merged parallel region,
1108 /// encapsulated in a master construct with a barrier for synchronization.
1109 auto CreateSequentialRegion = [&](Function *OuterFn,
1110 BasicBlock *OuterPredBB,
1111 Instruction *SeqStartI,
1112 Instruction *SeqEndI) {
1113 // Isolate the instructions of the sequential region to a separate
1114 // block.
1115 BasicBlock *ParentBB = SeqStartI->getParent();
1116 BasicBlock *SeqEndBB =
1117 SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
1118 BasicBlock *SeqAfterBB =
1119 SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
1120 BasicBlock *SeqStartBB =
1121 SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
1122
1123 assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
1124 "Expected a different CFG");
1125 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
1126 ParentBB->getTerminator()->eraseFromParent();
1127
1128 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1129 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1130 BasicBlock *CGEndBB =
1131 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1132 assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
1133 CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
1134 assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
1135 SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
1136 return Error::success();
1137 };
1138 auto FiniCB = [&](InsertPointTy CodeGenIP) { return Error::success(); };
1139
1140 // Find outputs from the sequential region to outside users and
1141 // broadcast their values to them.
1142 for (Instruction &I : *SeqStartBB) {
1143 SmallPtrSet<Instruction *, 4> OutsideUsers;
1144 for (User *Usr : I.users()) {
1145 Instruction &UsrI = *cast<Instruction>(Usr);
1146 // Ignore outputs to LT intrinsics, code extraction for the merged
1147 // parallel region will fix them.
1148 if (UsrI.isLifetimeStartOrEnd())
1149 continue;
1150
1151 if (UsrI.getParent() != SeqStartBB)
1152 OutsideUsers.insert(&UsrI);
1153 }
1154
1155 if (OutsideUsers.empty())
1156 continue;
1157
1158 // Emit an alloca in the outer region to store the broadcasted
1159 // value.
1160 const DataLayout &DL = M.getDataLayout();
1161 AllocaInst *AllocaI = new AllocaInst(
1162 I.getType(), DL.getAllocaAddrSpace(), nullptr,
1163 I.getName() + ".seq.output.alloc", OuterFn->front().begin());
1164
1165 // Emit a store instruction in the sequential BB to update the
1166 // value.
1167 new StoreInst(&I, AllocaI, SeqStartBB->getTerminator()->getIterator());
1168
1169 // Emit a load instruction and replace the use of the output value
1170 // with it.
1171 for (Instruction *UsrI : OutsideUsers) {
1172 LoadInst *LoadI = new LoadInst(I.getType(), AllocaI,
1173 I.getName() + ".seq.output.load",
1174 UsrI->getIterator());
1175 UsrI->replaceUsesOfWith(&I, LoadI);
1176 }
1177 }
1178
1180 InsertPointTy(ParentBB, ParentBB->end()), DL);
1182 OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
1183 assert(SeqAfterIP && "Unexpected error creating master");
1184
1186 OMPInfoCache.OMPBuilder.createBarrier(*SeqAfterIP, OMPD_parallel);
1187 assert(BarrierAfterIP && "Unexpected error creating barrier");
1188
1189 BranchInst::Create(SeqAfterBB, SeqAfterIP->getBlock());
1190
1191 LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
1192 << "\n");
1193 };
1194
1195 // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
1196 // contained in BB and only separated by instructions that can be
1197 // redundantly executed in parallel. The block BB is split before the first
1198 // call (in MergableCIs) and after the last so the entire region we merge
1199 // into a single parallel region is contained in a single basic block
1200 // without any other instructions. We use the OpenMPIRBuilder to outline
1201 // that block and call the resulting function via __kmpc_fork_call.
1202 auto Merge = [&](const SmallVectorImpl<CallInst *> &MergableCIs,
1203 BasicBlock *BB) {
1204 // TODO: Change the interface to allow single CIs expanded, e.g, to
1205 // include an outer loop.
1206 assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
1207
1208 auto Remark = [&](OptimizationRemark OR) {
1209 OR << "Parallel region merged with parallel region"
1210 << (MergableCIs.size() > 2 ? "s" : "") << " at ";
1211 for (auto *CI : llvm::drop_begin(MergableCIs)) {
1212 OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
1213 if (CI != MergableCIs.back())
1214 OR << ", ";
1215 }
1216 return OR << ".";
1217 };
1218
1219 emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
1220
1221 Function *OriginalFn = BB->getParent();
1222 LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
1223 << " parallel regions in " << OriginalFn->getName()
1224 << "\n");
1225
1226 // Isolate the calls to merge in a separate block.
1227 EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
1228 BasicBlock *AfterBB =
1229 SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1230 StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
1231 "omp.par.merged");
1232
1233 assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
1234 const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1235 BB->getTerminator()->eraseFromParent();
1236
1237 // Create sequential regions for sequential instructions that are
1238 // in-between mergable parallel regions.
1239 for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1240 It != End; ++It) {
1241 Instruction *ForkCI = *It;
1242 Instruction *NextForkCI = *(It + 1);
1243
1244 // Continue if there are not in-between instructions.
1245 if (ForkCI->getNextNode() == NextForkCI)
1246 continue;
1247
1248 CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1249 NextForkCI->getPrevNode());
1250 }
1251
1252 OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1253 DL);
1254 IRBuilder<>::InsertPoint AllocaIP(
1255 &OriginalFn->getEntryBlock(),
1256 OriginalFn->getEntryBlock().getFirstInsertionPt());
1257 // Create the merged parallel region with default proc binding, to
1258 // avoid overriding binding settings, and without explicit cancellation.
1260 OMPInfoCache.OMPBuilder.createParallel(
1261 Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
1262 OMP_PROC_BIND_default, /* IsCancellable */ false);
1263 assert(AfterIP && "Unexpected error creating parallel");
1264 BranchInst::Create(AfterBB, AfterIP->getBlock());
1265
1266 // Perform the actual outlining.
1267 OMPInfoCache.OMPBuilder.finalize(OriginalFn);
1268
1269 Function *OutlinedFn = MergableCIs.front()->getCaller();
1270
1271 // Replace the __kmpc_fork_call calls with direct calls to the outlined
1272 // callbacks.
1274 for (auto *CI : MergableCIs) {
1275 Value *Callee = CI->getArgOperand(CallbackCalleeOperand);
1276 FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask;
1277 Args.clear();
1278 Args.push_back(OutlinedFn->getArg(0));
1279 Args.push_back(OutlinedFn->getArg(1));
1280 for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1281 ++U)
1282 Args.push_back(CI->getArgOperand(U));
1283
1284 CallInst *NewCI =
1285 CallInst::Create(FT, Callee, Args, "", CI->getIterator());
1286 if (CI->getDebugLoc())
1287 NewCI->setDebugLoc(CI->getDebugLoc());
1288
1289 // Forward parameter attributes from the callback to the callee.
1290 for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1291 ++U)
1292 for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
1293 NewCI->addParamAttr(
1294 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1295
1296 // Emit an explicit barrier to replace the implicit fork-join barrier.
1297 if (CI != MergableCIs.back()) {
1298 // TODO: Remove barrier if the merged parallel region includes the
1299 // 'nowait' clause.
1301 OMPInfoCache.OMPBuilder.createBarrier(
1302 InsertPointTy(NewCI->getParent(),
1303 NewCI->getNextNode()->getIterator()),
1304 OMPD_parallel);
1305 assert(AfterIP && "Unexpected error creating barrier");
1306 }
1307
1308 CI->eraseFromParent();
1309 }
1310
1311 assert(OutlinedFn != OriginalFn && "Outlining failed");
1312 CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1313 CGUpdater.reanalyzeFunction(*OriginalFn);
1314
1315 NumOpenMPParallelRegionsMerged += MergableCIs.size();
1316
1317 return true;
1318 };
1319
1320 // Helper function that identifes sequences of
1321 // __kmpc_fork_call uses in a basic block.
1322 auto DetectPRsCB = [&](Use &U, Function &F) {
1323 CallInst *CI = getCallIfRegularCall(U, &RFI);
1324 BB2PRMap[CI->getParent()].insert(CI);
1325
1326 return false;
1327 };
1328
1329 BB2PRMap.clear();
1330 RFI.foreachUse(SCC, DetectPRsCB);
1331 SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1332 // Find mergable parallel regions within a basic block that are
1333 // safe to merge, that is any in-between instructions can safely
1334 // execute in parallel after merging.
1335 // TODO: support merging across basic-blocks.
1336 for (auto &It : BB2PRMap) {
1337 auto &CIs = It.getSecond();
1338 if (CIs.size() < 2)
1339 continue;
1340
1341 BasicBlock *BB = It.getFirst();
1342 SmallVector<CallInst *, 4> MergableCIs;
1343
1344 /// Returns true if the instruction is mergable, false otherwise.
1345 /// A terminator instruction is unmergable by definition since merging
1346 /// works within a BB. Instructions before the mergable region are
1347 /// mergable if they are not calls to OpenMP runtime functions that may
1348 /// set different execution parameters for subsequent parallel regions.
1349 /// Instructions in-between parallel regions are mergable if they are not
1350 /// calls to any non-intrinsic function since that may call a non-mergable
1351 /// OpenMP runtime function.
1352 auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1353 // We do not merge across BBs, hence return false (unmergable) if the
1354 // instruction is a terminator.
1355 if (I.isTerminator())
1356 return false;
1357
1358 if (!isa<CallInst>(&I))
1359 return true;
1360
1361 CallInst *CI = cast<CallInst>(&I);
1362 if (IsBeforeMergableRegion) {
1363 Function *CalledFunction = CI->getCalledFunction();
1364 if (!CalledFunction)
1365 return false;
1366 // Return false (unmergable) if the call before the parallel
1367 // region calls an explicit affinity (proc_bind) or number of
1368 // threads (num_threads) compiler-generated function. Those settings
1369 // may be incompatible with following parallel regions.
1370 // TODO: ICV tracking to detect compatibility.
1371 for (const auto &RFI : UnmergableCallsInfo) {
1372 if (CalledFunction == RFI.Declaration)
1373 return false;
1374 }
1375 } else {
1376 // Return false (unmergable) if there is a call instruction
1377 // in-between parallel regions when it is not an intrinsic. It
1378 // may call an unmergable OpenMP runtime function in its callpath.
1379 // TODO: Keep track of possible OpenMP calls in the callpath.
1380 if (!isa<IntrinsicInst>(CI))
1381 return false;
1382 }
1383
1384 return true;
1385 };
1386 // Find maximal number of parallel region CIs that are safe to merge.
1387 for (auto It = BB->begin(), End = BB->end(); It != End;) {
1388 Instruction &I = *It;
1389 ++It;
1390
1391 if (CIs.count(&I)) {
1392 MergableCIs.push_back(cast<CallInst>(&I));
1393 continue;
1394 }
1395
1396 // Continue expanding if the instruction is mergable.
1397 if (IsMergable(I, MergableCIs.empty()))
1398 continue;
1399
1400 // Forward the instruction iterator to skip the next parallel region
1401 // since there is an unmergable instruction which can affect it.
1402 for (; It != End; ++It) {
1403 Instruction &SkipI = *It;
1404 if (CIs.count(&SkipI)) {
1405 LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1406 << " due to " << I << "\n");
1407 ++It;
1408 break;
1409 }
1410 }
1411
1412 // Store mergable regions found.
1413 if (MergableCIs.size() > 1) {
1414 MergableCIsVector.push_back(MergableCIs);
1415 LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1416 << " parallel regions in block " << BB->getName()
1417 << " of function " << BB->getParent()->getName()
1418 << "\n";);
1419 }
1420
1421 MergableCIs.clear();
1422 }
1423
1424 if (!MergableCIsVector.empty()) {
1425 Changed = true;
1426
1427 for (auto &MergableCIs : MergableCIsVector)
1428 Merge(MergableCIs, BB);
1429 MergableCIsVector.clear();
1430 }
1431 }
1432
1433 if (Changed) {
1434 /// Re-collect use for fork calls, emitted barrier calls, and
1435 /// any emitted master/end_master calls.
1436 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1437 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1438 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1439 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1440 }
1441
1442 return Changed;
1443 }
1444
1445 /// Try to delete parallel regions if possible.
1446 bool deleteParallelRegions() {
1447 const unsigned CallbackCalleeOperand = 2;
1448
1449 OMPInformationCache::RuntimeFunctionInfo &RFI =
1450 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1451
1452 if (!RFI.Declaration)
1453 return false;
1454
1455 bool Changed = false;
1456 auto DeleteCallCB = [&](Use &U, Function &) {
1457 CallInst *CI = getCallIfRegularCall(U);
1458 if (!CI)
1459 return false;
1460 auto *Fn = dyn_cast<Function>(
1461 CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1462 if (!Fn)
1463 return false;
1464 if (!Fn->onlyReadsMemory())
1465 return false;
1466 if (!Fn->hasFnAttribute(Attribute::WillReturn))
1467 return false;
1468
1469 LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1470 << CI->getCaller()->getName() << "\n");
1471
1472 auto Remark = [&](OptimizationRemark OR) {
1473 return OR << "Removing parallel region with no side-effects.";
1474 };
1475 emitRemark<OptimizationRemark>(CI, "OMP160", Remark);
1476
1477 CI->eraseFromParent();
1478 Changed = true;
1479 ++NumOpenMPParallelRegionsDeleted;
1480 return true;
1481 };
1482
1483 RFI.foreachUse(SCC, DeleteCallCB);
1484
1485 return Changed;
1486 }
1487
1488 /// Try to eliminate runtime calls by reusing existing ones.
1489 bool deduplicateRuntimeCalls() {
1490 bool Changed = false;
1491
1492 RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1493 OMPRTL_omp_get_num_threads,
1494 OMPRTL_omp_in_parallel,
1495 OMPRTL_omp_get_cancellation,
1496 OMPRTL_omp_get_supported_active_levels,
1497 OMPRTL_omp_get_level,
1498 OMPRTL_omp_get_ancestor_thread_num,
1499 OMPRTL_omp_get_team_size,
1500 OMPRTL_omp_get_active_level,
1501 OMPRTL_omp_in_final,
1502 OMPRTL_omp_get_proc_bind,
1503 OMPRTL_omp_get_num_places,
1504 OMPRTL_omp_get_num_procs,
1505 OMPRTL_omp_get_place_num,
1506 OMPRTL_omp_get_partition_num_places,
1507 OMPRTL_omp_get_partition_place_nums};
1508
1509 // Global-tid is handled separately.
1511 collectGlobalThreadIdArguments(GTIdArgs);
1512 LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1513 << " global thread ID arguments\n");
1514
1515 for (Function *F : SCC) {
1516 for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1517 Changed |= deduplicateRuntimeCalls(
1518 *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1519
1520 // __kmpc_global_thread_num is special as we can replace it with an
1521 // argument in enough cases to make it worth trying.
1522 Value *GTIdArg = nullptr;
1523 for (Argument &Arg : F->args())
1524 if (GTIdArgs.count(&Arg)) {
1525 GTIdArg = &Arg;
1526 break;
1527 }
1528 Changed |= deduplicateRuntimeCalls(
1529 *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1530 }
1531
1532 return Changed;
1533 }
1534
1535 /// Tries to remove known runtime symbols that are optional from the module.
1536 bool removeRuntimeSymbols() {
1537 // The RPC client symbol is defined in `libc` and indicates that something
1538 // required an RPC server. If its users were all optimized out then we can
1539 // safely remove it.
1540 // TODO: This should be somewhere more common in the future.
1541 if (GlobalVariable *GV = M.getNamedGlobal("__llvm_rpc_client")) {
1542 if (GV->getNumUses() >= 1)
1543 return false;
1544
1545 GV->replaceAllUsesWith(PoisonValue::get(GV->getType()));
1546 GV->eraseFromParent();
1547 return true;
1548 }
1549 return false;
1550 }
1551
1552 /// Tries to hide the latency of runtime calls that involve host to
1553 /// device memory transfers by splitting them into their "issue" and "wait"
1554 /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1555 /// moved downards as much as possible. The "issue" issues the memory transfer
1556 /// asynchronously, returning a handle. The "wait" waits in the returned
1557 /// handle for the memory transfer to finish.
1558 bool hideMemTransfersLatency() {
1559 auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1560 bool Changed = false;
1561 auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1562 auto *RTCall = getCallIfRegularCall(U, &RFI);
1563 if (!RTCall)
1564 return false;
1565
1566 OffloadArray OffloadArrays[3];
1567 if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1568 return false;
1569
1570 LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1571
1572 // TODO: Check if can be moved upwards.
1573 bool WasSplit = false;
1574 Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1575 if (WaitMovementPoint)
1576 WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1577
1578 Changed |= WasSplit;
1579 return WasSplit;
1580 };
1581 if (OMPInfoCache.runtimeFnsAvailable(
1582 {OMPRTL___tgt_target_data_begin_mapper_issue,
1583 OMPRTL___tgt_target_data_begin_mapper_wait}))
1584 RFI.foreachUse(SCC, SplitMemTransfers);
1585
1586 return Changed;
1587 }
1588
1589 void analysisGlobalization() {
1590 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1591
1592 auto CheckGlobalization = [&](Use &U, Function &Decl) {
1593 if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1594 auto Remark = [&](OptimizationRemarkMissed ORM) {
1595 return ORM
1596 << "Found thread data sharing on the GPU. "
1597 << "Expect degraded performance due to data globalization.";
1598 };
1599 emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark);
1600 }
1601
1602 return false;
1603 };
1604
1605 RFI.foreachUse(SCC, CheckGlobalization);
1606 }
1607
1608 /// Maps the values stored in the offload arrays passed as arguments to
1609 /// \p RuntimeCall into the offload arrays in \p OAs.
1610 bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1612 assert(OAs.size() == 3 && "Need space for three offload arrays!");
1613
1614 // A runtime call that involves memory offloading looks something like:
1615 // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1616 // i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1617 // ...)
1618 // So, the idea is to access the allocas that allocate space for these
1619 // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1620 // Therefore:
1621 // i8** %offload_baseptrs.
1622 Value *BasePtrsArg =
1623 RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1624 // i8** %offload_ptrs.
1625 Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1626 // i8** %offload_sizes.
1627 Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1628
1629 // Get values stored in **offload_baseptrs.
1630 auto *V = getUnderlyingObject(BasePtrsArg);
1631 if (!isa<AllocaInst>(V))
1632 return false;
1633 auto *BasePtrsArray = cast<AllocaInst>(V);
1634 if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1635 return false;
1636
1637 // Get values stored in **offload_baseptrs.
1638 V = getUnderlyingObject(PtrsArg);
1639 if (!isa<AllocaInst>(V))
1640 return false;
1641 auto *PtrsArray = cast<AllocaInst>(V);
1642 if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1643 return false;
1644
1645 // Get values stored in **offload_sizes.
1646 V = getUnderlyingObject(SizesArg);
1647 // If it's a [constant] global array don't analyze it.
1648 if (isa<GlobalValue>(V))
1649 return isa<Constant>(V);
1650 if (!isa<AllocaInst>(V))
1651 return false;
1652
1653 auto *SizesArray = cast<AllocaInst>(V);
1654 if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1655 return false;
1656
1657 return true;
1658 }
1659
1660 /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1661 /// For now this is a way to test that the function getValuesInOffloadArrays
1662 /// is working properly.
1663 /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1664 void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1665 assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1666
1667 LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1668 std::string ValuesStr;
1669 raw_string_ostream Printer(ValuesStr);
1670 std::string Separator = " --- ";
1671
1672 for (auto *BP : OAs[0].StoredValues) {
1673 BP->print(Printer);
1674 Printer << Separator;
1675 }
1676 LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << ValuesStr << "\n");
1677 ValuesStr.clear();
1678
1679 for (auto *P : OAs[1].StoredValues) {
1680 P->print(Printer);
1681 Printer << Separator;
1682 }
1683 LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << ValuesStr << "\n");
1684 ValuesStr.clear();
1685
1686 for (auto *S : OAs[2].StoredValues) {
1687 S->print(Printer);
1688 Printer << Separator;
1689 }
1690 LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << ValuesStr << "\n");
1691 }
1692
1693 /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1694 /// moved. Returns nullptr if the movement is not possible, or not worth it.
1695 Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1696 // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1697 // Make it traverse the CFG.
1698
1699 Instruction *CurrentI = &RuntimeCall;
1700 bool IsWorthIt = false;
1701 while ((CurrentI = CurrentI->getNextNode())) {
1702
1703 // TODO: Once we detect the regions to be offloaded we should use the
1704 // alias analysis manager to check if CurrentI may modify one of
1705 // the offloaded regions.
1706 if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1707 if (IsWorthIt)
1708 return CurrentI;
1709
1710 return nullptr;
1711 }
1712
1713 // FIXME: For now if we move it over anything without side effect
1714 // is worth it.
1715 IsWorthIt = true;
1716 }
1717
1718 // Return end of BasicBlock.
1719 return RuntimeCall.getParent()->getTerminator();
1720 }
1721
1722 /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1723 bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1724 Instruction &WaitMovementPoint) {
1725 // Create stack allocated handle (__tgt_async_info) at the beginning of the
1726 // function. Used for storing information of the async transfer, allowing to
1727 // wait on it later.
1728 auto &IRBuilder = OMPInfoCache.OMPBuilder;
1729 Function *F = RuntimeCall.getCaller();
1730 BasicBlock &Entry = F->getEntryBlock();
1731 IRBuilder.Builder.SetInsertPoint(&Entry,
1732 Entry.getFirstNonPHIOrDbgOrAlloca());
1733 Value *Handle = IRBuilder.Builder.CreateAlloca(
1734 IRBuilder.AsyncInfo, /*ArraySize=*/nullptr, "handle");
1735 Handle =
1736 IRBuilder.Builder.CreateAddrSpaceCast(Handle, IRBuilder.AsyncInfoPtr);
1737
1738 // Add "issue" runtime call declaration:
1739 // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1740 // i8**, i8**, i64*, i64*)
1741 FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1742 M, OMPRTL___tgt_target_data_begin_mapper_issue);
1743
1744 // Change RuntimeCall call site for its asynchronous version.
1746 for (auto &Arg : RuntimeCall.args())
1747 Args.push_back(Arg.get());
1748 Args.push_back(Handle);
1749
1750 CallInst *IssueCallsite = CallInst::Create(IssueDecl, Args, /*NameStr=*/"",
1751 RuntimeCall.getIterator());
1752 OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);
1753 RuntimeCall.eraseFromParent();
1754
1755 // Add "wait" runtime call declaration:
1756 // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1757 FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1758 M, OMPRTL___tgt_target_data_begin_mapper_wait);
1759
1760 Value *WaitParams[2] = {
1761 IssueCallsite->getArgOperand(
1762 OffloadArray::DeviceIDArgNum), // device_id.
1763 Handle // handle to wait on.
1764 };
1765 CallInst *WaitCallsite = CallInst::Create(
1766 WaitDecl, WaitParams, /*NameStr=*/"", WaitMovementPoint.getIterator());
1767 OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);
1768
1769 return true;
1770 }
1771
1772 static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1773 bool GlobalOnly, bool &SingleChoice) {
1774 if (CurrentIdent == NextIdent)
1775 return CurrentIdent;
1776
1777 // TODO: Figure out how to actually combine multiple debug locations. For
1778 // now we just keep an existing one if there is a single choice.
1779 if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1780 SingleChoice = !CurrentIdent;
1781 return NextIdent;
1782 }
1783 return nullptr;
1784 }
1785
1786 /// Return an `struct ident_t*` value that represents the ones used in the
1787 /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1788 /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1789 /// return value we create one from scratch. We also do not yet combine
1790 /// information, e.g., the source locations, see combinedIdentStruct.
1791 Value *
1792 getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1793 Function &F, bool GlobalOnly) {
1794 bool SingleChoice = true;
1795 Value *Ident = nullptr;
1796 auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1797 CallInst *CI = getCallIfRegularCall(U, &RFI);
1798 if (!CI || &F != &Caller)
1799 return false;
1800 Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1801 /* GlobalOnly */ true, SingleChoice);
1802 return false;
1803 };
1804 RFI.foreachUse(SCC, CombineIdentStruct);
1805
1806 if (!Ident || !SingleChoice) {
1807 // The IRBuilder uses the insertion block to get to the module, this is
1808 // unfortunate but we work around it for now.
1809 if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1810 OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1811 &F.getEntryBlock(), F.getEntryBlock().begin()));
1812 // Create a fallback location if non was found.
1813 // TODO: Use the debug locations of the calls instead.
1814 uint32_t SrcLocStrSize;
1815 Constant *Loc =
1816 OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1817 Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize);
1818 }
1819 return Ident;
1820 }
1821
1822 /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1823 /// \p ReplVal if given.
1824 bool deduplicateRuntimeCalls(Function &F,
1825 OMPInformationCache::RuntimeFunctionInfo &RFI,
1826 Value *ReplVal = nullptr) {
1827 auto *UV = RFI.getUseVector(F);
1828 if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1829 return false;
1830
1831 LLVM_DEBUG(
1832 dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1833 << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1834
1835 assert((!ReplVal || (isa<Argument>(ReplVal) &&
1836 cast<Argument>(ReplVal)->getParent() == &F)) &&
1837 "Unexpected replacement value!");
1838
1839 // TODO: Use dominance to find a good position instead.
1840 auto CanBeMoved = [this](CallBase &CB) {
1841 unsigned NumArgs = CB.arg_size();
1842 if (NumArgs == 0)
1843 return true;
1844 if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1845 return false;
1846 for (unsigned U = 1; U < NumArgs; ++U)
1847 if (isa<Instruction>(CB.getArgOperand(U)))
1848 return false;
1849 return true;
1850 };
1851
1852 if (!ReplVal) {
1853 auto *DT =
1854 OMPInfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(F);
1855 if (!DT)
1856 return false;
1857 Instruction *IP = nullptr;
1858 for (Use *U : *UV) {
1859 if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1860 if (IP)
1861 IP = DT->findNearestCommonDominator(IP, CI);
1862 else
1863 IP = CI;
1864 if (!CanBeMoved(*CI))
1865 continue;
1866 if (!ReplVal)
1867 ReplVal = CI;
1868 }
1869 }
1870 if (!ReplVal)
1871 return false;
1872 assert(IP && "Expected insertion point!");
1873 cast<Instruction>(ReplVal)->moveBefore(IP);
1874 }
1875
1876 // If we use a call as a replacement value we need to make sure the ident is
1877 // valid at the new location. For now we just pick a global one, either
1878 // existing and used by one of the calls, or created from scratch.
1879 if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1880 if (!CI->arg_empty() &&
1881 CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1882 Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1883 /* GlobalOnly */ true);
1884 CI->setArgOperand(0, Ident);
1885 }
1886 }
1887
1888 bool Changed = false;
1889 auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1890 CallInst *CI = getCallIfRegularCall(U, &RFI);
1891 if (!CI || CI == ReplVal || &F != &Caller)
1892 return false;
1893 assert(CI->getCaller() == &F && "Unexpected call!");
1894
1895 auto Remark = [&](OptimizationRemark OR) {
1896 return OR << "OpenMP runtime call "
1897 << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1898 };
1899 if (CI->getDebugLoc())
1900 emitRemark<OptimizationRemark>(CI, "OMP170", Remark);
1901 else
1902 emitRemark<OptimizationRemark>(&F, "OMP170", Remark);
1903
1904 CI->replaceAllUsesWith(ReplVal);
1905 CI->eraseFromParent();
1906 ++NumOpenMPRuntimeCallsDeduplicated;
1907 Changed = true;
1908 return true;
1909 };
1910 RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1911
1912 return Changed;
1913 }
1914
1915 /// Collect arguments that represent the global thread id in \p GTIdArgs.
1916 void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1917 // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1918 // initialization. We could define an AbstractAttribute instead and
1919 // run the Attributor here once it can be run as an SCC pass.
1920
1921 // Helper to check the argument \p ArgNo at all call sites of \p F for
1922 // a GTId.
1923 auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1924 if (!F.hasLocalLinkage())
1925 return false;
1926 for (Use &U : F.uses()) {
1927 if (CallInst *CI = getCallIfRegularCall(U)) {
1928 Value *ArgOp = CI->getArgOperand(ArgNo);
1929 if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1930 getCallIfRegularCall(
1931 *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1932 continue;
1933 }
1934 return false;
1935 }
1936 return true;
1937 };
1938
1939 // Helper to identify uses of a GTId as GTId arguments.
1940 auto AddUserArgs = [&](Value &GTId) {
1941 for (Use &U : GTId.uses())
1942 if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1943 if (CI->isArgOperand(&U))
1944 if (Function *Callee = CI->getCalledFunction())
1945 if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1946 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1947 };
1948
1949 // The argument users of __kmpc_global_thread_num calls are GTIds.
1950 OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1951 OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1952
1953 GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1954 if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1955 AddUserArgs(*CI);
1956 return false;
1957 });
1958
1959 // Transitively search for more arguments by looking at the users of the
1960 // ones we know already. During the search the GTIdArgs vector is extended
1961 // so we cannot cache the size nor can we use a range based for.
1962 for (unsigned U = 0; U < GTIdArgs.size(); ++U)
1963 AddUserArgs(*GTIdArgs[U]);
1964 }
1965
1966 /// Kernel (=GPU) optimizations and utility functions
1967 ///
1968 ///{{
1969
1970 /// Cache to remember the unique kernel for a function.
1972
1973 /// Find the unique kernel that will execute \p F, if any.
1974 Kernel getUniqueKernelFor(Function &F);
1975
1976 /// Find the unique kernel that will execute \p I, if any.
1977 Kernel getUniqueKernelFor(Instruction &I) {
1978 return getUniqueKernelFor(*I.getFunction());
1979 }
1980
1981 /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1982 /// the cases we can avoid taking the address of a function.
1983 bool rewriteDeviceCodeStateMachine();
1984
1985 ///
1986 ///}}
1987
1988 /// Emit a remark generically
1989 ///
1990 /// This template function can be used to generically emit a remark. The
1991 /// RemarkKind should be one of the following:
1992 /// - OptimizationRemark to indicate a successful optimization attempt
1993 /// - OptimizationRemarkMissed to report a failed optimization attempt
1994 /// - OptimizationRemarkAnalysis to provide additional information about an
1995 /// optimization attempt
1996 ///
1997 /// The remark is built using a callback function provided by the caller that
1998 /// takes a RemarkKind as input and returns a RemarkKind.
1999 template <typename RemarkKind, typename RemarkCallBack>
2000 void emitRemark(Instruction *I, StringRef RemarkName,
2001 RemarkCallBack &&RemarkCB) const {
2002 Function *F = I->getParent()->getParent();
2003 auto &ORE = OREGetter(F);
2004
2005 if (RemarkName.starts_with("OMP"))
2006 ORE.emit([&]() {
2007 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
2008 << " [" << RemarkName << "]";
2009 });
2010 else
2011 ORE.emit(
2012 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
2013 }
2014
2015 /// Emit a remark on a function.
2016 template <typename RemarkKind, typename RemarkCallBack>
2017 void emitRemark(Function *F, StringRef RemarkName,
2018 RemarkCallBack &&RemarkCB) const {
2019 auto &ORE = OREGetter(F);
2020
2021 if (RemarkName.starts_with("OMP"))
2022 ORE.emit([&]() {
2023 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
2024 << " [" << RemarkName << "]";
2025 });
2026 else
2027 ORE.emit(
2028 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
2029 }
2030
2031 /// The underlying module.
2032 Module &M;
2033
2034 /// The SCC we are operating on.
2036
2037 /// Callback to update the call graph, the first argument is a removed call,
2038 /// the second an optional replacement call.
2039 CallGraphUpdater &CGUpdater;
2040
2041 /// Callback to get an OptimizationRemarkEmitter from a Function *
2042 OptimizationRemarkGetter OREGetter;
2043
2044 /// OpenMP-specific information cache. Also Used for Attributor runs.
2045 OMPInformationCache &OMPInfoCache;
2046
2047 /// Attributor instance.
2048 Attributor &A;
2049
2050 /// Helper function to run Attributor on SCC.
2051 bool runAttributor(bool IsModulePass) {
2052 if (SCC.empty())
2053 return false;
2054
2055 registerAAs(IsModulePass);
2056
2057 ChangeStatus Changed = A.run();
2058
2059 LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
2060 << " functions, result: " << Changed << ".\n");
2061
2062 if (Changed == ChangeStatus::CHANGED)
2063 OMPInfoCache.invalidateAnalyses();
2064
2065 return Changed == ChangeStatus::CHANGED;
2066 }
2067
2068 void registerFoldRuntimeCall(RuntimeFunction RF);
2069
2070 /// Populate the Attributor with abstract attribute opportunities in the
2071 /// functions.
2072 void registerAAs(bool IsModulePass);
2073
2074public:
2075 /// Callback to register AAs for live functions, including internal functions
2076 /// marked live during the traversal.
2077 static void registerAAsForFunction(Attributor &A, const Function &F);
2078};
2079
2080Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
2081 if (OMPInfoCache.CGSCC && !OMPInfoCache.CGSCC->empty() &&
2082 !OMPInfoCache.CGSCC->contains(&F))
2083 return nullptr;
2084
2085 // Use a scope to keep the lifetime of the CachedKernel short.
2086 {
2087 std::optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
2088 if (CachedKernel)
2089 return *CachedKernel;
2090
2091 // TODO: We should use an AA to create an (optimistic and callback
2092 // call-aware) call graph. For now we stick to simple patterns that
2093 // are less powerful, basically the worst fixpoint.
2094 if (isOpenMPKernel(F)) {
2095 CachedKernel = Kernel(&F);
2096 return *CachedKernel;
2097 }
2098
2099 CachedKernel = nullptr;
2100 if (!F.hasLocalLinkage()) {
2101
2102 // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
2103 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2104 return ORA << "Potentially unknown OpenMP target region caller.";
2105 };
2106 emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
2107
2108 return nullptr;
2109 }
2110 }
2111
2112 auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
2113 if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
2114 // Allow use in equality comparisons.
2115 if (Cmp->isEquality())
2116 return getUniqueKernelFor(*Cmp);
2117 return nullptr;
2118 }
2119 if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
2120 // Allow direct calls.
2121 if (CB->isCallee(&U))
2122 return getUniqueKernelFor(*CB);
2123
2124 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2125 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2126 // Allow the use in __kmpc_parallel_51 calls.
2127 if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
2128 return getUniqueKernelFor(*CB);
2129 return nullptr;
2130 }
2131 // Disallow every other use.
2132 return nullptr;
2133 };
2134
2135 // TODO: In the future we want to track more than just a unique kernel.
2136 SmallPtrSet<Kernel, 2> PotentialKernels;
2137 OMPInformationCache::foreachUse(F, [&](const Use &U) {
2138 PotentialKernels.insert(GetUniqueKernelForUse(U));
2139 });
2140
2141 Kernel K = nullptr;
2142 if (PotentialKernels.size() == 1)
2143 K = *PotentialKernels.begin();
2144
2145 // Cache the result.
2146 UniqueKernelMap[&F] = K;
2147
2148 return K;
2149}
2150
2151bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
2152 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2153 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2154
2155 bool Changed = false;
2156 if (!KernelParallelRFI)
2157 return Changed;
2158
2159 // If we have disabled state machine changes, exit
2161 return Changed;
2162
2163 for (Function *F : SCC) {
2164
2165 // Check if the function is a use in a __kmpc_parallel_51 call at
2166 // all.
2167 bool UnknownUse = false;
2168 bool KernelParallelUse = false;
2169 unsigned NumDirectCalls = 0;
2170
2171 SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
2172 OMPInformationCache::foreachUse(*F, [&](Use &U) {
2173 if (auto *CB = dyn_cast<CallBase>(U.getUser()))
2174 if (CB->isCallee(&U)) {
2175 ++NumDirectCalls;
2176 return;
2177 }
2178
2179 if (isa<ICmpInst>(U.getUser())) {
2180 ToBeReplacedStateMachineUses.push_back(&U);
2181 return;
2182 }
2183
2184 // Find wrapper functions that represent parallel kernels.
2185 CallInst *CI =
2186 OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
2187 const unsigned int WrapperFunctionArgNo = 6;
2188 if (!KernelParallelUse && CI &&
2189 CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
2190 KernelParallelUse = true;
2191 ToBeReplacedStateMachineUses.push_back(&U);
2192 return;
2193 }
2194 UnknownUse = true;
2195 });
2196
2197 // Do not emit a remark if we haven't seen a __kmpc_parallel_51
2198 // use.
2199 if (!KernelParallelUse)
2200 continue;
2201
2202 // If this ever hits, we should investigate.
2203 // TODO: Checking the number of uses is not a necessary restriction and
2204 // should be lifted.
2205 if (UnknownUse || NumDirectCalls != 1 ||
2206 ToBeReplacedStateMachineUses.size() > 2) {
2207 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2208 return ORA << "Parallel region is used in "
2209 << (UnknownUse ? "unknown" : "unexpected")
2210 << " ways. Will not attempt to rewrite the state machine.";
2211 };
2212 emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark);
2213 continue;
2214 }
2215
2216 // Even if we have __kmpc_parallel_51 calls, we (for now) give
2217 // up if the function is not called from a unique kernel.
2218 Kernel K = getUniqueKernelFor(*F);
2219 if (!K) {
2220 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2221 return ORA << "Parallel region is not called from a unique kernel. "
2222 "Will not attempt to rewrite the state machine.";
2223 };
2224 emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark);
2225 continue;
2226 }
2227
2228 // We now know F is a parallel body function called only from the kernel K.
2229 // We also identified the state machine uses in which we replace the
2230 // function pointer by a new global symbol for identification purposes. This
2231 // ensures only direct calls to the function are left.
2232
2233 Module &M = *F->getParent();
2234 Type *Int8Ty = Type::getInt8Ty(M.getContext());
2235
2236 auto *ID = new GlobalVariable(
2237 M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2238 UndefValue::get(Int8Ty), F->getName() + ".ID");
2239
2240 for (Use *U : ToBeReplacedStateMachineUses)
2242 ID, U->get()->getType()));
2243
2244 ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2245
2246 Changed = true;
2247 }
2248
2249 return Changed;
2250}
2251
2252/// Abstract Attribute for tracking ICV values.
2253struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2255 AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2256
2257 /// Returns true if value is assumed to be tracked.
2258 bool isAssumedTracked() const { return getAssumed(); }
2259
2260 /// Returns true if value is known to be tracked.
2261 bool isKnownTracked() const { return getAssumed(); }
2262
2263 /// Create an abstract attribute biew for the position \p IRP.
2264 static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2265
2266 /// Return the value with which \p I can be replaced for specific \p ICV.
2267 virtual std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2268 const Instruction *I,
2269 Attributor &A) const {
2270 return std::nullopt;
2271 }
2272
2273 /// Return an assumed unique ICV value if a single candidate is found. If
2274 /// there cannot be one, return a nullptr. If it is not clear yet, return
2275 /// std::nullopt.
2276 virtual std::optional<Value *>
2277 getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2278
2279 // Currently only nthreads is being tracked.
2280 // this array will only grow with time.
2281 InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2282
2283 /// See AbstractAttribute::getName()
2284 const std::string getName() const override { return "AAICVTracker"; }
2285
2286 /// See AbstractAttribute::getIdAddr()
2287 const char *getIdAddr() const override { return &ID; }
2288
2289 /// This function should return true if the type of the \p AA is AAICVTracker
2290 static bool classof(const AbstractAttribute *AA) {
2291 return (AA->getIdAddr() == &ID);
2292 }
2293
2294 static const char ID;
2295};
2296
2297struct AAICVTrackerFunction : public AAICVTracker {
2298 AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2299 : AAICVTracker(IRP, A) {}
2300
2301 // FIXME: come up with better string.
2302 const std::string getAsStr(Attributor *) const override {
2303 return "ICVTrackerFunction";
2304 }
2305
2306 // FIXME: come up with some stats.
2307 void trackStatistics() const override {}
2308
2309 /// We don't manifest anything for this AA.
2310 ChangeStatus manifest(Attributor &A) override {
2311 return ChangeStatus::UNCHANGED;
2312 }
2313
2314 // Map of ICV to their values at specific program point.
2316 InternalControlVar::ICV___last>
2317 ICVReplacementValuesMap;
2318
2319 ChangeStatus updateImpl(Attributor &A) override {
2320 ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
2321
2322 Function *F = getAnchorScope();
2323
2324 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2325
2326 for (InternalControlVar ICV : TrackableICVs) {
2327 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2328
2329 auto &ValuesMap = ICVReplacementValuesMap[ICV];
2330 auto TrackValues = [&](Use &U, Function &) {
2331 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2332 if (!CI)
2333 return false;
2334
2335 // FIXME: handle setters with more that 1 arguments.
2336 /// Track new value.
2337 if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2338 HasChanged = ChangeStatus::CHANGED;
2339
2340 return false;
2341 };
2342
2343 auto CallCheck = [&](Instruction &I) {
2344 std::optional<Value *> ReplVal = getValueForCall(A, I, ICV);
2345 if (ReplVal && ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2346 HasChanged = ChangeStatus::CHANGED;
2347
2348 return true;
2349 };
2350
2351 // Track all changes of an ICV.
2352 SetterRFI.foreachUse(TrackValues, F);
2353
2354 bool UsedAssumedInformation = false;
2355 A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2356 UsedAssumedInformation,
2357 /* CheckBBLivenessOnly */ true);
2358
2359 /// TODO: Figure out a way to avoid adding entry in
2360 /// ICVReplacementValuesMap
2361 Instruction *Entry = &F->getEntryBlock().front();
2362 if (HasChanged == ChangeStatus::CHANGED)
2363 ValuesMap.try_emplace(Entry);
2364 }
2365
2366 return HasChanged;
2367 }
2368
2369 /// Helper to check if \p I is a call and get the value for it if it is
2370 /// unique.
2371 std::optional<Value *> getValueForCall(Attributor &A, const Instruction &I,
2372 InternalControlVar &ICV) const {
2373
2374 const auto *CB = dyn_cast<CallBase>(&I);
2375 if (!CB || CB->hasFnAttr("no_openmp") ||
2376 CB->hasFnAttr("no_openmp_routines"))
2377 return std::nullopt;
2378
2379 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2380 auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2381 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2382 Function *CalledFunction = CB->getCalledFunction();
2383
2384 // Indirect call, assume ICV changes.
2385 if (CalledFunction == nullptr)
2386 return nullptr;
2387 if (CalledFunction == GetterRFI.Declaration)
2388 return std::nullopt;
2389 if (CalledFunction == SetterRFI.Declaration) {
2390 if (ICVReplacementValuesMap[ICV].count(&I))
2391 return ICVReplacementValuesMap[ICV].lookup(&I);
2392
2393 return nullptr;
2394 }
2395
2396 // Since we don't know, assume it changes the ICV.
2397 if (CalledFunction->isDeclaration())
2398 return nullptr;
2399
2400 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2401 *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
2402
2403 if (ICVTrackingAA->isAssumedTracked()) {
2404 std::optional<Value *> URV =
2405 ICVTrackingAA->getUniqueReplacementValue(ICV);
2406 if (!URV || (*URV && AA::isValidAtPosition(AA::ValueAndContext(**URV, I),
2407 OMPInfoCache)))
2408 return URV;
2409 }
2410
2411 // If we don't know, assume it changes.
2412 return nullptr;
2413 }
2414
2415 // We don't check unique value for a function, so return std::nullopt.
2416 std::optional<Value *>
2417 getUniqueReplacementValue(InternalControlVar ICV) const override {
2418 return std::nullopt;
2419 }
2420
2421 /// Return the value with which \p I can be replaced for specific \p ICV.
2422 std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2423 const Instruction *I,
2424 Attributor &A) const override {
2425 const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2426 if (ValuesMap.count(I))
2427 return ValuesMap.lookup(I);
2428
2431 Worklist.push_back(I);
2432
2433 std::optional<Value *> ReplVal;
2434
2435 while (!Worklist.empty()) {
2436 const Instruction *CurrInst = Worklist.pop_back_val();
2437 if (!Visited.insert(CurrInst).second)
2438 continue;
2439
2440 const BasicBlock *CurrBB = CurrInst->getParent();
2441
2442 // Go up and look for all potential setters/calls that might change the
2443 // ICV.
2444 while ((CurrInst = CurrInst->getPrevNode())) {
2445 if (ValuesMap.count(CurrInst)) {
2446 std::optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2447 // Unknown value, track new.
2448 if (!ReplVal) {
2449 ReplVal = NewReplVal;
2450 break;
2451 }
2452
2453 // If we found a new value, we can't know the icv value anymore.
2454 if (NewReplVal)
2455 if (ReplVal != NewReplVal)
2456 return nullptr;
2457
2458 break;
2459 }
2460
2461 std::optional<Value *> NewReplVal = getValueForCall(A, *CurrInst, ICV);
2462 if (!NewReplVal)
2463 continue;
2464
2465 // Unknown value, track new.
2466 if (!ReplVal) {
2467 ReplVal = NewReplVal;
2468 break;
2469 }
2470
2471 // if (NewReplVal.hasValue())
2472 // We found a new value, we can't know the icv value anymore.
2473 if (ReplVal != NewReplVal)
2474 return nullptr;
2475 }
2476
2477 // If we are in the same BB and we have a value, we are done.
2478 if (CurrBB == I->getParent() && ReplVal)
2479 return ReplVal;
2480
2481 // Go through all predecessors and add terminators for analysis.
2482 for (const BasicBlock *Pred : predecessors(CurrBB))
2483 if (const Instruction *Terminator = Pred->getTerminator())
2484 Worklist.push_back(Terminator);
2485 }
2486
2487 return ReplVal;
2488 }
2489};
2490
2491struct AAICVTrackerFunctionReturned : AAICVTracker {
2492 AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2493 : AAICVTracker(IRP, A) {}
2494
2495 // FIXME: come up with better string.
2496 const std::string getAsStr(Attributor *) const override {
2497 return "ICVTrackerFunctionReturned";
2498 }
2499
2500 // FIXME: come up with some stats.
2501 void trackStatistics() const override {}
2502
2503 /// We don't manifest anything for this AA.
2504 ChangeStatus manifest(Attributor &A) override {
2505 return ChangeStatus::UNCHANGED;
2506 }
2507
2508 // Map of ICV to their values at specific program point.
2510 InternalControlVar::ICV___last>
2511 ICVReplacementValuesMap;
2512
2513 /// Return the value with which \p I can be replaced for specific \p ICV.
2514 std::optional<Value *>
2515 getUniqueReplacementValue(InternalControlVar ICV) const override {
2516 return ICVReplacementValuesMap[ICV];
2517 }
2518
2519 ChangeStatus updateImpl(Attributor &A) override {
2520 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2521 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2522 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2523
2524 if (!ICVTrackingAA->isAssumedTracked())
2525 return indicatePessimisticFixpoint();
2526
2527 for (InternalControlVar ICV : TrackableICVs) {
2528 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2529 std::optional<Value *> UniqueICVValue;
2530
2531 auto CheckReturnInst = [&](Instruction &I) {
2532 std::optional<Value *> NewReplVal =
2533 ICVTrackingAA->getReplacementValue(ICV, &I, A);
2534
2535 // If we found a second ICV value there is no unique returned value.
2536 if (UniqueICVValue && UniqueICVValue != NewReplVal)
2537 return false;
2538
2539 UniqueICVValue = NewReplVal;
2540
2541 return true;
2542 };
2543
2544 bool UsedAssumedInformation = false;
2545 if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2546 UsedAssumedInformation,
2547 /* CheckBBLivenessOnly */ true))
2548 UniqueICVValue = nullptr;
2549
2550 if (UniqueICVValue == ReplVal)
2551 continue;
2552
2553 ReplVal = UniqueICVValue;
2554 Changed = ChangeStatus::CHANGED;
2555 }
2556
2557 return Changed;
2558 }
2559};
2560
2561struct AAICVTrackerCallSite : AAICVTracker {
2562 AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2563 : AAICVTracker(IRP, A) {}
2564
2565 void initialize(Attributor &A) override {
2566 assert(getAnchorScope() && "Expected anchor function");
2567
2568 // We only initialize this AA for getters, so we need to know which ICV it
2569 // gets.
2570 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2571 for (InternalControlVar ICV : TrackableICVs) {
2572 auto ICVInfo = OMPInfoCache.ICVs[ICV];
2573 auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2574 if (Getter.Declaration == getAssociatedFunction()) {
2575 AssociatedICV = ICVInfo.Kind;
2576 return;
2577 }
2578 }
2579
2580 /// Unknown ICV.
2581 indicatePessimisticFixpoint();
2582 }
2583
2584 ChangeStatus manifest(Attributor &A) override {
2585 if (!ReplVal || !*ReplVal)
2586 return ChangeStatus::UNCHANGED;
2587
2588 A.changeAfterManifest(IRPosition::inst(*getCtxI()), **ReplVal);
2589 A.deleteAfterManifest(*getCtxI());
2590
2591 return ChangeStatus::CHANGED;
2592 }
2593
2594 // FIXME: come up with better string.
2595 const std::string getAsStr(Attributor *) const override {
2596 return "ICVTrackerCallSite";
2597 }
2598
2599 // FIXME: come up with some stats.
2600 void trackStatistics() const override {}
2601
2602 InternalControlVar AssociatedICV;
2603 std::optional<Value *> ReplVal;
2604
2605 ChangeStatus updateImpl(Attributor &A) override {
2606 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2607 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2608
2609 // We don't have any information, so we assume it changes the ICV.
2610 if (!ICVTrackingAA->isAssumedTracked())
2611 return indicatePessimisticFixpoint();
2612
2613 std::optional<Value *> NewReplVal =
2614 ICVTrackingAA->getReplacementValue(AssociatedICV, getCtxI(), A);
2615
2616 if (ReplVal == NewReplVal)
2617 return ChangeStatus::UNCHANGED;
2618
2619 ReplVal = NewReplVal;
2620 return ChangeStatus::CHANGED;
2621 }
2622
2623 // Return the value with which associated value can be replaced for specific
2624 // \p ICV.
2625 std::optional<Value *>
2626 getUniqueReplacementValue(InternalControlVar ICV) const override {
2627 return ReplVal;
2628 }
2629};
2630
2631struct AAICVTrackerCallSiteReturned : AAICVTracker {
2632 AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2633 : AAICVTracker(IRP, A) {}
2634
2635 // FIXME: come up with better string.
2636 const std::string getAsStr(Attributor *) const override {
2637 return "ICVTrackerCallSiteReturned";
2638 }
2639
2640 // FIXME: come up with some stats.
2641 void trackStatistics() const override {}
2642
2643 /// We don't manifest anything for this AA.
2644 ChangeStatus manifest(Attributor &A) override {
2645 return ChangeStatus::UNCHANGED;
2646 }
2647
2648 // Map of ICV to their values at specific program point.
2650 InternalControlVar::ICV___last>
2651 ICVReplacementValuesMap;
2652
2653 /// Return the value with which associated value can be replaced for specific
2654 /// \p ICV.
2655 std::optional<Value *>
2656 getUniqueReplacementValue(InternalControlVar ICV) const override {
2657 return ICVReplacementValuesMap[ICV];
2658 }
2659
2660 ChangeStatus updateImpl(Attributor &A) override {
2661 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2662 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2663 *this, IRPosition::returned(*getAssociatedFunction()),
2664 DepClassTy::REQUIRED);
2665
2666 // We don't have any information, so we assume it changes the ICV.
2667 if (!ICVTrackingAA->isAssumedTracked())
2668 return indicatePessimisticFixpoint();
2669
2670 for (InternalControlVar ICV : TrackableICVs) {
2671 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2672 std::optional<Value *> NewReplVal =
2673 ICVTrackingAA->getUniqueReplacementValue(ICV);
2674
2675 if (ReplVal == NewReplVal)
2676 continue;
2677
2678 ReplVal = NewReplVal;
2679 Changed = ChangeStatus::CHANGED;
2680 }
2681 return Changed;
2682 }
2683};
2684
2685/// Determines if \p BB exits the function unconditionally itself or reaches a
2686/// block that does through only unique successors.
2687static bool hasFunctionEndAsUniqueSuccessor(const BasicBlock *BB) {
2688 if (succ_empty(BB))
2689 return true;
2690 const BasicBlock *const Successor = BB->getUniqueSuccessor();
2691 if (!Successor)
2692 return false;
2693 return hasFunctionEndAsUniqueSuccessor(Successor);
2694}
2695
2696struct AAExecutionDomainFunction : public AAExecutionDomain {
2697 AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2698 : AAExecutionDomain(IRP, A) {}
2699
2700 ~AAExecutionDomainFunction() { delete RPOT; }
2701
2702 void initialize(Attributor &A) override {
2704 assert(F && "Expected anchor function");
2706 }
2707
2708 const std::string getAsStr(Attributor *) const override {
2709 unsigned TotalBlocks = 0, InitialThreadBlocks = 0, AlignedBlocks = 0;
2710 for (auto &It : BEDMap) {
2711 if (!It.getFirst())
2712 continue;
2713 TotalBlocks++;
2714 InitialThreadBlocks += It.getSecond().IsExecutedByInitialThreadOnly;
2715 AlignedBlocks += It.getSecond().IsReachedFromAlignedBarrierOnly &&
2716 It.getSecond().IsReachingAlignedBarrierOnly;
2717 }
2718 return "[AAExecutionDomain] " + std::to_string(InitialThreadBlocks) + "/" +
2719 std::to_string(AlignedBlocks) + " of " +
2720 std::to_string(TotalBlocks) +
2721 " executed by initial thread / aligned";
2722 }
2723
2724 /// See AbstractAttribute::trackStatistics().
2725 void trackStatistics() const override {}
2726
2727 ChangeStatus manifest(Attributor &A) override {
2728 LLVM_DEBUG({
2729 for (const BasicBlock &BB : *getAnchorScope()) {
2731 continue;
2732 dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2733 << BB.getName() << " is executed by a single thread.\n";
2734 }
2735 });
2736
2738
2740 return Changed;
2741
2742 SmallPtrSet<CallBase *, 16> DeletedBarriers;
2743 auto HandleAlignedBarrier = [&](CallBase *CB) {
2744 const ExecutionDomainTy &ED = CB ? CEDMap[{CB, PRE}] : BEDMap[nullptr];
2745 if (!ED.IsReachedFromAlignedBarrierOnly ||
2746 ED.EncounteredNonLocalSideEffect)
2747 return;
2748 if (!ED.EncounteredAssumes.empty() && !A.isModulePass())
2749 return;
2750
2751 // We can remove this barrier, if it is one, or aligned barriers reaching
2752 // the kernel end (if CB is nullptr). Aligned barriers reaching the kernel
2753 // end should only be removed if the kernel end is their unique successor;
2754 // otherwise, they may have side-effects that aren't accounted for in the
2755 // kernel end in their other successors. If those barriers have other
2756 // barriers reaching them, those can be transitively removed as well as
2757 // long as the kernel end is also their unique successor.
2758 if (CB) {
2759 DeletedBarriers.insert(CB);
2760 A.deleteAfterManifest(*CB);
2761 ++NumBarriersEliminated;
2762 Changed = ChangeStatus::CHANGED;
2763 } else if (!ED.AlignedBarriers.empty()) {
2764 Changed = ChangeStatus::CHANGED;
2765 SmallVector<CallBase *> Worklist(ED.AlignedBarriers.begin(),
2766 ED.AlignedBarriers.end());
2768 while (!Worklist.empty()) {
2769 CallBase *LastCB = Worklist.pop_back_val();
2770 if (!Visited.insert(LastCB))
2771 continue;
2772 if (LastCB->getFunction() != getAnchorScope())
2773 continue;
2774 if (!hasFunctionEndAsUniqueSuccessor(LastCB->getParent()))
2775 continue;
2776 if (!DeletedBarriers.count(LastCB)) {
2777 ++NumBarriersEliminated;
2778 A.deleteAfterManifest(*LastCB);
2779 continue;
2780 }
2781 // The final aligned barrier (LastCB) reaching the kernel end was
2782 // removed already. This means we can go one step further and remove
2783 // the barriers encoutered last before (LastCB).
2784 const ExecutionDomainTy &LastED = CEDMap[{LastCB, PRE}];
2785 Worklist.append(LastED.AlignedBarriers.begin(),
2786 LastED.AlignedBarriers.end());
2787 }
2788 }
2789
2790 // If we actually eliminated a barrier we need to eliminate the associated
2791 // llvm.assumes as well to avoid creating UB.
2792 if (!ED.EncounteredAssumes.empty() && (CB || !ED.AlignedBarriers.empty()))
2793 for (auto *AssumeCB : ED.EncounteredAssumes)
2794 A.deleteAfterManifest(*AssumeCB);
2795 };
2796
2797 for (auto *CB : AlignedBarriers)
2798 HandleAlignedBarrier(CB);
2799
2800 // Handle the "kernel end barrier" for kernels too.
2802 HandleAlignedBarrier(nullptr);
2803
2804 return Changed;
2805 }
2806
2807 bool isNoOpFence(const FenceInst &FI) const override {
2808 return getState().isValidState() && !NonNoOpFences.count(&FI);
2809 }
2810
2811 /// Merge barrier and assumption information from \p PredED into the successor
2812 /// \p ED.
2813 void
2814 mergeInPredecessorBarriersAndAssumptions(Attributor &A, ExecutionDomainTy &ED,
2815 const ExecutionDomainTy &PredED);
2816
2817 /// Merge all information from \p PredED into the successor \p ED. If
2818 /// \p InitialEdgeOnly is set, only the initial edge will enter the block
2819 /// represented by \p ED from this predecessor.
2820 bool mergeInPredecessor(Attributor &A, ExecutionDomainTy &ED,
2821 const ExecutionDomainTy &PredED,
2822 bool InitialEdgeOnly = false);
2823
2824 /// Accumulate information for the entry block in \p EntryBBED.
2825 bool handleCallees(Attributor &A, ExecutionDomainTy &EntryBBED);
2826
2827 /// See AbstractAttribute::updateImpl.
2829
2830 /// Query interface, see AAExecutionDomain
2831 ///{
2832 bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2833 if (!isValidState())
2834 return false;
2835 assert(BB.getParent() == getAnchorScope() && "Block is out of scope!");
2836 return BEDMap.lookup(&BB).IsExecutedByInitialThreadOnly;
2837 }
2838
2840 const Instruction &I) const override {
2841 assert(I.getFunction() == getAnchorScope() &&
2842 "Instruction is out of scope!");
2843 if (!isValidState())
2844 return false;
2845
2846 bool ForwardIsOk = true;
2847 const Instruction *CurI;
2848
2849 // Check forward until a call or the block end is reached.
2850 CurI = &I;
2851 do {
2852 auto *CB = dyn_cast<CallBase>(CurI);
2853 if (!CB)
2854 continue;
2855 if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2856 return true;
2857 const auto &It = CEDMap.find({CB, PRE});
2858 if (It == CEDMap.end())
2859 continue;
2860 if (!It->getSecond().IsReachingAlignedBarrierOnly)
2861 ForwardIsOk = false;
2862 break;
2863 } while ((CurI = CurI->getNextNonDebugInstruction()));
2864
2865 if (!CurI && !BEDMap.lookup(I.getParent()).IsReachingAlignedBarrierOnly)
2866 ForwardIsOk = false;
2867
2868 // Check backward until a call or the block beginning is reached.
2869 CurI = &I;
2870 do {
2871 auto *CB = dyn_cast<CallBase>(CurI);
2872 if (!CB)
2873 continue;
2874 if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2875 return true;
2876 const auto &It = CEDMap.find({CB, POST});
2877 if (It == CEDMap.end())
2878 continue;
2879 if (It->getSecond().IsReachedFromAlignedBarrierOnly)
2880 break;
2881 return false;
2882 } while ((CurI = CurI->getPrevNonDebugInstruction()));
2883
2884 // Delayed decision on the forward pass to allow aligned barrier detection
2885 // in the backwards traversal.
2886 if (!ForwardIsOk)
2887 return false;
2888
2889 if (!CurI) {
2890 const BasicBlock *BB = I.getParent();
2891 if (BB == &BB->getParent()->getEntryBlock())
2892 return BEDMap.lookup(nullptr).IsReachedFromAlignedBarrierOnly;
2893 if (!llvm::all_of(predecessors(BB), [&](const BasicBlock *PredBB) {
2894 return BEDMap.lookup(PredBB).IsReachedFromAlignedBarrierOnly;
2895 })) {
2896 return false;
2897 }
2898 }
2899
2900 // On neither traversal we found a anything but aligned barriers.
2901 return true;
2902 }
2903
2904 ExecutionDomainTy getExecutionDomain(const BasicBlock &BB) const override {
2905 assert(isValidState() &&
2906 "No request should be made against an invalid state!");
2907 return BEDMap.lookup(&BB);
2908 }
2909 std::pair<ExecutionDomainTy, ExecutionDomainTy>
2910 getExecutionDomain(const CallBase &CB) const override {
2911 assert(isValidState() &&
2912 "No request should be made against an invalid state!");
2913 return {CEDMap.lookup({&CB, PRE}), CEDMap.lookup({&CB, POST})};
2914 }
2915 ExecutionDomainTy getFunctionExecutionDomain() const override {
2916 assert(isValidState() &&
2917 "No request should be made against an invalid state!");
2918 return InterProceduralED;
2919 }
2920 ///}
2921
2922 // Check if the edge into the successor block contains a condition that only
2923 // lets the main thread execute it.
2924 static bool isInitialThreadOnlyEdge(Attributor &A, BranchInst *Edge,
2925 BasicBlock &SuccessorBB) {
2926 if (!Edge || !Edge->isConditional())
2927 return false;
2928 if (Edge->getSuccessor(0) != &SuccessorBB)
2929 return false;
2930
2931 auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2932 if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2933 return false;
2934
2935 ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2936 if (!C)
2937 return false;
2938
2939 // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
2940 if (C->isAllOnesValue()) {
2941 auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2942 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2943 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2944 CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2945 if (!CB)
2946 return false;
2947 ConstantStruct *KernelEnvC =
2949 ConstantInt *ExecModeC =
2950 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
2951 return ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC;
2952 }
2953
2954 if (C->isZero()) {
2955 // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x()
2956 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2957 if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2958 return true;
2959
2960 // Match: 0 == llvm.amdgcn.workitem.id.x()
2961 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2962 if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2963 return true;
2964 }
2965
2966 return false;
2967 };
2968
2969 /// Mapping containing information about the function for other AAs.
2970 ExecutionDomainTy InterProceduralED;
2971
2972 enum Direction { PRE = 0, POST = 1 };
2973 /// Mapping containing information per block.
2976 CEDMap;
2977 SmallSetVector<CallBase *, 16> AlignedBarriers;
2978
2980
2981 /// Set \p R to \V and report true if that changed \p R.
2982 static bool setAndRecord(bool &R, bool V) {
2983 bool Eq = (R == V);
2984 R = V;
2985 return !Eq;
2986 }
2987
2988 /// Collection of fences known to be non-no-opt. All fences not in this set
2989 /// can be assumed no-opt.
2991};
2992
2993void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions(
2994 Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED) {
2995 for (auto *EA : PredED.EncounteredAssumes)
2996 ED.addAssumeInst(A, *EA);
2997
2998 for (auto *AB : PredED.AlignedBarriers)
2999 ED.addAlignedBarrier(A, *AB);
3000}
3001
3002bool AAExecutionDomainFunction::mergeInPredecessor(
3003 Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED,
3004 bool InitialEdgeOnly) {
3005
3006 bool Changed = false;
3007 Changed |=
3008 setAndRecord(ED.IsExecutedByInitialThreadOnly,
3009 InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly &&
3010 ED.IsExecutedByInitialThreadOnly));
3011
3012 Changed |= setAndRecord(ED.IsReachedFromAlignedBarrierOnly,
3013 ED.IsReachedFromAlignedBarrierOnly &&
3014 PredED.IsReachedFromAlignedBarrierOnly);
3015 Changed |= setAndRecord(ED.EncounteredNonLocalSideEffect,
3016 ED.EncounteredNonLocalSideEffect |
3017 PredED.EncounteredNonLocalSideEffect);
3018 // Do not track assumptions and barriers as part of Changed.
3019 if (ED.IsReachedFromAlignedBarrierOnly)
3020 mergeInPredecessorBarriersAndAssumptions(A, ED, PredED);
3021 else
3022 ED.clearAssumeInstAndAlignedBarriers();
3023 return Changed;
3024}
3025
3026bool AAExecutionDomainFunction::handleCallees(Attributor &A,
3027 ExecutionDomainTy &EntryBBED) {
3029 auto PredForCallSite = [&](AbstractCallSite ACS) {
3030 const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3031 *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
3032 DepClassTy::OPTIONAL);
3033 if (!EDAA || !EDAA->getState().isValidState())
3034 return false;
3035 CallSiteEDs.emplace_back(
3036 EDAA->getExecutionDomain(*cast<CallBase>(ACS.getInstruction())));
3037 return true;
3038 };
3039
3040 ExecutionDomainTy ExitED;
3041 bool AllCallSitesKnown;
3042 if (A.checkForAllCallSites(PredForCallSite, *this,
3043 /* RequiresAllCallSites */ true,
3044 AllCallSitesKnown)) {
3045 for (const auto &[CSInED, CSOutED] : CallSiteEDs) {
3046 mergeInPredecessor(A, EntryBBED, CSInED);
3047 ExitED.IsReachingAlignedBarrierOnly &=
3048 CSOutED.IsReachingAlignedBarrierOnly;
3049 }
3050
3051 } else {
3052 // We could not find all predecessors, so this is either a kernel or a
3053 // function with external linkage (or with some other weird uses).
3054 if (omp::isOpenMPKernel(*getAnchorScope())) {
3055 EntryBBED.IsExecutedByInitialThreadOnly = false;
3056 EntryBBED.IsReachedFromAlignedBarrierOnly = true;
3057 EntryBBED.EncounteredNonLocalSideEffect = false;
3058 ExitED.IsReachingAlignedBarrierOnly = false;
3059 } else {
3060 EntryBBED.IsExecutedByInitialThreadOnly = false;
3061 EntryBBED.IsReachedFromAlignedBarrierOnly = false;
3062 EntryBBED.EncounteredNonLocalSideEffect = true;
3063 ExitED.IsReachingAlignedBarrierOnly = false;
3064 }
3065 }
3066
3067 bool Changed = false;
3068 auto &FnED = BEDMap[nullptr];
3069 Changed |= setAndRecord(FnED.IsReachedFromAlignedBarrierOnly,
3070 FnED.IsReachedFromAlignedBarrierOnly &
3071 EntryBBED.IsReachedFromAlignedBarrierOnly);
3072 Changed |= setAndRecord(FnED.IsReachingAlignedBarrierOnly,
3073 FnED.IsReachingAlignedBarrierOnly &
3074 ExitED.IsReachingAlignedBarrierOnly);
3075 Changed |= setAndRecord(FnED.IsExecutedByInitialThreadOnly,
3076 EntryBBED.IsExecutedByInitialThreadOnly);
3077 return Changed;
3078}
3079
3080ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
3081
3082 bool Changed = false;
3083
3084 // Helper to deal with an aligned barrier encountered during the forward
3085 // traversal. \p CB is the aligned barrier, \p ED is the execution domain when
3086 // it was encountered.
3087 auto HandleAlignedBarrier = [&](CallBase &CB, ExecutionDomainTy &ED) {
3088 Changed |= AlignedBarriers.insert(&CB);
3089 // First, update the barrier ED kept in the separate CEDMap.
3090 auto &CallInED = CEDMap[{&CB, PRE}];
3091 Changed |= mergeInPredecessor(A, CallInED, ED);
3092 CallInED.IsReachingAlignedBarrierOnly = true;
3093 // Next adjust the ED we use for the traversal.
3094 ED.EncounteredNonLocalSideEffect = false;
3095 ED.IsReachedFromAlignedBarrierOnly = true;
3096 // Aligned barrier collection has to come last.
3097 ED.clearAssumeInstAndAlignedBarriers();
3098 ED.addAlignedBarrier(A, CB);
3099 auto &CallOutED = CEDMap[{&CB, POST}];
3100 Changed |= mergeInPredecessor(A, CallOutED, ED);
3101 };
3102
3103 auto *LivenessAA =
3104 A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL);
3105
3106 Function *F = getAnchorScope();
3107 BasicBlock &EntryBB = F->getEntryBlock();
3108 bool IsKernel = omp::isOpenMPKernel(*F);
3109
3110 SmallVector<Instruction *> SyncInstWorklist;
3111 for (auto &RIt : *RPOT) {
3112 BasicBlock &BB = *RIt;
3113
3114 bool IsEntryBB = &BB == &EntryBB;
3115 // TODO: We use local reasoning since we don't have a divergence analysis
3116 // running as well. We could basically allow uniform branches here.
3117 bool AlignedBarrierLastInBlock = IsEntryBB && IsKernel;
3118 bool IsExplicitlyAligned = IsEntryBB && IsKernel;
3119 ExecutionDomainTy ED;
3120 // Propagate "incoming edges" into information about this block.
3121 if (IsEntryBB) {
3122 Changed |= handleCallees(A, ED);
3123 } else {
3124 // For live non-entry blocks we only propagate
3125 // information via live edges.
3126 if (LivenessAA && LivenessAA->isAssumedDead(&BB))
3127 continue;
3128
3129 for (auto *PredBB : predecessors(&BB)) {
3130 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, &BB))
3131 continue;
3132 bool InitialEdgeOnly = isInitialThreadOnlyEdge(
3133 A, dyn_cast<BranchInst>(PredBB->getTerminator()), BB);
3134 mergeInPredecessor(A, ED, BEDMap[PredBB], InitialEdgeOnly);
3135 }
3136 }
3137
3138 // Now we traverse the block, accumulate effects in ED and attach
3139 // information to calls.
3140 for (Instruction &I : BB) {
3141 bool UsedAssumedInformation;
3142 if (A.isAssumedDead(I, *this, LivenessAA, UsedAssumedInformation,
3143 /* CheckBBLivenessOnly */ false, DepClassTy::OPTIONAL,
3144 /* CheckForDeadStore */ true))
3145 continue;
3146
3147 // Asummes and "assume-like" (dbg, lifetime, ...) are handled first, the
3148 // former is collected the latter is ignored.
3149 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
3150 if (auto *AI = dyn_cast_or_null<AssumeInst>(II)) {
3151 ED.addAssumeInst(A, *AI);
3152 continue;
3153 }
3154 // TODO: Should we also collect and delete lifetime markers?
3155 if (II->isAssumeLikeIntrinsic())
3156 continue;
3157 }
3158
3159 if (auto *FI = dyn_cast<FenceInst>(&I)) {
3160 if (!ED.EncounteredNonLocalSideEffect) {
3161 // An aligned fence without non-local side-effects is a no-op.
3162 if (ED.IsReachedFromAlignedBarrierOnly)
3163 continue;
3164 // A non-aligned fence without non-local side-effects is a no-op
3165 // if the ordering only publishes non-local side-effects (or less).
3166 switch (FI->getOrdering()) {
3167 case AtomicOrdering::NotAtomic:
3168 continue;
3169 case AtomicOrdering::Unordered:
3170 continue;
3171 case AtomicOrdering::Monotonic:
3172 continue;
3173 case AtomicOrdering::Acquire:
3174 break;
3175 case AtomicOrdering::Release:
3176 continue;
3177 case AtomicOrdering::AcquireRelease:
3178 break;
3179 case AtomicOrdering::SequentiallyConsistent:
3180 break;
3181 };
3182 }
3183 NonNoOpFences.insert(FI);
3184 }
3185
3186 auto *CB = dyn_cast<CallBase>(&I);
3187 bool IsNoSync = AA::isNoSyncInst(A, I, *this);
3188 bool IsAlignedBarrier =
3189 !IsNoSync && CB &&
3190 AANoSync::isAlignedBarrier(*CB, AlignedBarrierLastInBlock);
3191
3192 AlignedBarrierLastInBlock &= IsNoSync;
3193 IsExplicitlyAligned &= IsNoSync;
3194
3195 // Next we check for calls. Aligned barriers are handled
3196 // explicitly, everything else is kept for the backward traversal and will
3197 // also affect our state.
3198 if (CB) {
3199 if (IsAlignedBarrier) {
3200 HandleAlignedBarrier(*CB, ED);
3201 AlignedBarrierLastInBlock = true;
3202 IsExplicitlyAligned = true;
3203 continue;
3204 }
3205
3206 // Check the pointer(s) of a memory intrinsic explicitly.
3207 if (isa<MemIntrinsic>(&I)) {
3208 if (!ED.EncounteredNonLocalSideEffect &&
3210 ED.EncounteredNonLocalSideEffect = true;
3211 if (!IsNoSync) {
3212 ED.IsReachedFromAlignedBarrierOnly = false;
3213 SyncInstWorklist.push_back(&I);
3214 }
3215 continue;
3216 }
3217
3218 // Record how we entered the call, then accumulate the effect of the
3219 // call in ED for potential use by the callee.
3220 auto &CallInED = CEDMap[{CB, PRE}];
3221 Changed |= mergeInPredecessor(A, CallInED, ED);
3222
3223 // If we have a sync-definition we can check if it starts/ends in an
3224 // aligned barrier. If we are unsure we assume any sync breaks
3225 // alignment.
3227 if (!IsNoSync && Callee && !Callee->isDeclaration()) {
3228 const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3229 *this, IRPosition::function(*Callee), DepClassTy::OPTIONAL);
3230 if (EDAA && EDAA->getState().isValidState()) {
3231 const auto &CalleeED = EDAA->getFunctionExecutionDomain();
3233 CalleeED.IsReachedFromAlignedBarrierOnly;
3234 AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly;
3235 if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly)
3236 ED.EncounteredNonLocalSideEffect |=
3237 CalleeED.EncounteredNonLocalSideEffect;
3238 else
3239 ED.EncounteredNonLocalSideEffect =
3240 CalleeED.EncounteredNonLocalSideEffect;
3241 if (!CalleeED.IsReachingAlignedBarrierOnly) {
3242 Changed |=
3243 setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3244 SyncInstWorklist.push_back(&I);
3245 }
3246 if (CalleeED.IsReachedFromAlignedBarrierOnly)
3247 mergeInPredecessorBarriersAndAssumptions(A, ED, CalleeED);
3248 auto &CallOutED = CEDMap[{CB, POST}];
3249 Changed |= mergeInPredecessor(A, CallOutED, ED);
3250 continue;
3251 }
3252 }
3253 if (!IsNoSync) {
3254 ED.IsReachedFromAlignedBarrierOnly = false;
3255 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3256 SyncInstWorklist.push_back(&I);
3257 }
3258 AlignedBarrierLastInBlock &= ED.IsReachedFromAlignedBarrierOnly;
3259 ED.EncounteredNonLocalSideEffect |= !CB->doesNotAccessMemory();
3260 auto &CallOutED = CEDMap[{CB, POST}];
3261 Changed |= mergeInPredecessor(A, CallOutED, ED);
3262 }
3263
3264 if (!I.mayHaveSideEffects() && !I.mayReadFromMemory())
3265 continue;
3266
3267 // If we have a callee we try to use fine-grained information to
3268 // determine local side-effects.
3269 if (CB) {
3270 const auto *MemAA = A.getAAFor<AAMemoryLocation>(
3271 *this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL);
3272
3273 auto AccessPred = [&](const Instruction *I, const Value *Ptr,
3276 return !AA::isPotentiallyAffectedByBarrier(A, {Ptr}, *this, I);
3277 };
3278 if (MemAA && MemAA->getState().isValidState() &&
3279 MemAA->checkForAllAccessesToMemoryKind(
3281 continue;
3282 }
3283
3284 auto &InfoCache = A.getInfoCache();
3285 if (!I.mayHaveSideEffects() && InfoCache.isOnlyUsedByAssume(I))
3286 continue;
3287
3288 if (auto *LI = dyn_cast<LoadInst>(&I))
3289 if (LI->hasMetadata(LLVMContext::MD_invariant_load))
3290 continue;
3291
3292 if (!ED.EncounteredNonLocalSideEffect &&
3294 ED.EncounteredNonLocalSideEffect = true;
3295 }
3296
3297 bool IsEndAndNotReachingAlignedBarriersOnly = false;
3298 if (!isa<UnreachableInst>(BB.getTerminator()) &&
3299 !BB.getTerminator()->getNumSuccessors()) {
3300
3301 Changed |= mergeInPredecessor(A, InterProceduralED, ED);
3302
3303 auto &FnED = BEDMap[nullptr];
3304 if (IsKernel && !IsExplicitlyAligned)
3305 FnED.IsReachingAlignedBarrierOnly = false;
3306 Changed |= mergeInPredecessor(A, FnED, ED);
3307
3308 if (!FnED.IsReachingAlignedBarrierOnly) {
3309 IsEndAndNotReachingAlignedBarriersOnly = true;
3310 SyncInstWorklist.push_back(BB.getTerminator());
3311 auto &BBED = BEDMap[&BB];
3312 Changed |= setAndRecord(BBED.IsReachingAlignedBarrierOnly, false);
3313 }
3314 }
3315
3316 ExecutionDomainTy &StoredED = BEDMap[&BB];
3317 ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly &
3318 !IsEndAndNotReachingAlignedBarriersOnly;
3319
3320 // Check if we computed anything different as part of the forward
3321 // traversal. We do not take assumptions and aligned barriers into account
3322 // as they do not influence the state we iterate. Backward traversal values
3323 // are handled later on.
3324 if (ED.IsExecutedByInitialThreadOnly !=
3325 StoredED.IsExecutedByInitialThreadOnly ||
3326 ED.IsReachedFromAlignedBarrierOnly !=
3327 StoredED.IsReachedFromAlignedBarrierOnly ||
3328 ED.EncounteredNonLocalSideEffect !=
3329 StoredED.EncounteredNonLocalSideEffect)
3330 Changed = true;
3331
3332 // Update the state with the new value.
3333 StoredED = std::move(ED);
3334 }
3335
3336 // Propagate (non-aligned) sync instruction effects backwards until the
3337 // entry is hit or an aligned barrier.
3339 while (!SyncInstWorklist.empty()) {
3340 Instruction *SyncInst = SyncInstWorklist.pop_back_val();
3341 Instruction *CurInst = SyncInst;
3342 bool HitAlignedBarrierOrKnownEnd = false;
3343 while ((CurInst = CurInst->getPrevNode())) {
3344 auto *CB = dyn_cast<CallBase>(CurInst);
3345 if (!CB)
3346 continue;
3347 auto &CallOutED = CEDMap[{CB, POST}];
3348 Changed |= setAndRecord(CallOutED.IsReachingAlignedBarrierOnly, false);
3349 auto &CallInED = CEDMap[{CB, PRE}];
3350 HitAlignedBarrierOrKnownEnd =
3351 AlignedBarriers.count(CB) || !CallInED.IsReachingAlignedBarrierOnly;
3352 if (HitAlignedBarrierOrKnownEnd)
3353 break;
3354 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3355 }
3356 if (HitAlignedBarrierOrKnownEnd)
3357 continue;
3358 BasicBlock *SyncBB = SyncInst->getParent();
3359 for (auto *PredBB : predecessors(SyncBB)) {
3360 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, SyncBB))
3361 continue;
3362 if (!Visited.insert(PredBB))
3363 continue;
3364 auto &PredED = BEDMap[PredBB];
3365 if (setAndRecord(PredED.IsReachingAlignedBarrierOnly, false)) {
3366 Changed = true;
3367 SyncInstWorklist.push_back(PredBB->getTerminator());
3368 }
3369 }
3370 if (SyncBB != &EntryBB)
3371 continue;
3372 Changed |=
3373 setAndRecord(InterProceduralED.IsReachingAlignedBarrierOnly, false);
3374 }
3375
3376 return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
3377}
3378
3379/// Try to replace memory allocation calls called by a single thread with a
3380/// static buffer of shared memory.
3381struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
3383 AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3384
3385 /// Create an abstract attribute view for the position \p IRP.
3386 static AAHeapToShared &createForPosition(const IRPosition &IRP,
3387 Attributor &A);
3388
3389 /// Returns true if HeapToShared conversion is assumed to be possible.
3390 virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
3391
3392 /// Returns true if HeapToShared conversion is assumed and the CB is a
3393 /// callsite to a free operation to be removed.
3394 virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
3395
3396 /// See AbstractAttribute::getName().
3397 const std::string getName() const override { return "AAHeapToShared"; }
3398
3399 /// See AbstractAttribute::getIdAddr().
3400 const char *getIdAddr() const override { return &ID; }
3401
3402 /// This function should return true if the type of the \p AA is
3403 /// AAHeapToShared.
3404 static bool classof(const AbstractAttribute *AA) {
3405 return (AA->getIdAddr() == &ID);
3406 }
3407
3408 /// Unique ID (due to the unique address)
3409 static const char ID;
3410};
3411
3412struct AAHeapToSharedFunction : public AAHeapToShared {
3413 AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
3414 : AAHeapToShared(IRP, A) {}
3415
3416 const std::string getAsStr(Attributor *) const override {
3417 return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
3418 " malloc calls eligible.";
3419 }
3420
3421 /// See AbstractAttribute::trackStatistics().
3422 void trackStatistics() const override {}
3423
3424 /// This functions finds free calls that will be removed by the
3425 /// HeapToShared transformation.
3426 void findPotentialRemovedFreeCalls(Attributor &A) {
3427 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3428 auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3429
3430 PotentialRemovedFreeCalls.clear();
3431 // Update free call users of found malloc calls.
3432 for (CallBase *CB : MallocCalls) {
3434 for (auto *U : CB->users()) {
3435 CallBase *C = dyn_cast<CallBase>(U);
3436 if (C && C->getCalledFunction() == FreeRFI.Declaration)
3437 FreeCalls.push_back(C);
3438 }
3439
3440 if (FreeCalls.size() != 1)
3441 continue;
3442
3443 PotentialRemovedFreeCalls.insert(FreeCalls.front());
3444 }
3445 }
3446
3447 void initialize(Attributor &A) override {
3449 indicatePessimisticFixpoint();
3450 return;
3451 }
3452
3453 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3454 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3455 if (!RFI.Declaration)
3456 return;
3457
3459 [](const IRPosition &, const AbstractAttribute *,
3460 bool &) -> std::optional<Value *> { return nullptr; };
3461
3462 Function *F = getAnchorScope();
3463 for (User *U : RFI.Declaration->users())
3464 if (CallBase *CB = dyn_cast<CallBase>(U)) {
3465 if (CB->getFunction() != F)
3466 continue;
3467 MallocCalls.insert(CB);
3468 A.registerSimplificationCallback(IRPosition::callsite_returned(*CB),
3469 SCB);
3470 }
3471
3472 findPotentialRemovedFreeCalls(A);
3473 }
3474
3475 bool isAssumedHeapToShared(CallBase &CB) const override {
3476 return isValidState() && MallocCalls.count(&CB);
3477 }
3478
3479 bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
3480 return isValidState() && PotentialRemovedFreeCalls.count(&CB);
3481 }
3482
3483 ChangeStatus manifest(Attributor &A) override {
3484 if (MallocCalls.empty())
3485 return ChangeStatus::UNCHANGED;
3486
3487 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3488 auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3489
3490 Function *F = getAnchorScope();
3491 auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
3492 DepClassTy::OPTIONAL);
3493
3494 ChangeStatus Changed = ChangeStatus::UNCHANGED;
3495 for (CallBase *CB : MallocCalls) {
3496 // Skip replacing this if HeapToStack has already claimed it.
3497 if (HS && HS->isAssumedHeapToStack(*CB))
3498 continue;
3499
3500 // Find the unique free call to remove it.
3502 for (auto *U : CB->users()) {
3503 CallBase *C = dyn_cast<CallBase>(U);
3504 if (C && C->getCalledFunction() == FreeCall.Declaration)
3505 FreeCalls.push_back(C);
3506 }
3507 if (FreeCalls.size() != 1)
3508 continue;
3509
3510 auto *AllocSize = cast<ConstantInt>(CB->getArgOperand(0));
3511
3512 if (AllocSize->getZExtValue() + SharedMemoryUsed > SharedMemoryLimit) {
3513 LLVM_DEBUG(dbgs() << TAG << "Cannot replace call " << *CB
3514 << " with shared memory."
3515 << " Shared memory usage is limited to "
3516 << SharedMemoryLimit << " bytes\n");
3517 continue;
3518 }
3519
3520 LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
3521 << " with " << AllocSize->getZExtValue()
3522 << " bytes of shared memory\n");
3523
3524 // Create a new shared memory buffer of the same size as the allocation
3525 // and replace all the uses of the original allocation with it.
3526 Module *M = CB->getModule();
3527 Type *Int8Ty = Type::getInt8Ty(M->getContext());
3528 Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
3529 auto *SharedMem = new GlobalVariable(
3530 *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
3531 PoisonValue::get(Int8ArrTy), CB->getName() + "_shared", nullptr,
3533 static_cast<unsigned>(AddressSpace::Shared));
3534 auto *NewBuffer = ConstantExpr::getPointerCast(
3535 SharedMem, PointerType::getUnqual(M->getContext()));
3536
3537 auto Remark = [&](OptimizationRemark OR) {
3538 return OR << "Replaced globalized variable with "
3539 << ore::NV("SharedMemory", AllocSize->getZExtValue())
3540 << (AllocSize->isOne() ? " byte " : " bytes ")
3541 << "of shared memory.";
3542 };
3543 A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
3544
3545 MaybeAlign Alignment = CB->getRetAlign();
3546 assert(Alignment &&
3547 "HeapToShared on allocation without alignment attribute");
3548 SharedMem->setAlignment(*Alignment);
3549
3550 A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewBuffer);
3551 A.deleteAfterManifest(*CB);
3552 A.deleteAfterManifest(*FreeCalls.front());
3553
3554 SharedMemoryUsed += AllocSize->getZExtValue();
3555 NumBytesMovedToSharedMemory = SharedMemoryUsed;
3556 Changed = ChangeStatus::CHANGED;
3557 }
3558
3559 return Changed;
3560 }
3561
3562 ChangeStatus updateImpl(Attributor &A) override {
3563 if (MallocCalls.empty())
3564 return indicatePessimisticFixpoint();
3565 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3566 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3567 if (!RFI.Declaration)
3568 return ChangeStatus::UNCHANGED;
3569
3570 Function *F = getAnchorScope();
3571
3572 auto NumMallocCalls = MallocCalls.size();
3573
3574 // Only consider malloc calls executed by a single thread with a constant.
3575 for (User *U : RFI.Declaration->users()) {
3576 if (CallBase *CB = dyn_cast<CallBase>(U)) {
3577 if (CB->getCaller() != F)
3578 continue;
3579 if (!MallocCalls.count(CB))
3580 continue;
3581 if (!isa<ConstantInt>(CB->getArgOperand(0))) {
3582 MallocCalls.remove(CB);
3583 continue;
3584 }
3585 const auto *ED = A.getAAFor<AAExecutionDomain>(
3586 *this, IRPosition::function(*F), DepClassTy::REQUIRED);
3587 if (!ED || !ED->isExecutedByInitialThreadOnly(*CB))
3588 MallocCalls.remove(CB);
3589 }
3590 }
3591
3592 findPotentialRemovedFreeCalls(A);
3593
3594 if (NumMallocCalls != MallocCalls.size())
3595 return ChangeStatus::CHANGED;
3596
3597 return ChangeStatus::UNCHANGED;
3598 }
3599
3600 /// Collection of all malloc calls in a function.
3602 /// Collection of potentially removed free calls in a function.
3603 SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
3604 /// The total amount of shared memory that has been used for HeapToShared.
3605 unsigned SharedMemoryUsed = 0;
3606};
3607
3608struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
3610 AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3611
3612 /// The callee value is tracked beyond a simple stripPointerCasts, so we allow
3613 /// unknown callees.
3614 static bool requiresCalleeForCallBase() { return false; }
3615
3616 /// Statistics are tracked as part of manifest for now.
3617 void trackStatistics() const override {}
3618
3619 /// See AbstractAttribute::getAsStr()
3620 const std::string getAsStr(Attributor *) const override {
3621 if (!isValidState())
3622 return "<invalid>";
3623 return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
3624 : "generic") +
3625 std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
3626 : "") +
3627 std::string(" #PRs: ") +
3628 (ReachedKnownParallelRegions.isValidState()
3629 ? std::to_string(ReachedKnownParallelRegions.size())
3630 : "<invalid>") +
3631 ", #Unknown PRs: " +
3632 (ReachedUnknownParallelRegions.isValidState()
3633 ? std::to_string(ReachedUnknownParallelRegions.size())
3634 : "<invalid>") +
3635 ", #Reaching Kernels: " +
3636 (ReachingKernelEntries.isValidState()
3637 ? std::to_string(ReachingKernelEntries.size())
3638 : "<invalid>") +
3639 ", #ParLevels: " +
3640 (ParallelLevels.isValidState()
3641 ? std::to_string(ParallelLevels.size())
3642 : "<invalid>") +
3643 ", NestedPar: " + (NestedParallelism ? "yes" : "no");
3644 }
3645
3646 /// Create an abstract attribute biew for the position \p IRP.
3647 static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
3648
3649 /// See AbstractAttribute::getName()
3650 const std::string getName() const override { return "AAKernelInfo"; }
3651
3652 /// See AbstractAttribute::getIdAddr()
3653 const char *getIdAddr() const override { return &ID; }
3654
3655 /// This function should return true if the type of the \p AA is AAKernelInfo
3656 static bool classof(const AbstractAttribute *AA) {
3657 return (AA->getIdAddr() == &ID);
3658 }
3659
3660 static const char ID;
3661};
3662
3663/// The function kernel info abstract attribute, basically, what can we say
3664/// about a function with regards to the KernelInfoState.
3665struct AAKernelInfoFunction : AAKernelInfo {
3666 AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
3667 : AAKernelInfo(IRP, A) {}
3668
3669 SmallPtrSet<Instruction *, 4> GuardedInstructions;
3670
3671 SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
3672 return GuardedInstructions;
3673 }
3674
3675 void setConfigurationOfKernelEnvironment(ConstantStruct *ConfigC) {
3677 KernelEnvC, ConfigC, {KernelInfo::ConfigurationIdx});
3678 assert(NewKernelEnvC && "Failed to create new kernel environment");
3679 KernelEnvC = cast<ConstantStruct>(NewKernelEnvC);
3680 }
3681
3682#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER) \
3683 void set##MEMBER##OfKernelEnvironment(ConstantInt *NewVal) { \
3684 ConstantStruct *ConfigC = \
3685 KernelInfo::getConfigurationFromKernelEnvironment(KernelEnvC); \
3686 Constant *NewConfigC = ConstantFoldInsertValueInstruction( \
3687 ConfigC, NewVal, {KernelInfo::MEMBER##Idx}); \
3688 assert(NewConfigC && "Failed to create new configuration environment"); \
3689 setConfigurationOfKernelEnvironment(cast<ConstantStruct>(NewConfigC)); \
3690 }
3691
3692 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(UseGenericStateMachine)
3693 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MayUseNestedParallelism)
3699
3700#undef KERNEL_ENVIRONMENT_CONFIGURATION_SETTER
3701
3702 /// See AbstractAttribute::initialize(...).
3703 void initialize(Attributor &A) override {
3704 // This is a high-level transform that might change the constant arguments
3705 // of the init and dinit calls. We need to tell the Attributor about this
3706 // to avoid other parts using the current constant value for simpliication.
3707 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3708
3709 Function *Fn = getAnchorScope();
3710
3711 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
3712 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
3713 OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
3714 OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
3715
3716 // For kernels we perform more initialization work, first we find the init
3717 // and deinit calls.
3718 auto StoreCallBase = [](Use &U,
3719 OMPInformationCache::RuntimeFunctionInfo &RFI,
3720 CallBase *&Storage) {
3721 CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
3722 assert(CB &&
3723 "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
3724 assert(!Storage &&
3725 "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
3726 Storage = CB;
3727 return false;
3728 };
3729 InitRFI.foreachUse(
3730 [&](Use &U, Function &) {
3731 StoreCallBase(U, InitRFI, KernelInitCB);
3732 return false;
3733 },
3734 Fn);
3735 DeinitRFI.foreachUse(
3736 [&](Use &U, Function &) {
3737 StoreCallBase(U, DeinitRFI, KernelDeinitCB);
3738 return false;
3739 },
3740 Fn);
3741
3742 // Ignore kernels without initializers such as global constructors.
3743 if (!KernelInitCB || !KernelDeinitCB)
3744 return;
3745
3746 // Add itself to the reaching kernel and set IsKernelEntry.
3747 ReachingKernelEntries.insert(Fn);
3748 IsKernelEntry = true;
3749
3750 KernelEnvC =
3752 GlobalVariable *KernelEnvGV =
3754
3756 KernelConfigurationSimplifyCB =
3757 [&](const GlobalVariable &GV, const AbstractAttribute *AA,
3758 bool &UsedAssumedInformation) -> std::optional<Constant *> {
3759 if (!isAtFixpoint()) {
3760 if (!AA)
3761 return nullptr;
3762 UsedAssumedInformation = true;
3763 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3764 }
3765 return KernelEnvC;
3766 };
3767
3768 A.registerGlobalVariableSimplificationCallback(
3769 *KernelEnvGV, KernelConfigurationSimplifyCB);
3770
3771 // We cannot change to SPMD mode if the runtime functions aren't availible.
3772 bool CanChangeToSPMD = OMPInfoCache.runtimeFnsAvailable(
3773 {OMPRTL___kmpc_get_hardware_thread_id_in_block,
3774 OMPRTL___kmpc_barrier_simple_spmd});
3775
3776 // Check if we know we are in SPMD-mode already.
3777 ConstantInt *ExecModeC =
3778 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
3779 ConstantInt *AssumedExecModeC = ConstantInt::get(
3780 ExecModeC->getIntegerType(),
3782 if (ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)
3783 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3784 else if (DisableOpenMPOptSPMDization || !CanChangeToSPMD)
3785 // This is a generic region but SPMDization is disabled so stop
3786 // tracking.
3787 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3788 else
3789 setExecModeOfKernelEnvironment(AssumedExecModeC);
3790
3791 const Triple T(Fn->getParent()->getTargetTriple());
3792 auto *Int32Ty = Type::getInt32Ty(Fn->getContext());
3793 auto [MinThreads, MaxThreads] =
3795 if (MinThreads)
3796 setMinThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinThreads));
3797 if (MaxThreads)
3798 setMaxThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxThreads));
3799 auto [MinTeams, MaxTeams] =
3801 if (MinTeams)
3802 setMinTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinTeams));
3803 if (MaxTeams)
3804 setMaxTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxTeams));
3805
3806 ConstantInt *MayUseNestedParallelismC =
3807 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(KernelEnvC);
3808 ConstantInt *AssumedMayUseNestedParallelismC = ConstantInt::get(
3809 MayUseNestedParallelismC->getIntegerType(), NestedParallelism);
3810 setMayUseNestedParallelismOfKernelEnvironment(
3811 AssumedMayUseNestedParallelismC);
3812
3814 ConstantInt *UseGenericStateMachineC =
3815 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3816 KernelEnvC);
3817 ConstantInt *AssumedUseGenericStateMachineC =
3818 ConstantInt::get(UseGenericStateMachineC->getIntegerType(), false);
3819 setUseGenericStateMachineOfKernelEnvironment(
3820 AssumedUseGenericStateMachineC);
3821 }
3822
3823 // Register virtual uses of functions we might need to preserve.
3824 auto RegisterVirtualUse = [&](RuntimeFunction RFKind,
3826 if (!OMPInfoCache.RFIs[RFKind].Declaration)
3827 return;
3828 A.registerVirtualUseCallback(*OMPInfoCache.RFIs[RFKind].Declaration, CB);
3829 };
3830
3831 // Add a dependence to ensure updates if the state changes.
3832 auto AddDependence = [](Attributor &A, const AAKernelInfo *KI,
3833 const AbstractAttribute *QueryingAA) {
3834 if (QueryingAA) {
3835 A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);
3836 }
3837 return true;
3838 };
3839
3840 Attributor::VirtualUseCallbackTy CustomStateMachineUseCB =
3841 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3842 // Whenever we create a custom state machine we will insert calls to
3843 // __kmpc_get_hardware_num_threads_in_block,
3844 // __kmpc_get_warp_size,
3845 // __kmpc_barrier_simple_generic,
3846 // __kmpc_kernel_parallel, and
3847 // __kmpc_kernel_end_parallel.
3848 // Not needed if we are on track for SPMDzation.
3849 if (SPMDCompatibilityTracker.isValidState())
3850 return AddDependence(A, this, QueryingAA);
3851 // Not needed if we can't rewrite due to an invalid state.
3852 if (!ReachedKnownParallelRegions.isValidState())
3853 return AddDependence(A, this, QueryingAA);
3854 return false;
3855 };
3856
3857 // Not needed if we are pre-runtime merge.
3858 if (!KernelInitCB->getCalledFunction()->isDeclaration()) {
3859 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_num_threads_in_block,
3860 CustomStateMachineUseCB);
3861 RegisterVirtualUse(OMPRTL___kmpc_get_warp_size, CustomStateMachineUseCB);
3862 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_generic,
3863 CustomStateMachineUseCB);
3864 RegisterVirtualUse(OMPRTL___kmpc_kernel_parallel,
3865 CustomStateMachineUseCB);
3866 RegisterVirtualUse(OMPRTL___kmpc_kernel_end_parallel,
3867 CustomStateMachineUseCB);
3868 }
3869
3870 // If we do not perform SPMDzation we do not need the virtual uses below.
3871 if (SPMDCompatibilityTracker.isAtFixpoint())
3872 return;
3873
3874 Attributor::VirtualUseCallbackTy HWThreadIdUseCB =
3875 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3876 // Whenever we perform SPMDzation we will insert
3877 // __kmpc_get_hardware_thread_id_in_block calls.
3878 if (!SPMDCompatibilityTracker.isValidState())
3879 return AddDependence(A, this, QueryingAA);
3880 return false;
3881 };
3882 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_thread_id_in_block,
3883 HWThreadIdUseCB);
3884
3885 Attributor::VirtualUseCallbackTy SPMDBarrierUseCB =
3886 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3887 // Whenever we perform SPMDzation with guarding we will insert
3888 // __kmpc_simple_barrier_spmd calls. If SPMDzation failed, there is
3889 // nothing to guard, or there are no parallel regions, we don't need
3890 // the calls.
3891 if (!SPMDCompatibilityTracker.isValidState())
3892 return AddDependence(A, this, QueryingAA);
3893 if (SPMDCompatibilityTracker.empty())
3894 return AddDependence(A, this, QueryingAA);
3895 if (!mayContainParallelRegion())
3896 return AddDependence(A, this, QueryingAA);
3897 return false;
3898 };
3899 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_spmd, SPMDBarrierUseCB);
3900 }
3901
3902 /// Sanitize the string \p S such that it is a suitable global symbol name.
3903 static std::string sanitizeForGlobalName(std::string S) {
3904 std::replace_if(
3905 S.begin(), S.end(),
3906 [](const char C) {
3907 return !((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') ||
3908 (C >= '0' && C <= '9') || C == '_');
3909 },
3910 '.');
3911 return S;
3912 }
3913
3914 /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
3915 /// finished now.
3916 ChangeStatus manifest(Attributor &A) override {
3917 // If we are not looking at a kernel with __kmpc_target_init and
3918 // __kmpc_target_deinit call we cannot actually manifest the information.
3919 if (!KernelInitCB || !KernelDeinitCB)
3920 return ChangeStatus::UNCHANGED;
3921
3922 ChangeStatus Changed = ChangeStatus::UNCHANGED;
3923
3924 bool HasBuiltStateMachine = true;
3925 if (!changeToSPMDMode(A, Changed)) {
3926 if (!KernelInitCB->getCalledFunction()->isDeclaration())
3927 HasBuiltStateMachine = buildCustomStateMachine(A, Changed);
3928 else
3929 HasBuiltStateMachine = false;
3930 }
3931
3932 // We need to reset KernelEnvC if specific rewriting is not done.
3933 ConstantStruct *ExistingKernelEnvC =
3935 ConstantInt *OldUseGenericStateMachineVal =
3936 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3937 ExistingKernelEnvC);
3938 if (!HasBuiltStateMachine)
3939 setUseGenericStateMachineOfKernelEnvironment(
3940 OldUseGenericStateMachineVal);
3941
3942 // At last, update the KernelEnvc
3943 GlobalVariable *KernelEnvGV =
3945 if (KernelEnvGV->getInitializer() != KernelEnvC) {
3946 KernelEnvGV->setInitializer(KernelEnvC);
3947 Changed = ChangeStatus::CHANGED;
3948 }
3949
3950 return Changed;
3951 }
3952
3953 void insertInstructionGuardsHelper(Attributor &A) {
3954 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3955
3956 auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3957 Instruction *RegionEndI) {
3958 LoopInfo *LI = nullptr;
3959 DominatorTree *DT = nullptr;
3960 MemorySSAUpdater *MSU = nullptr;
3961 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3962
3963 BasicBlock *ParentBB = RegionStartI->getParent();
3964 Function *Fn = ParentBB->getParent();
3965 Module &M = *Fn->getParent();
3966
3967 // Create all the blocks and logic.
3968 // ParentBB:
3969 // goto RegionCheckTidBB
3970 // RegionCheckTidBB:
3971 // Tid = __kmpc_hardware_thread_id()
3972 // if (Tid != 0)
3973 // goto RegionBarrierBB
3974 // RegionStartBB:
3975 // <execute instructions guarded>
3976 // goto RegionEndBB
3977 // RegionEndBB:
3978 // <store escaping values to shared mem>
3979 // goto RegionBarrierBB
3980 // RegionBarrierBB:
3981 // __kmpc_simple_barrier_spmd()
3982 // // second barrier is omitted if lacking escaping values.
3983 // <load escaping values from shared mem>
3984 // __kmpc_simple_barrier_spmd()
3985 // goto RegionExitBB
3986 // RegionExitBB:
3987 // <execute rest of instructions>
3988
3989 BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3990 DT, LI, MSU, "region.guarded.end");
3991 BasicBlock *RegionBarrierBB =
3992 SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3993 MSU, "region.barrier");
3994 BasicBlock *RegionExitBB =
3995 SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3996 DT, LI, MSU, "region.exit");
3997 BasicBlock *RegionStartBB =
3998 SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
3999
4000 assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
4001 "Expected a different CFG");
4002
4003 BasicBlock *RegionCheckTidBB = SplitBlock(
4004 ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
4005
4006 // Register basic blocks with the Attributor.
4007 A.registerManifestAddedBasicBlock(*RegionEndBB);
4008 A.registerManifestAddedBasicBlock(*RegionBarrierBB);
4009 A.registerManifestAddedBasicBlock(*RegionExitBB);
4010 A.registerManifestAddedBasicBlock(*RegionStartBB);
4011 A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
4012
4013 bool HasBroadcastValues = false;
4014 // Find escaping outputs from the guarded region to outside users and
4015 // broadcast their values to them.
4016 for (Instruction &I : *RegionStartBB) {
4017 SmallVector<Use *, 4> OutsideUses;
4018 for (Use &U : I.uses()) {
4019 Instruction &UsrI = *cast<Instruction>(U.getUser());
4020 if (UsrI.getParent() != RegionStartBB)
4021 OutsideUses.push_back(&U);
4022 }
4023
4024 if (OutsideUses.empty())
4025 continue;
4026
4027 HasBroadcastValues = true;
4028
4029 // Emit a global variable in shared memory to store the broadcasted
4030 // value.
4031 auto *SharedMem = new GlobalVariable(
4032 M, I.getType(), /* IsConstant */ false,
4034 sanitizeForGlobalName(
4035 (I.getName() + ".guarded.output.alloc").str()),
4037 static_cast<unsigned>(AddressSpace::Shared));
4038
4039 // Emit a store instruction to update the value.
4040 new StoreInst(&I, SharedMem,
4041 RegionEndBB->getTerminator()->getIterator());
4042
4043 LoadInst *LoadI = new LoadInst(
4044 I.getType(), SharedMem, I.getName() + ".guarded.output.load",
4045 RegionBarrierBB->getTerminator()->getIterator());
4046
4047 // Emit a load instruction and replace uses of the output value.
4048 for (Use *U : OutsideUses)
4049 A.changeUseAfterManifest(*U, *LoadI);
4050 }
4051
4052 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4053
4054 // Go to tid check BB in ParentBB.
4055 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
4056 ParentBB->getTerminator()->eraseFromParent();
4058 InsertPointTy(ParentBB, ParentBB->end()), DL);
4059 OMPInfoCache.OMPBuilder.updateToLocation(Loc);
4060 uint32_t SrcLocStrSize;
4061 auto *SrcLocStr =
4062 OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4063 Value *Ident =
4064 OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4065 BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
4066
4067 // Add check for Tid in RegionCheckTidBB
4068 RegionCheckTidBB->getTerminator()->eraseFromParent();
4069 OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
4070 InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
4071 OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
4072 FunctionCallee HardwareTidFn =
4073 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4074 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4075 CallInst *Tid =
4076 OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
4077 Tid->setDebugLoc(DL);
4078 OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);
4079 Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
4080 OMPInfoCache.OMPBuilder.Builder
4081 .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
4082 ->setDebugLoc(DL);
4083
4084 // First barrier for synchronization, ensures main thread has updated
4085 // values.
4086 FunctionCallee BarrierFn =
4087 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4088 M, OMPRTL___kmpc_barrier_simple_spmd);
4089 OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
4090 RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
4091 CallInst *Barrier =
4092 OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});
4093 Barrier->setDebugLoc(DL);
4094 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4095
4096 // Second barrier ensures workers have read broadcast values.
4097 if (HasBroadcastValues) {
4098 CallInst *Barrier =
4099 CallInst::Create(BarrierFn, {Ident, Tid}, "",
4100 RegionBarrierBB->getTerminator()->getIterator());
4101 Barrier->setDebugLoc(DL);
4102 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4103 }
4104 };
4105
4106 auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4108 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4109 BasicBlock *BB = GuardedI->getParent();
4110 if (!Visited.insert(BB).second)
4111 continue;
4112
4114 Instruction *LastEffect = nullptr;
4115 BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();
4116 while (++IP != IPEnd) {
4117 if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
4118 continue;
4119 Instruction *I = &*IP;
4120 if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI))
4121 continue;
4122 if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) {
4123 LastEffect = nullptr;
4124 continue;
4125 }
4126 if (LastEffect)
4127 Reorders.push_back({I, LastEffect});
4128 LastEffect = &*IP;
4129 }
4130 for (auto &Reorder : Reorders)
4131 Reorder.first->moveBefore(Reorder.second);
4132 }
4133
4135
4136 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4137 BasicBlock *BB = GuardedI->getParent();
4138 auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
4139 IRPosition::function(*GuardedI->getFunction()), nullptr,
4140 DepClassTy::NONE);
4141 assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
4142 auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
4143 // Continue if instruction is already guarded.
4144 if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
4145 continue;
4146
4147 Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
4148 for (Instruction &I : *BB) {
4149 // If instruction I needs to be guarded update the guarded region
4150 // bounds.
4151 if (SPMDCompatibilityTracker.contains(&I)) {
4152 CalleeAAFunction.getGuardedInstructions().insert(&I);
4153 if (GuardedRegionStart)
4154 GuardedRegionEnd = &I;
4155 else
4156 GuardedRegionStart = GuardedRegionEnd = &I;
4157
4158 continue;
4159 }
4160
4161 // Instruction I does not need guarding, store
4162 // any region found and reset bounds.
4163 if (GuardedRegionStart) {
4164 GuardedRegions.push_back(
4165 std::make_pair(GuardedRegionStart, GuardedRegionEnd));
4166 GuardedRegionStart = nullptr;
4167 GuardedRegionEnd = nullptr;
4168 }
4169 }
4170 }
4171
4172 for (auto &GR : GuardedRegions)
4173 CreateGuardedRegion(GR.first, GR.second);
4174 }
4175
4176 void forceSingleThreadPerWorkgroupHelper(Attributor &A) {
4177 // Only allow 1 thread per workgroup to continue executing the user code.
4178 //
4179 // InitCB = __kmpc_target_init(...)
4180 // ThreadIdInBlock = __kmpc_get_hardware_thread_id_in_block();
4181 // if (ThreadIdInBlock != 0) return;
4182 // UserCode:
4183 // // user code
4184 //
4185 auto &Ctx = getAnchorValue().getContext();
4186 Function *Kernel = getAssociatedFunction();
4187 assert(Kernel && "Expected an associated function!");
4188
4189 // Create block for user code to branch to from initial block.
4190 BasicBlock *InitBB = KernelInitCB->getParent();
4191 BasicBlock *UserCodeBB = InitBB->splitBasicBlock(
4192 KernelInitCB->getNextNode(), "main.thread.user_code");
4193 BasicBlock *ReturnBB =
4194 BasicBlock::Create(Ctx, "exit.threads", Kernel, UserCodeBB);
4195
4196 // Register blocks with attributor:
4197 A.registerManifestAddedBasicBlock(*InitBB);
4198 A.registerManifestAddedBasicBlock(*UserCodeBB);
4199 A.registerManifestAddedBasicBlock(*ReturnBB);
4200
4201 // Debug location:
4202 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4203 ReturnInst::Create(Ctx, ReturnBB)->setDebugLoc(DLoc);
4204 InitBB->getTerminator()->eraseFromParent();
4205
4206 // Prepare call to OMPRTL___kmpc_get_hardware_thread_id_in_block.
4207 Module &M = *Kernel->getParent();
4208 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4209 FunctionCallee ThreadIdInBlockFn =
4210 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4211 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4212
4213 // Get thread ID in block.
4214 CallInst *ThreadIdInBlock =
4215 CallInst::Create(ThreadIdInBlockFn, "thread_id.in.block", InitBB);
4216 OMPInfoCache.setCallingConvention(ThreadIdInBlockFn, ThreadIdInBlock);
4217 ThreadIdInBlock->setDebugLoc(DLoc);
4218
4219 // Eliminate all threads in the block with ID not equal to 0:
4220 Instruction *IsMainThread =
4221 ICmpInst::Create(ICmpInst::ICmp, CmpInst::ICMP_NE, ThreadIdInBlock,
4222 ConstantInt::get(ThreadIdInBlock->getType(), 0),
4223 "thread.is_main", InitBB);
4224 IsMainThread->setDebugLoc(DLoc);
4225 BranchInst::Create(ReturnBB, UserCodeBB, IsMainThread, InitBB);
4226 }
4227
4228 bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {
4229 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4230
4231 if (!SPMDCompatibilityTracker.isAssumed()) {
4232 for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
4233 if (!NonCompatibleI)
4234 continue;
4235
4236 // Skip diagnostics on calls to known OpenMP runtime functions for now.
4237 if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
4238 if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
4239 continue;
4240
4241 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4242 ORA << "Value has potential side effects preventing SPMD-mode "
4243 "execution";
4244 if (isa<CallBase>(NonCompatibleI)) {
4245 ORA << ". Add `[[omp::assume(\"ompx_spmd_amenable\")]]` to "
4246 "the called function to override";
4247 }
4248 return ORA << ".";
4249 };
4250 A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
4251 Remark);
4252
4253 LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
4254 << *NonCompatibleI << "\n");
4255 }
4256
4257 return false;
4258 }
4259
4260 // Get the actual kernel, could be the caller of the anchor scope if we have
4261 // a debug wrapper.
4262 Function *Kernel = getAnchorScope();
4263 if (Kernel->hasLocalLinkage()) {
4264 assert(Kernel->hasOneUse() && "Unexpected use of debug kernel wrapper.");
4265 auto *CB = cast<CallBase>(Kernel->user_back());
4266 Kernel = CB->getCaller();
4267 }
4268 assert(omp::isOpenMPKernel(*Kernel) && "Expected kernel function!");
4269
4270 // Check if the kernel is already in SPMD mode, if so, return success.
4271 ConstantStruct *ExistingKernelEnvC =
4273 auto *ExecModeC =
4274 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4275 const int8_t ExecModeVal = ExecModeC->getSExtValue();
4276 if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC)
4277 return true;
4278
4279 // We will now unconditionally modify the IR, indicate a change.
4280 Changed = ChangeStatus::CHANGED;
4281
4282 // Do not use instruction guards when no parallel is present inside
4283 // the target region.
4284 if (mayContainParallelRegion())
4285 insertInstructionGuardsHelper(A);
4286 else
4287 forceSingleThreadPerWorkgroupHelper(A);
4288
4289 // Adjust the global exec mode flag that tells the runtime what mode this
4290 // kernel is executed in.
4291 assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
4292 "Initially non-SPMD kernel has SPMD exec mode!");
4293 setExecModeOfKernelEnvironment(
4294 ConstantInt::get(ExecModeC->getIntegerType(),
4295 ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
4296
4297 ++NumOpenMPTargetRegionKernelsSPMD;
4298
4299 auto Remark = [&](OptimizationRemark OR) {
4300 return OR << "Transformed generic-mode kernel to SPMD-mode.";
4301 };
4302 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
4303 return true;
4304 };
4305
4306 bool buildCustomStateMachine(Attributor &A, ChangeStatus &Changed) {
4307 // If we have disabled state machine rewrites, don't make a custom one
4309 return false;
4310
4311 // Don't rewrite the state machine if we are not in a valid state.
4312 if (!ReachedKnownParallelRegions.isValidState())
4313 return false;
4314
4315 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4316 if (!OMPInfoCache.runtimeFnsAvailable(
4317 {OMPRTL___kmpc_get_hardware_num_threads_in_block,
4318 OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
4319 OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
4320 return false;
4321
4322 ConstantStruct *ExistingKernelEnvC =
4324
4325 // Check if the current configuration is non-SPMD and generic state machine.
4326 // If we already have SPMD mode or a custom state machine we do not need to
4327 // go any further. If it is anything but a constant something is weird and
4328 // we give up.
4329 ConstantInt *UseStateMachineC =
4330 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4331 ExistingKernelEnvC);
4332 ConstantInt *ModeC =
4333 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4334
4335 // If we are stuck with generic mode, try to create a custom device (=GPU)
4336 // state machine which is specialized for the parallel regions that are
4337 // reachable by the kernel.
4338 if (UseStateMachineC->isZero() ||
4340 return false;
4341
4342 Changed = ChangeStatus::CHANGED;
4343
4344 // If not SPMD mode, indicate we use a custom state machine now.
4345 setUseGenericStateMachineOfKernelEnvironment(
4346 ConstantInt::get(UseStateMachineC->getIntegerType(), false));
4347
4348 // If we don't actually need a state machine we are done here. This can
4349 // happen if there simply are no parallel regions. In the resulting kernel
4350 // all worker threads will simply exit right away, leaving the main thread
4351 // to do the work alone.
4352 if (!mayContainParallelRegion()) {
4353 ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
4354
4355 auto Remark = [&](OptimizationRemark OR) {
4356 return OR << "Removing unused state machine from generic-mode kernel.";
4357 };
4358 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
4359
4360 return true;
4361 }
4362
4363 // Keep track in the statistics of our new shiny custom state machine.
4364 if (ReachedUnknownParallelRegions.empty()) {
4365 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
4366
4367 auto Remark = [&](OptimizationRemark OR) {
4368 return OR << "Rewriting generic-mode kernel with a customized state "
4369 "machine.";
4370 };
4371 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
4372 } else {
4373 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
4374
4376 return OR << "Generic-mode kernel is executed with a customized state "
4377 "machine that requires a fallback.";
4378 };
4379 A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
4380
4381 // Tell the user why we ended up with a fallback.
4382 for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
4383 if (!UnknownParallelRegionCB)
4384 continue;
4385 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4386 return ORA << "Call may contain unknown parallel regions. Use "
4387 << "`[[omp::assume(\"omp_no_parallelism\")]]` to "
4388 "override.";
4389 };
4390 A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
4391 "OMP133", Remark);
4392 }
4393 }
4394
4395 // Create all the blocks:
4396 //
4397 // InitCB = __kmpc_target_init(...)
4398 // BlockHwSize =
4399 // __kmpc_get_hardware_num_threads_in_block();
4400 // WarpSize = __kmpc_get_warp_size();
4401 // BlockSize = BlockHwSize - WarpSize;
4402 // IsWorkerCheckBB: bool IsWorker = InitCB != -1;
4403 // if (IsWorker) {
4404 // if (InitCB >= BlockSize) return;
4405 // SMBeginBB: __kmpc_barrier_simple_generic(...);
4406 // void *WorkFn;
4407 // bool Active = __kmpc_kernel_parallel(&WorkFn);
4408 // if (!WorkFn) return;
4409 // SMIsActiveCheckBB: if (Active) {
4410 // SMIfCascadeCurrentBB: if (WorkFn == <ParFn0>)
4411 // ParFn0(...);
4412 // SMIfCascadeCurrentBB: else if (WorkFn == <ParFn1>)
4413 // ParFn1(...);
4414 // ...
4415 // SMIfCascadeCurrentBB: else
4416 // ((WorkFnTy*)WorkFn)(...);
4417 // SMEndParallelBB: __kmpc_kernel_end_parallel(...);
4418 // }
4419 // SMDoneBB: __kmpc_barrier_simple_generic(...);
4420 // goto SMBeginBB;
4421 // }
4422 // UserCodeEntryBB: // user code
4423 // __kmpc_target_deinit(...)
4424 //
4425 auto &Ctx = getAnchorValue().getContext();
4426 Function *Kernel = getAssociatedFunction();
4427 assert(Kernel && "Expected an associated function!");
4428
4429 BasicBlock *InitBB = KernelInitCB->getParent();
4430 BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
4431 KernelInitCB->getNextNode(), "thread.user_code.check");
4432 BasicBlock *IsWorkerCheckBB =
4433 BasicBlock::Create(Ctx, "is_worker_check", Kernel, UserCodeEntryBB);
4434 BasicBlock *StateMachineBeginBB = BasicBlock::Create(
4435 Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
4436 BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
4437 Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
4438 BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
4439 Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
4440 BasicBlock *StateMachineIfCascadeCurrentBB =
4441 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4442 Kernel, UserCodeEntryBB);
4443 BasicBlock *StateMachineEndParallelBB =
4444 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
4445 Kernel, UserCodeEntryBB);
4446 BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
4447 Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
4448 A.registerManifestAddedBasicBlock(*InitBB);
4449 A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
4450 A.registerManifestAddedBasicBlock(*IsWorkerCheckBB);
4451 A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
4452 A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
4453 A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
4454 A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
4455 A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
4456 A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
4457
4458 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4459 ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
4460 InitBB->getTerminator()->eraseFromParent();
4461
4462 Instruction *IsWorker =
4463 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
4464 ConstantInt::get(KernelInitCB->getType(), -1),
4465 "thread.is_worker", InitBB);
4466 IsWorker->setDebugLoc(DLoc);
4467 BranchInst::Create(IsWorkerCheckBB, UserCodeEntryBB, IsWorker, InitBB);
4468
4469 Module &M = *Kernel->getParent();
4470 FunctionCallee BlockHwSizeFn =
4471 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4472 M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
4473 FunctionCallee WarpSizeFn =
4474 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4475 M, OMPRTL___kmpc_get_warp_size);
4476 CallInst *BlockHwSize =
4477 CallInst::Create(BlockHwSizeFn, "block.hw_size", IsWorkerCheckBB);
4478 OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);
4479 BlockHwSize->setDebugLoc(DLoc);
4480 CallInst *WarpSize =
4481 CallInst::Create(WarpSizeFn, "warp.size", IsWorkerCheckBB);
4482 OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);
4483 WarpSize->setDebugLoc(DLoc);
4484 Instruction *BlockSize = BinaryOperator::CreateSub(
4485 BlockHwSize, WarpSize, "block.size", IsWorkerCheckBB);
4486 BlockSize->setDebugLoc(DLoc);
4487 Instruction *IsMainOrWorker = ICmpInst::Create(
4488 ICmpInst::ICmp, llvm::CmpInst::ICMP_SLT, KernelInitCB, BlockSize,
4489 "thread.is_main_or_worker", IsWorkerCheckBB);
4490 IsMainOrWorker->setDebugLoc(DLoc);
4491 BranchInst::Create(StateMachineBeginBB, StateMachineFinishedBB,
4492 IsMainOrWorker, IsWorkerCheckBB);
4493
4494 // Create local storage for the work function pointer.
4495 const DataLayout &DL = M.getDataLayout();
4496 Type *VoidPtrTy = PointerType::getUnqual(Ctx);
4497 Instruction *WorkFnAI =
4498 new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
4499 "worker.work_fn.addr", Kernel->getEntryBlock().begin());
4500 WorkFnAI->setDebugLoc(DLoc);
4501
4502 OMPInfoCache.OMPBuilder.updateToLocation(
4504 IRBuilder<>::InsertPoint(StateMachineBeginBB,
4505 StateMachineBeginBB->end()),
4506 DLoc));
4507
4508 Value *Ident = KernelInfo::getIdentFromKernelEnvironment(KernelEnvC);
4509 Value *GTid = KernelInitCB;
4510
4511 FunctionCallee BarrierFn =
4512 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4513 M, OMPRTL___kmpc_barrier_simple_generic);
4514 CallInst *Barrier =
4515 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB);
4516 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4517 Barrier->setDebugLoc(DLoc);
4518
4519 if (WorkFnAI->getType()->getPointerAddressSpace() !=
4520 (unsigned int)AddressSpace::Generic) {
4521 WorkFnAI = new AddrSpaceCastInst(
4522 WorkFnAI, PointerType::get(Ctx, (unsigned int)AddressSpace::Generic),
4523 WorkFnAI->getName() + ".generic", StateMachineBeginBB);
4524 WorkFnAI->setDebugLoc(DLoc);
4525 }
4526
4527 FunctionCallee KernelParallelFn =
4528 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4529 M, OMPRTL___kmpc_kernel_parallel);
4530 CallInst *IsActiveWorker = CallInst::Create(
4531 KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
4532 OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);
4533 IsActiveWorker->setDebugLoc(DLoc);
4534 Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
4535 StateMachineBeginBB);
4536 WorkFn->setDebugLoc(DLoc);
4537
4538 FunctionType *ParallelRegionFnTy = FunctionType::get(
4540 false);
4541
4542 Instruction *IsDone =
4543 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
4544 Constant::getNullValue(VoidPtrTy), "worker.is_done",
4545 StateMachineBeginBB);
4546 IsDone->setDebugLoc(DLoc);
4547 BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
4548 IsDone, StateMachineBeginBB)
4549 ->setDebugLoc(DLoc);
4550
4551 BranchInst::Create(StateMachineIfCascadeCurrentBB,
4552 StateMachineDoneBarrierBB, IsActiveWorker,
4553 StateMachineIsActiveCheckBB)
4554 ->setDebugLoc(DLoc);
4555
4556 Value *ZeroArg =
4557 Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
4558
4559 const unsigned int WrapperFunctionArgNo = 6;
4560
4561 // Now that we have most of the CFG skeleton it is time for the if-cascade
4562 // that checks the function pointer we got from the runtime against the
4563 // parallel regions we expect, if there are any.
4564 for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) {
4565 auto *CB = ReachedKnownParallelRegions[I];
4566 auto *ParallelRegion = dyn_cast<Function>(
4567 CB->getArgOperand(WrapperFunctionArgNo)->stripPointerCasts());
4568 BasicBlock *PRExecuteBB = BasicBlock::Create(
4569 Ctx, "worker_state_machine.parallel_region.execute", Kernel,
4570 StateMachineEndParallelBB);
4571 CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
4572 ->setDebugLoc(DLoc);
4573 BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
4574 ->setDebugLoc(DLoc);
4575
4576 BasicBlock *PRNextBB =
4577 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4578 Kernel, StateMachineEndParallelBB);
4579 A.registerManifestAddedBasicBlock(*PRExecuteBB);
4580 A.registerManifestAddedBasicBlock(*PRNextBB);
4581
4582 // Check if we need to compare the pointer at all or if we can just
4583 // call the parallel region function.
4584 Value *IsPR;
4585 if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) {
4586 Instruction *CmpI = ICmpInst::Create(
4587 ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn, ParallelRegion,
4588 "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
4589 CmpI->setDebugLoc(DLoc);
4590 IsPR = CmpI;
4591 } else {
4592 IsPR = ConstantInt::getTrue(Ctx);
4593 }
4594
4595 BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
4596 StateMachineIfCascadeCurrentBB)
4597 ->setDebugLoc(DLoc);
4598 StateMachineIfCascadeCurrentBB = PRNextBB;
4599 }
4600
4601 // At the end of the if-cascade we place the indirect function pointer call
4602 // in case we might need it, that is if there can be parallel regions we
4603 // have not handled in the if-cascade above.
4604 if (!ReachedUnknownParallelRegions.empty()) {
4605 StateMachineIfCascadeCurrentBB->setName(
4606 "worker_state_machine.parallel_region.fallback.execute");
4607 CallInst::Create(ParallelRegionFnTy, WorkFn, {ZeroArg, GTid}, "",
4608 StateMachineIfCascadeCurrentBB)
4609 ->setDebugLoc(DLoc);
4610 }
4611 BranchInst::Create(StateMachineEndParallelBB,
4612 StateMachineIfCascadeCurrentBB)
4613 ->setDebugLoc(DLoc);
4614
4615 FunctionCallee EndParallelFn =
4616 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4617 M, OMPRTL___kmpc_kernel_end_parallel);
4618 CallInst *EndParallel =
4619 CallInst::Create(EndParallelFn, {}, "", StateMachineEndParallelBB);
4620 OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);
4621 EndParallel->setDebugLoc(DLoc);
4622 BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
4623 ->setDebugLoc(DLoc);
4624
4625 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
4626 ->setDebugLoc(DLoc);
4627 BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
4628 ->setDebugLoc(DLoc);
4629
4630 return true;
4631 }
4632
4633 /// Fixpoint iteration update function. Will be called every time a dependence
4634 /// changed its state (and in the beginning).
4635 ChangeStatus updateImpl(Attributor &A) override {
4636 KernelInfoState StateBefore = getState();
4637
4638 // When we leave this function this RAII will make sure the member
4639 // KernelEnvC is updated properly depending on the state. That member is
4640 // used for simplification of values and needs to be up to date at all
4641 // times.
4642 struct UpdateKernelEnvCRAII {
4643 AAKernelInfoFunction &AA;
4644
4645 UpdateKernelEnvCRAII(AAKernelInfoFunction &AA) : AA(AA) {}
4646
4647 ~UpdateKernelEnvCRAII() {
4648 if (!AA.KernelEnvC)
4649 return;
4650
4651 ConstantStruct *ExistingKernelEnvC =
4653
4654 if (!AA.isValidState()) {
4655 AA.KernelEnvC = ExistingKernelEnvC;
4656 return;
4657 }
4658
4659 if (!AA.ReachedKnownParallelRegions.isValidState())
4660 AA.setUseGenericStateMachineOfKernelEnvironment(
4661 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4662 ExistingKernelEnvC));
4663
4664 if (!AA.SPMDCompatibilityTracker.isValidState())
4665 AA.setExecModeOfKernelEnvironment(
4666 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC));
4667
4668 ConstantInt *MayUseNestedParallelismC =
4669 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(
4670 AA.KernelEnvC);
4671 ConstantInt *NewMayUseNestedParallelismC = ConstantInt::get(
4672 MayUseNestedParallelismC->getIntegerType(), AA.NestedParallelism);
4673 AA.setMayUseNestedParallelismOfKernelEnvironment(
4674 NewMayUseNestedParallelismC);
4675 }
4676 } RAII(*this);
4677
4678 // Callback to check a read/write instruction.
4679 auto CheckRWInst = [&](Instruction &I) {
4680 // We handle calls later.
4681 if (isa<CallBase>(I))
4682 return true;
4683 // We only care about write effects.
4684 if (!I.mayWriteToMemory())
4685 return true;
4686 if (auto *SI = dyn_cast<StoreInst>(&I)) {
4687 const auto *UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>(
4688 *this, IRPosition::value(*SI->getPointerOperand()),
4689 DepClassTy::OPTIONAL);
4690 auto *HS = A.getAAFor<AAHeapToStack>(
4691 *this, IRPosition::function(*I.getFunction()),
4692 DepClassTy::OPTIONAL);
4693 if (UnderlyingObjsAA &&
4694 UnderlyingObjsAA->forallUnderlyingObjects([&](Value &Obj) {
4695 if (AA::isAssumedThreadLocalObject(A, Obj, *this))
4696 return true;
4697 // Check for AAHeapToStack moved objects which must not be
4698 // guarded.
4699 auto *CB = dyn_cast<CallBase>(&Obj);
4700 return CB && HS && HS->isAssumedHeapToStack(*CB);
4701 }))
4702 return true;
4703 }
4704
4705 // Insert instruction that needs guarding.
4706 SPMDCompatibilityTracker.insert(&I);
4707 return true;
4708 };
4709
4710 bool UsedAssumedInformationInCheckRWInst = false;
4711 if (!SPMDCompatibilityTracker.isAtFixpoint())
4712 if (!A.checkForAllReadWriteInstructions(
4713 CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
4714 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4715
4716 bool UsedAssumedInformationFromReachingKernels = false;
4717 if (!IsKernelEntry) {
4718 updateParallelLevels(A);
4719
4720 bool AllReachingKernelsKnown = true;
4721 updateReachingKernelEntries(A, AllReachingKernelsKnown);
4722 UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
4723
4724 if (!SPMDCompatibilityTracker.empty()) {
4725 if (!ParallelLevels.isValidState())
4726 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4727 else if (!ReachingKernelEntries.isValidState())
4728 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4729 else {
4730 // Check if all reaching kernels agree on the mode as we can otherwise
4731 // not guard instructions. We might not be sure about the mode so we
4732 // we cannot fix the internal spmd-zation state either.
4733 int SPMD = 0, Generic = 0;
4734 for (auto *Kernel : ReachingKernelEntries) {
4735 auto *CBAA = A.getAAFor<AAKernelInfo>(
4736 *this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL);
4737 if (CBAA && CBAA->SPMDCompatibilityTracker.isValidState() &&
4738 CBAA->SPMDCompatibilityTracker.isAssumed())
4739 ++SPMD;
4740 else
4741 ++Generic;
4742 if (!CBAA || !CBAA->SPMDCompatibilityTracker.isAtFixpoint())
4743 UsedAssumedInformationFromReachingKernels = true;
4744 }
4745 if (SPMD != 0 && Generic != 0)
4746 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4747 }
4748 }
4749 }
4750
4751 // Callback to check a call instruction.
4752 bool AllParallelRegionStatesWereFixed = true;
4753 bool AllSPMDStatesWereFixed = true;
4754 auto CheckCallInst = [&](Instruction &I) {
4755 auto &CB = cast<CallBase>(I);
4756 auto *CBAA = A.getAAFor<AAKernelInfo>(
4757 *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4758 if (!CBAA)
4759 return false;
4760 getState() ^= CBAA->getState();
4761 AllSPMDStatesWereFixed &= CBAA->SPMDCompatibilityTracker.isAtFixpoint();
4762 AllParallelRegionStatesWereFixed &=
4763 CBAA->ReachedKnownParallelRegions.isAtFixpoint();
4764 AllParallelRegionStatesWereFixed &=
4765 CBAA->ReachedUnknownParallelRegions.isAtFixpoint();
4766 return true;
4767 };
4768
4769 bool UsedAssumedInformationInCheckCallInst = false;
4770 if (!A.checkForAllCallLikeInstructions(
4771 CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) {
4772 LLVM_DEBUG(dbgs() << TAG
4773 << "Failed to visit all call-like instructions!\n";);
4774 return indicatePessimisticFixpoint();
4775 }
4776
4777 // If we haven't used any assumed information for the reached parallel
4778 // region states we can fix it.
4779 if (!UsedAssumedInformationInCheckCallInst &&
4780 AllParallelRegionStatesWereFixed) {
4781 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
4782 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
4783 }
4784
4785 // If we haven't used any assumed information for the SPMD state we can fix
4786 // it.
4787 if (!UsedAssumedInformationInCheckRWInst &&
4788 !UsedAssumedInformationInCheckCallInst &&
4789 !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
4790 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4791
4792 return StateBefore == getState() ? ChangeStatus::UNCHANGED
4793 : ChangeStatus::CHANGED;
4794 }
4795
4796private:
4797 /// Update info regarding reaching kernels.
4798 void updateReachingKernelEntries(Attributor &A,
4799 bool &AllReachingKernelsKnown) {
4800 auto PredCallSite = [&](AbstractCallSite ACS) {
4801 Function *Caller = ACS.getInstruction()->getFunction();
4802
4803 assert(Caller && "Caller is nullptr");
4804
4805 auto *CAA = A.getOrCreateAAFor<AAKernelInfo>(
4806 IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
4807 if (CAA && CAA->ReachingKernelEntries.isValidState()) {
4808 ReachingKernelEntries ^= CAA->ReachingKernelEntries;
4809 return true;
4810 }
4811
4812 // We lost track of the caller of the associated function, any kernel
4813 // could reach now.
4814 ReachingKernelEntries.indicatePessimisticFixpoint();
4815
4816 return true;
4817 };
4818
4819 if (!A.checkForAllCallSites(PredCallSite, *this,
4820 true /* RequireAllCallSites */,
4821 AllReachingKernelsKnown))
4822 ReachingKernelEntries.indicatePessimisticFixpoint();
4823 }
4824
4825 /// Update info regarding parallel levels.
4826 void updateParallelLevels(Attributor &A) {
4827 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4828 OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
4829 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
4830
4831 auto PredCallSite = [&](AbstractCallSite ACS) {
4832 Function *Caller = ACS.getInstruction()->getFunction();
4833
4834 assert(Caller && "Caller is nullptr");
4835
4836 auto *CAA =
4837 A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
4838 if (CAA && CAA->ParallelLevels.isValidState()) {
4839 // Any function that is called by `__kmpc_parallel_51` will not be
4840 // folded as the parallel level in the function is updated. In order to
4841 // get it right, all the analysis would depend on the implentation. That
4842 // said, if in the future any change to the implementation, the analysis
4843 // could be wrong. As a consequence, we are just conservative here.
4844 if (Caller == Parallel51RFI.Declaration) {
4845 ParallelLevels.indicatePessimisticFixpoint();
4846 return true;
4847 }
4848
4849 ParallelLevels ^= CAA->ParallelLevels;
4850
4851 return true;
4852 }
4853
4854 // We lost track of the caller of the associated function, any kernel
4855 // could reach now.
4856 ParallelLevels.indicatePessimisticFixpoint();
4857
4858 return true;
4859 };
4860
4861 bool AllCallSitesKnown = true;
4862 if (!A.checkForAllCallSites(PredCallSite, *this,
4863 true /* RequireAllCallSites */,
4864 AllCallSitesKnown))
4865 ParallelLevels.indicatePessimisticFixpoint();
4866 }
4867};
4868
4869/// The call site kernel info abstract attribute, basically, what can we say
4870/// about a call site with regards to the KernelInfoState. For now this simply
4871/// forwards the information from the callee.
4872struct AAKernelInfoCallSite : AAKernelInfo {
4873 AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
4874 : AAKernelInfo(IRP, A) {}
4875
4876 /// See AbstractAttribute::initialize(...).
4877 void initialize(Attributor &A) override {
4878 AAKernelInfo::initialize(A);
4879
4880 CallBase &CB = cast<CallBase>(getAssociatedValue());
4881 auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>(
4882 *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4883
4884 // Check for SPMD-mode assumptions.
4885 if (AssumptionAA && AssumptionAA->hasAssumption("ompx_spmd_amenable")) {
4886 indicateOptimisticFixpoint();
4887 return;
4888 }
4889
4890 // First weed out calls we do not care about, that is readonly/readnone
4891 // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
4892 // parallel region or anything else we are looking for.
4893 if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
4894 indicateOptimisticFixpoint();
4895 return;
4896 }
4897
4898 // Next we check if we know the callee. If it is a known OpenMP function
4899 // we will handle them explicitly in the switch below. If it is not, we
4900 // will use an AAKernelInfo object on the callee to gather information and
4901 // merge that into the current state. The latter happens in the updateImpl.
4902 auto CheckCallee = [&](Function *Callee, unsigned NumCallees) {
4903 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4904 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4905 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4906 // Unknown caller or declarations are not analyzable, we give up.
4907 if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
4908
4909 // Unknown callees might contain parallel regions, except if they have
4910 // an appropriate assumption attached.
4911 if (!AssumptionAA ||
4912 !(AssumptionAA->hasAssumption("omp_no_openmp") ||
4913 AssumptionAA->hasAssumption("omp_no_parallelism")))
4914 ReachedUnknownParallelRegions.insert(&CB);
4915
4916 // If SPMDCompatibilityTracker is not fixed, we need to give up on the
4917 // idea we can run something unknown in SPMD-mode.
4918 if (!SPMDCompatibilityTracker.isAtFixpoint()) {
4919 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4920 SPMDCompatibilityTracker.insert(&CB);
4921 }
4922
4923 // We have updated the state for this unknown call properly, there
4924 // won't be any change so we indicate a fixpoint.
4925 indicateOptimisticFixpoint();
4926 }
4927 // If the callee is known and can be used in IPO, we will update the
4928 // state based on the callee state in updateImpl.
4929 return;
4930 }
4931 if (NumCallees > 1) {
4932 indicatePessimisticFixpoint();
4933 return;
4934 }
4935
4936 RuntimeFunction RF = It->getSecond();
4937 switch (RF) {
4938 // All the functions we know are compatible with SPMD mode.
4939 case OMPRTL___kmpc_is_spmd_exec_mode:
4940 case OMPRTL___kmpc_distribute_static_fini:
4941 case OMPRTL___kmpc_for_static_fini:
4942 case OMPRTL___kmpc_global_thread_num:
4943 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4944 case OMPRTL___kmpc_get_hardware_num_blocks:
4945 case OMPRTL___kmpc_single:
4946 case OMPRTL___kmpc_end_single:
4947 case OMPRTL___kmpc_master:
4948 case OMPRTL___kmpc_end_master:
4949 case OMPRTL___kmpc_barrier:
4950 case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
4951 case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
4952 case OMPRTL___kmpc_error:
4953 case OMPRTL___kmpc_flush:
4954 case OMPRTL___kmpc_get_hardware_thread_id_in_block:
4955 case OMPRTL___kmpc_get_warp_size:
4956 case OMPRTL_omp_get_thread_num:
4957 case OMPRTL_omp_get_num_threads:
4958 case OMPRTL_omp_get_max_threads:
4959 case OMPRTL_omp_in_parallel:
4960 case OMPRTL_omp_get_dynamic:
4961 case OMPRTL_omp_get_cancellation:
4962 case OMPRTL_omp_get_nested:
4963 case OMPRTL_omp_get_schedule:
4964 case OMPRTL_omp_get_thread_limit:
4965 case OMPRTL_omp_get_supported_active_levels:
4966 case OMPRTL_omp_get_max_active_levels:
4967 case OMPRTL_omp_get_level:
4968 case OMPRTL_omp_get_ancestor_thread_num:
4969 case OMPRTL_omp_get_team_size:
4970 case OMPRTL_omp_get_active_level:
4971 case OMPRTL_omp_in_final:
4972 case OMPRTL_omp_get_proc_bind:
4973 case OMPRTL_omp_get_num_places:
4974 case OMPRTL_omp_get_num_procs:
4975 case OMPRTL_omp_get_place_proc_ids:
4976 case OMPRTL_omp_get_place_num:
4977 case OMPRTL_omp_get_partition_num_places:
4978 case OMPRTL_omp_get_partition_place_nums:
4979 case OMPRTL_omp_get_wtime:
4980 break;
4981 case OMPRTL___kmpc_distribute_static_init_4:
4982 case OMPRTL___kmpc_distribute_static_init_4u:
4983 case OMPRTL___kmpc_distribute_static_init_8:
4984 case OMPRTL___kmpc_distribute_static_init_8u:
4985 case OMPRTL___kmpc_for_static_init_4:
4986 case OMPRTL___kmpc_for_static_init_4u:
4987 case OMPRTL___kmpc_for_static_init_8:
4988 case OMPRTL___kmpc_for_static_init_8u: {
4989 // Check the schedule and allow static schedule in SPMD mode.
4990 unsigned ScheduleArgOpNo = 2;
4991 auto *ScheduleTypeCI =
4992 dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
4993 unsigned ScheduleTypeVal =
4994 ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
4995 switch (OMPScheduleType(ScheduleTypeVal)) {
4996 case OMPScheduleType::UnorderedStatic:
4997 case OMPScheduleType::UnorderedStaticChunked:
4998 case OMPScheduleType::OrderedDistribute:
4999 case OMPScheduleType::OrderedDistributeChunked:
5000 break;
5001 default:
5002 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5003 SPMDCompatibilityTracker.insert(&CB);
5004 break;
5005 };
5006 } break;
5007 case OMPRTL___kmpc_target_init:
5008 KernelInitCB = &CB;
5009 break;
5010 case OMPRTL___kmpc_target_deinit:
5011 KernelDeinitCB = &CB;
5012 break;
5013 case OMPRTL___kmpc_parallel_51:
5014 if (!handleParallel51(A, CB))
5015 indicatePessimisticFixpoint();
5016 return;
5017 case OMPRTL___kmpc_omp_task:
5018 // We do not look into tasks right now, just give up.
5019 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5020 SPMDCompatibilityTracker.insert(&CB);
5021 ReachedUnknownParallelRegions.insert(&CB);
5022 break;
5023 case OMPRTL___kmpc_alloc_shared:
5024 case OMPRTL___kmpc_free_shared:
5025 // Return without setting a fixpoint, to be resolved in updateImpl.
5026 return;
5027 default:
5028 // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
5029 // generally. However, they do not hide parallel regions.
5030 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5031 SPMDCompatibilityTracker.insert(&CB);
5032 break;
5033 }
5034 // All other OpenMP runtime calls will not reach parallel regions so they
5035 // can be safely ignored for now. Since it is a known OpenMP runtime call
5036 // we have now modeled all effects and there is no need for any update.
5037 indicateOptimisticFixpoint();
5038 };
5039
5040 const auto *AACE =
5041 A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5042 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5043 CheckCallee(getAssociatedFunction(), 1);
5044 return;
5045 }
5046 const auto &OptimisticEdges = AACE->getOptimisticEdges();
5047 for (auto *Callee : OptimisticEdges) {
5048 CheckCallee(Callee, OptimisticEdges.size());
5049 if (isAtFixpoint())
5050 break;
5051 }
5052 }
5053
5054 ChangeStatus updateImpl(Attributor &A) override {
5055 // TODO: Once we have call site specific value information we can provide
5056 // call site specific liveness information and then it makes
5057 // sense to specialize attributes for call sites arguments instead of
5058 // redirecting requests to the callee argument.
5059 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5060 KernelInfoState StateBefore = getState();
5061
5062 auto CheckCallee = [&](Function *F, int NumCallees) {
5063 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
5064
5065 // If F is not a runtime function, propagate the AAKernelInfo of the
5066 // callee.
5067 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
5068 const IRPosition &FnPos = IRPosition::function(*F);
5069 auto *FnAA =
5070 A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
5071 if (!FnAA)
5072 return indicatePessimisticFixpoint();
5073 if (getState() == FnAA->getState())
5074 return ChangeStatus::UNCHANGED;
5075 getState() = FnAA->getState();
5076 return ChangeStatus::CHANGED;
5077 }
5078 if (NumCallees > 1)
5079 return indicatePessimisticFixpoint();
5080
5081 CallBase &CB = cast<CallBase>(getAssociatedValue());
5082 if (It->getSecond() == OMPRTL___kmpc_parallel_51) {
5083 if (!handleParallel51(A, CB))
5084 return indicatePessimisticFixpoint();
5085 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5086 : ChangeStatus::CHANGED;
5087 }
5088
5089 // F is a runtime function that allocates or frees memory, check
5090 // AAHeapToStack and AAHeapToShared.
5091 assert(
5092 (It->getSecond() == OMPRTL___kmpc_alloc_shared ||
5093 It->getSecond() == OMPRTL___kmpc_free_shared) &&
5094 "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
5095
5096 auto *HeapToStackAA = A.getAAFor<AAHeapToStack>(
5097 *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5098 auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>(
5099 *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5100
5101 RuntimeFunction RF = It->getSecond();
5102
5103 switch (RF) {
5104 // If neither HeapToStack nor HeapToShared assume the call is removed,
5105 // assume SPMD incompatibility.
5106 case OMPRTL___kmpc_alloc_shared:
5107 if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&
5108 (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))
5109 SPMDCompatibilityTracker.insert(&CB);
5110 break;
5111 case OMPRTL___kmpc_free_shared:
5112 if ((!HeapToStackAA ||
5113 !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&
5114 (!HeapToSharedAA ||
5115 !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))
5116 SPMDCompatibilityTracker.insert(&CB);
5117 break;
5118 default:
5119 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5120 SPMDCompatibilityTracker.insert(&CB);
5121 }
5122 return ChangeStatus::CHANGED;
5123 };
5124
5125 const auto *AACE =
5126 A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5127 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5128 if (Function *F = getAssociatedFunction())
5129 CheckCallee(F, /*NumCallees=*/1);
5130 } else {
5131 const auto &OptimisticEdges = AACE->getOptimisticEdges();
5132 for (auto *Callee : OptimisticEdges) {
5133 CheckCallee(Callee, OptimisticEdges.size());
5134 if (isAtFixpoint())
5135 break;
5136 }
5137 }
5138
5139 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5140 : ChangeStatus::CHANGED;
5141 }
5142
5143 /// Deal with a __kmpc_parallel_51 call (\p CB). Returns true if the call was
5144 /// handled, if a problem occurred, false is returned.
5145 bool handleParallel51(Attributor &A, CallBase &CB) {
5146 const unsigned int NonWrapperFunctionArgNo = 5;
5147 const unsigned int WrapperFunctionArgNo = 6;
5148 auto ParallelRegionOpArgNo = SPMDCompatibilityTracker.isAssumed()
5149 ? NonWrapperFunctionArgNo
5150 : WrapperFunctionArgNo;
5151
5152 auto *ParallelRegion = dyn_cast<Function>(
5153 CB.getArgOperand(ParallelRegionOpArgNo)->stripPointerCasts());
5154 if (!ParallelRegion)
5155 return false;
5156
5157 ReachedKnownParallelRegions.insert(&CB);
5158 /// Check nested parallelism
5159 auto *FnAA = A.getAAFor<AAKernelInfo>(
5160 *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL);
5161 NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
5162 !FnAA->ReachedKnownParallelRegions.empty() ||
5163 !FnAA->ReachedKnownParallelRegions.isValidState() ||
5164 !FnAA->ReachedUnknownParallelRegions.isValidState() ||
5165 !FnAA->ReachedUnknownParallelRegions.empty();
5166 return true;
5167 }
5168};
5169
5170struct AAFoldRuntimeCall
5171 : public StateWrapper<BooleanState, AbstractAttribute> {
5173
5174 AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
5175
5176 /// Statistics are tracked as part of manifest for now.
5177 void trackStatistics() const override {}
5178
5179 /// Create an abstract attribute biew for the position \p IRP.
5180 static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
5181 Attributor &A);
5182
5183 /// See AbstractAttribute::getName()
5184 const std::string getName() const override { return "AAFoldRuntimeCall"; }
5185
5186 /// See AbstractAttribute::getIdAddr()
5187 const char *getIdAddr() const override { return &ID; }
5188
5189 /// This function should return true if the type of the \p AA is
5190 /// AAFoldRuntimeCall
5191 static bool classof(const AbstractAttribute *AA) {
5192 return (AA->getIdAddr() == &ID);
5193 }
5194
5195 static const char ID;
5196};
5197
5198struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
5199 AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
5200 : AAFoldRuntimeCall(IRP, A) {}
5201
5202 /// See AbstractAttribute::getAsStr()
5203 const std::string getAsStr(Attributor *) const override {
5204 if (!isValidState())
5205 return "<invalid>";
5206
5207 std::string Str("simplified value: ");
5208
5209 if (!SimplifiedValue)
5210 return Str + std::string("none");
5211
5212 if (!*SimplifiedValue)
5213 return Str + std::string("nullptr");
5214
5215 if (ConstantInt *CI = dyn_cast<ConstantInt>(*SimplifiedValue))
5216 return Str + std::to_string(CI->getSExtValue());
5217
5218 return Str + std::string("unknown");
5219 }
5220
5221 void initialize(Attributor &A) override {
5223 indicatePessimisticFixpoint();
5224
5225 Function *Callee = getAssociatedFunction();
5226
5227 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5228 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
5229 assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
5230 "Expected a known OpenMP runtime function");
5231
5232 RFKind = It->getSecond();
5233
5234 CallBase &CB = cast<CallBase>(getAssociatedValue());
5235 A.registerSimplificationCallback(
5237 [&](const IRPosition &IRP, const AbstractAttribute *AA,
5238 bool &UsedAssumedInformation) -> std::optional<Value *> {
5239 assert((isValidState() ||
5240 (SimplifiedValue && *SimplifiedValue == nullptr)) &&
5241 "Unexpected invalid state!");
5242
5243 if (!isAtFixpoint()) {
5244 UsedAssumedInformation = true;
5245 if (AA)
5246 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
5247 }
5248 return SimplifiedValue;
5249 });
5250 }
5251
5252 ChangeStatus updateImpl(Attributor &A) override {
5253 ChangeStatus Changed = ChangeStatus::UNCHANGED;
5254 switch (RFKind) {
5255 case OMPRTL___kmpc_is_spmd_exec_mode:
5256 Changed |= foldIsSPMDExecMode(A);
5257 break;
5258 case OMPRTL___kmpc_parallel_level:
5259 Changed |= foldParallelLevel(A);
5260 break;
5261 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
5262 Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
5263 break;
5264 case OMPRTL___kmpc_get_hardware_num_blocks:
5265 Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
5266 break;
5267 default:
5268 llvm_unreachable("Unhandled OpenMP runtime function!");
5269 }
5270
5271 return Changed;
5272 }
5273
5274 ChangeStatus manifest(Attributor &A) override {
5275 ChangeStatus Changed = ChangeStatus::UNCHANGED;
5276
5277 if (SimplifiedValue && *SimplifiedValue) {
5278 Instruction &I = *getCtxI();
5279 A.changeAfterManifest(IRPosition::inst(I), **SimplifiedValue);
5280 A.deleteAfterManifest(I);
5281
5282 CallBase *CB = dyn_cast<CallBase>(&I);
5283 auto Remark = [&](OptimizationRemark OR) {
5284 if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue))
5285 return OR << "Replacing OpenMP runtime call "
5286 << CB->getCalledFunction()->getName() << " with "
5287 << ore::NV("FoldedValue", C->getZExtValue()) << ".";
5288 return OR << "Replacing OpenMP runtime call "
5289 << CB->getCalledFunction()->getName() << ".";
5290 };
5291
5292 if (CB && EnableVerboseRemarks)
5293 A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark);
5294
5295 LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "
5296 << **SimplifiedValue << "\n");
5297
5298 Changed = ChangeStatus::CHANGED;
5299 }
5300
5301 return Changed;
5302 }
5303
5304 ChangeStatus indicatePessimisticFixpoint() override {
5305 SimplifiedValue = nullptr;
5306 return AAFoldRuntimeCall::indicatePessimisticFixpoint();
5307 }
5308
5309private:
5310 /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
5311 ChangeStatus foldIsSPMDExecMode(Attributor &A) {
5312 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5313
5314 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5315 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5316 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5317 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5318
5319 if (!CallerKernelInfoAA ||
5320 !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5321 return indicatePessimisticFixpoint();
5322
5323 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5324 auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5325 DepClassTy::REQUIRED);
5326
5327 if (!AA || !AA->isValidState()) {
5328 SimplifiedValue = nullptr;
5329 return indicatePessimisticFixpoint();
5330 }
5331
5332 if (AA->SPMDCompatibilityTracker.isAssumed()) {
5333 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5334 ++KnownSPMDCount;
5335 else
5336 ++AssumedSPMDCount;
5337 } else {
5338 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5339 ++KnownNonSPMDCount;
5340 else
5341 ++AssumedNonSPMDCount;
5342 }
5343 }
5344
5345 if ((AssumedSPMDCount + KnownSPMDCount) &&
5346 (AssumedNonSPMDCount + KnownNonSPMDCount))
5347 return indicatePessimisticFixpoint();
5348
5349 auto &Ctx = getAnchorValue().getContext();
5350 if (KnownSPMDCount || AssumedSPMDCount) {
5351 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5352 "Expected only SPMD kernels!");
5353 // All reaching kernels are in SPMD mode. Update all function calls to
5354 // __kmpc_is_spmd_exec_mode to 1.
5355 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
5356 } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
5357 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5358 "Expected only non-SPMD kernels!");
5359 // All reaching kernels are in non-SPMD mode. Update all function
5360 // calls to __kmpc_is_spmd_exec_mode to 0.
5361 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
5362 } else {
5363 // We have empty reaching kernels, therefore we cannot tell if the
5364 // associated call site can be folded. At this moment, SimplifiedValue
5365 // must be none.
5366 assert(!SimplifiedValue && "SimplifiedValue should be none");
5367 }
5368
5369 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5370 : ChangeStatus::CHANGED;
5371 }
5372
5373 /// Fold __kmpc_parallel_level into a constant if possible.
5374 ChangeStatus foldParallelLevel(Attributor &A) {
5375 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5376
5377 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5378 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5379
5380 if (!CallerKernelInfoAA ||
5381 !CallerKernelInfoAA->ParallelLevels.isValidState())
5382 return indicatePessimisticFixpoint();
5383
5384 if (!CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5385 return indicatePessimisticFixpoint();
5386
5387 if (CallerKernelInfoAA->ReachingKernelEntries.empty()) {
5388 assert(!SimplifiedValue &&
5389 "SimplifiedValue should keep none at this point");
5390 return ChangeStatus::UNCHANGED;
5391 }
5392
5393 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5394 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5395 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5396 auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5397 DepClassTy::REQUIRED);
5398 if (!AA || !AA->SPMDCompatibilityTracker.isValidState())
5399 return indicatePessimisticFixpoint();
5400
5401 if (AA->SPMDCompatibilityTracker.isAssumed()) {
5402 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5403 ++KnownSPMDCount;
5404 else
5405 ++AssumedSPMDCount;
5406 } else {
5407 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5408 ++KnownNonSPMDCount;
5409 else
5410 ++AssumedNonSPMDCount;
5411 }
5412 }
5413
5414 if ((AssumedSPMDCount + KnownSPMDCount) &&
5415 (AssumedNonSPMDCount + KnownNonSPMDCount))
5416 return indicatePessimisticFixpoint();
5417
5418 auto &Ctx = getAnchorValue().getContext();
5419 // If the caller can only be reached by SPMD kernel entries, the parallel
5420 // level is 1. Similarly, if the caller can only be reached by non-SPMD
5421 // kernel entries, it is 0.
5422 if (AssumedSPMDCount || KnownSPMDCount) {
5423 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5424 "Expected only SPMD kernels!");
5425 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
5426 } else {
5427 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5428 "Expected only non-SPMD kernels!");
5429 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
5430 }
5431 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5432 : ChangeStatus::CHANGED;
5433 }
5434
5435 ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
5436 // Specialize only if all the calls agree with the attribute constant value
5437 int32_t CurrentAttrValue = -1;
5438 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5439
5440 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5441 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5442
5443 if (!CallerKernelInfoAA ||
5444 !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5445 return indicatePessimisticFixpoint();
5446
5447 // Iterate over the kernels that reach this function
5448 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5449 int32_t NextAttrVal = K->getFnAttributeAsParsedInteger(Attr, -1);
5450
5451 if (NextAttrVal == -1 ||
5452 (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
5453 return indicatePessimisticFixpoint();
5454 CurrentAttrValue = NextAttrVal;
5455 }
5456
5457 if (CurrentAttrValue != -1) {
5458 auto &Ctx = getAnchorValue().getContext();
5459 SimplifiedValue =
5460 ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
5461 }
5462 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5463 : ChangeStatus::CHANGED;
5464 }
5465
5466 /// An optional value the associated value is assumed to fold to. That is, we
5467 /// assume the associated value (which is a call) can be replaced by this
5468 /// simplified value.
5469 std::optional<Value *> SimplifiedValue;
5470
5471 /// The runtime function kind of the callee of the associated call site.
5472 RuntimeFunction RFKind;
5473};
5474
5475} // namespace
5476
5477/// Register folding callsite
5478void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
5479 auto &RFI = OMPInfoCache.RFIs[RF];
5480 RFI.foreachUse(SCC, [&](Use &U, Function &F) {
5481 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
5482 if (!CI)
5483 return false;
5484 A.getOrCreateAAFor<AAFoldRuntimeCall>(
5485 IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
5486 DepClassTy::NONE, /* ForceUpdate */ false,
5487 /* UpdateAfterInit */ false);
5488 return false;
5489 });
5490}
5491
5492void OpenMPOpt::registerAAs(bool IsModulePass) {
5493 if (SCC.empty())
5494 return;
5495
5496 if (IsModulePass) {
5497 // Ensure we create the AAKernelInfo AAs first and without triggering an
5498 // update. This will make sure we register all value simplification
5499 // callbacks before any other AA has the chance to create an AAValueSimplify
5500 // or similar.
5501 auto CreateKernelInfoCB = [&](Use &, Function &Kernel) {
5502 A.getOrCreateAAFor<AAKernelInfo>(
5503 IRPosition::function(Kernel), /* QueryingAA */ nullptr,
5504 DepClassTy::NONE, /* ForceUpdate */ false,
5505 /* UpdateAfterInit */ false);
5506 return false;
5507 };
5508 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
5509 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
5510 InitRFI.foreachUse(SCC, CreateKernelInfoCB);
5511
5512 registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
5513 registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
5514 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
5515 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
5516 }
5517
5518 // Create CallSite AA for all Getters.
5519 if (DeduceICVValues) {
5520 for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
5521 auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
5522
5523 auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
5524
5525 auto CreateAA = [&](Use &U, Function &Caller) {
5526 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
5527 if (!CI)
5528 return false;
5529
5530 auto &CB = cast<CallBase>(*CI);
5531
5533 A.getOrCreateAAFor<AAICVTracker>(CBPos);
5534 return false;
5535 };
5536
5537 GetterRFI.foreachUse(SCC, CreateAA);
5538 }
5539 }
5540
5541 // Create an ExecutionDomain AA for every function and a HeapToStack AA for
5542 // every function if there is a device kernel.
5543 if (!isOpenMPDevice(M))
5544 return;
5545
5546 for (auto *F : SCC) {
5547 if (F->isDeclaration())
5548 continue;
5549
5550 // We look at internal functions only on-demand but if any use is not a
5551 // direct call or outside the current set of analyzed functions, we have
5552 // to do it eagerly.
5553 if (F->hasLocalLinkage()) {
5554 if (llvm::all_of(F->uses(), [this](const Use &U) {
5555 const auto *CB = dyn_cast<CallBase>(U.getUser());
5556 return CB && CB->isCallee(&U) &&
5557 A.isRunOn(const_cast<Function *>(CB->getCaller()));
5558 }))
5559 continue;
5560 }
5561 registerAAsForFunction(A, *F);
5562 }
5563}
5564
5565void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) {
5567 A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
5568 A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(F));
5570 A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(F));
5571 if (F.hasFnAttribute(Attribute::Convergent))
5572 A.getOrCreateAAFor<AANonConvergent>(IRPosition::function(F));
5573
5574 for (auto &I : instructions(F)) {
5575 if (auto *LI = dyn_cast<LoadInst>(&I)) {
5576 bool UsedAssumedInformation = false;
5577 A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
5578 UsedAssumedInformation, AA::Interprocedural);
5579 A.getOrCreateAAFor<AAAddressSpace>(
5580 IRPosition::value(*LI->getPointerOperand()));
5581 continue;
5582 }
5583 if (auto *CI = dyn_cast<CallBase>(&I)) {
5584 if (CI->isIndirectCall())
5585 A.getOrCreateAAFor<AAIndirectCallInfo>(
5587 }
5588 if (auto *SI = dyn_cast<StoreInst>(&I)) {
5589 A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI));
5590 A.getOrCreateAAFor<AAAddressSpace>(
5591 IRPosition::value(*SI->getPointerOperand()));
5592 continue;
5593 }
5594 if (auto *FI = dyn_cast<FenceInst>(&I)) {
5595 A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*FI));
5596 continue;
5597 }
5598 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
5599 if (II->getIntrinsicID() == Intrinsic::assume) {
5600 A.getOrCreateAAFor<AAPotentialValues>(
5601 IRPosition::value(*II->getArgOperand(0)));
5602 continue;
5603 }
5604 }
5605 }
5606}
5607
5608const char AAICVTracker::ID = 0;
5609const char AAKernelInfo::ID = 0;
5610const char AAExecutionDomain::ID = 0;
5611const char AAHeapToShared::ID = 0;
5612const char AAFoldRuntimeCall::ID = 0;
5613
5614AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
5615 Attributor &A) {
5616 AAICVTracker *AA = nullptr;
5617 switch (IRP.getPositionKind()) {
5622 llvm_unreachable("ICVTracker can only be created for function position!");
5624 AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
5625 break;
5627 AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
5628 break;
5630 AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
5631 break;
5633 AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
5634 break;
5635 }
5636
5637 return *AA;
5638}
5639
5641 Attributor &A) {
5642 AAExecutionDomainFunction *AA = nullptr;
5643 switch (IRP.getPositionKind()) {
5652 "AAExecutionDomain can only be created for function position!");
5654 AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
5655 break;
5656 }
5657
5658 return *AA;
5659}
5660
5661AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
5662 Attributor &A) {
5663 AAHeapToSharedFunction *AA = nullptr;
5664 switch (IRP.getPositionKind()) {
5673 "AAHeapToShared can only be created for function position!");
5675 AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
5676 break;
5677 }
5678
5679 return *AA;
5680}
5681
5682AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
5683 Attributor &A) {
5684 AAKernelInfo *AA = nullptr;
5685 switch (IRP.getPositionKind()) {
5692 llvm_unreachable("KernelInfo can only be created for function position!");
5694 AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
5695 break;
5697 AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
5698 break;
5699 }
5700
5701 return *AA;
5702}
5703
5704AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
5705 Attributor &A) {
5706 AAFoldRuntimeCall *AA = nullptr;
5707 switch (IRP.getPositionKind()) {
5715 llvm_unreachable("KernelInfo can only be created for call site position!");
5717 AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
5718 break;
5719 }
5720
5721 return *AA;
5722}
5723
5725 if (!containsOpenMP(M))
5726 return PreservedAnalyses::all();
5728 return PreservedAnalyses::all();
5729
5732 KernelSet Kernels = getDeviceKernels(M);
5733
5735 LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt Module Pass:\n" << M);
5736
5737 auto IsCalled = [&](Function &F) {
5738 if (Kernels.contains(&F))
5739 return true;
5740 for (const User *U : F.users())
5741 if (!isa<BlockAddress>(U))
5742 return true;
5743 return false;
5744 };
5745
5746 auto EmitRemark = [&](Function &F) {
5748 ORE.emit([&]() {
5749 OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
5750 return ORA << "Could not internalize function. "
5751 << "Some optimizations may not be possible. [OMP140]";
5752 });
5753 };
5754
5755 bool Changed = false;
5756
5757 // Create internal copies of each function if this is a kernel Module. This
5758 // allows iterprocedural passes to see every call edge.
5759 DenseMap<Function *, Function *> InternalizedMap;
5760 if (isOpenMPDevice(M)) {
5761 SmallPtrSet<Function *, 16> InternalizeFns;
5762 for (Function &F : M)
5763 if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
5766 InternalizeFns.insert(&F);
5767 } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
5768 EmitRemark(F);
5769 }
5770 }
5771
5772 Changed |=
5773 Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
5774 }
5775
5776 // Look at every function in the Module unless it was internalized.
5777 SetVector<Function *> Functions;
5779 for (Function &F : M)
5780 if (!F.isDeclaration() && !InternalizedMap.lookup(&F)) {
5781 SCC.push_back(&F);
5782 Functions.insert(&F);
5783 }
5784
5785 if (SCC.empty())
5786 return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
5787
5788 AnalysisGetter AG(FAM);
5789
5790 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5792 };
5793
5795 CallGraphUpdater CGUpdater;
5796
5797 bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5799 OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, PostLink);
5800
5801 unsigned MaxFixpointIterations =
5803
5804 AttributorConfig AC(CGUpdater);
5806 AC.IsModulePass = true;
5807 AC.RewriteSignatures = false;
5808 AC.MaxFixpointIterations = MaxFixpointIterations;
5809 AC.OREGetter = OREGetter;
5810 AC.PassName = DEBUG_TYPE;
5811 AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5812 AC.IPOAmendableCB = [](const Function &F) {
5813 return F.hasFnAttribute("kernel");
5814 };
5815
5816 Attributor A(Functions, InfoCache, AC);
5817
5818 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5819 Changed |= OMPOpt.run(true);
5820
5821 // Optionally inline device functions for potentially better performance.
5823 for (Function &F : M)
5824 if (!F.isDeclaration() && !Kernels.contains(&F) &&
5825 !F.hasFnAttribute(Attribute::NoInline))
5826 F.addFnAttr(Attribute::AlwaysInline);
5827
5829 LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M);
5830
5831 if (Changed)
5832 return PreservedAnalyses::none();
5833
5834 return PreservedAnalyses::all();
5835}
5836
5839 LazyCallGraph &CG,
5840 CGSCCUpdateResult &UR) {
5841 if (!containsOpenMP(*C.begin()->getFunction().getParent()))
5842 return PreservedAnalyses::all();
5844 return PreservedAnalyses::all();
5845
5847 // If there are kernels in the module, we have to run on all SCC's.
5848 for (LazyCallGraph::Node &N : C) {
5849 Function *Fn = &N.getFunction();
5850 SCC.push_back(Fn);
5851 }
5852
5853 if (SCC.empty())
5854 return PreservedAnalyses::all();
5855
5856 Module &M = *C.begin()->getFunction().getParent();
5857
5859 LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt CGSCC Pass:\n" << M);
5860
5861 KernelSet Kernels = getDeviceKernels(M);
5862
5864 AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
5865
5866 AnalysisGetter AG(FAM);
5867
5868 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5870 };
5871
5873 CallGraphUpdater CGUpdater;
5874 CGUpdater.initialize(CG, C, AM, UR);
5875
5876 bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5878 SetVector<Function *> Functions(SCC.begin(), SCC.end());
5879 OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
5880 /*CGSCC*/ &Functions, PostLink);
5881
5882 unsigned MaxFixpointIterations =
5884
5885 AttributorConfig AC(CGUpdater);
5887 AC.IsModulePass = false;
5888 AC.RewriteSignatures = false;
5889 AC.MaxFixpointIterations = MaxFixpointIterations;
5890 AC.OREGetter = OREGetter;
5891 AC.PassName = DEBUG_TYPE;
5892 AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5893
5894 Attributor A(Functions, InfoCache, AC);
5895
5896 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5897 bool Changed = OMPOpt.run(false);
5898
5900 LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
5901
5902 if (Changed)
5903 return PreservedAnalyses::none();
5904
5905 return PreservedAnalyses::all();
5906}
5907
5909 return Fn.hasFnAttribute("kernel");
5910}
5911
5913 // TODO: Create a more cross-platform way of determining device kernels.
5914 NamedMDNode *MD = M.getNamedMetadata("nvvm.annotations");
5915 KernelSet Kernels;
5916
5917 if (!MD)
5918 return Kernels;
5919
5920 for (auto *Op : MD->operands()) {
5921 if (Op->getNumOperands() < 2)
5922 continue;
5923 MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
5924 if (!KindID || KindID->getString() != "kernel")
5925 continue;
5926
5927 Function *KernelFn =
5928 mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
5929 if (!KernelFn)
5930 continue;
5931
5932 // We are only interested in OpenMP target regions. Others, such as kernels
5933 // generated by CUDA but linked together, are not interesting to this pass.
5934 if (isOpenMPKernel(*KernelFn)) {
5935 ++NumOpenMPTargetRegionKernels;
5936 Kernels.insert(KernelFn);
5937 } else
5938 ++NumNonOpenMPTargetRegionKernels;
5939 }
5940
5941 return Kernels;
5942}
5943
5945 Metadata *MD = M.getModuleFlag("openmp");
5946 if (!MD)
5947 return false;
5948
5949 return true;
5950}
5951
5953 Metadata *MD = M.getModuleFlag("openmp-device");
5954 if (!MD)
5955 return false;
5956
5957 return true;
5958}
@ Generic
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
static cl::opt< unsigned > SetFixpointIterations("attributor-max-iterations", cl::Hidden, cl::desc("Maximal number of fixpoint iterations."), cl::init(32))
static const Function * getParent(const Value *V)
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
This file provides interfaces used to manipulate a call graph, regardless if it is a "old style" Call...
This file provides interfaces used to build and manipulate a call graph, which is a very useful tool ...
This file contains the declarations for the subclasses of Constant, which represent the different fla...
dxil pretty DXIL Metadata Pretty Printer
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(...)
Definition: Debug.h:106
bool End
Definition: ELF_riscv.cpp:480
This file defines an array type that can be indexed using scoped enum values.
#define DEBUG_TYPE
static void emitRemark(const Function &F, OptimizationRemarkEmitter &ORE, bool Skip)
static bool lookup(const GsymReader &GR, DataExtractor &Data, uint64_t &Offset, uint64_t BaseAddr, uint64_t Addr, SourceLocations &SrcLocs, llvm::Error &Err)
A Lookup helper functions.
Definition: InlineInfo.cpp:108
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
This file provides utility analysis objects describing memory locations.
uint64_t IntrinsicInst * II
This file defines constans and helpers used when dealing with OpenMP.
This file defines constans that will be used by both host and device compilation.
static constexpr auto TAG
Definition: OpenMPOpt.cpp:184
static cl::opt< bool > HideMemoryTransferLatency("openmp-hide-memory-transfer-latency", cl::desc("[WIP] Tries to hide the latency of host to device memory" " transfers"), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptStateMachineRewrite("openmp-opt-disable-state-machine-rewrite", cl::desc("Disable OpenMP optimizations that replace the state machine."), cl::Hidden, cl::init(false))
static cl::opt< bool > EnableParallelRegionMerging("openmp-opt-enable-merging", cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden, cl::init(false))
static cl::opt< bool > AlwaysInlineDeviceFunctions("openmp-opt-inline-device", cl::desc("Inline all applicible functions on the device."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintModuleAfterOptimizations("openmp-opt-print-module-after", cl::desc("Print the current module after OpenMP optimizations."), cl::Hidden, cl::init(false))
#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER)
Definition: OpenMPOpt.cpp:240
#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX)
Definition: OpenMPOpt.cpp:217
#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER)
Definition: OpenMPOpt.cpp:3682
static cl::opt< bool > PrintOpenMPKernels("openmp-print-gpu-kernels", cl::init(false), cl::Hidden)
static cl::opt< bool > DisableOpenMPOptFolding("openmp-opt-disable-folding", cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintModuleBeforeOptimizations("openmp-opt-print-module-before", cl::desc("Print the current module before OpenMP optimizations."), cl::Hidden, cl::init(false))
static cl::opt< unsigned > SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden, cl::desc("Maximal number of attributor iterations."), cl::init(256))
static cl::opt< bool > DisableInternalization("openmp-opt-disable-internalization", cl::desc("Disable function internalization."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintICVValues("openmp-print-icv-values", cl::init(false), cl::Hidden)
static cl::opt< bool > DisableOpenMPOptimizations("openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."), cl::Hidden, cl::init(false))
static cl::opt< unsigned > SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden, cl::desc("Maximum amount of shared memory to use."), cl::init(std::numeric_limits< unsigned >::max()))
static cl::opt< bool > EnableVerboseRemarks("openmp-opt-verbose-remarks", cl::desc("Enables more verbose remarks."), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptDeglobalization("openmp-opt-disable-deglobalization", cl::desc("Disable OpenMP optimizations involving deglobalization."), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptBarrierElimination("openmp-opt-disable-barrier-elimination", cl::desc("Disable OpenMP optimizations that eliminate barriers."), cl::Hidden, cl::init(false))
#define DEBUG_TYPE
Definition: OpenMPOpt.cpp:66
static cl::opt< bool > DeduceICVValues("openmp-deduce-icv-values", cl::init(false), cl::Hidden)
#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX)
Definition: OpenMPOpt.cpp:209
#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE)
Definition: OpenMPOpt.cpp:230
static cl::opt< bool > DisableOpenMPOptSPMDization("openmp-opt-disable-spmdization", cl::desc("Disable OpenMP optimizations involving SPMD-ization."), cl::Hidden, cl::init(false))
#define P(N)
FunctionAnalysisManager FAM
if(PassOpts->AAPipeline)
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
static StringRef getName(Value *V)
R600 Clause Merge
Basic Register Allocator
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static bool contains(SmallPtrSetImpl< ConstantExpr * > &Cache, ConstantExpr *Expr, Constant *C)
Definition: Value.cpp:469
This file implements a set that has insertion order iteration characteristics.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:166
This file contains some functions that are useful when dealing with strings.
static const int BlockSize
Definition: TarWriter.cpp:33
static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, ArrayRef< StringLiteral > StandardNames)
Initialize the set of available library functions based on the specified target triple.
Value * RHS
static cl::opt< unsigned > MaxThreads("xcore-max-threads", cl::Optional, cl::desc("Maximum number of threads (for emulation thread-local storage)"), cl::Hidden, cl::value_desc("number"), cl::init(8))
AbstractCallSite.
This class represents a conversion between pointers from one address space to another.
an instruction to allocate memory on the stack
Definition: Instructions.h:63
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:410
This class represents an incoming formal argument to a Function.
Definition: Argument.h:31
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
size_t size() const
size - Get the array size.
Definition: ArrayRef.h:168
AttributeSet getParamAttrs(unsigned ArgNo) const
The attributes for the argument or parameter at the given index are returned.
LLVM Basic Block Representation.
Definition: BasicBlock.h:61
iterator end()
Definition: BasicBlock.h:461
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:448
const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
Definition: BasicBlock.cpp:416
reverse_iterator rbegin()
Definition: BasicBlock.h:464
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition: BasicBlock.h:212
BasicBlock * splitBasicBlock(iterator I, const Twine &BBName="", bool Before=false)
Split the basic block into two basic blocks at the specified instruction.
Definition: BasicBlock.cpp:577
const BasicBlock * getUniqueSuccessor() const
Return the successor of this block if it has a unique successor.
Definition: BasicBlock.cpp:497
InstListType::reverse_iterator reverse_iterator
Definition: BasicBlock.h:179
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:219
reverse_iterator rend()
Definition: BasicBlock.h:466
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:239
Conditional or Unconditional Branch instruction.
bool isConditional() const
static BranchInst * Create(BasicBlock *IfTrue, InsertPosition InsertBefore=nullptr)
BasicBlock * getSuccessor(unsigned i) const
Value * getCondition() const
Allocate memory in an ever growing pool, as if by bump-pointer.
Definition: Allocator.h:66
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
Definition: InstrTypes.h:1120
void setCallingConv(CallingConv::ID CC)
Definition: InstrTypes.h:1411
bool arg_empty() const
Definition: InstrTypes.h:1291
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1349
bool doesNotAccessMemory(unsigned OpNo) const
Definition: InstrTypes.h:1718
bool isIndirectCall() const
Return true if the callsite is an indirect call.
bool isCallee(Value::const_user_iterator UI) const
Determine whether the passed iterator points to the callee operand's Use.
Definition: InstrTypes.h:1360
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1294
void setArgOperand(unsigned i, Value *v)
Definition: InstrTypes.h:1299
iterator_range< User::op_iterator > args()
Iteration adapter for range-for loops.
Definition: InstrTypes.h:1285
unsigned getArgOperandNo(const Use *U) const
Given a use for a arg operand, get the arg operand number that corresponds to it.
Definition: InstrTypes.h:1325
unsigned arg_size() const
Definition: InstrTypes.h:1292
AttributeList getAttributes() const
Return the attributes for this call.
Definition: InstrTypes.h:1425
void addParamAttr(unsigned ArgNo, Attribute::AttrKind Kind)
Adds the attribute to the indicated argument.
Definition: InstrTypes.h:1502
bool isArgOperand(const Use *U) const
Definition: InstrTypes.h:1314
bool hasOperandBundles() const
Return true if this User has any operand bundles.
Definition: InstrTypes.h:1969
Function * getCaller()
Helper to get the caller (the parent function).
Wrapper to unify "old style" CallGraph and "new style" LazyCallGraph.
void initialize(LazyCallGraph &LCG, LazyCallGraph::SCC &SCC, CGSCCAnalysisManager &AM, CGSCCUpdateResult &UR)
Initializers for usage outside of a CGSCC pass, inside a CGSCC pass in the old and new pass manager (...
This class represents a function call, abstracting a target machine's calling convention.
static CallInst * Create(FunctionType *Ty, Value *F, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
@ ICMP_SLT
signed less than
Definition: InstrTypes.h:702
@ ICMP_EQ
equal
Definition: InstrTypes.h:694
@ ICMP_NE
not equal
Definition: InstrTypes.h:695
static Constant * getPointerCast(Constant *C, Type *Ty)
Create a BitCast, AddrSpaceCast, or a PtrToInt cast constant expression.
Definition: Constants.cpp:2253
static Constant * getPointerBitCastOrAddrSpaceCast(Constant *C, Type *Ty)
Create a BitCast or AddrSpaceCast for a pointer type depending on the address space.
Definition: Constants.cpp:2268
This is the shared class of boolean and integer constants.
Definition: Constants.h:83
IntegerType * getIntegerType() const
Variant of the getType() method to always return an IntegerType, which reduces the amount of casting ...
Definition: Constants.h:187
static ConstantInt * getTrue(LLVMContext &Context)
Definition: Constants.cpp:866
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition: Constants.h:208
int64_t getSExtValue() const
Return the constant as a 64-bit integer value after it has been sign extended as appropriate for the ...
Definition: Constants.h:163
This is an important base class in LLVM.
Definition: Constant.h:42
static Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
Definition: Constants.cpp:373
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:63
A debug info location.
Definition: DebugLoc.h:33
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition: DenseMap.h:194
iterator begin()
Definition: DenseMap.h:75
iterator end()
Definition: DenseMap.h:84
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:211
Implements a dense probed hash-table based set.
Definition: DenseSet.h:278
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:279
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:162
Instruction * findNearestCommonDominator(Instruction *I1, Instruction *I2) const
Find the nearest instruction I that dominates both I1 and I2, in the sense that a result produced bef...
Definition: Dominators.cpp:344
static ErrorSuccess success()
Create a success value.
Definition: Error.h:337
Tagged union holding either a T or a Error.
Definition: Error.h:481
An instruction for ordering other memory operations.
Definition: Instructions.h:424
AtomicOrdering getOrdering() const
Returns the ordering constraint of this fence instruction.
Definition: Instructions.h:449
A proxy from a FunctionAnalysisManager to an SCC.
A handy container for a FunctionType+Callee-pointer pair, which can be passed around as a single enti...
Definition: DerivedTypes.h:170
const BasicBlock & getEntryBlock() const
Definition: Function.h:809
const BasicBlock & front() const
Definition: Function.h:860
LLVMContext & getContext() const
getContext - Return a reference to the LLVMContext associated with this function.
Definition: Function.cpp:369
Argument * getArg(unsigned i) const
Definition: Function.h:886
bool hasFnAttribute(Attribute::AttrKind Kind) const
Return true if the function has the attribute.
Definition: Function.cpp:731
bool isDeclaration() const
Return true if the primary definition of this global value is outside of the current translation unit...
Definition: Globals.cpp:296
bool hasLocalLinkage() const
Definition: GlobalValue.h:528
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:656
@ PrivateLinkage
Like Internal, but omit from symbol table.
Definition: GlobalValue.h:60
@ InternalLinkage
Rename collisions when linking (static functions).
Definition: GlobalValue.h:59
const Constant * getInitializer() const
getInitializer - Return the initializer for this global variable.
void setInitializer(Constant *InitVal)
setInitializer - Sets the initializer for this global variable, removing any existing initializer if ...
Definition: Globals.cpp:492
InsertPoint - A saved insertion point.
Definition: IRBuilder.h:254
AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")
Definition: IRBuilder.h:1796
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition: IRBuilder.h:177
Value * CreateAddrSpaceCast(Value *V, Type *DestTy, const Twine &Name="")
Definition: IRBuilder.h:2160
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2697
An analysis over an "outer" IR unit that provides access to an analysis manager over an "inner" IR un...
Definition: PassManager.h:567
bool isLifetimeStartOrEnd() const LLVM_READONLY
Return true if the instruction is a llvm.lifetime.start or llvm.lifetime.end marker.
bool mayWriteToMemory() const LLVM_READONLY
Return true if this instruction may modify memory.
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
Definition: Instruction.h:475
const Module * getModule() const
Return the module owning the function this instruction belongs to or nullptr it the function does not...
Definition: Instruction.cpp:68
const Instruction * getPrevNonDebugInstruction(bool SkipPseudoOp=false) const
Return a pointer to the previous non-debug instruction in the same basic block as 'this',...
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:94
const Function * getFunction() const
Return the function this instruction belongs to.
Definition: Instruction.cpp:72
bool mayHaveSideEffects() const LLVM_READONLY
Return true if the instruction may have side effects.
bool mayReadFromMemory() const LLVM_READONLY
Return true if this instruction may read memory.
const Instruction * getNextNonDebugInstruction(bool SkipPseudoOp=false) const
Return a pointer to the next non-debug instruction in the same basic block as 'this',...
void setDebugLoc(DebugLoc Loc)
Set the debug location information for this instruction.
Definition: Instruction.h:472
void setSuccessor(unsigned Idx, BasicBlock *BB)
Update the specified successor to point at the provided block.
void moveBefore(Instruction *MovePos)
Unlink this instruction from its current basic block and insert it into the basic block that MovePos ...
A node in the call graph.
An SCC of the call graph.
A lazily constructed view of the call graph of a module.
An instruction for reading from memory.
Definition: Instructions.h:176
A single uniqued string.
Definition: Metadata.h:720
StringRef getString() const
Definition: Metadata.cpp:616
void eraseFromParent()
This method unlinks 'this' from the containing function and deletes it.
Root of the metadata hierarchy.
Definition: Metadata.h:62
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
const std::string & getTargetTriple() const
Get the target triple which is a string describing the target host.
Definition: Module.h:298
MutableArrayRef - Represent a mutable reference to an array (0 or more elements consecutively in memo...
Definition: ArrayRef.h:310
A tuple of MDNodes.
Definition: Metadata.h:1731
iterator_range< op_iterator > operands()
Definition: Metadata.h:1827
An interface to create LLVM-IR for OpenMP directives.
Definition: OMPIRBuilder.h:474
static std::pair< int32_t, int32_t > readThreadBoundsForKernel(const Triple &T, Function &Kernel)
}
void addAttributes(omp::RuntimeFunction FnID, Function &Fn)
Add attributes known for FnID to Fn.
IRBuilder<>::InsertPoint InsertPointTy
Type used throughout for insertion points.
Definition: OMPIRBuilder.h:520
static std::pair< int32_t, int32_t > readTeamBoundsForKernel(const Triple &T, Function &Kernel)
Read/write a bounds on teams for Kernel.
PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR)
Definition: OpenMPOpt.cpp:5837
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM)
Definition: OpenMPOpt.cpp:5724
Diagnostic information for optimization analysis remarks.
The optimization diagnostic interface.
Diagnostic information for missed-optimization remarks.
Diagnostic information for applied optimization remarks.
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1878
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Definition: Analysis.h:114
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:117
static ReturnInst * Create(LLVMContext &C, Value *retVal=nullptr, InsertPosition InsertBefore=nullptr)
A vector that has set insertion semantics.
Definition: SetVector.h:57
size_type size() const
Determine the number of elements in the SetVector.
Definition: SetVector.h:98
const value_type & back() const
Return the last element of the SetVector.
Definition: SetVector.h:149
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
Definition: SetVector.h:264
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:162
size_type size() const
Definition: SmallPtrSet.h:94
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:363
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
Definition: SmallPtrSet.h:452
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:384
iterator begin() const
Definition: SmallPtrSet.h:472
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:519
A SetVector that performs no allocations if smaller than a certain size.
Definition: SetVector.h:370
bool empty() const
Definition: SmallVector.h:81
size_t size() const
Definition: SmallVector.h:78
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:573
void assign(size_type NumElts, ValueParamT Elt)
Definition: SmallVector.h:704
reference emplace_back(ArgTypes &&... Args)
Definition: SmallVector.h:937
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:683
void push_back(const T &Elt)
Definition: SmallVector.h:413
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1196
An instruction for storing to memory.
Definition: Instructions.h:292
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:51
bool starts_with(StringRef Prefix) const
Check if this string starts with the given Prefix.
Definition: StringRef.h:265
Triple - Helper class for working with autoconf configuration names.
Definition: Triple.h:44
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
unsigned getPointerAddressSpace() const
Get the address space of this pointer or pointer vector type.
static Type * getVoidTy(LLVMContext &C)
static IntegerType * getInt16Ty(LLVMContext &C)
static IntegerType * getInt8Ty(LLVMContext &C)
static IntegerType * getInt32Ty(LLVMContext &C)
static UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
Definition: Constants.cpp:1859
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Definition: User.cpp:21
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
void setName(const Twine &Name)
Change the name of the value.
Definition: Value.cpp:377
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:534
iterator_range< user_iterator > users()
Definition: Value.h:421
User * user_back()
Definition: Value.h:407
const Value * stripPointerCasts() const
Strip off pointer casts, all-zero GEPs and address space casts.
Definition: Value.cpp:694
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
Definition: ilist_node.h:32
self_iterator getIterator()
Definition: ilist_node.h:132
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition: ilist_node.h:353
A raw_ostream that writes to an std::string.
Definition: raw_ostream.h:661
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
GlobalVariable * getKernelEnvironementGVFromKernelInitCB(CallBase *KernelInitCB)
Definition: OpenMPOpt.cpp:259
ConstantStruct * getKernelEnvironementFromKernelInitCB(CallBase *KernelInitCB)
Definition: OpenMPOpt.cpp:266
bool isValidAtPosition(const ValueAndContext &VAC, InformationCache &InfoCache)
Return true if the value of VAC is a valid at the position of VAC, that is a constant,...
Definition: Attributor.cpp:290
bool isPotentiallyAffectedByBarrier(Attributor &A, const Instruction &I, const AbstractAttribute &QueryingAA)
Return true if I is potentially affected by a barrier.
Definition: Attributor.cpp:889
@ Interprocedural
Definition: Attributor.h:182
bool isNoSyncInst(Attributor &A, const Instruction &I, const AbstractAttribute &QueryingAA)
Return true if I is a nosync instruction.
Definition: Attributor.cpp:205
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
E & operator^=(E &LHS, E RHS)
Definition: BitmaskEnum.h:195
@ Entry
Definition: COFF.h:844
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:443
PointerTypeMap run(const Module &M)
Compute the PointerTypeMap for the module M.
bool isOpenMPDevice(Module &M)
Helper to determine if M is a OpenMP target offloading device module.
Definition: OpenMPOpt.cpp:5952
bool containsOpenMP(Module &M)
Helper to determine if M contains OpenMP.
Definition: OpenMPOpt.cpp:5944
InternalControlVar
IDs for all Internal Control Variables (ICVs).
Definition: OMPConstants.h:26
RuntimeFunction
IDs for all omp runtime library (RTL) functions.
Definition: OMPConstants.h:45
KernelSet getDeviceKernels(Module &M)
Get OpenMP device kernels in M.
Definition: OpenMPOpt.cpp:5912
Function * Kernel
Summary of a kernel (=entry point for target offloading).
Definition: OpenMPOpt.h:21
@ OMP_TGT_EXEC_MODE_GENERIC_SPMD
bool isOpenMPKernel(Function &Fn)
Return true iff Fn is an OpenMP GPU kernel; Fn has the "kernel" attribute.
Definition: OpenMPOpt.cpp:5908
DiagnosticInfoOptimizationBase::Argument NV
const_iterator begin(StringRef path, Style style=Style::native)
Get begin iterator over path.
Definition: Path.cpp:226
const_iterator end(StringRef path)
Get end iterator over path.
Definition: Path.cpp:235
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition: STLExtras.h:329
@ Offset
Definition: DWP.cpp:480
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1739
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition: STLExtras.h:1697
bool succ_empty(const Instruction *I)
Definition: CFG.h:255
std::string to_string(const T &Value)
Definition: ScopedPrinter.h:85
bool operator!=(uint64_t V1, const APInt &V2)
Definition: APInt.h:2082
Value * GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, const DataLayout &DL, bool AllowNonInbounds=true)
Analyze the specified pointer to see if it can be expressed as a base pointer plus a constant offset.
const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)
This method strips off any GEP address adjustments, pointer casts or llvm.threadlocal....
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
@ CGSCC
Definition: Attributor.h:6477
@ FullLTOPostLink
Full LTO postlink (backend compile) phase.
@ ThinLTOPreLink
ThinLTO prelink (summary) phase.
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
bool operator&=(SparseBitVector< ElementSize > *LHS, const SparseBitVector< ElementSize > &RHS)
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition: STLExtras.h:1938
BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
auto predecessors(const MachineBasicBlock *BB)
ChangeStatus
{
Definition: Attributor.h:489
const char * toString(DWARFSectionKind Kind)
Constant * ConstantFoldInsertValueInstruction(Constant *Agg, Constant *Val, ArrayRef< unsigned > Idxs)
ConstantFoldInsertValueInstruction - Attempt to constant fold an insertvalue instruction with the spe...
Implement std::hash so that hash_code can be used in STL containers.
Definition: BitVector.h:858
#define N
An abstract interface for address space information.
Definition: Attributor.h:6285
An abstract attribute for getting assumption information.
Definition: Attributor.h:6211
An abstract state for querying live call edges.
Definition: Attributor.h:5486
virtual const SetVector< Function * > & getOptimisticEdges() const =0
Get the optimistic edges.
bool isExecutedByInitialThreadOnly(const Instruction &I) const
Check if an instruction is executed only by the initial thread.
Definition: Attributor.h:5646
static AAExecutionDomain & createForPosition(const IRPosition &IRP, Attributor &A)
Create an abstract attribute view for the position IRP.
Definition: OpenMPOpt.cpp:5640
virtual ExecutionDomainTy getFunctionExecutionDomain() const =0
virtual ExecutionDomainTy getExecutionDomain(const BasicBlock &) const =0
virtual bool isExecutedInAlignedRegion(Attributor &A, const Instruction &I) const =0
Check if the instruction I is executed in an aligned region, that is, the synchronizing effects befor...
virtual bool isNoOpFence(const FenceInst &FI) const =0
Helper function to determine if FI is a no-op given the information about its execution from ExecDoma...
static const char ID
Unique ID (due to the unique address)
Definition: Attributor.h:5677
An abstract interface for indirect call information interference.
Definition: Attributor.h:6404
An abstract interface for liveness abstract attribute.
Definition: Attributor.h:3976
An abstract interface for all memory location attributes (readnone/argmemonly/inaccessiblememonly/ina...
Definition: Attributor.h:4704
AccessKind
Simple enum to distinguish read/write/read-write accesses.
Definition: Attributor.h:4840
StateType::base_t MemoryLocationsKind
Definition: Attributor.h:4705
static bool isAlignedBarrier(const CallBase &CB, bool ExecutedAligned)
Helper function to determine if CB is an aligned (GPU) barrier.
An abstract Attribute for determining the necessity of the convergent attribute.
Definition: Attributor.h:5722
An abstract attribute for getting all assumption underlying objects.
Definition: Attributor.h:6243
Base struct for all "concrete attribute" deductions.
Definition: Attributor.h:3284
virtual ChangeStatus manifest(Attributor &A)
Hook for the Attributor to trigger the manifestation of the information represented by the abstract a...
Definition: Attributor.h:3399
virtual void initialize(Attributor &A)
Initialize the state with the information in the Attributor A.
Definition: Attributor.h:3348
virtual const std::string getAsStr(Attributor *A) const =0
This function should return the "summarized" assumed state as string.
virtual ChangeStatus updateImpl(Attributor &A)=0
The actual update/transfer function which has to be implemented by the derived classes.
virtual void trackStatistics() const =0
Hook to enable custom statistic tracking, called after manifest that resulted in a change if statisti...
virtual const char * getIdAddr() const =0
This function should return the address of the ID of the AbstractAttribute.
An interface to query the internal state of an abstract attribute.
Definition: Attributor.h:2604
virtual ChangeStatus indicatePessimisticFixpoint()=0
Indicate that the abstract state should converge to the pessimistic state.
virtual bool isAtFixpoint() const =0
Return if this abstract state is fixed, thus does not need to be updated if information changes as it...
virtual bool isValidState() const =0
Return if this abstract state is in a valid state.
virtual ChangeStatus indicateOptimisticFixpoint()=0
Indicate that the abstract state should converge to the optimistic state.
Wrapper for FunctionAnalysisManager.
Definition: Attributor.h:1127
Configuration for the Attributor.
Definition: Attributor.h:1422
std::function< void(Attributor &A, const Function &F)> InitializationCallback
Callback function to be invoked on internal functions marked live.
Definition: Attributor.h:1453
std::optional< unsigned > MaxFixpointIterations
Maximum number of iterations to run until fixpoint.
Definition: Attributor.h:1469
bool RewriteSignatures
Flag to determine if we rewrite function signatures.
Definition: Attributor.h:1439
const char * PassName
}
Definition: Attributor.h:1479
OptimizationRemarkGetter OREGetter
Definition: Attributor.h:1475
IPOAmendableCBTy IPOAmendableCB
Definition: Attributor.h:1482
bool IsModulePass
Is the user of the Attributor a module pass or not.
Definition: Attributor.h:1433
bool DefaultInitializeLiveInternals
Flag to determine if we want to initialize all default AAs for an internal function marked live.
Definition: Attributor.h:1443
The fixpoint analysis framework that orchestrates the attribute deduction.
Definition: Attributor.h:1516
static bool isInternalizable(Function &F)
Returns true if the function F can be internalized.
std::function< std::optional< Constant * >(const GlobalVariable &, const AbstractAttribute *, bool &)> GlobalVariableSimplifictionCallbackTy
Register CB as a simplification callback.
Definition: Attributor.h:2033
std::function< bool(Attributor &, const AbstractAttribute *)> VirtualUseCallbackTy
Definition: Attributor.h:2063
static bool internalizeFunctions(SmallPtrSetImpl< Function * > &FnSet, DenseMap< Function *, Function * > &FnMap)
Make copies of each function in the set FnSet such that the copied version has internal linkage after...
std::function< std::optional< Value * >(const IRPosition &, const AbstractAttribute *, bool &)> SimplifictionCallbackTy
Register CB as a simplification callback.
Definition: Attributor.h:2017
Simple wrapper for a single bit (boolean) state.
Definition: Attributor.h:2888
Support structure for SCC passes to communicate updates the call graph back to the CGSCC pass manager...
Helper to describe and deal with positions in the LLVM-IR.
Definition: Attributor.h:586
static const IRPosition callsite_returned(const CallBase &CB)
Create a position describing the returned value of CB.
Definition: Attributor.h:654
static const IRPosition returned(const Function &F, const CallBaseContext *CBContext=nullptr)
Create a position describing the returned value of F.
Definition: Attributor.h:636
static const IRPosition value(const Value &V, const CallBaseContext *CBContext=nullptr)
Create a position describing the value of V.
Definition: Attributor.h:610
static const IRPosition inst(const Instruction &I, const CallBaseContext *CBContext=nullptr)
Create a position describing the instruction I.
Definition: Attributor.h:622
@ IRP_ARGUMENT
An attribute for a function argument.
Definition: Attributor.h:600
@ IRP_RETURNED
An attribute for the function return value.
Definition: Attributor.h:596
@ IRP_CALL_SITE
An attribute for a call site (function scope).
Definition: Attributor.h:599
@ IRP_CALL_SITE_RETURNED
An attribute for a call site return value.
Definition: Attributor.h:597
@ IRP_FUNCTION
An attribute for a function (scope).
Definition: Attributor.h:598
@ IRP_FLOAT
A position that is not associated with a spot suitable for attributes.
Definition: Attributor.h:594
@ IRP_CALL_SITE_ARGUMENT
An attribute for a call site argument.
Definition: Attributor.h:601
@ IRP_INVALID
An invalid position.
Definition: Attributor.h:593
static const IRPosition function(const Function &F, const CallBaseContext *CBContext=nullptr)
Create a position describing the function scope of F.
Definition: Attributor.h:629
Kind getPositionKind() const
Return the associated position kind.
Definition: Attributor.h:882
static const IRPosition callsite_function(const CallBase &CB)
Create a position describing the function scope of CB.
Definition: Attributor.h:649
Function * getAnchorScope() const
Return the Function surrounding the anchor value.
Definition: Attributor.h:758
Data structure to hold cached (LLVM-IR) information.
Definition: Attributor.h:1203
bool isValidState() const override
See AbstractState::isValidState() NOTE: For now we simply pretend that the worst possible state is in...
Definition: Attributor.h:2663
ChangeStatus indicatePessimisticFixpoint() override
See AbstractState::indicatePessimisticFixpoint(...)
Definition: Attributor.h:2675
Direction
An enum for the direction of the loop.
Definition: LoopInfo.h:215
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Definition: Alignment.h:117
Description of a LLVM-IR insertion point (IP) and a debug/source location (filename,...
Definition: OMPIRBuilder.h:645
Helper to tie a abstract state implementation to an abstract attribute.
Definition: Attributor.h:3173
StateType & getState() override
See AbstractAttribute::getState(...).
Definition: Attributor.h:3181
Defines various target-specific GPU grid values that must be consistent between host RTL (plugin),...
Definition: OMPGridValues.h:57