LLVM 20.0.0git
OpenMPOpt.cpp
Go to the documentation of this file.
1//===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// OpenMP specific optimizations:
10//
11// - Deduplication of runtime calls, e.g., omp_get_thread_num.
12// - Replacing globalized device memory with stack memory.
13// - Replacing globalized device memory with shared memory.
14// - Parallel region merging.
15// - Transforming generic-mode device kernels to SPMD mode.
16// - Specializing the state machine for generic-mode device kernels.
17//
18//===----------------------------------------------------------------------===//
19
21
22#include "llvm/ADT/DenseSet.h"
25#include "llvm/ADT/SetVector.h"
28#include "llvm/ADT/Statistic.h"
30#include "llvm/ADT/StringRef.h"
38#include "llvm/IR/Assumptions.h"
39#include "llvm/IR/BasicBlock.h"
40#include "llvm/IR/CallingConv.h"
41#include "llvm/IR/Constants.h"
43#include "llvm/IR/Dominators.h"
44#include "llvm/IR/Function.h"
45#include "llvm/IR/GlobalValue.h"
47#include "llvm/IR/InstrTypes.h"
48#include "llvm/IR/Instruction.h"
51#include "llvm/IR/IntrinsicsAMDGPU.h"
52#include "llvm/IR/IntrinsicsNVPTX.h"
53#include "llvm/IR/LLVMContext.h"
56#include "llvm/Support/Debug.h"
60
61#include <algorithm>
62#include <optional>
63#include <string>
64
65using namespace llvm;
66using namespace omp;
67
68#define DEBUG_TYPE "openmp-opt"
69
71 "openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."),
72 cl::Hidden, cl::init(false));
73
75 "openmp-opt-enable-merging",
76 cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
77 cl::init(false));
78
79static cl::opt<bool>
80 DisableInternalization("openmp-opt-disable-internalization",
81 cl::desc("Disable function internalization."),
82 cl::Hidden, cl::init(false));
83
84static cl::opt<bool> DeduceICVValues("openmp-deduce-icv-values",
85 cl::init(false), cl::Hidden);
86static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
88static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
89 cl::init(false), cl::Hidden);
90
92 "openmp-hide-memory-transfer-latency",
93 cl::desc("[WIP] Tries to hide the latency of host to device memory"
94 " transfers"),
95 cl::Hidden, cl::init(false));
96
98 "openmp-opt-disable-deglobalization",
99 cl::desc("Disable OpenMP optimizations involving deglobalization."),
100 cl::Hidden, cl::init(false));
101
103 "openmp-opt-disable-spmdization",
104 cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
105 cl::Hidden, cl::init(false));
106
108 "openmp-opt-disable-folding",
109 cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
110 cl::init(false));
111
113 "openmp-opt-disable-state-machine-rewrite",
114 cl::desc("Disable OpenMP optimizations that replace the state machine."),
115 cl::Hidden, cl::init(false));
116
118 "openmp-opt-disable-barrier-elimination",
119 cl::desc("Disable OpenMP optimizations that eliminate barriers."),
120 cl::Hidden, cl::init(false));
121
123 "openmp-opt-print-module-after",
124 cl::desc("Print the current module after OpenMP optimizations."),
125 cl::Hidden, cl::init(false));
126
128 "openmp-opt-print-module-before",
129 cl::desc("Print the current module before OpenMP optimizations."),
130 cl::Hidden, cl::init(false));
131
133 "openmp-opt-inline-device",
134 cl::desc("Inline all applicable functions on the device."), cl::Hidden,
135 cl::init(false));
136
137static cl::opt<bool>
138 EnableVerboseRemarks("openmp-opt-verbose-remarks",
139 cl::desc("Enables more verbose remarks."), cl::Hidden,
140 cl::init(false));
141
143 SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden,
144 cl::desc("Maximal number of attributor iterations."),
145 cl::init(256));
146
148 SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden,
149 cl::desc("Maximum amount of shared memory to use."),
150 cl::init(std::numeric_limits<unsigned>::max()));
151
152STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
153 "Number of OpenMP runtime calls deduplicated");
154STATISTIC(NumOpenMPParallelRegionsDeleted,
155 "Number of OpenMP parallel regions deleted");
156STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
157 "Number of OpenMP runtime functions identified");
158STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
159 "Number of OpenMP runtime function uses identified");
160STATISTIC(NumOpenMPTargetRegionKernels,
161 "Number of OpenMP target region entry points (=kernels) identified");
162STATISTIC(NumNonOpenMPTargetRegionKernels,
163 "Number of non-OpenMP target region kernels identified");
164STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
165 "Number of OpenMP target region entry points (=kernels) executed in "
166 "SPMD-mode instead of generic-mode");
167STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
168 "Number of OpenMP target region entry points (=kernels) executed in "
169 "generic-mode without a state machines");
170STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
171 "Number of OpenMP target region entry points (=kernels) executed in "
172 "generic-mode with customized state machines with fallback");
173STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
174 "Number of OpenMP target region entry points (=kernels) executed in "
175 "generic-mode with customized state machines without fallback");
177 NumOpenMPParallelRegionsReplacedInGPUStateMachine,
178 "Number of OpenMP parallel regions replaced with ID in GPU state machines");
179STATISTIC(NumOpenMPParallelRegionsMerged,
180 "Number of OpenMP parallel regions merged");
181STATISTIC(NumBytesMovedToSharedMemory,
182 "Amount of memory pushed to shared memory");
183STATISTIC(NumBarriersEliminated, "Number of redundant barriers eliminated");
184
185#if !defined(NDEBUG)
186static constexpr auto TAG = "[" DEBUG_TYPE "]";
187#endif
188
189namespace KernelInfo {
190
191// struct ConfigurationEnvironmentTy {
192// uint8_t UseGenericStateMachine;
193// uint8_t MayUseNestedParallelism;
194// llvm::omp::OMPTgtExecModeFlags ExecMode;
195// int32_t MinThreads;
196// int32_t MaxThreads;
197// int32_t MinTeams;
198// int32_t MaxTeams;
199// };
200
201// struct DynamicEnvironmentTy {
202// uint16_t DebugIndentionLevel;
203// };
204
205// struct KernelEnvironmentTy {
206// ConfigurationEnvironmentTy Configuration;
207// IdentTy *Ident;
208// DynamicEnvironmentTy *DynamicEnv;
209// };
210
211#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX) \
212 constexpr const unsigned MEMBER##Idx = IDX;
213
214KERNEL_ENVIRONMENT_IDX(Configuration, 0)
216
217#undef KERNEL_ENVIRONMENT_IDX
218
219#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX) \
220 constexpr const unsigned MEMBER##Idx = IDX;
221
222KERNEL_ENVIRONMENT_CONFIGURATION_IDX(UseGenericStateMachine, 0)
223KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MayUseNestedParallelism, 1)
229
230#undef KERNEL_ENVIRONMENT_CONFIGURATION_IDX
231
232#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE) \
233 RETURNTYPE *get##MEMBER##FromKernelEnvironment(ConstantStruct *KernelEnvC) { \
234 return cast<RETURNTYPE>(KernelEnvC->getAggregateElement(MEMBER##Idx)); \
235 }
236
239
240#undef KERNEL_ENVIRONMENT_GETTER
241
242#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER) \
243 ConstantInt *get##MEMBER##FromKernelEnvironment( \
244 ConstantStruct *KernelEnvC) { \
245 ConstantStruct *ConfigC = \
246 getConfigurationFromKernelEnvironment(KernelEnvC); \
247 return dyn_cast<ConstantInt>(ConfigC->getAggregateElement(MEMBER##Idx)); \
248 }
249
250KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(UseGenericStateMachine)
251KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MayUseNestedParallelism)
257
258#undef KERNEL_ENVIRONMENT_CONFIGURATION_GETTER
259
262 constexpr const int InitKernelEnvironmentArgNo = 0;
263 return cast<GlobalVariable>(
264 KernelInitCB->getArgOperand(InitKernelEnvironmentArgNo)
266}
267
269 GlobalVariable *KernelEnvGV =
271 return cast<ConstantStruct>(KernelEnvGV->getInitializer());
272}
273} // namespace KernelInfo
274
275namespace {
276
277struct AAHeapToShared;
278
279struct AAICVTracker;
280
281/// OpenMP specific information. For now, stores RFIs and ICVs also needed for
282/// Attributor runs.
283struct OMPInformationCache : public InformationCache {
284 OMPInformationCache(Module &M, AnalysisGetter &AG,
286 bool OpenMPPostLink)
287 : InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),
288 OpenMPPostLink(OpenMPPostLink) {
289
290 OMPBuilder.Config.IsTargetDevice = isOpenMPDevice(OMPBuilder.M);
291 const Triple T(OMPBuilder.M.getTargetTriple());
292 switch (T.getArch()) {
296 assert(OMPBuilder.Config.IsTargetDevice &&
297 "OpenMP AMDGPU/NVPTX is only prepared to deal with device code.");
298 OMPBuilder.Config.IsGPU = true;
299 break;
300 default:
301 OMPBuilder.Config.IsGPU = false;
302 break;
303 }
304 OMPBuilder.initialize();
305 initializeRuntimeFunctions(M);
306 initializeInternalControlVars();
307 }
308
309 /// Generic information that describes an internal control variable.
310 struct InternalControlVarInfo {
311 /// The kind, as described by InternalControlVar enum.
313
314 /// The name of the ICV.
316
317 /// Environment variable associated with this ICV.
318 StringRef EnvVarName;
319
320 /// Initial value kind.
321 ICVInitValue InitKind;
322
323 /// Initial value.
324 ConstantInt *InitValue;
325
326 /// Setter RTL function associated with this ICV.
327 RuntimeFunction Setter;
328
329 /// Getter RTL function associated with this ICV.
330 RuntimeFunction Getter;
331
332 /// RTL Function corresponding to the override clause of this ICV
334 };
335
336 /// Generic information that describes a runtime function
337 struct RuntimeFunctionInfo {
338
339 /// The kind, as described by the RuntimeFunction enum.
341
342 /// The name of the function.
344
345 /// Flag to indicate a variadic function.
346 bool IsVarArg;
347
348 /// The return type of the function.
350
351 /// The argument types of the function.
352 SmallVector<Type *, 8> ArgumentTypes;
353
354 /// The declaration if available.
355 Function *Declaration = nullptr;
356
357 /// Uses of this runtime function per function containing the use.
358 using UseVector = SmallVector<Use *, 16>;
359
360 /// Clear UsesMap for runtime function.
361 void clearUsesMap() { UsesMap.clear(); }
362
363 /// Boolean conversion that is true if the runtime function was found.
364 operator bool() const { return Declaration; }
365
366 /// Return the vector of uses in function \p F.
367 UseVector &getOrCreateUseVector(Function *F) {
368 std::shared_ptr<UseVector> &UV = UsesMap[F];
369 if (!UV)
370 UV = std::make_shared<UseVector>();
371 return *UV;
372 }
373
374 /// Return the vector of uses in function \p F or `nullptr` if there are
375 /// none.
376 const UseVector *getUseVector(Function &F) const {
377 auto I = UsesMap.find(&F);
378 if (I != UsesMap.end())
379 return I->second.get();
380 return nullptr;
381 }
382
383 /// Return how many functions contain uses of this runtime function.
384 size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
385
386 /// Return the number of arguments (or the minimal number for variadic
387 /// functions).
388 size_t getNumArgs() const { return ArgumentTypes.size(); }
389
390 /// Run the callback \p CB on each use and forget the use if the result is
391 /// true. The callback will be fed the function in which the use was
392 /// encountered as second argument.
393 void foreachUse(SmallVectorImpl<Function *> &SCC,
394 function_ref<bool(Use &, Function &)> CB) {
395 for (Function *F : SCC)
396 foreachUse(CB, F);
397 }
398
399 /// Run the callback \p CB on each use within the function \p F and forget
400 /// the use if the result is true.
401 void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
402 SmallVector<unsigned, 8> ToBeDeleted;
403 ToBeDeleted.clear();
404
405 unsigned Idx = 0;
406 UseVector &UV = getOrCreateUseVector(F);
407
408 for (Use *U : UV) {
409 if (CB(*U, *F))
410 ToBeDeleted.push_back(Idx);
411 ++Idx;
412 }
413
414 // Remove the to-be-deleted indices in reverse order as prior
415 // modifications will not modify the smaller indices.
416 while (!ToBeDeleted.empty()) {
417 unsigned Idx = ToBeDeleted.pop_back_val();
418 UV[Idx] = UV.back();
419 UV.pop_back();
420 }
421 }
422
423 private:
424 /// Map from functions to all uses of this runtime function contained in
425 /// them.
427
428 public:
429 /// Iterators for the uses of this runtime function.
430 decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
431 decltype(UsesMap)::iterator end() { return UsesMap.end(); }
432 };
433
434 /// An OpenMP-IR-Builder instance
435 OpenMPIRBuilder OMPBuilder;
436
437 /// Map from runtime function kind to the runtime function description.
438 EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
439 RuntimeFunction::OMPRTL___last>
440 RFIs;
441
442 /// Map from function declarations/definitions to their runtime enum type.
443 DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
444
445 /// Map from ICV kind to the ICV description.
446 EnumeratedArray<InternalControlVarInfo, InternalControlVar,
447 InternalControlVar::ICV___last>
448 ICVs;
449
450 /// Helper to initialize all internal control variable information for those
451 /// defined in OMPKinds.def.
452 void initializeInternalControlVars() {
453#define ICV_RT_SET(_Name, RTL) \
454 { \
455 auto &ICV = ICVs[_Name]; \
456 ICV.Setter = RTL; \
457 }
458#define ICV_RT_GET(Name, RTL) \
459 { \
460 auto &ICV = ICVs[Name]; \
461 ICV.Getter = RTL; \
462 }
463#define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \
464 { \
465 auto &ICV = ICVs[Enum]; \
466 ICV.Name = _Name; \
467 ICV.Kind = Enum; \
468 ICV.InitKind = Init; \
469 ICV.EnvVarName = _EnvVarName; \
470 switch (ICV.InitKind) { \
471 case ICV_IMPLEMENTATION_DEFINED: \
472 ICV.InitValue = nullptr; \
473 break; \
474 case ICV_ZERO: \
475 ICV.InitValue = ConstantInt::get( \
476 Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \
477 break; \
478 case ICV_FALSE: \
479 ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \
480 break; \
481 case ICV_LAST: \
482 break; \
483 } \
484 }
485#include "llvm/Frontend/OpenMP/OMPKinds.def"
486 }
487
488 /// Returns true if the function declaration \p F matches the runtime
489 /// function types, that is, return type \p RTFRetType, and argument types
490 /// \p RTFArgTypes.
491 static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
492 SmallVector<Type *, 8> &RTFArgTypes) {
493 // TODO: We should output information to the user (under debug output
494 // and via remarks).
495
496 if (!F)
497 return false;
498 if (F->getReturnType() != RTFRetType)
499 return false;
500 if (F->arg_size() != RTFArgTypes.size())
501 return false;
502
503 auto *RTFTyIt = RTFArgTypes.begin();
504 for (Argument &Arg : F->args()) {
505 if (Arg.getType() != *RTFTyIt)
506 return false;
507
508 ++RTFTyIt;
509 }
510
511 return true;
512 }
513
514 // Helper to collect all uses of the declaration in the UsesMap.
515 unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
516 unsigned NumUses = 0;
517 if (!RFI.Declaration)
518 return NumUses;
519 OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
520
521 if (CollectStats) {
522 NumOpenMPRuntimeFunctionsIdentified += 1;
523 NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
524 }
525
526 // TODO: We directly convert uses into proper calls and unknown uses.
527 for (Use &U : RFI.Declaration->uses()) {
528 if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
529 if (!CGSCC || CGSCC->empty() || CGSCC->contains(UserI->getFunction())) {
530 RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
531 ++NumUses;
532 }
533 } else {
534 RFI.getOrCreateUseVector(nullptr).push_back(&U);
535 ++NumUses;
536 }
537 }
538 return NumUses;
539 }
540
541 // Helper function to recollect uses of a runtime function.
542 void recollectUsesForFunction(RuntimeFunction RTF) {
543 auto &RFI = RFIs[RTF];
544 RFI.clearUsesMap();
545 collectUses(RFI, /*CollectStats*/ false);
546 }
547
548 // Helper function to recollect uses of all runtime functions.
549 void recollectUses() {
550 for (int Idx = 0; Idx < RFIs.size(); ++Idx)
551 recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
552 }
553
554 // Helper function to inherit the calling convention of the function callee.
555 void setCallingConvention(FunctionCallee Callee, CallInst *CI) {
556 if (Function *Fn = dyn_cast<Function>(Callee.getCallee()))
557 CI->setCallingConv(Fn->getCallingConv());
558 }
559
560 // Helper function to determine if it's legal to create a call to the runtime
561 // functions.
562 bool runtimeFnsAvailable(ArrayRef<RuntimeFunction> Fns) {
563 // We can always emit calls if we haven't yet linked in the runtime.
564 if (!OpenMPPostLink)
565 return true;
566
567 // Once the runtime has been already been linked in we cannot emit calls to
568 // any undefined functions.
569 for (RuntimeFunction Fn : Fns) {
570 RuntimeFunctionInfo &RFI = RFIs[Fn];
571
572 if (RFI.Declaration && RFI.Declaration->isDeclaration())
573 return false;
574 }
575 return true;
576 }
577
578 /// Helper to initialize all runtime function information for those defined
579 /// in OpenMPKinds.def.
580 void initializeRuntimeFunctions(Module &M) {
581
582 // Helper macros for handling __VA_ARGS__ in OMP_RTL
583#define OMP_TYPE(VarName, ...) \
584 Type *VarName = OMPBuilder.VarName; \
585 (void)VarName;
586
587#define OMP_ARRAY_TYPE(VarName, ...) \
588 ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \
589 (void)VarName##Ty; \
590 PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \
591 (void)VarName##PtrTy;
592
593#define OMP_FUNCTION_TYPE(VarName, ...) \
594 FunctionType *VarName = OMPBuilder.VarName; \
595 (void)VarName; \
596 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
597 (void)VarName##Ptr;
598
599#define OMP_STRUCT_TYPE(VarName, ...) \
600 StructType *VarName = OMPBuilder.VarName; \
601 (void)VarName; \
602 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
603 (void)VarName##Ptr;
604
605#define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \
606 { \
607 SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \
608 Function *F = M.getFunction(_Name); \
609 RTLFunctions.insert(F); \
610 if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \
611 RuntimeFunctionIDMap[F] = _Enum; \
612 auto &RFI = RFIs[_Enum]; \
613 RFI.Kind = _Enum; \
614 RFI.Name = _Name; \
615 RFI.IsVarArg = _IsVarArg; \
616 RFI.ReturnType = OMPBuilder._ReturnType; \
617 RFI.ArgumentTypes = std::move(ArgsTypes); \
618 RFI.Declaration = F; \
619 unsigned NumUses = collectUses(RFI); \
620 (void)NumUses; \
621 LLVM_DEBUG({ \
622 dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \
623 << " found\n"; \
624 if (RFI.Declaration) \
625 dbgs() << TAG << "-> got " << NumUses << " uses in " \
626 << RFI.getNumFunctionsWithUses() \
627 << " different functions.\n"; \
628 }); \
629 } \
630 }
631#include "llvm/Frontend/OpenMP/OMPKinds.def"
632
633 // Remove the `noinline` attribute from `__kmpc`, `ompx::` and `omp_`
634 // functions, except if `optnone` is present.
635 if (isOpenMPDevice(M)) {
636 for (Function &F : M) {
637 for (StringRef Prefix : {"__kmpc", "_ZN4ompx", "omp_"})
638 if (F.hasFnAttribute(Attribute::NoInline) &&
639 F.getName().starts_with(Prefix) &&
640 !F.hasFnAttribute(Attribute::OptimizeNone))
641 F.removeFnAttr(Attribute::NoInline);
642 }
643 }
644
645 // TODO: We should attach the attributes defined in OMPKinds.def.
646 }
647
648 /// Collection of known OpenMP runtime functions..
649 DenseSet<const Function *> RTLFunctions;
650
651 /// Indicates if we have already linked in the OpenMP device library.
652 bool OpenMPPostLink = false;
653};
654
655template <typename Ty, bool InsertInvalidates = true>
656struct BooleanStateWithSetVector : public BooleanState {
657 bool contains(const Ty &Elem) const { return Set.contains(Elem); }
658 bool insert(const Ty &Elem) {
659 if (InsertInvalidates)
661 return Set.insert(Elem);
662 }
663
664 const Ty &operator[](int Idx) const { return Set[Idx]; }
665 bool operator==(const BooleanStateWithSetVector &RHS) const {
666 return BooleanState::operator==(RHS) && Set == RHS.Set;
667 }
668 bool operator!=(const BooleanStateWithSetVector &RHS) const {
669 return !(*this == RHS);
670 }
671
672 bool empty() const { return Set.empty(); }
673 size_t size() const { return Set.size(); }
674
675 /// "Clamp" this state with \p RHS.
676 BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
677 BooleanState::operator^=(RHS);
678 Set.insert(RHS.Set.begin(), RHS.Set.end());
679 return *this;
680 }
681
682private:
683 /// A set to keep track of elements.
685
686public:
687 typename decltype(Set)::iterator begin() { return Set.begin(); }
688 typename decltype(Set)::iterator end() { return Set.end(); }
689 typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
690 typename decltype(Set)::const_iterator end() const { return Set.end(); }
691};
692
693template <typename Ty, bool InsertInvalidates = true>
694using BooleanStateWithPtrSetVector =
695 BooleanStateWithSetVector<Ty *, InsertInvalidates>;
696
697struct KernelInfoState : AbstractState {
698 /// Flag to track if we reached a fixpoint.
699 bool IsAtFixpoint = false;
700
701 /// The parallel regions (identified by the outlined parallel functions) that
702 /// can be reached from the associated function.
703 BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
704 ReachedKnownParallelRegions;
705
706 /// State to track what parallel region we might reach.
707 BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
708
709 /// State to track if we are in SPMD-mode, assumed or know, and why we decided
710 /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
711 /// false.
712 BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
713
714 /// The __kmpc_target_init call in this kernel, if any. If we find more than
715 /// one we abort as the kernel is malformed.
716 CallBase *KernelInitCB = nullptr;
717
718 /// The constant kernel environement as taken from and passed to
719 /// __kmpc_target_init.
720 ConstantStruct *KernelEnvC = nullptr;
721
722 /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
723 /// one we abort as the kernel is malformed.
724 CallBase *KernelDeinitCB = nullptr;
725
726 /// Flag to indicate if the associated function is a kernel entry.
727 bool IsKernelEntry = false;
728
729 /// State to track what kernel entries can reach the associated function.
730 BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
731
732 /// State to indicate if we can track parallel level of the associated
733 /// function. We will give up tracking if we encounter unknown caller or the
734 /// caller is __kmpc_parallel_51.
735 BooleanStateWithSetVector<uint8_t> ParallelLevels;
736
737 /// Flag that indicates if the kernel has nested Parallelism
738 bool NestedParallelism = false;
739
740 /// Abstract State interface
741 ///{
742
743 KernelInfoState() = default;
744 KernelInfoState(bool BestState) {
745 if (!BestState)
747 }
748
749 /// See AbstractState::isValidState(...)
750 bool isValidState() const override { return true; }
751
752 /// See AbstractState::isAtFixpoint(...)
753 bool isAtFixpoint() const override { return IsAtFixpoint; }
754
755 /// See AbstractState::indicatePessimisticFixpoint(...)
757 IsAtFixpoint = true;
758 ParallelLevels.indicatePessimisticFixpoint();
759 ReachingKernelEntries.indicatePessimisticFixpoint();
760 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
761 ReachedKnownParallelRegions.indicatePessimisticFixpoint();
762 ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
763 NestedParallelism = true;
765 }
766
767 /// See AbstractState::indicateOptimisticFixpoint(...)
769 IsAtFixpoint = true;
770 ParallelLevels.indicateOptimisticFixpoint();
771 ReachingKernelEntries.indicateOptimisticFixpoint();
772 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
773 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
774 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
776 }
777
778 /// Return the assumed state
779 KernelInfoState &getAssumed() { return *this; }
780 const KernelInfoState &getAssumed() const { return *this; }
781
782 bool operator==(const KernelInfoState &RHS) const {
783 if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
784 return false;
785 if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
786 return false;
787 if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
788 return false;
789 if (ReachingKernelEntries != RHS.ReachingKernelEntries)
790 return false;
791 if (ParallelLevels != RHS.ParallelLevels)
792 return false;
793 if (NestedParallelism != RHS.NestedParallelism)
794 return false;
795 return true;
796 }
797
798 /// Returns true if this kernel contains any OpenMP parallel regions.
799 bool mayContainParallelRegion() {
800 return !ReachedKnownParallelRegions.empty() ||
801 !ReachedUnknownParallelRegions.empty();
802 }
803
804 /// Return empty set as the best state of potential values.
805 static KernelInfoState getBestState() { return KernelInfoState(true); }
806
807 static KernelInfoState getBestState(KernelInfoState &KIS) {
808 return getBestState();
809 }
810
811 /// Return full set as the worst state of potential values.
812 static KernelInfoState getWorstState() { return KernelInfoState(false); }
813
814 /// "Clamp" this state with \p KIS.
815 KernelInfoState operator^=(const KernelInfoState &KIS) {
816 // Do not merge two different _init and _deinit call sites.
817 if (KIS.KernelInitCB) {
818 if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
819 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
820 "assumptions.");
821 KernelInitCB = KIS.KernelInitCB;
822 }
823 if (KIS.KernelDeinitCB) {
824 if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
825 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
826 "assumptions.");
827 KernelDeinitCB = KIS.KernelDeinitCB;
828 }
829 if (KIS.KernelEnvC) {
830 if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
831 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
832 "assumptions.");
833 KernelEnvC = KIS.KernelEnvC;
834 }
835 SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
836 ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
837 ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
838 NestedParallelism |= KIS.NestedParallelism;
839 return *this;
840 }
841
842 KernelInfoState operator&=(const KernelInfoState &KIS) {
843 return (*this ^= KIS);
844 }
845
846 ///}
847};
848
849/// Used to map the values physically (in the IR) stored in an offload
850/// array, to a vector in memory.
851struct OffloadArray {
852 /// Physical array (in the IR).
853 AllocaInst *Array = nullptr;
854 /// Mapped values.
855 SmallVector<Value *, 8> StoredValues;
856 /// Last stores made in the offload array.
857 SmallVector<StoreInst *, 8> LastAccesses;
858
859 OffloadArray() = default;
860
861 /// Initializes the OffloadArray with the values stored in \p Array before
862 /// instruction \p Before is reached. Returns false if the initialization
863 /// fails.
864 /// This MUST be used immediately after the construction of the object.
865 bool initialize(AllocaInst &Array, Instruction &Before) {
866 if (!Array.getAllocatedType()->isArrayTy())
867 return false;
868
869 if (!getValues(Array, Before))
870 return false;
871
872 this->Array = &Array;
873 return true;
874 }
875
876 static const unsigned DeviceIDArgNum = 1;
877 static const unsigned BasePtrsArgNum = 3;
878 static const unsigned PtrsArgNum = 4;
879 static const unsigned SizesArgNum = 5;
880
881private:
882 /// Traverses the BasicBlock where \p Array is, collecting the stores made to
883 /// \p Array, leaving StoredValues with the values stored before the
884 /// instruction \p Before is reached.
885 bool getValues(AllocaInst &Array, Instruction &Before) {
886 // Initialize container.
887 const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
888 StoredValues.assign(NumValues, nullptr);
889 LastAccesses.assign(NumValues, nullptr);
890
891 // TODO: This assumes the instruction \p Before is in the same
892 // BasicBlock as Array. Make it general, for any control flow graph.
893 BasicBlock *BB = Array.getParent();
894 if (BB != Before.getParent())
895 return false;
896
897 const DataLayout &DL = Array.getDataLayout();
898 const unsigned int PointerSize = DL.getPointerSize();
899
900 for (Instruction &I : *BB) {
901 if (&I == &Before)
902 break;
903
904 if (!isa<StoreInst>(&I))
905 continue;
906
907 auto *S = cast<StoreInst>(&I);
908 int64_t Offset = -1;
909 auto *Dst =
910 GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
911 if (Dst == &Array) {
912 int64_t Idx = Offset / PointerSize;
913 StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
914 LastAccesses[Idx] = S;
915 }
916 }
917
918 return isFilled();
919 }
920
921 /// Returns true if all values in StoredValues and
922 /// LastAccesses are not nullptrs.
923 bool isFilled() {
924 const unsigned NumValues = StoredValues.size();
925 for (unsigned I = 0; I < NumValues; ++I) {
926 if (!StoredValues[I] || !LastAccesses[I])
927 return false;
928 }
929
930 return true;
931 }
932};
933
934struct OpenMPOpt {
935
936 using OptimizationRemarkGetter =
938
939 OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
940 OptimizationRemarkGetter OREGetter,
941 OMPInformationCache &OMPInfoCache, Attributor &A)
942 : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
943 OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
944
945 /// Check if any remarks are enabled for openmp-opt
946 bool remarksEnabled() {
947 auto &Ctx = M.getContext();
948 return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
949 }
950
951 /// Run all OpenMP optimizations on the underlying SCC.
952 bool run(bool IsModulePass) {
953 if (SCC.empty())
954 return false;
955
956 bool Changed = false;
957
958 LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
959 << " functions\n");
960
961 if (IsModulePass) {
962 Changed |= runAttributor(IsModulePass);
963
964 // Recollect uses, in case Attributor deleted any.
965 OMPInfoCache.recollectUses();
966
967 // TODO: This should be folded into buildCustomStateMachine.
968 Changed |= rewriteDeviceCodeStateMachine();
969
970 if (remarksEnabled())
971 analysisGlobalization();
972 } else {
973 if (PrintICVValues)
974 printICVs();
976 printKernels();
977
978 Changed |= runAttributor(IsModulePass);
979
980 // Recollect uses, in case Attributor deleted any.
981 OMPInfoCache.recollectUses();
982
983 Changed |= deleteParallelRegions();
984
986 Changed |= hideMemTransfersLatency();
987 Changed |= deduplicateRuntimeCalls();
989 if (mergeParallelRegions()) {
990 deduplicateRuntimeCalls();
991 Changed = true;
992 }
993 }
994 }
995
996 if (OMPInfoCache.OpenMPPostLink)
997 Changed |= removeRuntimeSymbols();
998
999 return Changed;
1000 }
1001
1002 /// Print initial ICV values for testing.
1003 /// FIXME: This should be done from the Attributor once it is added.
1004 void printICVs() const {
1005 InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
1006 ICV_proc_bind};
1007
1008 for (Function *F : SCC) {
1009 for (auto ICV : ICVs) {
1010 auto ICVInfo = OMPInfoCache.ICVs[ICV];
1011 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1012 return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
1013 << " Value: "
1014 << (ICVInfo.InitValue
1015 ? toString(ICVInfo.InitValue->getValue(), 10, true)
1016 : "IMPLEMENTATION_DEFINED");
1017 };
1018
1019 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
1020 }
1021 }
1022 }
1023
1024 /// Print OpenMP GPU kernels for testing.
1025 void printKernels() const {
1026 for (Function *F : SCC) {
1027 if (!omp::isOpenMPKernel(*F))
1028 continue;
1029
1030 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1031 return ORA << "OpenMP GPU kernel "
1032 << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
1033 };
1034
1035 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
1036 }
1037 }
1038
1039 /// Return the call if \p U is a callee use in a regular call. If \p RFI is
1040 /// given it has to be the callee or a nullptr is returned.
1041 static CallInst *getCallIfRegularCall(
1042 Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1043 CallInst *CI = dyn_cast<CallInst>(U.getUser());
1044 if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
1045 (!RFI ||
1046 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1047 return CI;
1048 return nullptr;
1049 }
1050
1051 /// Return the call if \p V is a regular call. If \p RFI is given it has to be
1052 /// the callee or a nullptr is returned.
1053 static CallInst *getCallIfRegularCall(
1054 Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1055 CallInst *CI = dyn_cast<CallInst>(&V);
1056 if (CI && !CI->hasOperandBundles() &&
1057 (!RFI ||
1058 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1059 return CI;
1060 return nullptr;
1061 }
1062
1063private:
1064 /// Merge parallel regions when it is safe.
1065 bool mergeParallelRegions() {
1066 const unsigned CallbackCalleeOperand = 2;
1067 const unsigned CallbackFirstArgOperand = 3;
1068 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
1069
1070 // Check if there are any __kmpc_fork_call calls to merge.
1071 OMPInformationCache::RuntimeFunctionInfo &RFI =
1072 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1073
1074 if (!RFI.Declaration)
1075 return false;
1076
1077 // Unmergable calls that prevent merging a parallel region.
1078 OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
1079 OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
1080 OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
1081 };
1082
1083 bool Changed = false;
1084 LoopInfo *LI = nullptr;
1085 DominatorTree *DT = nullptr;
1086
1088
1089 BasicBlock *StartBB = nullptr, *EndBB = nullptr;
1090 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1091 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1092 BasicBlock *CGEndBB =
1093 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1094 assert(StartBB != nullptr && "StartBB should not be null");
1095 CGStartBB->getTerminator()->setSuccessor(0, StartBB);
1096 assert(EndBB != nullptr && "EndBB should not be null");
1097 EndBB->getTerminator()->setSuccessor(0, CGEndBB);
1098 return Error::success();
1099 };
1100
1101 auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
1102 Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
1103 ReplacementValue = &Inner;
1104 return CodeGenIP;
1105 };
1106
1107 auto FiniCB = [&](InsertPointTy CodeGenIP) { return Error::success(); };
1108
1109 /// Create a sequential execution region within a merged parallel region,
1110 /// encapsulated in a master construct with a barrier for synchronization.
1111 auto CreateSequentialRegion = [&](Function *OuterFn,
1112 BasicBlock *OuterPredBB,
1113 Instruction *SeqStartI,
1114 Instruction *SeqEndI) {
1115 // Isolate the instructions of the sequential region to a separate
1116 // block.
1117 BasicBlock *ParentBB = SeqStartI->getParent();
1118 BasicBlock *SeqEndBB =
1119 SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
1120 BasicBlock *SeqAfterBB =
1121 SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
1122 BasicBlock *SeqStartBB =
1123 SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
1124
1125 assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
1126 "Expected a different CFG");
1127 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
1128 ParentBB->getTerminator()->eraseFromParent();
1129
1130 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1131 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1132 BasicBlock *CGEndBB =
1133 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1134 assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
1135 CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
1136 assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
1137 SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
1138 return Error::success();
1139 };
1140 auto FiniCB = [&](InsertPointTy CodeGenIP) { return Error::success(); };
1141
1142 // Find outputs from the sequential region to outside users and
1143 // broadcast their values to them.
1144 for (Instruction &I : *SeqStartBB) {
1145 SmallPtrSet<Instruction *, 4> OutsideUsers;
1146 for (User *Usr : I.users()) {
1147 Instruction &UsrI = *cast<Instruction>(Usr);
1148 // Ignore outputs to LT intrinsics, code extraction for the merged
1149 // parallel region will fix them.
1150 if (UsrI.isLifetimeStartOrEnd())
1151 continue;
1152
1153 if (UsrI.getParent() != SeqStartBB)
1154 OutsideUsers.insert(&UsrI);
1155 }
1156
1157 if (OutsideUsers.empty())
1158 continue;
1159
1160 // Emit an alloca in the outer region to store the broadcasted
1161 // value.
1162 const DataLayout &DL = M.getDataLayout();
1163 AllocaInst *AllocaI = new AllocaInst(
1164 I.getType(), DL.getAllocaAddrSpace(), nullptr,
1165 I.getName() + ".seq.output.alloc", OuterFn->front().begin());
1166
1167 // Emit a store instruction in the sequential BB to update the
1168 // value.
1169 new StoreInst(&I, AllocaI, SeqStartBB->getTerminator()->getIterator());
1170
1171 // Emit a load instruction and replace the use of the output value
1172 // with it.
1173 for (Instruction *UsrI : OutsideUsers) {
1174 LoadInst *LoadI = new LoadInst(I.getType(), AllocaI,
1175 I.getName() + ".seq.output.load",
1176 UsrI->getIterator());
1177 UsrI->replaceUsesOfWith(&I, LoadI);
1178 }
1179 }
1180
1182 InsertPointTy(ParentBB, ParentBB->end()), DL);
1184 OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB));
1185 cantFail(
1186 OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel));
1187
1188 BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
1189
1190 LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
1191 << "\n");
1192 };
1193
1194 // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
1195 // contained in BB and only separated by instructions that can be
1196 // redundantly executed in parallel. The block BB is split before the first
1197 // call (in MergableCIs) and after the last so the entire region we merge
1198 // into a single parallel region is contained in a single basic block
1199 // without any other instructions. We use the OpenMPIRBuilder to outline
1200 // that block and call the resulting function via __kmpc_fork_call.
1201 auto Merge = [&](const SmallVectorImpl<CallInst *> &MergableCIs,
1202 BasicBlock *BB) {
1203 // TODO: Change the interface to allow single CIs expanded, e.g, to
1204 // include an outer loop.
1205 assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
1206
1207 auto Remark = [&](OptimizationRemark OR) {
1208 OR << "Parallel region merged with parallel region"
1209 << (MergableCIs.size() > 2 ? "s" : "") << " at ";
1210 for (auto *CI : llvm::drop_begin(MergableCIs)) {
1211 OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
1212 if (CI != MergableCIs.back())
1213 OR << ", ";
1214 }
1215 return OR << ".";
1216 };
1217
1218 emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
1219
1220 Function *OriginalFn = BB->getParent();
1221 LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
1222 << " parallel regions in " << OriginalFn->getName()
1223 << "\n");
1224
1225 // Isolate the calls to merge in a separate block.
1226 EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
1227 BasicBlock *AfterBB =
1228 SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1229 StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
1230 "omp.par.merged");
1231
1232 assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
1233 const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1234 BB->getTerminator()->eraseFromParent();
1235
1236 // Create sequential regions for sequential instructions that are
1237 // in-between mergable parallel regions.
1238 for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1239 It != End; ++It) {
1240 Instruction *ForkCI = *It;
1241 Instruction *NextForkCI = *(It + 1);
1242
1243 // Continue if there are not in-between instructions.
1244 if (ForkCI->getNextNode() == NextForkCI)
1245 continue;
1246
1247 CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1248 NextForkCI->getPrevNode());
1249 }
1250
1251 OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1252 DL);
1253 IRBuilder<>::InsertPoint AllocaIP(
1254 &OriginalFn->getEntryBlock(),
1255 OriginalFn->getEntryBlock().getFirstInsertionPt());
1256 // Create the merged parallel region with default proc binding, to
1257 // avoid overriding binding settings, and without explicit cancellation.
1259 cantFail(OMPInfoCache.OMPBuilder.createParallel(
1260 Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
1261 OMP_PROC_BIND_default, /* IsCancellable */ false));
1262 BranchInst::Create(AfterBB, AfterIP.getBlock());
1263
1264 // Perform the actual outlining.
1265 OMPInfoCache.OMPBuilder.finalize(OriginalFn);
1266
1267 Function *OutlinedFn = MergableCIs.front()->getCaller();
1268
1269 // Replace the __kmpc_fork_call calls with direct calls to the outlined
1270 // callbacks.
1272 for (auto *CI : MergableCIs) {
1273 Value *Callee = CI->getArgOperand(CallbackCalleeOperand);
1274 FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask;
1275 Args.clear();
1276 Args.push_back(OutlinedFn->getArg(0));
1277 Args.push_back(OutlinedFn->getArg(1));
1278 for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1279 ++U)
1280 Args.push_back(CI->getArgOperand(U));
1281
1282 CallInst *NewCI =
1283 CallInst::Create(FT, Callee, Args, "", CI->getIterator());
1284 if (CI->getDebugLoc())
1285 NewCI->setDebugLoc(CI->getDebugLoc());
1286
1287 // Forward parameter attributes from the callback to the callee.
1288 for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1289 ++U)
1290 for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
1291 NewCI->addParamAttr(
1292 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1293
1294 // Emit an explicit barrier to replace the implicit fork-join barrier.
1295 if (CI != MergableCIs.back()) {
1296 // TODO: Remove barrier if the merged parallel region includes the
1297 // 'nowait' clause.
1298 cantFail(OMPInfoCache.OMPBuilder.createBarrier(
1299 InsertPointTy(NewCI->getParent(),
1300 NewCI->getNextNode()->getIterator()),
1301 OMPD_parallel));
1302 }
1303
1304 CI->eraseFromParent();
1305 }
1306
1307 assert(OutlinedFn != OriginalFn && "Outlining failed");
1308 CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1309 CGUpdater.reanalyzeFunction(*OriginalFn);
1310
1311 NumOpenMPParallelRegionsMerged += MergableCIs.size();
1312
1313 return true;
1314 };
1315
1316 // Helper function that identifes sequences of
1317 // __kmpc_fork_call uses in a basic block.
1318 auto DetectPRsCB = [&](Use &U, Function &F) {
1319 CallInst *CI = getCallIfRegularCall(U, &RFI);
1320 BB2PRMap[CI->getParent()].insert(CI);
1321
1322 return false;
1323 };
1324
1325 BB2PRMap.clear();
1326 RFI.foreachUse(SCC, DetectPRsCB);
1327 SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1328 // Find mergable parallel regions within a basic block that are
1329 // safe to merge, that is any in-between instructions can safely
1330 // execute in parallel after merging.
1331 // TODO: support merging across basic-blocks.
1332 for (auto &It : BB2PRMap) {
1333 auto &CIs = It.getSecond();
1334 if (CIs.size() < 2)
1335 continue;
1336
1337 BasicBlock *BB = It.getFirst();
1338 SmallVector<CallInst *, 4> MergableCIs;
1339
1340 /// Returns true if the instruction is mergable, false otherwise.
1341 /// A terminator instruction is unmergable by definition since merging
1342 /// works within a BB. Instructions before the mergable region are
1343 /// mergable if they are not calls to OpenMP runtime functions that may
1344 /// set different execution parameters for subsequent parallel regions.
1345 /// Instructions in-between parallel regions are mergable if they are not
1346 /// calls to any non-intrinsic function since that may call a non-mergable
1347 /// OpenMP runtime function.
1348 auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1349 // We do not merge across BBs, hence return false (unmergable) if the
1350 // instruction is a terminator.
1351 if (I.isTerminator())
1352 return false;
1353
1354 if (!isa<CallInst>(&I))
1355 return true;
1356
1357 CallInst *CI = cast<CallInst>(&I);
1358 if (IsBeforeMergableRegion) {
1359 Function *CalledFunction = CI->getCalledFunction();
1360 if (!CalledFunction)
1361 return false;
1362 // Return false (unmergable) if the call before the parallel
1363 // region calls an explicit affinity (proc_bind) or number of
1364 // threads (num_threads) compiler-generated function. Those settings
1365 // may be incompatible with following parallel regions.
1366 // TODO: ICV tracking to detect compatibility.
1367 for (const auto &RFI : UnmergableCallsInfo) {
1368 if (CalledFunction == RFI.Declaration)
1369 return false;
1370 }
1371 } else {
1372 // Return false (unmergable) if there is a call instruction
1373 // in-between parallel regions when it is not an intrinsic. It
1374 // may call an unmergable OpenMP runtime function in its callpath.
1375 // TODO: Keep track of possible OpenMP calls in the callpath.
1376 if (!isa<IntrinsicInst>(CI))
1377 return false;
1378 }
1379
1380 return true;
1381 };
1382 // Find maximal number of parallel region CIs that are safe to merge.
1383 for (auto It = BB->begin(), End = BB->end(); It != End;) {
1384 Instruction &I = *It;
1385 ++It;
1386
1387 if (CIs.count(&I)) {
1388 MergableCIs.push_back(cast<CallInst>(&I));
1389 continue;
1390 }
1391
1392 // Continue expanding if the instruction is mergable.
1393 if (IsMergable(I, MergableCIs.empty()))
1394 continue;
1395
1396 // Forward the instruction iterator to skip the next parallel region
1397 // since there is an unmergable instruction which can affect it.
1398 for (; It != End; ++It) {
1399 Instruction &SkipI = *It;
1400 if (CIs.count(&SkipI)) {
1401 LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1402 << " due to " << I << "\n");
1403 ++It;
1404 break;
1405 }
1406 }
1407
1408 // Store mergable regions found.
1409 if (MergableCIs.size() > 1) {
1410 MergableCIsVector.push_back(MergableCIs);
1411 LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1412 << " parallel regions in block " << BB->getName()
1413 << " of function " << BB->getParent()->getName()
1414 << "\n";);
1415 }
1416
1417 MergableCIs.clear();
1418 }
1419
1420 if (!MergableCIsVector.empty()) {
1421 Changed = true;
1422
1423 for (auto &MergableCIs : MergableCIsVector)
1424 Merge(MergableCIs, BB);
1425 MergableCIsVector.clear();
1426 }
1427 }
1428
1429 if (Changed) {
1430 /// Re-collect use for fork calls, emitted barrier calls, and
1431 /// any emitted master/end_master calls.
1432 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1433 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1434 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1435 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1436 }
1437
1438 return Changed;
1439 }
1440
1441 /// Try to delete parallel regions if possible.
1442 bool deleteParallelRegions() {
1443 const unsigned CallbackCalleeOperand = 2;
1444
1445 OMPInformationCache::RuntimeFunctionInfo &RFI =
1446 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1447
1448 if (!RFI.Declaration)
1449 return false;
1450
1451 bool Changed = false;
1452 auto DeleteCallCB = [&](Use &U, Function &) {
1453 CallInst *CI = getCallIfRegularCall(U);
1454 if (!CI)
1455 return false;
1456 auto *Fn = dyn_cast<Function>(
1457 CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1458 if (!Fn)
1459 return false;
1460 if (!Fn->onlyReadsMemory())
1461 return false;
1462 if (!Fn->hasFnAttribute(Attribute::WillReturn))
1463 return false;
1464
1465 LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1466 << CI->getCaller()->getName() << "\n");
1467
1468 auto Remark = [&](OptimizationRemark OR) {
1469 return OR << "Removing parallel region with no side-effects.";
1470 };
1471 emitRemark<OptimizationRemark>(CI, "OMP160", Remark);
1472
1473 CI->eraseFromParent();
1474 Changed = true;
1475 ++NumOpenMPParallelRegionsDeleted;
1476 return true;
1477 };
1478
1479 RFI.foreachUse(SCC, DeleteCallCB);
1480
1481 return Changed;
1482 }
1483
1484 /// Try to eliminate runtime calls by reusing existing ones.
1485 bool deduplicateRuntimeCalls() {
1486 bool Changed = false;
1487
1488 RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1489 OMPRTL_omp_get_num_threads,
1490 OMPRTL_omp_in_parallel,
1491 OMPRTL_omp_get_cancellation,
1492 OMPRTL_omp_get_supported_active_levels,
1493 OMPRTL_omp_get_level,
1494 OMPRTL_omp_get_ancestor_thread_num,
1495 OMPRTL_omp_get_team_size,
1496 OMPRTL_omp_get_active_level,
1497 OMPRTL_omp_in_final,
1498 OMPRTL_omp_get_proc_bind,
1499 OMPRTL_omp_get_num_places,
1500 OMPRTL_omp_get_num_procs,
1501 OMPRTL_omp_get_place_num,
1502 OMPRTL_omp_get_partition_num_places,
1503 OMPRTL_omp_get_partition_place_nums};
1504
1505 // Global-tid is handled separately.
1507 collectGlobalThreadIdArguments(GTIdArgs);
1508 LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1509 << " global thread ID arguments\n");
1510
1511 for (Function *F : SCC) {
1512 for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1513 Changed |= deduplicateRuntimeCalls(
1514 *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1515
1516 // __kmpc_global_thread_num is special as we can replace it with an
1517 // argument in enough cases to make it worth trying.
1518 Value *GTIdArg = nullptr;
1519 for (Argument &Arg : F->args())
1520 if (GTIdArgs.count(&Arg)) {
1521 GTIdArg = &Arg;
1522 break;
1523 }
1524 Changed |= deduplicateRuntimeCalls(
1525 *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1526 }
1527
1528 return Changed;
1529 }
1530
1531 /// Tries to remove known runtime symbols that are optional from the module.
1532 bool removeRuntimeSymbols() {
1533 // The RPC client symbol is defined in `libc` and indicates that something
1534 // required an RPC server. If its users were all optimized out then we can
1535 // safely remove it.
1536 // TODO: This should be somewhere more common in the future.
1537 if (GlobalVariable *GV = M.getNamedGlobal("__llvm_rpc_client")) {
1538 if (GV->getNumUses() >= 1)
1539 return false;
1540
1541 GV->replaceAllUsesWith(PoisonValue::get(GV->getType()));
1542 GV->eraseFromParent();
1543 return true;
1544 }
1545 return false;
1546 }
1547
1548 /// Tries to hide the latency of runtime calls that involve host to
1549 /// device memory transfers by splitting them into their "issue" and "wait"
1550 /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1551 /// moved downards as much as possible. The "issue" issues the memory transfer
1552 /// asynchronously, returning a handle. The "wait" waits in the returned
1553 /// handle for the memory transfer to finish.
1554 bool hideMemTransfersLatency() {
1555 auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1556 bool Changed = false;
1557 auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1558 auto *RTCall = getCallIfRegularCall(U, &RFI);
1559 if (!RTCall)
1560 return false;
1561
1562 OffloadArray OffloadArrays[3];
1563 if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1564 return false;
1565
1566 LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1567
1568 // TODO: Check if can be moved upwards.
1569 bool WasSplit = false;
1570 Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1571 if (WaitMovementPoint)
1572 WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1573
1574 Changed |= WasSplit;
1575 return WasSplit;
1576 };
1577 if (OMPInfoCache.runtimeFnsAvailable(
1578 {OMPRTL___tgt_target_data_begin_mapper_issue,
1579 OMPRTL___tgt_target_data_begin_mapper_wait}))
1580 RFI.foreachUse(SCC, SplitMemTransfers);
1581
1582 return Changed;
1583 }
1584
1585 void analysisGlobalization() {
1586 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1587
1588 auto CheckGlobalization = [&](Use &U, Function &Decl) {
1589 if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1590 auto Remark = [&](OptimizationRemarkMissed ORM) {
1591 return ORM
1592 << "Found thread data sharing on the GPU. "
1593 << "Expect degraded performance due to data globalization.";
1594 };
1595 emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark);
1596 }
1597
1598 return false;
1599 };
1600
1601 RFI.foreachUse(SCC, CheckGlobalization);
1602 }
1603
1604 /// Maps the values stored in the offload arrays passed as arguments to
1605 /// \p RuntimeCall into the offload arrays in \p OAs.
1606 bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1608 assert(OAs.size() == 3 && "Need space for three offload arrays!");
1609
1610 // A runtime call that involves memory offloading looks something like:
1611 // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1612 // i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1613 // ...)
1614 // So, the idea is to access the allocas that allocate space for these
1615 // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1616 // Therefore:
1617 // i8** %offload_baseptrs.
1618 Value *BasePtrsArg =
1619 RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1620 // i8** %offload_ptrs.
1621 Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1622 // i8** %offload_sizes.
1623 Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1624
1625 // Get values stored in **offload_baseptrs.
1626 auto *V = getUnderlyingObject(BasePtrsArg);
1627 if (!isa<AllocaInst>(V))
1628 return false;
1629 auto *BasePtrsArray = cast<AllocaInst>(V);
1630 if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1631 return false;
1632
1633 // Get values stored in **offload_baseptrs.
1634 V = getUnderlyingObject(PtrsArg);
1635 if (!isa<AllocaInst>(V))
1636 return false;
1637 auto *PtrsArray = cast<AllocaInst>(V);
1638 if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1639 return false;
1640
1641 // Get values stored in **offload_sizes.
1642 V = getUnderlyingObject(SizesArg);
1643 // If it's a [constant] global array don't analyze it.
1644 if (isa<GlobalValue>(V))
1645 return isa<Constant>(V);
1646 if (!isa<AllocaInst>(V))
1647 return false;
1648
1649 auto *SizesArray = cast<AllocaInst>(V);
1650 if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1651 return false;
1652
1653 return true;
1654 }
1655
1656 /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1657 /// For now this is a way to test that the function getValuesInOffloadArrays
1658 /// is working properly.
1659 /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1660 void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1661 assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1662
1663 LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1664 std::string ValuesStr;
1665 raw_string_ostream Printer(ValuesStr);
1666 std::string Separator = " --- ";
1667
1668 for (auto *BP : OAs[0].StoredValues) {
1669 BP->print(Printer);
1670 Printer << Separator;
1671 }
1672 LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << ValuesStr << "\n");
1673 ValuesStr.clear();
1674
1675 for (auto *P : OAs[1].StoredValues) {
1676 P->print(Printer);
1677 Printer << Separator;
1678 }
1679 LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << ValuesStr << "\n");
1680 ValuesStr.clear();
1681
1682 for (auto *S : OAs[2].StoredValues) {
1683 S->print(Printer);
1684 Printer << Separator;
1685 }
1686 LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << ValuesStr << "\n");
1687 }
1688
1689 /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1690 /// moved. Returns nullptr if the movement is not possible, or not worth it.
1691 Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1692 // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1693 // Make it traverse the CFG.
1694
1695 Instruction *CurrentI = &RuntimeCall;
1696 bool IsWorthIt = false;
1697 while ((CurrentI = CurrentI->getNextNode())) {
1698
1699 // TODO: Once we detect the regions to be offloaded we should use the
1700 // alias analysis manager to check if CurrentI may modify one of
1701 // the offloaded regions.
1702 if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1703 if (IsWorthIt)
1704 return CurrentI;
1705
1706 return nullptr;
1707 }
1708
1709 // FIXME: For now if we move it over anything without side effect
1710 // is worth it.
1711 IsWorthIt = true;
1712 }
1713
1714 // Return end of BasicBlock.
1715 return RuntimeCall.getParent()->getTerminator();
1716 }
1717
1718 /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1719 bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1720 Instruction &WaitMovementPoint) {
1721 // Create stack allocated handle (__tgt_async_info) at the beginning of the
1722 // function. Used for storing information of the async transfer, allowing to
1723 // wait on it later.
1724 auto &IRBuilder = OMPInfoCache.OMPBuilder;
1725 Function *F = RuntimeCall.getCaller();
1726 BasicBlock &Entry = F->getEntryBlock();
1727 IRBuilder.Builder.SetInsertPoint(&Entry,
1728 Entry.getFirstNonPHIOrDbgOrAlloca());
1729 Value *Handle = IRBuilder.Builder.CreateAlloca(
1730 IRBuilder.AsyncInfo, /*ArraySize=*/nullptr, "handle");
1731 Handle =
1732 IRBuilder.Builder.CreateAddrSpaceCast(Handle, IRBuilder.AsyncInfoPtr);
1733
1734 // Add "issue" runtime call declaration:
1735 // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1736 // i8**, i8**, i64*, i64*)
1737 FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1738 M, OMPRTL___tgt_target_data_begin_mapper_issue);
1739
1740 // Change RuntimeCall call site for its asynchronous version.
1742 for (auto &Arg : RuntimeCall.args())
1743 Args.push_back(Arg.get());
1744 Args.push_back(Handle);
1745
1746 CallInst *IssueCallsite = CallInst::Create(IssueDecl, Args, /*NameStr=*/"",
1747 RuntimeCall.getIterator());
1748 OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);
1749 RuntimeCall.eraseFromParent();
1750
1751 // Add "wait" runtime call declaration:
1752 // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1753 FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1754 M, OMPRTL___tgt_target_data_begin_mapper_wait);
1755
1756 Value *WaitParams[2] = {
1757 IssueCallsite->getArgOperand(
1758 OffloadArray::DeviceIDArgNum), // device_id.
1759 Handle // handle to wait on.
1760 };
1761 CallInst *WaitCallsite = CallInst::Create(
1762 WaitDecl, WaitParams, /*NameStr=*/"", WaitMovementPoint.getIterator());
1763 OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);
1764
1765 return true;
1766 }
1767
1768 static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1769 bool GlobalOnly, bool &SingleChoice) {
1770 if (CurrentIdent == NextIdent)
1771 return CurrentIdent;
1772
1773 // TODO: Figure out how to actually combine multiple debug locations. For
1774 // now we just keep an existing one if there is a single choice.
1775 if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1776 SingleChoice = !CurrentIdent;
1777 return NextIdent;
1778 }
1779 return nullptr;
1780 }
1781
1782 /// Return an `struct ident_t*` value that represents the ones used in the
1783 /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1784 /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1785 /// return value we create one from scratch. We also do not yet combine
1786 /// information, e.g., the source locations, see combinedIdentStruct.
1787 Value *
1788 getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1789 Function &F, bool GlobalOnly) {
1790 bool SingleChoice = true;
1791 Value *Ident = nullptr;
1792 auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1793 CallInst *CI = getCallIfRegularCall(U, &RFI);
1794 if (!CI || &F != &Caller)
1795 return false;
1796 Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1797 /* GlobalOnly */ true, SingleChoice);
1798 return false;
1799 };
1800 RFI.foreachUse(SCC, CombineIdentStruct);
1801
1802 if (!Ident || !SingleChoice) {
1803 // The IRBuilder uses the insertion block to get to the module, this is
1804 // unfortunate but we work around it for now.
1805 if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1806 OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1807 &F.getEntryBlock(), F.getEntryBlock().begin()));
1808 // Create a fallback location if non was found.
1809 // TODO: Use the debug locations of the calls instead.
1810 uint32_t SrcLocStrSize;
1811 Constant *Loc =
1812 OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1813 Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize);
1814 }
1815 return Ident;
1816 }
1817
1818 /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1819 /// \p ReplVal if given.
1820 bool deduplicateRuntimeCalls(Function &F,
1821 OMPInformationCache::RuntimeFunctionInfo &RFI,
1822 Value *ReplVal = nullptr) {
1823 auto *UV = RFI.getUseVector(F);
1824 if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1825 return false;
1826
1827 LLVM_DEBUG(
1828 dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1829 << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1830
1831 assert((!ReplVal || (isa<Argument>(ReplVal) &&
1832 cast<Argument>(ReplVal)->getParent() == &F)) &&
1833 "Unexpected replacement value!");
1834
1835 // TODO: Use dominance to find a good position instead.
1836 auto CanBeMoved = [this](CallBase &CB) {
1837 unsigned NumArgs = CB.arg_size();
1838 if (NumArgs == 0)
1839 return true;
1840 if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1841 return false;
1842 for (unsigned U = 1; U < NumArgs; ++U)
1843 if (isa<Instruction>(CB.getArgOperand(U)))
1844 return false;
1845 return true;
1846 };
1847
1848 if (!ReplVal) {
1849 auto *DT =
1850 OMPInfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(F);
1851 if (!DT)
1852 return false;
1853 Instruction *IP = nullptr;
1854 for (Use *U : *UV) {
1855 if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1856 if (IP)
1857 IP = DT->findNearestCommonDominator(IP, CI);
1858 else
1859 IP = CI;
1860 if (!CanBeMoved(*CI))
1861 continue;
1862 if (!ReplVal)
1863 ReplVal = CI;
1864 }
1865 }
1866 if (!ReplVal)
1867 return false;
1868 assert(IP && "Expected insertion point!");
1869 cast<Instruction>(ReplVal)->moveBefore(IP->getIterator());
1870 }
1871
1872 // If we use a call as a replacement value we need to make sure the ident is
1873 // valid at the new location. For now we just pick a global one, either
1874 // existing and used by one of the calls, or created from scratch.
1875 if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1876 if (!CI->arg_empty() &&
1877 CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1878 Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1879 /* GlobalOnly */ true);
1880 CI->setArgOperand(0, Ident);
1881 }
1882 }
1883
1884 bool Changed = false;
1885 auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1886 CallInst *CI = getCallIfRegularCall(U, &RFI);
1887 if (!CI || CI == ReplVal || &F != &Caller)
1888 return false;
1889 assert(CI->getCaller() == &F && "Unexpected call!");
1890
1891 auto Remark = [&](OptimizationRemark OR) {
1892 return OR << "OpenMP runtime call "
1893 << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1894 };
1895 if (CI->getDebugLoc())
1896 emitRemark<OptimizationRemark>(CI, "OMP170", Remark);
1897 else
1898 emitRemark<OptimizationRemark>(&F, "OMP170", Remark);
1899
1900 CI->replaceAllUsesWith(ReplVal);
1901 CI->eraseFromParent();
1902 ++NumOpenMPRuntimeCallsDeduplicated;
1903 Changed = true;
1904 return true;
1905 };
1906 RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1907
1908 return Changed;
1909 }
1910
1911 /// Collect arguments that represent the global thread id in \p GTIdArgs.
1912 void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1913 // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1914 // initialization. We could define an AbstractAttribute instead and
1915 // run the Attributor here once it can be run as an SCC pass.
1916
1917 // Helper to check the argument \p ArgNo at all call sites of \p F for
1918 // a GTId.
1919 auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1920 if (!F.hasLocalLinkage())
1921 return false;
1922 for (Use &U : F.uses()) {
1923 if (CallInst *CI = getCallIfRegularCall(U)) {
1924 Value *ArgOp = CI->getArgOperand(ArgNo);
1925 if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1926 getCallIfRegularCall(
1927 *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1928 continue;
1929 }
1930 return false;
1931 }
1932 return true;
1933 };
1934
1935 // Helper to identify uses of a GTId as GTId arguments.
1936 auto AddUserArgs = [&](Value &GTId) {
1937 for (Use &U : GTId.uses())
1938 if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1939 if (CI->isArgOperand(&U))
1940 if (Function *Callee = CI->getCalledFunction())
1941 if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1942 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1943 };
1944
1945 // The argument users of __kmpc_global_thread_num calls are GTIds.
1946 OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1947 OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1948
1949 GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1950 if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1951 AddUserArgs(*CI);
1952 return false;
1953 });
1954
1955 // Transitively search for more arguments by looking at the users of the
1956 // ones we know already. During the search the GTIdArgs vector is extended
1957 // so we cannot cache the size nor can we use a range based for.
1958 for (unsigned U = 0; U < GTIdArgs.size(); ++U)
1959 AddUserArgs(*GTIdArgs[U]);
1960 }
1961
1962 /// Kernel (=GPU) optimizations and utility functions
1963 ///
1964 ///{{
1965
1966 /// Cache to remember the unique kernel for a function.
1968
1969 /// Find the unique kernel that will execute \p F, if any.
1970 Kernel getUniqueKernelFor(Function &F);
1971
1972 /// Find the unique kernel that will execute \p I, if any.
1973 Kernel getUniqueKernelFor(Instruction &I) {
1974 return getUniqueKernelFor(*I.getFunction());
1975 }
1976
1977 /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1978 /// the cases we can avoid taking the address of a function.
1979 bool rewriteDeviceCodeStateMachine();
1980
1981 ///
1982 ///}}
1983
1984 /// Emit a remark generically
1985 ///
1986 /// This template function can be used to generically emit a remark. The
1987 /// RemarkKind should be one of the following:
1988 /// - OptimizationRemark to indicate a successful optimization attempt
1989 /// - OptimizationRemarkMissed to report a failed optimization attempt
1990 /// - OptimizationRemarkAnalysis to provide additional information about an
1991 /// optimization attempt
1992 ///
1993 /// The remark is built using a callback function provided by the caller that
1994 /// takes a RemarkKind as input and returns a RemarkKind.
1995 template <typename RemarkKind, typename RemarkCallBack>
1996 void emitRemark(Instruction *I, StringRef RemarkName,
1997 RemarkCallBack &&RemarkCB) const {
1998 Function *F = I->getParent()->getParent();
1999 auto &ORE = OREGetter(F);
2000
2001 if (RemarkName.starts_with("OMP"))
2002 ORE.emit([&]() {
2003 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
2004 << " [" << RemarkName << "]";
2005 });
2006 else
2007 ORE.emit(
2008 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
2009 }
2010
2011 /// Emit a remark on a function.
2012 template <typename RemarkKind, typename RemarkCallBack>
2013 void emitRemark(Function *F, StringRef RemarkName,
2014 RemarkCallBack &&RemarkCB) const {
2015 auto &ORE = OREGetter(F);
2016
2017 if (RemarkName.starts_with("OMP"))
2018 ORE.emit([&]() {
2019 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
2020 << " [" << RemarkName << "]";
2021 });
2022 else
2023 ORE.emit(
2024 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
2025 }
2026
2027 /// The underlying module.
2028 Module &M;
2029
2030 /// The SCC we are operating on.
2032
2033 /// Callback to update the call graph, the first argument is a removed call,
2034 /// the second an optional replacement call.
2035 CallGraphUpdater &CGUpdater;
2036
2037 /// Callback to get an OptimizationRemarkEmitter from a Function *
2038 OptimizationRemarkGetter OREGetter;
2039
2040 /// OpenMP-specific information cache. Also Used for Attributor runs.
2041 OMPInformationCache &OMPInfoCache;
2042
2043 /// Attributor instance.
2044 Attributor &A;
2045
2046 /// Helper function to run Attributor on SCC.
2047 bool runAttributor(bool IsModulePass) {
2048 if (SCC.empty())
2049 return false;
2050
2051 registerAAs(IsModulePass);
2052
2053 ChangeStatus Changed = A.run();
2054
2055 LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
2056 << " functions, result: " << Changed << ".\n");
2057
2058 if (Changed == ChangeStatus::CHANGED)
2059 OMPInfoCache.invalidateAnalyses();
2060
2061 return Changed == ChangeStatus::CHANGED;
2062 }
2063
2064 void registerFoldRuntimeCall(RuntimeFunction RF);
2065
2066 /// Populate the Attributor with abstract attribute opportunities in the
2067 /// functions.
2068 void registerAAs(bool IsModulePass);
2069
2070public:
2071 /// Callback to register AAs for live functions, including internal functions
2072 /// marked live during the traversal.
2073 static void registerAAsForFunction(Attributor &A, const Function &F);
2074};
2075
2076Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
2077 if (OMPInfoCache.CGSCC && !OMPInfoCache.CGSCC->empty() &&
2078 !OMPInfoCache.CGSCC->contains(&F))
2079 return nullptr;
2080
2081 // Use a scope to keep the lifetime of the CachedKernel short.
2082 {
2083 std::optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
2084 if (CachedKernel)
2085 return *CachedKernel;
2086
2087 // TODO: We should use an AA to create an (optimistic and callback
2088 // call-aware) call graph. For now we stick to simple patterns that
2089 // are less powerful, basically the worst fixpoint.
2090 if (isOpenMPKernel(F)) {
2091 CachedKernel = Kernel(&F);
2092 return *CachedKernel;
2093 }
2094
2095 CachedKernel = nullptr;
2096 if (!F.hasLocalLinkage()) {
2097
2098 // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
2099 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2100 return ORA << "Potentially unknown OpenMP target region caller.";
2101 };
2102 emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
2103
2104 return nullptr;
2105 }
2106 }
2107
2108 auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
2109 if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
2110 // Allow use in equality comparisons.
2111 if (Cmp->isEquality())
2112 return getUniqueKernelFor(*Cmp);
2113 return nullptr;
2114 }
2115 if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
2116 // Allow direct calls.
2117 if (CB->isCallee(&U))
2118 return getUniqueKernelFor(*CB);
2119
2120 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2121 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2122 // Allow the use in __kmpc_parallel_51 calls.
2123 if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
2124 return getUniqueKernelFor(*CB);
2125 return nullptr;
2126 }
2127 // Disallow every other use.
2128 return nullptr;
2129 };
2130
2131 // TODO: In the future we want to track more than just a unique kernel.
2132 SmallPtrSet<Kernel, 2> PotentialKernels;
2133 OMPInformationCache::foreachUse(F, [&](const Use &U) {
2134 PotentialKernels.insert(GetUniqueKernelForUse(U));
2135 });
2136
2137 Kernel K = nullptr;
2138 if (PotentialKernels.size() == 1)
2139 K = *PotentialKernels.begin();
2140
2141 // Cache the result.
2142 UniqueKernelMap[&F] = K;
2143
2144 return K;
2145}
2146
2147bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
2148 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2149 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2150
2151 bool Changed = false;
2152 if (!KernelParallelRFI)
2153 return Changed;
2154
2155 // If we have disabled state machine changes, exit
2157 return Changed;
2158
2159 for (Function *F : SCC) {
2160
2161 // Check if the function is a use in a __kmpc_parallel_51 call at
2162 // all.
2163 bool UnknownUse = false;
2164 bool KernelParallelUse = false;
2165 unsigned NumDirectCalls = 0;
2166
2167 SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
2168 OMPInformationCache::foreachUse(*F, [&](Use &U) {
2169 if (auto *CB = dyn_cast<CallBase>(U.getUser()))
2170 if (CB->isCallee(&U)) {
2171 ++NumDirectCalls;
2172 return;
2173 }
2174
2175 if (isa<ICmpInst>(U.getUser())) {
2176 ToBeReplacedStateMachineUses.push_back(&U);
2177 return;
2178 }
2179
2180 // Find wrapper functions that represent parallel kernels.
2181 CallInst *CI =
2182 OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
2183 const unsigned int WrapperFunctionArgNo = 6;
2184 if (!KernelParallelUse && CI &&
2185 CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
2186 KernelParallelUse = true;
2187 ToBeReplacedStateMachineUses.push_back(&U);
2188 return;
2189 }
2190 UnknownUse = true;
2191 });
2192
2193 // Do not emit a remark if we haven't seen a __kmpc_parallel_51
2194 // use.
2195 if (!KernelParallelUse)
2196 continue;
2197
2198 // If this ever hits, we should investigate.
2199 // TODO: Checking the number of uses is not a necessary restriction and
2200 // should be lifted.
2201 if (UnknownUse || NumDirectCalls != 1 ||
2202 ToBeReplacedStateMachineUses.size() > 2) {
2203 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2204 return ORA << "Parallel region is used in "
2205 << (UnknownUse ? "unknown" : "unexpected")
2206 << " ways. Will not attempt to rewrite the state machine.";
2207 };
2208 emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark);
2209 continue;
2210 }
2211
2212 // Even if we have __kmpc_parallel_51 calls, we (for now) give
2213 // up if the function is not called from a unique kernel.
2214 Kernel K = getUniqueKernelFor(*F);
2215 if (!K) {
2216 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2217 return ORA << "Parallel region is not called from a unique kernel. "
2218 "Will not attempt to rewrite the state machine.";
2219 };
2220 emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark);
2221 continue;
2222 }
2223
2224 // We now know F is a parallel body function called only from the kernel K.
2225 // We also identified the state machine uses in which we replace the
2226 // function pointer by a new global symbol for identification purposes. This
2227 // ensures only direct calls to the function are left.
2228
2229 Module &M = *F->getParent();
2230 Type *Int8Ty = Type::getInt8Ty(M.getContext());
2231
2232 auto *ID = new GlobalVariable(
2233 M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2234 UndefValue::get(Int8Ty), F->getName() + ".ID");
2235
2236 for (Use *U : ToBeReplacedStateMachineUses)
2238 ID, U->get()->getType()));
2239
2240 ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2241
2242 Changed = true;
2243 }
2244
2245 return Changed;
2246}
2247
2248/// Abstract Attribute for tracking ICV values.
2249struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2251 AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2252
2253 /// Returns true if value is assumed to be tracked.
2254 bool isAssumedTracked() const { return getAssumed(); }
2255
2256 /// Returns true if value is known to be tracked.
2257 bool isKnownTracked() const { return getAssumed(); }
2258
2259 /// Create an abstract attribute biew for the position \p IRP.
2260 static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2261
2262 /// Return the value with which \p I can be replaced for specific \p ICV.
2263 virtual std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2264 const Instruction *I,
2265 Attributor &A) const {
2266 return std::nullopt;
2267 }
2268
2269 /// Return an assumed unique ICV value if a single candidate is found. If
2270 /// there cannot be one, return a nullptr. If it is not clear yet, return
2271 /// std::nullopt.
2272 virtual std::optional<Value *>
2273 getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2274
2275 // Currently only nthreads is being tracked.
2276 // this array will only grow with time.
2277 InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2278
2279 /// See AbstractAttribute::getName()
2280 const std::string getName() const override { return "AAICVTracker"; }
2281
2282 /// See AbstractAttribute::getIdAddr()
2283 const char *getIdAddr() const override { return &ID; }
2284
2285 /// This function should return true if the type of the \p AA is AAICVTracker
2286 static bool classof(const AbstractAttribute *AA) {
2287 return (AA->getIdAddr() == &ID);
2288 }
2289
2290 static const char ID;
2291};
2292
2293struct AAICVTrackerFunction : public AAICVTracker {
2294 AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2295 : AAICVTracker(IRP, A) {}
2296
2297 // FIXME: come up with better string.
2298 const std::string getAsStr(Attributor *) const override {
2299 return "ICVTrackerFunction";
2300 }
2301
2302 // FIXME: come up with some stats.
2303 void trackStatistics() const override {}
2304
2305 /// We don't manifest anything for this AA.
2306 ChangeStatus manifest(Attributor &A) override {
2307 return ChangeStatus::UNCHANGED;
2308 }
2309
2310 // Map of ICV to their values at specific program point.
2312 InternalControlVar::ICV___last>
2313 ICVReplacementValuesMap;
2314
2315 ChangeStatus updateImpl(Attributor &A) override {
2316 ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
2317
2318 Function *F = getAnchorScope();
2319
2320 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2321
2322 for (InternalControlVar ICV : TrackableICVs) {
2323 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2324
2325 auto &ValuesMap = ICVReplacementValuesMap[ICV];
2326 auto TrackValues = [&](Use &U, Function &) {
2327 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2328 if (!CI)
2329 return false;
2330
2331 // FIXME: handle setters with more that 1 arguments.
2332 /// Track new value.
2333 if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2334 HasChanged = ChangeStatus::CHANGED;
2335
2336 return false;
2337 };
2338
2339 auto CallCheck = [&](Instruction &I) {
2340 std::optional<Value *> ReplVal = getValueForCall(A, I, ICV);
2341 if (ReplVal && ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2342 HasChanged = ChangeStatus::CHANGED;
2343
2344 return true;
2345 };
2346
2347 // Track all changes of an ICV.
2348 SetterRFI.foreachUse(TrackValues, F);
2349
2350 bool UsedAssumedInformation = false;
2351 A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2352 UsedAssumedInformation,
2353 /* CheckBBLivenessOnly */ true);
2354
2355 /// TODO: Figure out a way to avoid adding entry in
2356 /// ICVReplacementValuesMap
2357 Instruction *Entry = &F->getEntryBlock().front();
2358 if (HasChanged == ChangeStatus::CHANGED)
2359 ValuesMap.try_emplace(Entry);
2360 }
2361
2362 return HasChanged;
2363 }
2364
2365 /// Helper to check if \p I is a call and get the value for it if it is
2366 /// unique.
2367 std::optional<Value *> getValueForCall(Attributor &A, const Instruction &I,
2368 InternalControlVar &ICV) const {
2369
2370 const auto *CB = dyn_cast<CallBase>(&I);
2371 if (!CB || CB->hasFnAttr("no_openmp") ||
2372 CB->hasFnAttr("no_openmp_routines"))
2373 return std::nullopt;
2374
2375 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2376 auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2377 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2378 Function *CalledFunction = CB->getCalledFunction();
2379
2380 // Indirect call, assume ICV changes.
2381 if (CalledFunction == nullptr)
2382 return nullptr;
2383 if (CalledFunction == GetterRFI.Declaration)
2384 return std::nullopt;
2385 if (CalledFunction == SetterRFI.Declaration) {
2386 if (ICVReplacementValuesMap[ICV].count(&I))
2387 return ICVReplacementValuesMap[ICV].lookup(&I);
2388
2389 return nullptr;
2390 }
2391
2392 // Since we don't know, assume it changes the ICV.
2393 if (CalledFunction->isDeclaration())
2394 return nullptr;
2395
2396 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2397 *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
2398
2399 if (ICVTrackingAA->isAssumedTracked()) {
2400 std::optional<Value *> URV =
2401 ICVTrackingAA->getUniqueReplacementValue(ICV);
2402 if (!URV || (*URV && AA::isValidAtPosition(AA::ValueAndContext(**URV, I),
2403 OMPInfoCache)))
2404 return URV;
2405 }
2406
2407 // If we don't know, assume it changes.
2408 return nullptr;
2409 }
2410
2411 // We don't check unique value for a function, so return std::nullopt.
2412 std::optional<Value *>
2413 getUniqueReplacementValue(InternalControlVar ICV) const override {
2414 return std::nullopt;
2415 }
2416
2417 /// Return the value with which \p I can be replaced for specific \p ICV.
2418 std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2419 const Instruction *I,
2420 Attributor &A) const override {
2421 const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2422 if (ValuesMap.count(I))
2423 return ValuesMap.lookup(I);
2424
2427 Worklist.push_back(I);
2428
2429 std::optional<Value *> ReplVal;
2430
2431 while (!Worklist.empty()) {
2432 const Instruction *CurrInst = Worklist.pop_back_val();
2433 if (!Visited.insert(CurrInst).second)
2434 continue;
2435
2436 const BasicBlock *CurrBB = CurrInst->getParent();
2437
2438 // Go up and look for all potential setters/calls that might change the
2439 // ICV.
2440 while ((CurrInst = CurrInst->getPrevNode())) {
2441 if (ValuesMap.count(CurrInst)) {
2442 std::optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2443 // Unknown value, track new.
2444 if (!ReplVal) {
2445 ReplVal = NewReplVal;
2446 break;
2447 }
2448
2449 // If we found a new value, we can't know the icv value anymore.
2450 if (NewReplVal)
2451 if (ReplVal != NewReplVal)
2452 return nullptr;
2453
2454 break;
2455 }
2456
2457 std::optional<Value *> NewReplVal = getValueForCall(A, *CurrInst, ICV);
2458 if (!NewReplVal)
2459 continue;
2460
2461 // Unknown value, track new.
2462 if (!ReplVal) {
2463 ReplVal = NewReplVal;
2464 break;
2465 }
2466
2467 // if (NewReplVal.hasValue())
2468 // We found a new value, we can't know the icv value anymore.
2469 if (ReplVal != NewReplVal)
2470 return nullptr;
2471 }
2472
2473 // If we are in the same BB and we have a value, we are done.
2474 if (CurrBB == I->getParent() && ReplVal)
2475 return ReplVal;
2476
2477 // Go through all predecessors and add terminators for analysis.
2478 for (const BasicBlock *Pred : predecessors(CurrBB))
2479 if (const Instruction *Terminator = Pred->getTerminator())
2480 Worklist.push_back(Terminator);
2481 }
2482
2483 return ReplVal;
2484 }
2485};
2486
2487struct AAICVTrackerFunctionReturned : AAICVTracker {
2488 AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2489 : AAICVTracker(IRP, A) {}
2490
2491 // FIXME: come up with better string.
2492 const std::string getAsStr(Attributor *) const override {
2493 return "ICVTrackerFunctionReturned";
2494 }
2495
2496 // FIXME: come up with some stats.
2497 void trackStatistics() const override {}
2498
2499 /// We don't manifest anything for this AA.
2500 ChangeStatus manifest(Attributor &A) override {
2501 return ChangeStatus::UNCHANGED;
2502 }
2503
2504 // Map of ICV to their values at specific program point.
2506 InternalControlVar::ICV___last>
2507 ICVReplacementValuesMap;
2508
2509 /// Return the value with which \p I can be replaced for specific \p ICV.
2510 std::optional<Value *>
2511 getUniqueReplacementValue(InternalControlVar ICV) const override {
2512 return ICVReplacementValuesMap[ICV];
2513 }
2514
2515 ChangeStatus updateImpl(Attributor &A) override {
2516 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2517 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2518 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2519
2520 if (!ICVTrackingAA->isAssumedTracked())
2521 return indicatePessimisticFixpoint();
2522
2523 for (InternalControlVar ICV : TrackableICVs) {
2524 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2525 std::optional<Value *> UniqueICVValue;
2526
2527 auto CheckReturnInst = [&](Instruction &I) {
2528 std::optional<Value *> NewReplVal =
2529 ICVTrackingAA->getReplacementValue(ICV, &I, A);
2530
2531 // If we found a second ICV value there is no unique returned value.
2532 if (UniqueICVValue && UniqueICVValue != NewReplVal)
2533 return false;
2534
2535 UniqueICVValue = NewReplVal;
2536
2537 return true;
2538 };
2539
2540 bool UsedAssumedInformation = false;
2541 if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2542 UsedAssumedInformation,
2543 /* CheckBBLivenessOnly */ true))
2544 UniqueICVValue = nullptr;
2545
2546 if (UniqueICVValue == ReplVal)
2547 continue;
2548
2549 ReplVal = UniqueICVValue;
2550 Changed = ChangeStatus::CHANGED;
2551 }
2552
2553 return Changed;
2554 }
2555};
2556
2557struct AAICVTrackerCallSite : AAICVTracker {
2558 AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2559 : AAICVTracker(IRP, A) {}
2560
2561 void initialize(Attributor &A) override {
2562 assert(getAnchorScope() && "Expected anchor function");
2563
2564 // We only initialize this AA for getters, so we need to know which ICV it
2565 // gets.
2566 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2567 for (InternalControlVar ICV : TrackableICVs) {
2568 auto ICVInfo = OMPInfoCache.ICVs[ICV];
2569 auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2570 if (Getter.Declaration == getAssociatedFunction()) {
2571 AssociatedICV = ICVInfo.Kind;
2572 return;
2573 }
2574 }
2575
2576 /// Unknown ICV.
2577 indicatePessimisticFixpoint();
2578 }
2579
2580 ChangeStatus manifest(Attributor &A) override {
2581 if (!ReplVal || !*ReplVal)
2582 return ChangeStatus::UNCHANGED;
2583
2584 A.changeAfterManifest(IRPosition::inst(*getCtxI()), **ReplVal);
2585 A.deleteAfterManifest(*getCtxI());
2586
2587 return ChangeStatus::CHANGED;
2588 }
2589
2590 // FIXME: come up with better string.
2591 const std::string getAsStr(Attributor *) const override {
2592 return "ICVTrackerCallSite";
2593 }
2594
2595 // FIXME: come up with some stats.
2596 void trackStatistics() const override {}
2597
2598 InternalControlVar AssociatedICV;
2599 std::optional<Value *> ReplVal;
2600
2601 ChangeStatus updateImpl(Attributor &A) override {
2602 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2603 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2604
2605 // We don't have any information, so we assume it changes the ICV.
2606 if (!ICVTrackingAA->isAssumedTracked())
2607 return indicatePessimisticFixpoint();
2608
2609 std::optional<Value *> NewReplVal =
2610 ICVTrackingAA->getReplacementValue(AssociatedICV, getCtxI(), A);
2611
2612 if (ReplVal == NewReplVal)
2613 return ChangeStatus::UNCHANGED;
2614
2615 ReplVal = NewReplVal;
2616 return ChangeStatus::CHANGED;
2617 }
2618
2619 // Return the value with which associated value can be replaced for specific
2620 // \p ICV.
2621 std::optional<Value *>
2622 getUniqueReplacementValue(InternalControlVar ICV) const override {
2623 return ReplVal;
2624 }
2625};
2626
2627struct AAICVTrackerCallSiteReturned : AAICVTracker {
2628 AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2629 : AAICVTracker(IRP, A) {}
2630
2631 // FIXME: come up with better string.
2632 const std::string getAsStr(Attributor *) const override {
2633 return "ICVTrackerCallSiteReturned";
2634 }
2635
2636 // FIXME: come up with some stats.
2637 void trackStatistics() const override {}
2638
2639 /// We don't manifest anything for this AA.
2640 ChangeStatus manifest(Attributor &A) override {
2641 return ChangeStatus::UNCHANGED;
2642 }
2643
2644 // Map of ICV to their values at specific program point.
2646 InternalControlVar::ICV___last>
2647 ICVReplacementValuesMap;
2648
2649 /// Return the value with which associated value can be replaced for specific
2650 /// \p ICV.
2651 std::optional<Value *>
2652 getUniqueReplacementValue(InternalControlVar ICV) const override {
2653 return ICVReplacementValuesMap[ICV];
2654 }
2655
2656 ChangeStatus updateImpl(Attributor &A) override {
2657 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2658 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2659 *this, IRPosition::returned(*getAssociatedFunction()),
2660 DepClassTy::REQUIRED);
2661
2662 // We don't have any information, so we assume it changes the ICV.
2663 if (!ICVTrackingAA->isAssumedTracked())
2664 return indicatePessimisticFixpoint();
2665
2666 for (InternalControlVar ICV : TrackableICVs) {
2667 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2668 std::optional<Value *> NewReplVal =
2669 ICVTrackingAA->getUniqueReplacementValue(ICV);
2670
2671 if (ReplVal == NewReplVal)
2672 continue;
2673
2674 ReplVal = NewReplVal;
2675 Changed = ChangeStatus::CHANGED;
2676 }
2677 return Changed;
2678 }
2679};
2680
2681/// Determines if \p BB exits the function unconditionally itself or reaches a
2682/// block that does through only unique successors.
2683static bool hasFunctionEndAsUniqueSuccessor(const BasicBlock *BB) {
2684 if (succ_empty(BB))
2685 return true;
2686 const BasicBlock *const Successor = BB->getUniqueSuccessor();
2687 if (!Successor)
2688 return false;
2689 return hasFunctionEndAsUniqueSuccessor(Successor);
2690}
2691
2692struct AAExecutionDomainFunction : public AAExecutionDomain {
2693 AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2694 : AAExecutionDomain(IRP, A) {}
2695
2696 ~AAExecutionDomainFunction() { delete RPOT; }
2697
2698 void initialize(Attributor &A) override {
2700 assert(F && "Expected anchor function");
2702 }
2703
2704 const std::string getAsStr(Attributor *) const override {
2705 unsigned TotalBlocks = 0, InitialThreadBlocks = 0, AlignedBlocks = 0;
2706 for (auto &It : BEDMap) {
2707 if (!It.getFirst())
2708 continue;
2709 TotalBlocks++;
2710 InitialThreadBlocks += It.getSecond().IsExecutedByInitialThreadOnly;
2711 AlignedBlocks += It.getSecond().IsReachedFromAlignedBarrierOnly &&
2712 It.getSecond().IsReachingAlignedBarrierOnly;
2713 }
2714 return "[AAExecutionDomain] " + std::to_string(InitialThreadBlocks) + "/" +
2715 std::to_string(AlignedBlocks) + " of " +
2716 std::to_string(TotalBlocks) +
2717 " executed by initial thread / aligned";
2718 }
2719
2720 /// See AbstractAttribute::trackStatistics().
2721 void trackStatistics() const override {}
2722
2723 ChangeStatus manifest(Attributor &A) override {
2724 LLVM_DEBUG({
2725 for (const BasicBlock &BB : *getAnchorScope()) {
2727 continue;
2728 dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2729 << BB.getName() << " is executed by a single thread.\n";
2730 }
2731 });
2732
2734
2736 return Changed;
2737
2738 SmallPtrSet<CallBase *, 16> DeletedBarriers;
2739 auto HandleAlignedBarrier = [&](CallBase *CB) {
2740 const ExecutionDomainTy &ED = CB ? CEDMap[{CB, PRE}] : BEDMap[nullptr];
2741 if (!ED.IsReachedFromAlignedBarrierOnly ||
2742 ED.EncounteredNonLocalSideEffect)
2743 return;
2744 if (!ED.EncounteredAssumes.empty() && !A.isModulePass())
2745 return;
2746
2747 // We can remove this barrier, if it is one, or aligned barriers reaching
2748 // the kernel end (if CB is nullptr). Aligned barriers reaching the kernel
2749 // end should only be removed if the kernel end is their unique successor;
2750 // otherwise, they may have side-effects that aren't accounted for in the
2751 // kernel end in their other successors. If those barriers have other
2752 // barriers reaching them, those can be transitively removed as well as
2753 // long as the kernel end is also their unique successor.
2754 if (CB) {
2755 DeletedBarriers.insert(CB);
2756 A.deleteAfterManifest(*CB);
2757 ++NumBarriersEliminated;
2758 Changed = ChangeStatus::CHANGED;
2759 } else if (!ED.AlignedBarriers.empty()) {
2760 Changed = ChangeStatus::CHANGED;
2761 SmallVector<CallBase *> Worklist(ED.AlignedBarriers.begin(),
2762 ED.AlignedBarriers.end());
2764 while (!Worklist.empty()) {
2765 CallBase *LastCB = Worklist.pop_back_val();
2766 if (!Visited.insert(LastCB))
2767 continue;
2768 if (LastCB->getFunction() != getAnchorScope())
2769 continue;
2770 if (!hasFunctionEndAsUniqueSuccessor(LastCB->getParent()))
2771 continue;
2772 if (!DeletedBarriers.count(LastCB)) {
2773 ++NumBarriersEliminated;
2774 A.deleteAfterManifest(*LastCB);
2775 continue;
2776 }
2777 // The final aligned barrier (LastCB) reaching the kernel end was
2778 // removed already. This means we can go one step further and remove
2779 // the barriers encoutered last before (LastCB).
2780 const ExecutionDomainTy &LastED = CEDMap[{LastCB, PRE}];
2781 Worklist.append(LastED.AlignedBarriers.begin(),
2782 LastED.AlignedBarriers.end());
2783 }
2784 }
2785
2786 // If we actually eliminated a barrier we need to eliminate the associated
2787 // llvm.assumes as well to avoid creating UB.
2788 if (!ED.EncounteredAssumes.empty() && (CB || !ED.AlignedBarriers.empty()))
2789 for (auto *AssumeCB : ED.EncounteredAssumes)
2790 A.deleteAfterManifest(*AssumeCB);
2791 };
2792
2793 for (auto *CB : AlignedBarriers)
2794 HandleAlignedBarrier(CB);
2795
2796 // Handle the "kernel end barrier" for kernels too.
2798 HandleAlignedBarrier(nullptr);
2799
2800 return Changed;
2801 }
2802
2803 bool isNoOpFence(const FenceInst &FI) const override {
2804 return getState().isValidState() && !NonNoOpFences.count(&FI);
2805 }
2806
2807 /// Merge barrier and assumption information from \p PredED into the successor
2808 /// \p ED.
2809 void
2810 mergeInPredecessorBarriersAndAssumptions(Attributor &A, ExecutionDomainTy &ED,
2811 const ExecutionDomainTy &PredED);
2812
2813 /// Merge all information from \p PredED into the successor \p ED. If
2814 /// \p InitialEdgeOnly is set, only the initial edge will enter the block
2815 /// represented by \p ED from this predecessor.
2816 bool mergeInPredecessor(Attributor &A, ExecutionDomainTy &ED,
2817 const ExecutionDomainTy &PredED,
2818 bool InitialEdgeOnly = false);
2819
2820 /// Accumulate information for the entry block in \p EntryBBED.
2821 bool handleCallees(Attributor &A, ExecutionDomainTy &EntryBBED);
2822
2823 /// See AbstractAttribute::updateImpl.
2825
2826 /// Query interface, see AAExecutionDomain
2827 ///{
2828 bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2829 if (!isValidState())
2830 return false;
2831 assert(BB.getParent() == getAnchorScope() && "Block is out of scope!");
2832 return BEDMap.lookup(&BB).IsExecutedByInitialThreadOnly;
2833 }
2834
2836 const Instruction &I) const override {
2837 assert(I.getFunction() == getAnchorScope() &&
2838 "Instruction is out of scope!");
2839 if (!isValidState())
2840 return false;
2841
2842 bool ForwardIsOk = true;
2843 const Instruction *CurI;
2844
2845 // Check forward until a call or the block end is reached.
2846 CurI = &I;
2847 do {
2848 auto *CB = dyn_cast<CallBase>(CurI);
2849 if (!CB)
2850 continue;
2851 if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2852 return true;
2853 const auto &It = CEDMap.find({CB, PRE});
2854 if (It == CEDMap.end())
2855 continue;
2856 if (!It->getSecond().IsReachingAlignedBarrierOnly)
2857 ForwardIsOk = false;
2858 break;
2859 } while ((CurI = CurI->getNextNonDebugInstruction()));
2860
2861 if (!CurI && !BEDMap.lookup(I.getParent()).IsReachingAlignedBarrierOnly)
2862 ForwardIsOk = false;
2863
2864 // Check backward until a call or the block beginning is reached.
2865 CurI = &I;
2866 do {
2867 auto *CB = dyn_cast<CallBase>(CurI);
2868 if (!CB)
2869 continue;
2870 if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2871 return true;
2872 const auto &It = CEDMap.find({CB, POST});
2873 if (It == CEDMap.end())
2874 continue;
2875 if (It->getSecond().IsReachedFromAlignedBarrierOnly)
2876 break;
2877 return false;
2878 } while ((CurI = CurI->getPrevNonDebugInstruction()));
2879
2880 // Delayed decision on the forward pass to allow aligned barrier detection
2881 // in the backwards traversal.
2882 if (!ForwardIsOk)
2883 return false;
2884
2885 if (!CurI) {
2886 const BasicBlock *BB = I.getParent();
2887 if (BB == &BB->getParent()->getEntryBlock())
2888 return BEDMap.lookup(nullptr).IsReachedFromAlignedBarrierOnly;
2889 if (!llvm::all_of(predecessors(BB), [&](const BasicBlock *PredBB) {
2890 return BEDMap.lookup(PredBB).IsReachedFromAlignedBarrierOnly;
2891 })) {
2892 return false;
2893 }
2894 }
2895
2896 // On neither traversal we found a anything but aligned barriers.
2897 return true;
2898 }
2899
2900 ExecutionDomainTy getExecutionDomain(const BasicBlock &BB) const override {
2901 assert(isValidState() &&
2902 "No request should be made against an invalid state!");
2903 return BEDMap.lookup(&BB);
2904 }
2905 std::pair<ExecutionDomainTy, ExecutionDomainTy>
2906 getExecutionDomain(const CallBase &CB) const override {
2907 assert(isValidState() &&
2908 "No request should be made against an invalid state!");
2909 return {CEDMap.lookup({&CB, PRE}), CEDMap.lookup({&CB, POST})};
2910 }
2911 ExecutionDomainTy getFunctionExecutionDomain() const override {
2912 assert(isValidState() &&
2913 "No request should be made against an invalid state!");
2914 return InterProceduralED;
2915 }
2916 ///}
2917
2918 // Check if the edge into the successor block contains a condition that only
2919 // lets the main thread execute it.
2920 static bool isInitialThreadOnlyEdge(Attributor &A, BranchInst *Edge,
2921 BasicBlock &SuccessorBB) {
2922 if (!Edge || !Edge->isConditional())
2923 return false;
2924 if (Edge->getSuccessor(0) != &SuccessorBB)
2925 return false;
2926
2927 auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2928 if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2929 return false;
2930
2931 ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2932 if (!C)
2933 return false;
2934
2935 // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
2936 if (C->isAllOnesValue()) {
2937 auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2938 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2939 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2940 CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2941 if (!CB)
2942 return false;
2943 ConstantStruct *KernelEnvC =
2945 ConstantInt *ExecModeC =
2946 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
2947 return ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC;
2948 }
2949
2950 if (C->isZero()) {
2951 // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x()
2952 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2953 if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2954 return true;
2955
2956 // Match: 0 == llvm.amdgcn.workitem.id.x()
2957 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2958 if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2959 return true;
2960 }
2961
2962 return false;
2963 };
2964
2965 /// Mapping containing information about the function for other AAs.
2966 ExecutionDomainTy InterProceduralED;
2967
2968 enum Direction { PRE = 0, POST = 1 };
2969 /// Mapping containing information per block.
2972 CEDMap;
2973 SmallSetVector<CallBase *, 16> AlignedBarriers;
2974
2976
2977 /// Set \p R to \V and report true if that changed \p R.
2978 static bool setAndRecord(bool &R, bool V) {
2979 bool Eq = (R == V);
2980 R = V;
2981 return !Eq;
2982 }
2983
2984 /// Collection of fences known to be non-no-opt. All fences not in this set
2985 /// can be assumed no-opt.
2987};
2988
2989void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions(
2990 Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED) {
2991 for (auto *EA : PredED.EncounteredAssumes)
2992 ED.addAssumeInst(A, *EA);
2993
2994 for (auto *AB : PredED.AlignedBarriers)
2995 ED.addAlignedBarrier(A, *AB);
2996}
2997
2998bool AAExecutionDomainFunction::mergeInPredecessor(
2999 Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED,
3000 bool InitialEdgeOnly) {
3001
3002 bool Changed = false;
3003 Changed |=
3004 setAndRecord(ED.IsExecutedByInitialThreadOnly,
3005 InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly &&
3006 ED.IsExecutedByInitialThreadOnly));
3007
3008 Changed |= setAndRecord(ED.IsReachedFromAlignedBarrierOnly,
3009 ED.IsReachedFromAlignedBarrierOnly &&
3010 PredED.IsReachedFromAlignedBarrierOnly);
3011 Changed |= setAndRecord(ED.EncounteredNonLocalSideEffect,
3012 ED.EncounteredNonLocalSideEffect |
3013 PredED.EncounteredNonLocalSideEffect);
3014 // Do not track assumptions and barriers as part of Changed.
3015 if (ED.IsReachedFromAlignedBarrierOnly)
3016 mergeInPredecessorBarriersAndAssumptions(A, ED, PredED);
3017 else
3018 ED.clearAssumeInstAndAlignedBarriers();
3019 return Changed;
3020}
3021
3022bool AAExecutionDomainFunction::handleCallees(Attributor &A,
3023 ExecutionDomainTy &EntryBBED) {
3025 auto PredForCallSite = [&](AbstractCallSite ACS) {
3026 const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3027 *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
3028 DepClassTy::OPTIONAL);
3029 if (!EDAA || !EDAA->getState().isValidState())
3030 return false;
3031 CallSiteEDs.emplace_back(
3032 EDAA->getExecutionDomain(*cast<CallBase>(ACS.getInstruction())));
3033 return true;
3034 };
3035
3036 ExecutionDomainTy ExitED;
3037 bool AllCallSitesKnown;
3038 if (A.checkForAllCallSites(PredForCallSite, *this,
3039 /* RequiresAllCallSites */ true,
3040 AllCallSitesKnown)) {
3041 for (const auto &[CSInED, CSOutED] : CallSiteEDs) {
3042 mergeInPredecessor(A, EntryBBED, CSInED);
3043 ExitED.IsReachingAlignedBarrierOnly &=
3044 CSOutED.IsReachingAlignedBarrierOnly;
3045 }
3046
3047 } else {
3048 // We could not find all predecessors, so this is either a kernel or a
3049 // function with external linkage (or with some other weird uses).
3050 if (omp::isOpenMPKernel(*getAnchorScope())) {
3051 EntryBBED.IsExecutedByInitialThreadOnly = false;
3052 EntryBBED.IsReachedFromAlignedBarrierOnly = true;
3053 EntryBBED.EncounteredNonLocalSideEffect = false;
3054 ExitED.IsReachingAlignedBarrierOnly = false;
3055 } else {
3056 EntryBBED.IsExecutedByInitialThreadOnly = false;
3057 EntryBBED.IsReachedFromAlignedBarrierOnly = false;
3058 EntryBBED.EncounteredNonLocalSideEffect = true;
3059 ExitED.IsReachingAlignedBarrierOnly = false;
3060 }
3061 }
3062
3063 bool Changed = false;
3064 auto &FnED = BEDMap[nullptr];
3065 Changed |= setAndRecord(FnED.IsReachedFromAlignedBarrierOnly,
3066 FnED.IsReachedFromAlignedBarrierOnly &
3067 EntryBBED.IsReachedFromAlignedBarrierOnly);
3068 Changed |= setAndRecord(FnED.IsReachingAlignedBarrierOnly,
3069 FnED.IsReachingAlignedBarrierOnly &
3070 ExitED.IsReachingAlignedBarrierOnly);
3071 Changed |= setAndRecord(FnED.IsExecutedByInitialThreadOnly,
3072 EntryBBED.IsExecutedByInitialThreadOnly);
3073 return Changed;
3074}
3075
3076ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
3077
3078 bool Changed = false;
3079
3080 // Helper to deal with an aligned barrier encountered during the forward
3081 // traversal. \p CB is the aligned barrier, \p ED is the execution domain when
3082 // it was encountered.
3083 auto HandleAlignedBarrier = [&](CallBase &CB, ExecutionDomainTy &ED) {
3084 Changed |= AlignedBarriers.insert(&CB);
3085 // First, update the barrier ED kept in the separate CEDMap.
3086 auto &CallInED = CEDMap[{&CB, PRE}];
3087 Changed |= mergeInPredecessor(A, CallInED, ED);
3088 CallInED.IsReachingAlignedBarrierOnly = true;
3089 // Next adjust the ED we use for the traversal.
3090 ED.EncounteredNonLocalSideEffect = false;
3091 ED.IsReachedFromAlignedBarrierOnly = true;
3092 // Aligned barrier collection has to come last.
3093 ED.clearAssumeInstAndAlignedBarriers();
3094 ED.addAlignedBarrier(A, CB);
3095 auto &CallOutED = CEDMap[{&CB, POST}];
3096 Changed |= mergeInPredecessor(A, CallOutED, ED);
3097 };
3098
3099 auto *LivenessAA =
3100 A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL);
3101
3102 Function *F = getAnchorScope();
3103 BasicBlock &EntryBB = F->getEntryBlock();
3104 bool IsKernel = omp::isOpenMPKernel(*F);
3105
3106 SmallVector<Instruction *> SyncInstWorklist;
3107 for (auto &RIt : *RPOT) {
3108 BasicBlock &BB = *RIt;
3109
3110 bool IsEntryBB = &BB == &EntryBB;
3111 // TODO: We use local reasoning since we don't have a divergence analysis
3112 // running as well. We could basically allow uniform branches here.
3113 bool AlignedBarrierLastInBlock = IsEntryBB && IsKernel;
3114 bool IsExplicitlyAligned = IsEntryBB && IsKernel;
3115 ExecutionDomainTy ED;
3116 // Propagate "incoming edges" into information about this block.
3117 if (IsEntryBB) {
3118 Changed |= handleCallees(A, ED);
3119 } else {
3120 // For live non-entry blocks we only propagate
3121 // information via live edges.
3122 if (LivenessAA && LivenessAA->isAssumedDead(&BB))
3123 continue;
3124
3125 for (auto *PredBB : predecessors(&BB)) {
3126 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, &BB))
3127 continue;
3128 bool InitialEdgeOnly = isInitialThreadOnlyEdge(
3129 A, dyn_cast<BranchInst>(PredBB->getTerminator()), BB);
3130 mergeInPredecessor(A, ED, BEDMap[PredBB], InitialEdgeOnly);
3131 }
3132 }
3133
3134 // Now we traverse the block, accumulate effects in ED and attach
3135 // information to calls.
3136 for (Instruction &I : BB) {
3137 bool UsedAssumedInformation;
3138 if (A.isAssumedDead(I, *this, LivenessAA, UsedAssumedInformation,
3139 /* CheckBBLivenessOnly */ false, DepClassTy::OPTIONAL,
3140 /* CheckForDeadStore */ true))
3141 continue;
3142
3143 // Asummes and "assume-like" (dbg, lifetime, ...) are handled first, the
3144 // former is collected the latter is ignored.
3145 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
3146 if (auto *AI = dyn_cast_or_null<AssumeInst>(II)) {
3147 ED.addAssumeInst(A, *AI);
3148 continue;
3149 }
3150 // TODO: Should we also collect and delete lifetime markers?
3151 if (II->isAssumeLikeIntrinsic())
3152 continue;
3153 }
3154
3155 if (auto *FI = dyn_cast<FenceInst>(&I)) {
3156 if (!ED.EncounteredNonLocalSideEffect) {
3157 // An aligned fence without non-local side-effects is a no-op.
3158 if (ED.IsReachedFromAlignedBarrierOnly)
3159 continue;
3160 // A non-aligned fence without non-local side-effects is a no-op
3161 // if the ordering only publishes non-local side-effects (or less).
3162 switch (FI->getOrdering()) {
3163 case AtomicOrdering::NotAtomic:
3164 continue;
3165 case AtomicOrdering::Unordered:
3166 continue;
3167 case AtomicOrdering::Monotonic:
3168 continue;
3169 case AtomicOrdering::Acquire:
3170 break;
3171 case AtomicOrdering::Release:
3172 continue;
3173 case AtomicOrdering::AcquireRelease:
3174 break;
3175 case AtomicOrdering::SequentiallyConsistent:
3176 break;
3177 };
3178 }
3179 NonNoOpFences.insert(FI);
3180 }
3181
3182 auto *CB = dyn_cast<CallBase>(&I);
3183 bool IsNoSync = AA::isNoSyncInst(A, I, *this);
3184 bool IsAlignedBarrier =
3185 !IsNoSync && CB &&
3186 AANoSync::isAlignedBarrier(*CB, AlignedBarrierLastInBlock);
3187
3188 AlignedBarrierLastInBlock &= IsNoSync;
3189 IsExplicitlyAligned &= IsNoSync;
3190
3191 // Next we check for calls. Aligned barriers are handled
3192 // explicitly, everything else is kept for the backward traversal and will
3193 // also affect our state.
3194 if (CB) {
3195 if (IsAlignedBarrier) {
3196 HandleAlignedBarrier(*CB, ED);
3197 AlignedBarrierLastInBlock = true;
3198 IsExplicitlyAligned = true;
3199 continue;
3200 }
3201
3202 // Check the pointer(s) of a memory intrinsic explicitly.
3203 if (isa<MemIntrinsic>(&I)) {
3204 if (!ED.EncounteredNonLocalSideEffect &&
3206 ED.EncounteredNonLocalSideEffect = true;
3207 if (!IsNoSync) {
3208 ED.IsReachedFromAlignedBarrierOnly = false;
3209 SyncInstWorklist.push_back(&I);
3210 }
3211 continue;
3212 }
3213
3214 // Record how we entered the call, then accumulate the effect of the
3215 // call in ED for potential use by the callee.
3216 auto &CallInED = CEDMap[{CB, PRE}];
3217 Changed |= mergeInPredecessor(A, CallInED, ED);
3218
3219 // If we have a sync-definition we can check if it starts/ends in an
3220 // aligned barrier. If we are unsure we assume any sync breaks
3221 // alignment.
3223 if (!IsNoSync && Callee && !Callee->isDeclaration()) {
3224 const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3225 *this, IRPosition::function(*Callee), DepClassTy::OPTIONAL);
3226 if (EDAA && EDAA->getState().isValidState()) {
3227 const auto &CalleeED = EDAA->getFunctionExecutionDomain();
3229 CalleeED.IsReachedFromAlignedBarrierOnly;
3230 AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly;
3231 if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly)
3232 ED.EncounteredNonLocalSideEffect |=
3233 CalleeED.EncounteredNonLocalSideEffect;
3234 else
3235 ED.EncounteredNonLocalSideEffect =
3236 CalleeED.EncounteredNonLocalSideEffect;
3237 if (!CalleeED.IsReachingAlignedBarrierOnly) {
3238 Changed |=
3239 setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3240 SyncInstWorklist.push_back(&I);
3241 }
3242 if (CalleeED.IsReachedFromAlignedBarrierOnly)
3243 mergeInPredecessorBarriersAndAssumptions(A, ED, CalleeED);
3244 auto &CallOutED = CEDMap[{CB, POST}];
3245 Changed |= mergeInPredecessor(A, CallOutED, ED);
3246 continue;
3247 }
3248 }
3249 if (!IsNoSync) {
3250 ED.IsReachedFromAlignedBarrierOnly = false;
3251 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3252 SyncInstWorklist.push_back(&I);
3253 }
3254 AlignedBarrierLastInBlock &= ED.IsReachedFromAlignedBarrierOnly;
3255 ED.EncounteredNonLocalSideEffect |= !CB->doesNotAccessMemory();
3256 auto &CallOutED = CEDMap[{CB, POST}];
3257 Changed |= mergeInPredecessor(A, CallOutED, ED);
3258 }
3259
3260 if (!I.mayHaveSideEffects() && !I.mayReadFromMemory())
3261 continue;
3262
3263 // If we have a callee we try to use fine-grained information to
3264 // determine local side-effects.
3265 if (CB) {
3266 const auto *MemAA = A.getAAFor<AAMemoryLocation>(
3267 *this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL);
3268
3269 auto AccessPred = [&](const Instruction *I, const Value *Ptr,
3272 return !AA::isPotentiallyAffectedByBarrier(A, {Ptr}, *this, I);
3273 };
3274 if (MemAA && MemAA->getState().isValidState() &&
3275 MemAA->checkForAllAccessesToMemoryKind(
3277 continue;
3278 }
3279
3280 auto &InfoCache = A.getInfoCache();
3281 if (!I.mayHaveSideEffects() && InfoCache.isOnlyUsedByAssume(I))
3282 continue;
3283
3284 if (auto *LI = dyn_cast<LoadInst>(&I))
3285 if (LI->hasMetadata(LLVMContext::MD_invariant_load))
3286 continue;
3287
3288 if (!ED.EncounteredNonLocalSideEffect &&
3290 ED.EncounteredNonLocalSideEffect = true;
3291 }
3292
3293 bool IsEndAndNotReachingAlignedBarriersOnly = false;
3294 if (!isa<UnreachableInst>(BB.getTerminator()) &&
3295 !BB.getTerminator()->getNumSuccessors()) {
3296
3297 Changed |= mergeInPredecessor(A, InterProceduralED, ED);
3298
3299 auto &FnED = BEDMap[nullptr];
3300 if (IsKernel && !IsExplicitlyAligned)
3301 FnED.IsReachingAlignedBarrierOnly = false;
3302 Changed |= mergeInPredecessor(A, FnED, ED);
3303
3304 if (!FnED.IsReachingAlignedBarrierOnly) {
3305 IsEndAndNotReachingAlignedBarriersOnly = true;
3306 SyncInstWorklist.push_back(BB.getTerminator());
3307 auto &BBED = BEDMap[&BB];
3308 Changed |= setAndRecord(BBED.IsReachingAlignedBarrierOnly, false);
3309 }
3310 }
3311
3312 ExecutionDomainTy &StoredED = BEDMap[&BB];
3313 ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly &
3314 !IsEndAndNotReachingAlignedBarriersOnly;
3315
3316 // Check if we computed anything different as part of the forward
3317 // traversal. We do not take assumptions and aligned barriers into account
3318 // as they do not influence the state we iterate. Backward traversal values
3319 // are handled later on.
3320 if (ED.IsExecutedByInitialThreadOnly !=
3321 StoredED.IsExecutedByInitialThreadOnly ||
3322 ED.IsReachedFromAlignedBarrierOnly !=
3323 StoredED.IsReachedFromAlignedBarrierOnly ||
3324 ED.EncounteredNonLocalSideEffect !=
3325 StoredED.EncounteredNonLocalSideEffect)
3326 Changed = true;
3327
3328 // Update the state with the new value.
3329 StoredED = std::move(ED);
3330 }
3331
3332 // Propagate (non-aligned) sync instruction effects backwards until the
3333 // entry is hit or an aligned barrier.
3335 while (!SyncInstWorklist.empty()) {
3336 Instruction *SyncInst = SyncInstWorklist.pop_back_val();
3337 Instruction *CurInst = SyncInst;
3338 bool HitAlignedBarrierOrKnownEnd = false;
3339 while ((CurInst = CurInst->getPrevNode())) {
3340 auto *CB = dyn_cast<CallBase>(CurInst);
3341 if (!CB)
3342 continue;
3343 auto &CallOutED = CEDMap[{CB, POST}];
3344 Changed |= setAndRecord(CallOutED.IsReachingAlignedBarrierOnly, false);
3345 auto &CallInED = CEDMap[{CB, PRE}];
3346 HitAlignedBarrierOrKnownEnd =
3347 AlignedBarriers.count(CB) || !CallInED.IsReachingAlignedBarrierOnly;
3348 if (HitAlignedBarrierOrKnownEnd)
3349 break;
3350 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3351 }
3352 if (HitAlignedBarrierOrKnownEnd)
3353 continue;
3354 BasicBlock *SyncBB = SyncInst->getParent();
3355 for (auto *PredBB : predecessors(SyncBB)) {
3356 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, SyncBB))
3357 continue;
3358 if (!Visited.insert(PredBB))
3359 continue;
3360 auto &PredED = BEDMap[PredBB];
3361 if (setAndRecord(PredED.IsReachingAlignedBarrierOnly, false)) {
3362 Changed = true;
3363 SyncInstWorklist.push_back(PredBB->getTerminator());
3364 }
3365 }
3366 if (SyncBB != &EntryBB)
3367 continue;
3368 Changed |=
3369 setAndRecord(InterProceduralED.IsReachingAlignedBarrierOnly, false);
3370 }
3371
3372 return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
3373}
3374
3375/// Try to replace memory allocation calls called by a single thread with a
3376/// static buffer of shared memory.
3377struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
3379 AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3380
3381 /// Create an abstract attribute view for the position \p IRP.
3382 static AAHeapToShared &createForPosition(const IRPosition &IRP,
3383 Attributor &A);
3384
3385 /// Returns true if HeapToShared conversion is assumed to be possible.
3386 virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
3387
3388 /// Returns true if HeapToShared conversion is assumed and the CB is a
3389 /// callsite to a free operation to be removed.
3390 virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
3391
3392 /// See AbstractAttribute::getName().
3393 const std::string getName() const override { return "AAHeapToShared"; }
3394
3395 /// See AbstractAttribute::getIdAddr().
3396 const char *getIdAddr() const override { return &ID; }
3397
3398 /// This function should return true if the type of the \p AA is
3399 /// AAHeapToShared.
3400 static bool classof(const AbstractAttribute *AA) {
3401 return (AA->getIdAddr() == &ID);
3402 }
3403
3404 /// Unique ID (due to the unique address)
3405 static const char ID;
3406};
3407
3408struct AAHeapToSharedFunction : public AAHeapToShared {
3409 AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
3410 : AAHeapToShared(IRP, A) {}
3411
3412 const std::string getAsStr(Attributor *) const override {
3413 return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
3414 " malloc calls eligible.";
3415 }
3416
3417 /// See AbstractAttribute::trackStatistics().
3418 void trackStatistics() const override {}
3419
3420 /// This functions finds free calls that will be removed by the
3421 /// HeapToShared transformation.
3422 void findPotentialRemovedFreeCalls(Attributor &A) {
3423 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3424 auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3425
3426 PotentialRemovedFreeCalls.clear();
3427 // Update free call users of found malloc calls.
3428 for (CallBase *CB : MallocCalls) {
3430 for (auto *U : CB->users()) {
3431 CallBase *C = dyn_cast<CallBase>(U);
3432 if (C && C->getCalledFunction() == FreeRFI.Declaration)
3433 FreeCalls.push_back(C);
3434 }
3435
3436 if (FreeCalls.size() != 1)
3437 continue;
3438
3439 PotentialRemovedFreeCalls.insert(FreeCalls.front());
3440 }
3441 }
3442
3443 void initialize(Attributor &A) override {
3445 indicatePessimisticFixpoint();
3446 return;
3447 }
3448
3449 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3450 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3451 if (!RFI.Declaration)
3452 return;
3453
3455 [](const IRPosition &, const AbstractAttribute *,
3456 bool &) -> std::optional<Value *> { return nullptr; };
3457
3458 Function *F = getAnchorScope();
3459 for (User *U : RFI.Declaration->users())
3460 if (CallBase *CB = dyn_cast<CallBase>(U)) {
3461 if (CB->getFunction() != F)
3462 continue;
3463 MallocCalls.insert(CB);
3464 A.registerSimplificationCallback(IRPosition::callsite_returned(*CB),
3465 SCB);
3466 }
3467
3468 findPotentialRemovedFreeCalls(A);
3469 }
3470
3471 bool isAssumedHeapToShared(CallBase &CB) const override {
3472 return isValidState() && MallocCalls.count(&CB);
3473 }
3474
3475 bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
3476 return isValidState() && PotentialRemovedFreeCalls.count(&CB);
3477 }
3478
3479 ChangeStatus manifest(Attributor &A) override {
3480 if (MallocCalls.empty())
3481 return ChangeStatus::UNCHANGED;
3482
3483 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3484 auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3485
3486 Function *F = getAnchorScope();
3487 auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
3488 DepClassTy::OPTIONAL);
3489
3490 ChangeStatus Changed = ChangeStatus::UNCHANGED;
3491 for (CallBase *CB : MallocCalls) {
3492 // Skip replacing this if HeapToStack has already claimed it.
3493 if (HS && HS->isAssumedHeapToStack(*CB))
3494 continue;
3495
3496 // Find the unique free call to remove it.
3498 for (auto *U : CB->users()) {
3499 CallBase *C = dyn_cast<CallBase>(U);
3500 if (C && C->getCalledFunction() == FreeCall.Declaration)
3501 FreeCalls.push_back(C);
3502 }
3503 if (FreeCalls.size() != 1)
3504 continue;
3505
3506 auto *AllocSize = cast<ConstantInt>(CB->getArgOperand(0));
3507
3508 if (AllocSize->getZExtValue() + SharedMemoryUsed > SharedMemoryLimit) {
3509 LLVM_DEBUG(dbgs() << TAG << "Cannot replace call " << *CB
3510 << " with shared memory."
3511 << " Shared memory usage is limited to "
3512 << SharedMemoryLimit << " bytes\n");
3513 continue;
3514 }
3515
3516 LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
3517 << " with " << AllocSize->getZExtValue()
3518 << " bytes of shared memory\n");
3519
3520 // Create a new shared memory buffer of the same size as the allocation
3521 // and replace all the uses of the original allocation with it.
3522 Module *M = CB->getModule();
3523 Type *Int8Ty = Type::getInt8Ty(M->getContext());
3524 Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
3525 auto *SharedMem = new GlobalVariable(
3526 *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
3527 PoisonValue::get(Int8ArrTy), CB->getName() + "_shared", nullptr,
3529 static_cast<unsigned>(AddressSpace::Shared));
3530 auto *NewBuffer = ConstantExpr::getPointerCast(
3531 SharedMem, PointerType::getUnqual(M->getContext()));
3532
3533 auto Remark = [&](OptimizationRemark OR) {
3534 return OR << "Replaced globalized variable with "
3535 << ore::NV("SharedMemory", AllocSize->getZExtValue())
3536 << (AllocSize->isOne() ? " byte " : " bytes ")
3537 << "of shared memory.";
3538 };
3539 A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
3540
3541 MaybeAlign Alignment = CB->getRetAlign();
3542 assert(Alignment &&
3543 "HeapToShared on allocation without alignment attribute");
3544 SharedMem->setAlignment(*Alignment);
3545
3546 A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewBuffer);
3547 A.deleteAfterManifest(*CB);
3548 A.deleteAfterManifest(*FreeCalls.front());
3549
3550 SharedMemoryUsed += AllocSize->getZExtValue();
3551 NumBytesMovedToSharedMemory = SharedMemoryUsed;
3552 Changed = ChangeStatus::CHANGED;
3553 }
3554
3555 return Changed;
3556 }
3557
3558 ChangeStatus updateImpl(Attributor &A) override {
3559 if (MallocCalls.empty())
3560 return indicatePessimisticFixpoint();
3561 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3562 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3563 if (!RFI.Declaration)
3564 return ChangeStatus::UNCHANGED;
3565
3566 Function *F = getAnchorScope();
3567
3568 auto NumMallocCalls = MallocCalls.size();
3569
3570 // Only consider malloc calls executed by a single thread with a constant.
3571 for (User *U : RFI.Declaration->users()) {
3572 if (CallBase *CB = dyn_cast<CallBase>(U)) {
3573 if (CB->getCaller() != F)
3574 continue;
3575 if (!MallocCalls.count(CB))
3576 continue;
3577 if (!isa<ConstantInt>(CB->getArgOperand(0))) {
3578 MallocCalls.remove(CB);
3579 continue;
3580 }
3581 const auto *ED = A.getAAFor<AAExecutionDomain>(
3582 *this, IRPosition::function(*F), DepClassTy::REQUIRED);
3583 if (!ED || !ED->isExecutedByInitialThreadOnly(*CB))
3584 MallocCalls.remove(CB);
3585 }
3586 }
3587
3588 findPotentialRemovedFreeCalls(A);
3589
3590 if (NumMallocCalls != MallocCalls.size())
3591 return ChangeStatus::CHANGED;
3592
3593 return ChangeStatus::UNCHANGED;
3594 }
3595
3596 /// Collection of all malloc calls in a function.
3598 /// Collection of potentially removed free calls in a function.
3599 SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
3600 /// The total amount of shared memory that has been used for HeapToShared.
3601 unsigned SharedMemoryUsed = 0;
3602};
3603
3604struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
3606 AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3607
3608 /// The callee value is tracked beyond a simple stripPointerCasts, so we allow
3609 /// unknown callees.
3610 static bool requiresCalleeForCallBase() { return false; }
3611
3612 /// Statistics are tracked as part of manifest for now.
3613 void trackStatistics() const override {}
3614
3615 /// See AbstractAttribute::getAsStr()
3616 const std::string getAsStr(Attributor *) const override {
3617 if (!isValidState())
3618 return "<invalid>";
3619 return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
3620 : "generic") +
3621 std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
3622 : "") +
3623 std::string(" #PRs: ") +
3624 (ReachedKnownParallelRegions.isValidState()
3625 ? std::to_string(ReachedKnownParallelRegions.size())
3626 : "<invalid>") +
3627 ", #Unknown PRs: " +
3628 (ReachedUnknownParallelRegions.isValidState()
3629 ? std::to_string(ReachedUnknownParallelRegions.size())
3630 : "<invalid>") +
3631 ", #Reaching Kernels: " +
3632 (ReachingKernelEntries.isValidState()
3633 ? std::to_string(ReachingKernelEntries.size())
3634 : "<invalid>") +
3635 ", #ParLevels: " +
3636 (ParallelLevels.isValidState()
3637 ? std::to_string(ParallelLevels.size())
3638 : "<invalid>") +
3639 ", NestedPar: " + (NestedParallelism ? "yes" : "no");
3640 }
3641
3642 /// Create an abstract attribute biew for the position \p IRP.
3643 static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
3644
3645 /// See AbstractAttribute::getName()
3646 const std::string getName() const override { return "AAKernelInfo"; }
3647
3648 /// See AbstractAttribute::getIdAddr()
3649 const char *getIdAddr() const override { return &ID; }
3650
3651 /// This function should return true if the type of the \p AA is AAKernelInfo
3652 static bool classof(const AbstractAttribute *AA) {
3653 return (AA->getIdAddr() == &ID);
3654 }
3655
3656 static const char ID;
3657};
3658
3659/// The function kernel info abstract attribute, basically, what can we say
3660/// about a function with regards to the KernelInfoState.
3661struct AAKernelInfoFunction : AAKernelInfo {
3662 AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
3663 : AAKernelInfo(IRP, A) {}
3664
3665 SmallPtrSet<Instruction *, 4> GuardedInstructions;
3666
3667 SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
3668 return GuardedInstructions;
3669 }
3670
3671 void setConfigurationOfKernelEnvironment(ConstantStruct *ConfigC) {
3673 KernelEnvC, ConfigC, {KernelInfo::ConfigurationIdx});
3674 assert(NewKernelEnvC && "Failed to create new kernel environment");
3675 KernelEnvC = cast<ConstantStruct>(NewKernelEnvC);
3676 }
3677
3678#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER) \
3679 void set##MEMBER##OfKernelEnvironment(ConstantInt *NewVal) { \
3680 ConstantStruct *ConfigC = \
3681 KernelInfo::getConfigurationFromKernelEnvironment(KernelEnvC); \
3682 Constant *NewConfigC = ConstantFoldInsertValueInstruction( \
3683 ConfigC, NewVal, {KernelInfo::MEMBER##Idx}); \
3684 assert(NewConfigC && "Failed to create new configuration environment"); \
3685 setConfigurationOfKernelEnvironment(cast<ConstantStruct>(NewConfigC)); \
3686 }
3687
3688 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(UseGenericStateMachine)
3689 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MayUseNestedParallelism)
3695
3696#undef KERNEL_ENVIRONMENT_CONFIGURATION_SETTER
3697
3698 /// See AbstractAttribute::initialize(...).
3699 void initialize(Attributor &A) override {
3700 // This is a high-level transform that might change the constant arguments
3701 // of the init and dinit calls. We need to tell the Attributor about this
3702 // to avoid other parts using the current constant value for simpliication.
3703 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3704
3705 Function *Fn = getAnchorScope();
3706
3707 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
3708 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
3709 OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
3710 OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
3711
3712 // For kernels we perform more initialization work, first we find the init
3713 // and deinit calls.
3714 auto StoreCallBase = [](Use &U,
3715 OMPInformationCache::RuntimeFunctionInfo &RFI,
3716 CallBase *&Storage) {
3717 CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
3718 assert(CB &&
3719 "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
3720 assert(!Storage &&
3721 "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
3722 Storage = CB;
3723 return false;
3724 };
3725 InitRFI.foreachUse(
3726 [&](Use &U, Function &) {
3727 StoreCallBase(U, InitRFI, KernelInitCB);
3728 return false;
3729 },
3730 Fn);
3731 DeinitRFI.foreachUse(
3732 [&](Use &U, Function &) {
3733 StoreCallBase(U, DeinitRFI, KernelDeinitCB);
3734 return false;
3735 },
3736 Fn);
3737
3738 // Ignore kernels without initializers such as global constructors.
3739 if (!KernelInitCB || !KernelDeinitCB)
3740 return;
3741
3742 // Add itself to the reaching kernel and set IsKernelEntry.
3743 ReachingKernelEntries.insert(Fn);
3744 IsKernelEntry = true;
3745
3746 KernelEnvC =
3748 GlobalVariable *KernelEnvGV =
3750
3752 KernelConfigurationSimplifyCB =
3753 [&](const GlobalVariable &GV, const AbstractAttribute *AA,
3754 bool &UsedAssumedInformation) -> std::optional<Constant *> {
3755 if (!isAtFixpoint()) {
3756 if (!AA)
3757 return nullptr;
3758 UsedAssumedInformation = true;
3759 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3760 }
3761 return KernelEnvC;
3762 };
3763
3764 A.registerGlobalVariableSimplificationCallback(
3765 *KernelEnvGV, KernelConfigurationSimplifyCB);
3766
3767 // We cannot change to SPMD mode if the runtime functions aren't availible.
3768 bool CanChangeToSPMD = OMPInfoCache.runtimeFnsAvailable(
3769 {OMPRTL___kmpc_get_hardware_thread_id_in_block,
3770 OMPRTL___kmpc_barrier_simple_spmd});
3771
3772 // Check if we know we are in SPMD-mode already.
3773 ConstantInt *ExecModeC =
3774 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
3775 ConstantInt *AssumedExecModeC = ConstantInt::get(
3776 ExecModeC->getIntegerType(),
3778 if (ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)
3779 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3780 else if (DisableOpenMPOptSPMDization || !CanChangeToSPMD)
3781 // This is a generic region but SPMDization is disabled so stop
3782 // tracking.
3783 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3784 else
3785 setExecModeOfKernelEnvironment(AssumedExecModeC);
3786
3787 const Triple T(Fn->getParent()->getTargetTriple());
3788 auto *Int32Ty = Type::getInt32Ty(Fn->getContext());
3789 auto [MinThreads, MaxThreads] =
3791 if (MinThreads)
3792 setMinThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinThreads));
3793 if (MaxThreads)
3794 setMaxThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxThreads));
3795 auto [MinTeams, MaxTeams] =
3797 if (MinTeams)
3798 setMinTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinTeams));
3799 if (MaxTeams)
3800 setMaxTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxTeams));
3801
3802 ConstantInt *MayUseNestedParallelismC =
3803 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(KernelEnvC);
3804 ConstantInt *AssumedMayUseNestedParallelismC = ConstantInt::get(
3805 MayUseNestedParallelismC->getIntegerType(), NestedParallelism);
3806 setMayUseNestedParallelismOfKernelEnvironment(
3807 AssumedMayUseNestedParallelismC);
3808
3810 ConstantInt *UseGenericStateMachineC =
3811 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3812 KernelEnvC);
3813 ConstantInt *AssumedUseGenericStateMachineC =
3814 ConstantInt::get(UseGenericStateMachineC->getIntegerType(), false);
3815 setUseGenericStateMachineOfKernelEnvironment(
3816 AssumedUseGenericStateMachineC);
3817 }
3818
3819 // Register virtual uses of functions we might need to preserve.
3820 auto RegisterVirtualUse = [&](RuntimeFunction RFKind,
3822 if (!OMPInfoCache.RFIs[RFKind].Declaration)
3823 return;
3824 A.registerVirtualUseCallback(*OMPInfoCache.RFIs[RFKind].Declaration, CB);
3825 };
3826
3827 // Add a dependence to ensure updates if the state changes.
3828 auto AddDependence = [](Attributor &A, const AAKernelInfo *KI,
3829 const AbstractAttribute *QueryingAA) {
3830 if (QueryingAA) {
3831 A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);
3832 }
3833 return true;
3834 };
3835
3836 Attributor::VirtualUseCallbackTy CustomStateMachineUseCB =
3837 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3838 // Whenever we create a custom state machine we will insert calls to
3839 // __kmpc_get_hardware_num_threads_in_block,
3840 // __kmpc_get_warp_size,
3841 // __kmpc_barrier_simple_generic,
3842 // __kmpc_kernel_parallel, and
3843 // __kmpc_kernel_end_parallel.
3844 // Not needed if we are on track for SPMDzation.
3845 if (SPMDCompatibilityTracker.isValidState())
3846 return AddDependence(A, this, QueryingAA);
3847 // Not needed if we can't rewrite due to an invalid state.
3848 if (!ReachedKnownParallelRegions.isValidState())
3849 return AddDependence(A, this, QueryingAA);
3850 return false;
3851 };
3852
3853 // Not needed if we are pre-runtime merge.
3854 if (!KernelInitCB->getCalledFunction()->isDeclaration()) {
3855 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_num_threads_in_block,
3856 CustomStateMachineUseCB);
3857 RegisterVirtualUse(OMPRTL___kmpc_get_warp_size, CustomStateMachineUseCB);
3858 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_generic,
3859 CustomStateMachineUseCB);
3860 RegisterVirtualUse(OMPRTL___kmpc_kernel_parallel,
3861 CustomStateMachineUseCB);
3862 RegisterVirtualUse(OMPRTL___kmpc_kernel_end_parallel,
3863 CustomStateMachineUseCB);
3864 }
3865
3866 // If we do not perform SPMDzation we do not need the virtual uses below.
3867 if (SPMDCompatibilityTracker.isAtFixpoint())
3868 return;
3869
3870 Attributor::VirtualUseCallbackTy HWThreadIdUseCB =
3871 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3872 // Whenever we perform SPMDzation we will insert
3873 // __kmpc_get_hardware_thread_id_in_block calls.
3874 if (!SPMDCompatibilityTracker.isValidState())
3875 return AddDependence(A, this, QueryingAA);
3876 return false;
3877 };
3878 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_thread_id_in_block,
3879 HWThreadIdUseCB);
3880
3881 Attributor::VirtualUseCallbackTy SPMDBarrierUseCB =
3882 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3883 // Whenever we perform SPMDzation with guarding we will insert
3884 // __kmpc_simple_barrier_spmd calls. If SPMDzation failed, there is
3885 // nothing to guard, or there are no parallel regions, we don't need
3886 // the calls.
3887 if (!SPMDCompatibilityTracker.isValidState())
3888 return AddDependence(A, this, QueryingAA);
3889 if (SPMDCompatibilityTracker.empty())
3890 return AddDependence(A, this, QueryingAA);
3891 if (!mayContainParallelRegion())
3892 return AddDependence(A, this, QueryingAA);
3893 return false;
3894 };
3895 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_spmd, SPMDBarrierUseCB);
3896 }
3897
3898 /// Sanitize the string \p S such that it is a suitable global symbol name.
3899 static std::string sanitizeForGlobalName(std::string S) {
3900 std::replace_if(
3901 S.begin(), S.end(),
3902 [](const char C) {
3903 return !((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') ||
3904 (C >= '0' && C <= '9') || C == '_');
3905 },
3906 '.');
3907 return S;
3908 }
3909
3910 /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
3911 /// finished now.
3912 ChangeStatus manifest(Attributor &A) override {
3913 // If we are not looking at a kernel with __kmpc_target_init and
3914 // __kmpc_target_deinit call we cannot actually manifest the information.
3915 if (!KernelInitCB || !KernelDeinitCB)
3916 return ChangeStatus::UNCHANGED;
3917
3918 ChangeStatus Changed = ChangeStatus::UNCHANGED;
3919
3920 bool HasBuiltStateMachine = true;
3921 if (!changeToSPMDMode(A, Changed)) {
3922 if (!KernelInitCB->getCalledFunction()->isDeclaration())
3923 HasBuiltStateMachine = buildCustomStateMachine(A, Changed);
3924 else
3925 HasBuiltStateMachine = false;
3926 }
3927
3928 // We need to reset KernelEnvC if specific rewriting is not done.
3929 ConstantStruct *ExistingKernelEnvC =
3931 ConstantInt *OldUseGenericStateMachineVal =
3932 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3933 ExistingKernelEnvC);
3934 if (!HasBuiltStateMachine)
3935 setUseGenericStateMachineOfKernelEnvironment(
3936 OldUseGenericStateMachineVal);
3937
3938 // At last, update the KernelEnvc
3939 GlobalVariable *KernelEnvGV =
3941 if (KernelEnvGV->getInitializer() != KernelEnvC) {
3942 KernelEnvGV->setInitializer(KernelEnvC);
3943 Changed = ChangeStatus::CHANGED;
3944 }
3945
3946 return Changed;
3947 }
3948
3949 void insertInstructionGuardsHelper(Attributor &A) {
3950 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3951
3952 auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3953 Instruction *RegionEndI) {
3954 LoopInfo *LI = nullptr;
3955 DominatorTree *DT = nullptr;
3956 MemorySSAUpdater *MSU = nullptr;
3957 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3958
3959 BasicBlock *ParentBB = RegionStartI->getParent();
3960 Function *Fn = ParentBB->getParent();
3961 Module &M = *Fn->getParent();
3962
3963 // Create all the blocks and logic.
3964 // ParentBB:
3965 // goto RegionCheckTidBB
3966 // RegionCheckTidBB:
3967 // Tid = __kmpc_hardware_thread_id()
3968 // if (Tid != 0)
3969 // goto RegionBarrierBB
3970 // RegionStartBB:
3971 // <execute instructions guarded>
3972 // goto RegionEndBB
3973 // RegionEndBB:
3974 // <store escaping values to shared mem>
3975 // goto RegionBarrierBB
3976 // RegionBarrierBB:
3977 // __kmpc_simple_barrier_spmd()
3978 // // second barrier is omitted if lacking escaping values.
3979 // <load escaping values from shared mem>
3980 // __kmpc_simple_barrier_spmd()
3981 // goto RegionExitBB
3982 // RegionExitBB:
3983 // <execute rest of instructions>
3984
3985 BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3986 DT, LI, MSU, "region.guarded.end");
3987 BasicBlock *RegionBarrierBB =
3988 SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3989 MSU, "region.barrier");
3990 BasicBlock *RegionExitBB =
3991 SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3992 DT, LI, MSU, "region.exit");
3993 BasicBlock *RegionStartBB =
3994 SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
3995
3996 assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
3997 "Expected a different CFG");
3998
3999 BasicBlock *RegionCheckTidBB = SplitBlock(
4000 ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
4001
4002 // Register basic blocks with the Attributor.
4003 A.registerManifestAddedBasicBlock(*RegionEndBB);
4004 A.registerManifestAddedBasicBlock(*RegionBarrierBB);
4005 A.registerManifestAddedBasicBlock(*RegionExitBB);
4006 A.registerManifestAddedBasicBlock(*RegionStartBB);
4007 A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
4008
4009 bool HasBroadcastValues = false;
4010 // Find escaping outputs from the guarded region to outside users and
4011 // broadcast their values to them.
4012 for (Instruction &I : *RegionStartBB) {
4013 SmallVector<Use *, 4> OutsideUses;
4014 for (Use &U : I.uses()) {
4015 Instruction &UsrI = *cast<Instruction>(U.getUser());
4016 if (UsrI.getParent() != RegionStartBB)
4017 OutsideUses.push_back(&U);
4018 }
4019
4020 if (OutsideUses.empty())
4021 continue;
4022
4023 HasBroadcastValues = true;
4024
4025 // Emit a global variable in shared memory to store the broadcasted
4026 // value.
4027 auto *SharedMem = new GlobalVariable(
4028 M, I.getType(), /* IsConstant */ false,
4030 sanitizeForGlobalName(
4031 (I.getName() + ".guarded.output.alloc").str()),
4033 static_cast<unsigned>(AddressSpace::Shared));
4034
4035 // Emit a store instruction to update the value.
4036 new StoreInst(&I, SharedMem,
4037 RegionEndBB->getTerminator()->getIterator());
4038
4039 LoadInst *LoadI = new LoadInst(
4040 I.getType(), SharedMem, I.getName() + ".guarded.output.load",
4041 RegionBarrierBB->getTerminator()->getIterator());
4042
4043 // Emit a load instruction and replace uses of the output value.
4044 for (Use *U : OutsideUses)
4045 A.changeUseAfterManifest(*U, *LoadI);
4046 }
4047
4048 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4049
4050 // Go to tid check BB in ParentBB.
4051 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
4052 ParentBB->getTerminator()->eraseFromParent();
4054 InsertPointTy(ParentBB, ParentBB->end()), DL);
4055 OMPInfoCache.OMPBuilder.updateToLocation(Loc);
4056 uint32_t SrcLocStrSize;
4057 auto *SrcLocStr =
4058 OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4059 Value *Ident =
4060 OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4061 BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
4062
4063 // Add check for Tid in RegionCheckTidBB
4064 RegionCheckTidBB->getTerminator()->eraseFromParent();
4065 OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
4066 InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
4067 OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
4068 FunctionCallee HardwareTidFn =
4069 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4070 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4071 CallInst *Tid =
4072 OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
4073 Tid->setDebugLoc(DL);
4074 OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);
4075 Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
4076 OMPInfoCache.OMPBuilder.Builder
4077 .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
4078 ->setDebugLoc(DL);
4079
4080 // First barrier for synchronization, ensures main thread has updated
4081 // values.
4082 FunctionCallee BarrierFn =
4083 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4084 M, OMPRTL___kmpc_barrier_simple_spmd);
4085 OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
4086 RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
4087 CallInst *Barrier =
4088 OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});
4089 Barrier->setDebugLoc(DL);
4090 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4091
4092 // Second barrier ensures workers have read broadcast values.
4093 if (HasBroadcastValues) {
4094 CallInst *Barrier =
4095 CallInst::Create(BarrierFn, {Ident, Tid}, "",
4096 RegionBarrierBB->getTerminator()->getIterator());
4097 Barrier->setDebugLoc(DL);
4098 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4099 }
4100 };
4101
4102 auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4104 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4105 BasicBlock *BB = GuardedI->getParent();
4106 if (!Visited.insert(BB).second)
4107 continue;
4108
4110 Instruction *LastEffect = nullptr;
4111 BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();
4112 while (++IP != IPEnd) {
4113 if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
4114 continue;
4115 Instruction *I = &*IP;
4116 if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI))
4117 continue;
4118 if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) {
4119 LastEffect = nullptr;
4120 continue;
4121 }
4122 if (LastEffect)
4123 Reorders.push_back({I, LastEffect});
4124 LastEffect = &*IP;
4125 }
4126 for (auto &Reorder : Reorders)
4127 Reorder.first->moveBefore(Reorder.second->getIterator());
4128 }
4129
4131
4132 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4133 BasicBlock *BB = GuardedI->getParent();
4134 auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
4135 IRPosition::function(*GuardedI->getFunction()), nullptr,
4136 DepClassTy::NONE);
4137 assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
4138 auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
4139 // Continue if instruction is already guarded.
4140 if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
4141 continue;
4142
4143 Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
4144 for (Instruction &I : *BB) {
4145 // If instruction I needs to be guarded update the guarded region
4146 // bounds.
4147 if (SPMDCompatibilityTracker.contains(&I)) {
4148 CalleeAAFunction.getGuardedInstructions().insert(&I);
4149 if (GuardedRegionStart)
4150 GuardedRegionEnd = &I;
4151 else
4152 GuardedRegionStart = GuardedRegionEnd = &I;
4153
4154 continue;
4155 }
4156
4157 // Instruction I does not need guarding, store
4158 // any region found and reset bounds.
4159 if (GuardedRegionStart) {
4160 GuardedRegions.push_back(
4161 std::make_pair(GuardedRegionStart, GuardedRegionEnd));
4162 GuardedRegionStart = nullptr;
4163 GuardedRegionEnd = nullptr;
4164 }
4165 }
4166 }
4167
4168 for (auto &GR : GuardedRegions)
4169 CreateGuardedRegion(GR.first, GR.second);
4170 }
4171
4172 void forceSingleThreadPerWorkgroupHelper(Attributor &A) {
4173 // Only allow 1 thread per workgroup to continue executing the user code.
4174 //
4175 // InitCB = __kmpc_target_init(...)
4176 // ThreadIdInBlock = __kmpc_get_hardware_thread_id_in_block();
4177 // if (ThreadIdInBlock != 0) return;
4178 // UserCode:
4179 // // user code
4180 //
4181 auto &Ctx = getAnchorValue().getContext();
4182 Function *Kernel = getAssociatedFunction();
4183 assert(Kernel && "Expected an associated function!");
4184
4185 // Create block for user code to branch to from initial block.
4186 BasicBlock *InitBB = KernelInitCB->getParent();
4187 BasicBlock *UserCodeBB = InitBB->splitBasicBlock(
4188 KernelInitCB->getNextNode(), "main.thread.user_code");
4189 BasicBlock *ReturnBB =
4190 BasicBlock::Create(Ctx, "exit.threads", Kernel, UserCodeBB);
4191
4192 // Register blocks with attributor:
4193 A.registerManifestAddedBasicBlock(*InitBB);
4194 A.registerManifestAddedBasicBlock(*UserCodeBB);
4195 A.registerManifestAddedBasicBlock(*ReturnBB);
4196
4197 // Debug location:
4198 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4199 ReturnInst::Create(Ctx, ReturnBB)->setDebugLoc(DLoc);
4200 InitBB->getTerminator()->eraseFromParent();
4201
4202 // Prepare call to OMPRTL___kmpc_get_hardware_thread_id_in_block.
4203 Module &M = *Kernel->getParent();
4204 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4205 FunctionCallee ThreadIdInBlockFn =
4206 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4207 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4208
4209 // Get thread ID in block.
4210 CallInst *ThreadIdInBlock =
4211 CallInst::Create(ThreadIdInBlockFn, "thread_id.in.block", InitBB);
4212 OMPInfoCache.setCallingConvention(ThreadIdInBlockFn, ThreadIdInBlock);
4213 ThreadIdInBlock->setDebugLoc(DLoc);
4214
4215 // Eliminate all threads in the block with ID not equal to 0:
4216 Instruction *IsMainThread =
4217 ICmpInst::Create(ICmpInst::ICmp, CmpInst::ICMP_NE, ThreadIdInBlock,
4218 ConstantInt::get(ThreadIdInBlock->getType(), 0),
4219 "thread.is_main", InitBB);
4220 IsMainThread->setDebugLoc(DLoc);
4221 BranchInst::Create(ReturnBB, UserCodeBB, IsMainThread, InitBB);
4222 }
4223
4224 bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {
4225 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4226
4227 if (!SPMDCompatibilityTracker.isAssumed()) {
4228 for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
4229 if (!NonCompatibleI)
4230 continue;
4231
4232 // Skip diagnostics on calls to known OpenMP runtime functions for now.
4233 if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
4234 if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
4235 continue;
4236
4237 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4238 ORA << "Value has potential side effects preventing SPMD-mode "
4239 "execution";
4240 if (isa<CallBase>(NonCompatibleI)) {
4241 ORA << ". Add `[[omp::assume(\"ompx_spmd_amenable\")]]` to "
4242 "the called function to override";
4243 }
4244 return ORA << ".";
4245 };
4246 A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
4247 Remark);
4248
4249 LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
4250 << *NonCompatibleI << "\n");
4251 }
4252
4253 return false;
4254 }
4255
4256 // Get the actual kernel, could be the caller of the anchor scope if we have
4257 // a debug wrapper.
4258 Function *Kernel = getAnchorScope();
4259 if (Kernel->hasLocalLinkage()) {
4260 assert(Kernel->hasOneUse() && "Unexpected use of debug kernel wrapper.");
4261 auto *CB = cast<CallBase>(Kernel->user_back());
4262 Kernel = CB->getCaller();
4263 }
4264 assert(omp::isOpenMPKernel(*Kernel) && "Expected kernel function!");
4265
4266 // Check if the kernel is already in SPMD mode, if so, return success.
4267 ConstantStruct *ExistingKernelEnvC =
4269 auto *ExecModeC =
4270 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4271 const int8_t ExecModeVal = ExecModeC->getSExtValue();
4272 if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC)
4273 return true;
4274
4275 // We will now unconditionally modify the IR, indicate a change.
4276 Changed = ChangeStatus::CHANGED;
4277
4278 // Do not use instruction guards when no parallel is present inside
4279 // the target region.
4280 if (mayContainParallelRegion())
4281 insertInstructionGuardsHelper(A);
4282 else
4283 forceSingleThreadPerWorkgroupHelper(A);
4284
4285 // Adjust the global exec mode flag that tells the runtime what mode this
4286 // kernel is executed in.
4287 assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
4288 "Initially non-SPMD kernel has SPMD exec mode!");
4289 setExecModeOfKernelEnvironment(
4290 ConstantInt::get(ExecModeC->getIntegerType(),
4291 ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
4292
4293 ++NumOpenMPTargetRegionKernelsSPMD;
4294
4295 auto Remark = [&](OptimizationRemark OR) {
4296 return OR << "Transformed generic-mode kernel to SPMD-mode.";
4297 };
4298 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
4299 return true;
4300 };
4301
4302 bool buildCustomStateMachine(Attributor &A, ChangeStatus &Changed) {
4303 // If we have disabled state machine rewrites, don't make a custom one
4305 return false;
4306
4307 // Don't rewrite the state machine if we are not in a valid state.
4308 if (!ReachedKnownParallelRegions.isValidState())
4309 return false;
4310
4311 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4312 if (!OMPInfoCache.runtimeFnsAvailable(
4313 {OMPRTL___kmpc_get_hardware_num_threads_in_block,
4314 OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
4315 OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
4316 return false;
4317
4318 ConstantStruct *ExistingKernelEnvC =
4320
4321 // Check if the current configuration is non-SPMD and generic state machine.
4322 // If we already have SPMD mode or a custom state machine we do not need to
4323 // go any further. If it is anything but a constant something is weird and
4324 // we give up.
4325 ConstantInt *UseStateMachineC =
4326 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4327 ExistingKernelEnvC);
4328 ConstantInt *ModeC =
4329 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4330
4331 // If we are stuck with generic mode, try to create a custom device (=GPU)
4332 // state machine which is specialized for the parallel regions that are
4333 // reachable by the kernel.
4334 if (UseStateMachineC->isZero() ||
4336 return false;
4337
4338 Changed = ChangeStatus::CHANGED;
4339
4340 // If not SPMD mode, indicate we use a custom state machine now.
4341 setUseGenericStateMachineOfKernelEnvironment(
4342 ConstantInt::get(UseStateMachineC->getIntegerType(), false));
4343
4344 // If we don't actually need a state machine we are done here. This can
4345 // happen if there simply are no parallel regions. In the resulting kernel
4346 // all worker threads will simply exit right away, leaving the main thread
4347 // to do the work alone.
4348 if (!mayContainParallelRegion()) {
4349 ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
4350
4351 auto Remark = [&](OptimizationRemark OR) {
4352 return OR << "Removing unused state machine from generic-mode kernel.";
4353 };
4354 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
4355
4356 return true;
4357 }
4358
4359 // Keep track in the statistics of our new shiny custom state machine.
4360 if (ReachedUnknownParallelRegions.empty()) {
4361 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
4362
4363 auto Remark = [&](OptimizationRemark OR) {
4364 return OR << "Rewriting generic-mode kernel with a customized state "
4365 "machine.";
4366 };
4367 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
4368 } else {
4369 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
4370
4372 return OR << "Generic-mode kernel is executed with a customized state "
4373 "machine that requires a fallback.";
4374 };
4375 A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
4376
4377 // Tell the user why we ended up with a fallback.
4378 for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
4379 if (!UnknownParallelRegionCB)
4380 continue;
4381 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4382 return ORA << "Call may contain unknown parallel regions. Use "
4383 << "`[[omp::assume(\"omp_no_parallelism\")]]` to "
4384 "override.";
4385 };
4386 A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
4387 "OMP133", Remark);
4388 }
4389 }
4390
4391 // Create all the blocks:
4392 //
4393 // InitCB = __kmpc_target_init(...)
4394 // BlockHwSize =
4395 // __kmpc_get_hardware_num_threads_in_block();
4396 // WarpSize = __kmpc_get_warp_size();
4397 // BlockSize = BlockHwSize - WarpSize;
4398 // IsWorkerCheckBB: bool IsWorker = InitCB != -1;
4399 // if (IsWorker) {
4400 // if (InitCB >= BlockSize) return;
4401 // SMBeginBB: __kmpc_barrier_simple_generic(...);
4402 // void *WorkFn;
4403 // bool Active = __kmpc_kernel_parallel(&WorkFn);
4404 // if (!WorkFn) return;
4405 // SMIsActiveCheckBB: if (Active) {
4406 // SMIfCascadeCurrentBB: if (WorkFn == <ParFn0>)
4407 // ParFn0(...);
4408 // SMIfCascadeCurrentBB: else if (WorkFn == <ParFn1>)
4409 // ParFn1(...);
4410 // ...
4411 // SMIfCascadeCurrentBB: else
4412 // ((WorkFnTy*)WorkFn)(...);
4413 // SMEndParallelBB: __kmpc_kernel_end_parallel(...);
4414 // }
4415 // SMDoneBB: __kmpc_barrier_simple_generic(...);
4416 // goto SMBeginBB;
4417 // }
4418 // UserCodeEntryBB: // user code
4419 // __kmpc_target_deinit(...)
4420 //
4421 auto &Ctx = getAnchorValue().getContext();
4422 Function *Kernel = getAssociatedFunction();
4423 assert(Kernel && "Expected an associated function!");
4424
4425 BasicBlock *InitBB = KernelInitCB->getParent();
4426 BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
4427 KernelInitCB->getNextNode(), "thread.user_code.check");
4428 BasicBlock *IsWorkerCheckBB =
4429 BasicBlock::Create(Ctx, "is_worker_check", Kernel, UserCodeEntryBB);
4430 BasicBlock *StateMachineBeginBB = BasicBlock::Create(
4431 Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
4432 BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
4433 Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
4434 BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
4435 Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
4436 BasicBlock *StateMachineIfCascadeCurrentBB =
4437 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4438 Kernel, UserCodeEntryBB);
4439 BasicBlock *StateMachineEndParallelBB =
4440 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
4441 Kernel, UserCodeEntryBB);
4442 BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
4443 Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
4444 A.registerManifestAddedBasicBlock(*InitBB);
4445 A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
4446 A.registerManifestAddedBasicBlock(*IsWorkerCheckBB);
4447 A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
4448 A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
4449 A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
4450 A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
4451 A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
4452 A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
4453
4454 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4455 ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
4456 InitBB->getTerminator()->eraseFromParent();
4457
4458 Instruction *IsWorker =
4459 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
4460 ConstantInt::get(KernelInitCB->getType(), -1),
4461 "thread.is_worker", InitBB);
4462 IsWorker->setDebugLoc(DLoc);
4463 BranchInst::Create(IsWorkerCheckBB, UserCodeEntryBB, IsWorker, InitBB);
4464
4465 Module &M = *Kernel->getParent();
4466 FunctionCallee BlockHwSizeFn =
4467 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4468 M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
4469 FunctionCallee WarpSizeFn =
4470 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4471 M, OMPRTL___kmpc_get_warp_size);
4472 CallInst *BlockHwSize =
4473 CallInst::Create(BlockHwSizeFn, "block.hw_size", IsWorkerCheckBB);
4474 OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);
4475 BlockHwSize->setDebugLoc(DLoc);
4476 CallInst *WarpSize =
4477 CallInst::Create(WarpSizeFn, "warp.size", IsWorkerCheckBB);
4478 OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);
4479 WarpSize->setDebugLoc(DLoc);
4480 Instruction *BlockSize = BinaryOperator::CreateSub(
4481 BlockHwSize, WarpSize, "block.size", IsWorkerCheckBB);
4482 BlockSize->setDebugLoc(DLoc);
4483 Instruction *IsMainOrWorker = ICmpInst::Create(
4484 ICmpInst::ICmp, llvm::CmpInst::ICMP_SLT, KernelInitCB, BlockSize,
4485 "thread.is_main_or_worker", IsWorkerCheckBB);
4486 IsMainOrWorker->setDebugLoc(DLoc);
4487 BranchInst::Create(StateMachineBeginBB, StateMachineFinishedBB,
4488 IsMainOrWorker, IsWorkerCheckBB);
4489
4490 // Create local storage for the work function pointer.
4491 const DataLayout &DL = M.getDataLayout();
4492 Type *VoidPtrTy = PointerType::getUnqual(Ctx);
4493 Instruction *WorkFnAI =
4494 new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
4495 "worker.work_fn.addr", Kernel->getEntryBlock().begin());
4496 WorkFnAI->setDebugLoc(DLoc);
4497
4498 OMPInfoCache.OMPBuilder.updateToLocation(
4500 IRBuilder<>::InsertPoint(StateMachineBeginBB,
4501 StateMachineBeginBB->end()),
4502 DLoc));
4503
4504 Value *Ident = KernelInfo::getIdentFromKernelEnvironment(KernelEnvC);
4505 Value *GTid = KernelInitCB;
4506
4507 FunctionCallee BarrierFn =
4508 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4509 M, OMPRTL___kmpc_barrier_simple_generic);
4510 CallInst *Barrier =
4511 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB);
4512 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4513 Barrier->setDebugLoc(DLoc);
4514
4515 if (WorkFnAI->getType()->getPointerAddressSpace() !=
4516 (unsigned int)AddressSpace::Generic) {
4517 WorkFnAI = new AddrSpaceCastInst(
4518 WorkFnAI, PointerType::get(Ctx, (unsigned int)AddressSpace::Generic),
4519 WorkFnAI->getName() + ".generic", StateMachineBeginBB);
4520 WorkFnAI->setDebugLoc(DLoc);
4521 }
4522
4523 FunctionCallee KernelParallelFn =
4524 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4525 M, OMPRTL___kmpc_kernel_parallel);
4526 CallInst *IsActiveWorker = CallInst::Create(
4527 KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
4528 OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);
4529 IsActiveWorker->setDebugLoc(DLoc);
4530 Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
4531 StateMachineBeginBB);
4532 WorkFn->setDebugLoc(DLoc);
4533
4534 FunctionType *ParallelRegionFnTy = FunctionType::get(
4536 false);
4537
4538 Instruction *IsDone =
4539 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
4540 Constant::getNullValue(VoidPtrTy), "worker.is_done",
4541 StateMachineBeginBB);
4542 IsDone->setDebugLoc(DLoc);
4543 BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
4544 IsDone, StateMachineBeginBB)
4545 ->setDebugLoc(DLoc);
4546
4547 BranchInst::Create(StateMachineIfCascadeCurrentBB,
4548 StateMachineDoneBarrierBB, IsActiveWorker,
4549 StateMachineIsActiveCheckBB)
4550 ->setDebugLoc(DLoc);
4551
4552 Value *ZeroArg =
4553 Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
4554
4555 const unsigned int WrapperFunctionArgNo = 6;
4556
4557 // Now that we have most of the CFG skeleton it is time for the if-cascade
4558 // that checks the function pointer we got from the runtime against the
4559 // parallel regions we expect, if there are any.
4560 for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) {
4561 auto *CB = ReachedKnownParallelRegions[I];
4562 auto *ParallelRegion = dyn_cast<Function>(
4563 CB->getArgOperand(WrapperFunctionArgNo)->stripPointerCasts());
4564 BasicBlock *PRExecuteBB = BasicBlock::Create(
4565 Ctx, "worker_state_machine.parallel_region.execute", Kernel,
4566 StateMachineEndParallelBB);
4567 CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
4568 ->setDebugLoc(DLoc);
4569 BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
4570 ->setDebugLoc(DLoc);
4571
4572 BasicBlock *PRNextBB =
4573 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4574 Kernel, StateMachineEndParallelBB);
4575 A.registerManifestAddedBasicBlock(*PRExecuteBB);
4576 A.registerManifestAddedBasicBlock(*PRNextBB);
4577
4578 // Check if we need to compare the pointer at all or if we can just
4579 // call the parallel region function.
4580 Value *IsPR;
4581 if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) {
4582 Instruction *CmpI = ICmpInst::Create(
4583 ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn, ParallelRegion,
4584 "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
4585 CmpI->setDebugLoc(DLoc);
4586 IsPR = CmpI;
4587 } else {
4588 IsPR = ConstantInt::getTrue(Ctx);
4589 }
4590
4591 BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
4592 StateMachineIfCascadeCurrentBB)
4593 ->setDebugLoc(DLoc);
4594 StateMachineIfCascadeCurrentBB = PRNextBB;
4595 }
4596
4597 // At the end of the if-cascade we place the indirect function pointer call
4598 // in case we might need it, that is if there can be parallel regions we
4599 // have not handled in the if-cascade above.
4600 if (!ReachedUnknownParallelRegions.empty()) {
4601 StateMachineIfCascadeCurrentBB->setName(
4602 "worker_state_machine.parallel_region.fallback.execute");
4603 CallInst::Create(ParallelRegionFnTy, WorkFn, {ZeroArg, GTid}, "",
4604 StateMachineIfCascadeCurrentBB)
4605 ->setDebugLoc(DLoc);
4606 }
4607 BranchInst::Create(StateMachineEndParallelBB,
4608 StateMachineIfCascadeCurrentBB)
4609 ->setDebugLoc(DLoc);
4610
4611 FunctionCallee EndParallelFn =
4612 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4613 M, OMPRTL___kmpc_kernel_end_parallel);
4614 CallInst *EndParallel =
4615 CallInst::Create(EndParallelFn, {}, "", StateMachineEndParallelBB);
4616 OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);
4617 EndParallel->setDebugLoc(DLoc);
4618 BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
4619 ->setDebugLoc(DLoc);
4620
4621 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
4622 ->setDebugLoc(DLoc);
4623 BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
4624 ->setDebugLoc(DLoc);
4625
4626 return true;
4627 }
4628
4629 /// Fixpoint iteration update function. Will be called every time a dependence
4630 /// changed its state (and in the beginning).
4631 ChangeStatus updateImpl(Attributor &A) override {
4632 KernelInfoState StateBefore = getState();
4633
4634 // When we leave this function this RAII will make sure the member
4635 // KernelEnvC is updated properly depending on the state. That member is
4636 // used for simplification of values and needs to be up to date at all
4637 // times.
4638 struct UpdateKernelEnvCRAII {
4639 AAKernelInfoFunction &AA;
4640
4641 UpdateKernelEnvCRAII(AAKernelInfoFunction &AA) : AA(AA) {}
4642
4643 ~UpdateKernelEnvCRAII() {
4644 if (!AA.KernelEnvC)
4645 return;
4646
4647 ConstantStruct *ExistingKernelEnvC =
4649
4650 if (!AA.isValidState()) {
4651 AA.KernelEnvC = ExistingKernelEnvC;
4652 return;
4653 }
4654
4655 if (!AA.ReachedKnownParallelRegions.isValidState())
4656 AA.setUseGenericStateMachineOfKernelEnvironment(
4657 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4658 ExistingKernelEnvC));
4659
4660 if (!AA.SPMDCompatibilityTracker.isValidState())
4661 AA.setExecModeOfKernelEnvironment(
4662 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC));
4663
4664 ConstantInt *MayUseNestedParallelismC =
4665 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(
4666 AA.KernelEnvC);
4667 ConstantInt *NewMayUseNestedParallelismC = ConstantInt::get(
4668 MayUseNestedParallelismC->getIntegerType(), AA.NestedParallelism);
4669 AA.setMayUseNestedParallelismOfKernelEnvironment(
4670 NewMayUseNestedParallelismC);
4671 }
4672 } RAII(*this);
4673
4674 // Callback to check a read/write instruction.
4675 auto CheckRWInst = [&](Instruction &I) {
4676 // We handle calls later.
4677 if (isa<CallBase>(I))
4678 return true;
4679 // We only care about write effects.
4680 if (!I.mayWriteToMemory())
4681 return true;
4682 if (auto *SI = dyn_cast<StoreInst>(&I)) {
4683 const auto *UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>(
4684 *this, IRPosition::value(*SI->getPointerOperand()),
4685 DepClassTy::OPTIONAL);
4686 auto *HS = A.getAAFor<AAHeapToStack>(
4687 *this, IRPosition::function(*I.getFunction()),
4688 DepClassTy::OPTIONAL);
4689 if (UnderlyingObjsAA &&
4690 UnderlyingObjsAA->forallUnderlyingObjects([&](Value &Obj) {
4691 if (AA::isAssumedThreadLocalObject(A, Obj, *this))
4692 return true;
4693 // Check for AAHeapToStack moved objects which must not be
4694 // guarded.
4695 auto *CB = dyn_cast<CallBase>(&Obj);
4696 return CB && HS && HS->isAssumedHeapToStack(*CB);
4697 }))
4698 return true;
4699 }
4700
4701 // Insert instruction that needs guarding.
4702 SPMDCompatibilityTracker.insert(&I);
4703 return true;
4704 };
4705
4706 bool UsedAssumedInformationInCheckRWInst = false;
4707 if (!SPMDCompatibilityTracker.isAtFixpoint())
4708 if (!A.checkForAllReadWriteInstructions(
4709 CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
4710 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4711
4712 bool UsedAssumedInformationFromReachingKernels = false;
4713 if (!IsKernelEntry) {
4714 updateParallelLevels(A);
4715
4716 bool AllReachingKernelsKnown = true;
4717 updateReachingKernelEntries(A, AllReachingKernelsKnown);
4718 UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
4719
4720 if (!SPMDCompatibilityTracker.empty()) {
4721 if (!ParallelLevels.isValidState())
4722 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4723 else if (!ReachingKernelEntries.isValidState())
4724 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4725 else {
4726 // Check if all reaching kernels agree on the mode as we can otherwise
4727 // not guard instructions. We might not be sure about the mode so we
4728 // we cannot fix the internal spmd-zation state either.
4729 int SPMD = 0, Generic = 0;
4730 for (auto *Kernel : ReachingKernelEntries) {
4731 auto *CBAA = A.getAAFor<AAKernelInfo>(
4732 *this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL);
4733 if (CBAA && CBAA->SPMDCompatibilityTracker.isValidState() &&
4734 CBAA->SPMDCompatibilityTracker.isAssumed())
4735 ++SPMD;
4736 else
4737 ++Generic;
4738 if (!CBAA || !CBAA->SPMDCompatibilityTracker.isAtFixpoint())
4739 UsedAssumedInformationFromReachingKernels = true;
4740 }
4741 if (SPMD != 0 && Generic != 0)
4742 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4743 }
4744 }
4745 }
4746
4747 // Callback to check a call instruction.
4748 bool AllParallelRegionStatesWereFixed = true;
4749 bool AllSPMDStatesWereFixed = true;
4750 auto CheckCallInst = [&](Instruction &I) {
4751 auto &CB = cast<CallBase>(I);
4752 auto *CBAA = A.getAAFor<AAKernelInfo>(
4753 *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4754 if (!CBAA)
4755 return false;
4756 getState() ^= CBAA->getState();
4757 AllSPMDStatesWereFixed &= CBAA->SPMDCompatibilityTracker.isAtFixpoint();
4758 AllParallelRegionStatesWereFixed &=
4759 CBAA->ReachedKnownParallelRegions.isAtFixpoint();
4760 AllParallelRegionStatesWereFixed &=
4761 CBAA->ReachedUnknownParallelRegions.isAtFixpoint();
4762 return true;
4763 };
4764
4765 bool UsedAssumedInformationInCheckCallInst = false;
4766 if (!A.checkForAllCallLikeInstructions(
4767 CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) {
4768 LLVM_DEBUG(dbgs() << TAG
4769 << "Failed to visit all call-like instructions!\n";);
4770 return indicatePessimisticFixpoint();
4771 }
4772
4773 // If we haven't used any assumed information for the reached parallel
4774 // region states we can fix it.
4775 if (!UsedAssumedInformationInCheckCallInst &&
4776 AllParallelRegionStatesWereFixed) {
4777 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
4778 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
4779 }
4780
4781 // If we haven't used any assumed information for the SPMD state we can fix
4782 // it.
4783 if (!UsedAssumedInformationInCheckRWInst &&
4784 !UsedAssumedInformationInCheckCallInst &&
4785 !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
4786 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4787
4788 return StateBefore == getState() ? ChangeStatus::UNCHANGED
4789 : ChangeStatus::CHANGED;
4790 }
4791
4792private:
4793 /// Update info regarding reaching kernels.
4794 void updateReachingKernelEntries(Attributor &A,
4795 bool &AllReachingKernelsKnown) {
4796 auto PredCallSite = [&](AbstractCallSite ACS) {
4797 Function *Caller = ACS.getInstruction()->getFunction();
4798
4799 assert(Caller && "Caller is nullptr");
4800
4801 auto *CAA = A.getOrCreateAAFor<AAKernelInfo>(
4802 IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
4803 if (CAA && CAA->ReachingKernelEntries.isValidState()) {
4804 ReachingKernelEntries ^= CAA->ReachingKernelEntries;
4805 return true;
4806 }
4807
4808 // We lost track of the caller of the associated function, any kernel
4809 // could reach now.
4810 ReachingKernelEntries.indicatePessimisticFixpoint();
4811
4812 return true;
4813 };
4814
4815 if (!A.checkForAllCallSites(PredCallSite, *this,
4816 true /* RequireAllCallSites */,
4817 AllReachingKernelsKnown))
4818 ReachingKernelEntries.indicatePessimisticFixpoint();
4819 }
4820
4821 /// Update info regarding parallel levels.
4822 void updateParallelLevels(Attributor &A) {
4823 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4824 OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
4825 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
4826
4827 auto PredCallSite = [&](AbstractCallSite ACS) {
4828 Function *Caller = ACS.getInstruction()->getFunction();
4829
4830 assert(Caller && "Caller is nullptr");
4831
4832 auto *CAA =
4833 A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
4834 if (CAA && CAA->ParallelLevels.isValidState()) {
4835 // Any function that is called by `__kmpc_parallel_51` will not be
4836 // folded as the parallel level in the function is updated. In order to
4837 // get it right, all the analysis would depend on the implentation. That
4838 // said, if in the future any change to the implementation, the analysis
4839 // could be wrong. As a consequence, we are just conservative here.
4840 if (Caller == Parallel51RFI.Declaration) {
4841 ParallelLevels.indicatePessimisticFixpoint();
4842 return true;
4843 }
4844
4845 ParallelLevels ^= CAA->ParallelLevels;
4846
4847 return true;
4848 }
4849
4850 // We lost track of the caller of the associated function, any kernel
4851 // could reach now.
4852 ParallelLevels.indicatePessimisticFixpoint();
4853
4854 return true;
4855 };
4856
4857 bool AllCallSitesKnown = true;
4858 if (!A.checkForAllCallSites(PredCallSite, *this,
4859 true /* RequireAllCallSites */,
4860 AllCallSitesKnown))
4861 ParallelLevels.indicatePessimisticFixpoint();
4862 }
4863};
4864
4865/// The call site kernel info abstract attribute, basically, what can we say
4866/// about a call site with regards to the KernelInfoState. For now this simply
4867/// forwards the information from the callee.
4868struct AAKernelInfoCallSite : AAKernelInfo {
4869 AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
4870 : AAKernelInfo(IRP, A) {}
4871
4872 /// See AbstractAttribute::initialize(...).
4873 void initialize(Attributor &A) override {
4874 AAKernelInfo::initialize(A);
4875
4876 CallBase &CB = cast<CallBase>(getAssociatedValue());
4877 auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>(
4878 *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4879
4880 // Check for SPMD-mode assumptions.
4881 if (AssumptionAA && AssumptionAA->hasAssumption("ompx_spmd_amenable")) {
4882 indicateOptimisticFixpoint();
4883 return;
4884 }
4885
4886 // First weed out calls we do not care about, that is readonly/readnone
4887 // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
4888 // parallel region or anything else we are looking for.
4889 if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
4890 indicateOptimisticFixpoint();
4891 return;
4892 }
4893
4894 // Next we check if we know the callee. If it is a known OpenMP function
4895 // we will handle them explicitly in the switch below. If it is not, we
4896 // will use an AAKernelInfo object on the callee to gather information and
4897 // merge that into the current state. The latter happens in the updateImpl.
4898 auto CheckCallee = [&](Function *Callee, unsigned NumCallees) {
4899 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4900 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4901 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4902 // Unknown caller or declarations are not analyzable, we give up.
4903 if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
4904
4905 // Unknown callees might contain parallel regions, except if they have
4906 // an appropriate assumption attached.
4907 if (!AssumptionAA ||
4908 !(AssumptionAA->hasAssumption("omp_no_openmp") ||
4909 AssumptionAA->hasAssumption("omp_no_parallelism")))
4910 ReachedUnknownParallelRegions.insert(&CB);
4911
4912 // If SPMDCompatibilityTracker is not fixed, we need to give up on the
4913 // idea we can run something unknown in SPMD-mode.
4914 if (!SPMDCompatibilityTracker.isAtFixpoint()) {
4915 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4916 SPMDCompatibilityTracker.insert(&CB);
4917 }
4918
4919 // We have updated the state for this unknown call properly, there
4920 // won't be any change so we indicate a fixpoint.
4921 indicateOptimisticFixpoint();
4922 }
4923 // If the callee is known and can be used in IPO, we will update the
4924 // state based on the callee state in updateImpl.
4925 return;
4926 }
4927 if (NumCallees > 1) {
4928 indicatePessimisticFixpoint();
4929 return;
4930 }
4931
4932 RuntimeFunction RF = It->getSecond();
4933 switch (RF) {
4934 // All the functions we know are compatible with SPMD mode.
4935 case OMPRTL___kmpc_is_spmd_exec_mode:
4936 case OMPRTL___kmpc_distribute_static_fini:
4937 case OMPRTL___kmpc_for_static_fini:
4938 case OMPRTL___kmpc_global_thread_num:
4939 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4940 case OMPRTL___kmpc_get_hardware_num_blocks:
4941 case OMPRTL___kmpc_single:
4942 case OMPRTL___kmpc_end_single:
4943 case OMPRTL___kmpc_master:
4944 case OMPRTL___kmpc_end_master:
4945 case OMPRTL___kmpc_barrier:
4946 case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
4947 case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
4948 case OMPRTL___kmpc_error:
4949 case OMPRTL___kmpc_flush:
4950 case OMPRTL___kmpc_get_hardware_thread_id_in_block:
4951 case OMPRTL___kmpc_get_warp_size:
4952 case OMPRTL_omp_get_thread_num:
4953 case OMPRTL_omp_get_num_threads:
4954 case OMPRTL_omp_get_max_threads:
4955 case OMPRTL_omp_in_parallel:
4956 case OMPRTL_omp_get_dynamic:
4957 case OMPRTL_omp_get_cancellation:
4958 case OMPRTL_omp_get_nested:
4959 case OMPRTL_omp_get_schedule:
4960 case OMPRTL_omp_get_thread_limit:
4961 case OMPRTL_omp_get_supported_active_levels:
4962 case OMPRTL_omp_get_max_active_levels:
4963 case OMPRTL_omp_get_level:
4964 case OMPRTL_omp_get_ancestor_thread_num:
4965 case OMPRTL_omp_get_team_size:
4966 case OMPRTL_omp_get_active_level:
4967 case OMPRTL_omp_in_final:
4968 case OMPRTL_omp_get_proc_bind:
4969 case OMPRTL_omp_get_num_places:
4970 case OMPRTL_omp_get_num_procs:
4971 case OMPRTL_omp_get_place_proc_ids:
4972 case OMPRTL_omp_get_place_num:
4973 case OMPRTL_omp_get_partition_num_places:
4974 case OMPRTL_omp_get_partition_place_nums:
4975 case OMPRTL_omp_get_wtime:
4976 break;
4977 case OMPRTL___kmpc_distribute_static_init_4:
4978 case OMPRTL___kmpc_distribute_static_init_4u:
4979 case OMPRTL___kmpc_distribute_static_init_8:
4980 case OMPRTL___kmpc_distribute_static_init_8u:
4981 case OMPRTL___kmpc_for_static_init_4:
4982 case OMPRTL___kmpc_for_static_init_4u:
4983 case OMPRTL___kmpc_for_static_init_8:
4984 case OMPRTL___kmpc_for_static_init_8u: {
4985 // Check the schedule and allow static schedule in SPMD mode.
4986 unsigned ScheduleArgOpNo = 2;
4987 auto *ScheduleTypeCI =
4988 dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
4989 unsigned ScheduleTypeVal =
4990 ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
4991 switch (OMPScheduleType(ScheduleTypeVal)) {
4992 case OMPScheduleType::UnorderedStatic:
4993 case OMPScheduleType::UnorderedStaticChunked:
4994 case OMPScheduleType::OrderedDistribute:
4995 case OMPScheduleType::OrderedDistributeChunked:
4996 break;
4997 default:
4998 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4999 SPMDCompatibilityTracker.insert(&CB);
5000 break;
5001 };
5002 } break;
5003 case OMPRTL___kmpc_target_init:
5004 KernelInitCB = &CB;
5005 break;
5006 case OMPRTL___kmpc_target_deinit:
5007 KernelDeinitCB = &CB;
5008 break;
5009 case OMPRTL___kmpc_parallel_51:
5010 if (!handleParallel51(A, CB))
5011 indicatePessimisticFixpoint();
5012 return;
5013 case OMPRTL___kmpc_omp_task:
5014 // We do not look into tasks right now, just give up.
5015 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5016 SPMDCompatibilityTracker.insert(&CB);
5017 ReachedUnknownParallelRegions.insert(&CB);
5018 break;
5019 case OMPRTL___kmpc_alloc_shared:
5020 case OMPRTL___kmpc_free_shared:
5021 // Return without setting a fixpoint, to be resolved in updateImpl.
5022 return;
5023 default:
5024 // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
5025 // generally. However, they do not hide parallel regions.
5026 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5027 SPMDCompatibilityTracker.insert(&CB);
5028 break;
5029 }
5030 // All other OpenMP runtime calls will not reach parallel regions so they
5031 // can be safely ignored for now. Since it is a known OpenMP runtime call
5032 // we have now modeled all effects and there is no need for any update.
5033 indicateOptimisticFixpoint();
5034 };
5035
5036 const auto *AACE =
5037 A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5038 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5039 CheckCallee(getAssociatedFunction(), 1);
5040 return;
5041 }
5042 const auto &OptimisticEdges = AACE->getOptimisticEdges();
5043 for (auto *Callee : OptimisticEdges) {
5044 CheckCallee(Callee, OptimisticEdges.size());
5045 if (isAtFixpoint())
5046 break;
5047 }
5048 }
5049
5050 ChangeStatus updateImpl(Attributor &A) override {
5051 // TODO: Once we have call site specific value information we can provide
5052 // call site specific liveness information and then it makes
5053 // sense to specialize attributes for call sites arguments instead of
5054 // redirecting requests to the callee argument.
5055 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5056 KernelInfoState StateBefore = getState();
5057
5058 auto CheckCallee = [&](Function *F, int NumCallees) {
5059 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
5060
5061 // If F is not a runtime function, propagate the AAKernelInfo of the
5062 // callee.
5063 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
5064 const IRPosition &FnPos = IRPosition::function(*F);
5065 auto *FnAA =
5066 A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
5067 if (!FnAA)
5068 return indicatePessimisticFixpoint();
5069 if (getState() == FnAA->getState())
5070 return ChangeStatus::UNCHANGED;
5071 getState() = FnAA->getState();
5072 return ChangeStatus::CHANGED;
5073 }
5074 if (NumCallees > 1)
5075 return indicatePessimisticFixpoint();
5076
5077 CallBase &CB = cast<CallBase>(getAssociatedValue());
5078 if (It->getSecond() == OMPRTL___kmpc_parallel_51) {
5079 if (!handleParallel51(A, CB))
5080 return indicatePessimisticFixpoint();
5081 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5082 : ChangeStatus::CHANGED;
5083 }
5084
5085 // F is a runtime function that allocates or frees memory, check
5086 // AAHeapToStack and AAHeapToShared.
5087 assert(
5088 (It->getSecond() == OMPRTL___kmpc_alloc_shared ||
5089 It->getSecond() == OMPRTL___kmpc_free_shared) &&
5090 "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
5091
5092 auto *HeapToStackAA = A.getAAFor<AAHeapToStack>(
5093 *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5094 auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>(
5095 *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5096
5097 RuntimeFunction RF = It->getSecond();
5098
5099 switch (RF) {
5100 // If neither HeapToStack nor HeapToShared assume the call is removed,
5101 // assume SPMD incompatibility.
5102 case OMPRTL___kmpc_alloc_shared:
5103 if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&
5104 (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))
5105 SPMDCompatibilityTracker.insert(&CB);
5106 break;
5107 case OMPRTL___kmpc_free_shared:
5108 if ((!HeapToStackAA ||
5109 !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&
5110 (!HeapToSharedAA ||
5111 !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))
5112 SPMDCompatibilityTracker.insert(&CB);
5113 break;
5114 default:
5115 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5116 SPMDCompatibilityTracker.insert(&CB);
5117 }
5118 return ChangeStatus::CHANGED;
5119 };
5120
5121 const auto *AACE =
5122 A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5123 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5124 if (Function *F = getAssociatedFunction())
5125 CheckCallee(F, /*NumCallees=*/1);
5126 } else {
5127 const auto &OptimisticEdges = AACE->getOptimisticEdges();
5128 for (auto *Callee : OptimisticEdges) {
5129 CheckCallee(Callee, OptimisticEdges.size());
5130 if (isAtFixpoint())
5131 break;
5132 }
5133 }
5134
5135 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5136 : ChangeStatus::CHANGED;
5137 }
5138
5139 /// Deal with a __kmpc_parallel_51 call (\p CB). Returns true if the call was
5140 /// handled, if a problem occurred, false is returned.
5141 bool handleParallel51(Attributor &A, CallBase &CB) {
5142 const unsigned int NonWrapperFunctionArgNo = 5;
5143 const unsigned int WrapperFunctionArgNo = 6;
5144 auto ParallelRegionOpArgNo = SPMDCompatibilityTracker.isAssumed()
5145 ? NonWrapperFunctionArgNo
5146 : WrapperFunctionArgNo;
5147
5148 auto *ParallelRegion = dyn_cast<Function>(
5149 CB.getArgOperand(ParallelRegionOpArgNo)->stripPointerCasts());
5150 if (!ParallelRegion)
5151 return false;
5152
5153 ReachedKnownParallelRegions.insert(&CB);
5154 /// Check nested parallelism
5155 auto *FnAA = A.getAAFor<AAKernelInfo>(
5156 *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL);
5157 NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
5158 !FnAA->ReachedKnownParallelRegions.empty() ||
5159 !FnAA->ReachedKnownParallelRegions.isValidState() ||
5160 !FnAA->ReachedUnknownParallelRegions.isValidState() ||
5161 !FnAA->ReachedUnknownParallelRegions.empty();
5162 return true;
5163 }
5164};
5165
5166struct AAFoldRuntimeCall
5167 : public StateWrapper<BooleanState, AbstractAttribute> {
5169
5170 AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
5171
5172 /// Statistics are tracked as part of manifest for now.
5173 void trackStatistics() const override {}
5174
5175 /// Create an abstract attribute biew for the position \p IRP.
5176 static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
5177 Attributor &A);
5178
5179 /// See AbstractAttribute::getName()
5180 const std::string getName() const override { return "AAFoldRuntimeCall"; }
5181
5182 /// See AbstractAttribute::getIdAddr()
5183 const char *getIdAddr() const override { return &ID; }
5184
5185 /// This function should return true if the type of the \p AA is
5186 /// AAFoldRuntimeCall
5187 static bool classof(const AbstractAttribute *AA) {
5188 return (AA->getIdAddr() == &ID);
5189 }
5190
5191 static const char ID;
5192};
5193
5194struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
5195 AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
5196 : AAFoldRuntimeCall(IRP, A) {}
5197
5198 /// See AbstractAttribute::getAsStr()
5199 const std::string getAsStr(Attributor *) const override {
5200 if (!isValidState())
5201 return "<invalid>";
5202
5203 std::string Str("simplified value: ");
5204
5205 if (!SimplifiedValue)
5206 return Str + std::string("none");
5207
5208 if (!*SimplifiedValue)
5209 return Str + std::string("nullptr");
5210
5211 if (ConstantInt *CI = dyn_cast<ConstantInt>(*SimplifiedValue))
5212 return Str + std::to_string(CI->getSExtValue());
5213
5214 return Str + std::string("unknown");
5215 }
5216
5217 void initialize(Attributor &A) override {
5219 indicatePessimisticFixpoint();
5220
5221 Function *Callee = getAssociatedFunction();
5222
5223 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5224 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
5225 assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
5226 "Expected a known OpenMP runtime function");
5227
5228 RFKind = It->getSecond();
5229
5230 CallBase &CB = cast<CallBase>(getAssociatedValue());
5231 A.registerSimplificationCallback(
5233 [&](const IRPosition &IRP, const AbstractAttribute *AA,
5234 bool &UsedAssumedInformation) -> std::optional<Value *> {
5235 assert((isValidState() ||
5236 (SimplifiedValue && *SimplifiedValue == nullptr)) &&
5237 "Unexpected invalid state!");
5238
5239 if (!isAtFixpoint()) {
5240 UsedAssumedInformation = true;
5241 if (AA)
5242 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
5243 }
5244 return SimplifiedValue;
5245 });
5246 }
5247
5248 ChangeStatus updateImpl(Attributor &A) override {
5249 ChangeStatus Changed = ChangeStatus::UNCHANGED;
5250 switch (RFKind) {
5251 case OMPRTL___kmpc_is_spmd_exec_mode:
5252 Changed |= foldIsSPMDExecMode(A);
5253 break;
5254 case OMPRTL___kmpc_parallel_level:
5255 Changed |= foldParallelLevel(A);
5256 break;
5257 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
5258 Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
5259 break;
5260 case OMPRTL___kmpc_get_hardware_num_blocks:
5261 Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
5262 break;
5263 default:
5264 llvm_unreachable("Unhandled OpenMP runtime function!");
5265 }
5266
5267 return Changed;
5268 }
5269
5270 ChangeStatus manifest(Attributor &A) override {
5271 ChangeStatus Changed = ChangeStatus::UNCHANGED;
5272
5273 if (SimplifiedValue && *SimplifiedValue) {
5274 Instruction &I = *getCtxI();
5275 A.changeAfterManifest(IRPosition::inst(I), **SimplifiedValue);
5276 A.deleteAfterManifest(I);
5277
5278 CallBase *CB = dyn_cast<CallBase>(&I);
5279 auto Remark = [&](OptimizationRemark OR) {
5280 if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue))
5281 return OR << "Replacing OpenMP runtime call "
5282 << CB->getCalledFunction()->getName() << " with "
5283 << ore::NV("FoldedValue", C->getZExtValue()) << ".";
5284 return OR << "Replacing OpenMP runtime call "
5285 << CB->getCalledFunction()->getName() << ".";
5286 };
5287
5288 if (CB && EnableVerboseRemarks)
5289 A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark);
5290
5291 LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "
5292 << **SimplifiedValue << "\n");
5293
5294 Changed = ChangeStatus::CHANGED;
5295 }
5296
5297 return Changed;
5298 }
5299
5300 ChangeStatus indicatePessimisticFixpoint() override {
5301 SimplifiedValue = nullptr;
5302 return AAFoldRuntimeCall::indicatePessimisticFixpoint();
5303 }
5304
5305private:
5306 /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
5307 ChangeStatus foldIsSPMDExecMode(Attributor &A) {
5308 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5309
5310 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5311 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5312 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5313 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5314
5315 if (!CallerKernelInfoAA ||
5316 !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5317 return indicatePessimisticFixpoint();
5318
5319 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5320 auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5321 DepClassTy::REQUIRED);
5322
5323 if (!AA || !AA->isValidState()) {
5324 SimplifiedValue = nullptr;
5325 return indicatePessimisticFixpoint();
5326 }
5327
5328 if (AA->SPMDCompatibilityTracker.isAssumed()) {
5329 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5330 ++KnownSPMDCount;
5331 else
5332 ++AssumedSPMDCount;
5333 } else {
5334 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5335 ++KnownNonSPMDCount;
5336 else
5337 ++AssumedNonSPMDCount;
5338 }
5339 }
5340
5341 if ((AssumedSPMDCount + KnownSPMDCount) &&
5342 (AssumedNonSPMDCount + KnownNonSPMDCount))
5343 return indicatePessimisticFixpoint();
5344
5345 auto &Ctx = getAnchorValue().getContext();
5346 if (KnownSPMDCount || AssumedSPMDCount) {
5347 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5348 "Expected only SPMD kernels!");
5349 // All reaching kernels are in SPMD mode. Update all function calls to
5350 // __kmpc_is_spmd_exec_mode to 1.
5351 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
5352 } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
5353 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5354 "Expected only non-SPMD kernels!");
5355 // All reaching kernels are in non-SPMD mode. Update all function
5356 // calls to __kmpc_is_spmd_exec_mode to 0.
5357 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
5358 } else {
5359 // We have empty reaching kernels, therefore we cannot tell if the
5360 // associated call site can be folded. At this moment, SimplifiedValue
5361 // must be none.
5362 assert(!SimplifiedValue && "SimplifiedValue should be none");
5363 }
5364
5365 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5366 : ChangeStatus::CHANGED;
5367 }
5368
5369 /// Fold __kmpc_parallel_level into a constant if possible.
5370 ChangeStatus foldParallelLevel(Attributor &A) {
5371 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5372
5373 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5374 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5375
5376 if (!CallerKernelInfoAA ||
5377 !CallerKernelInfoAA->ParallelLevels.isValidState())
5378 return indicatePessimisticFixpoint();
5379
5380 if (!CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5381 return indicatePessimisticFixpoint();
5382
5383 if (CallerKernelInfoAA->ReachingKernelEntries.empty()) {
5384 assert(!SimplifiedValue &&
5385 "SimplifiedValue should keep none at this point");
5386 return ChangeStatus::UNCHANGED;
5387 }
5388
5389 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5390 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5391 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5392 auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5393 DepClassTy::REQUIRED);
5394 if (!AA || !AA->SPMDCompatibilityTracker.isValidState())
5395 return indicatePessimisticFixpoint();
5396
5397 if (AA->SPMDCompatibilityTracker.isAssumed()) {
5398 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5399 ++KnownSPMDCount;
5400 else
5401 ++AssumedSPMDCount;
5402 } else {
5403 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5404 ++KnownNonSPMDCount;
5405 else
5406 ++AssumedNonSPMDCount;
5407 }
5408 }
5409
5410 if ((AssumedSPMDCount + KnownSPMDCount) &&
5411 (AssumedNonSPMDCount + KnownNonSPMDCount))
5412 return indicatePessimisticFixpoint();
5413
5414 auto &Ctx = getAnchorValue().getContext();
5415 // If the caller can only be reached by SPMD kernel entries, the parallel
5416 // level is 1. Similarly, if the caller can only be reached by non-SPMD
5417 // kernel entries, it is 0.
5418 if (AssumedSPMDCount || KnownSPMDCount) {
5419 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5420 "Expected only SPMD kernels!");
5421 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
5422 } else {
5423 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5424 "Expected only non-SPMD kernels!");
5425 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
5426 }
5427 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5428 : ChangeStatus::CHANGED;
5429 }
5430
5431 ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
5432 // Specialize only if all the calls agree with the attribute constant value
5433 int32_t CurrentAttrValue = -1;
5434 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5435
5436 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5437 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5438
5439 if (!CallerKernelInfoAA ||
5440 !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5441 return indicatePessimisticFixpoint();
5442
5443 // Iterate over the kernels that reach this function
5444 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5445 int32_t NextAttrVal = K->getFnAttributeAsParsedInteger(Attr, -1);
5446
5447 if (NextAttrVal == -1 ||
5448 (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
5449 return indicatePessimisticFixpoint();
5450 CurrentAttrValue = NextAttrVal;
5451 }
5452
5453 if (CurrentAttrValue != -1) {
5454 auto &Ctx = getAnchorValue().getContext();
5455 SimplifiedValue =
5456 ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
5457 }
5458 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5459 : ChangeStatus::CHANGED;
5460 }
5461
5462 /// An optional value the associated value is assumed to fold to. That is, we
5463 /// assume the associated value (which is a call) can be replaced by this
5464 /// simplified value.
5465 std::optional<Value *> SimplifiedValue;
5466
5467 /// The runtime function kind of the callee of the associated call site.
5468 RuntimeFunction RFKind;
5469};
5470
5471} // namespace
5472
5473/// Register folding callsite
5474void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
5475 auto &RFI = OMPInfoCache.RFIs[RF];
5476 RFI.foreachUse(SCC, [&](Use &U, Function &F) {
5477 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
5478 if (!CI)
5479 return false;
5480 A.getOrCreateAAFor<AAFoldRuntimeCall>(
5481 IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
5482 DepClassTy::NONE, /* ForceUpdate */ false,
5483 /* UpdateAfterInit */ false);
5484 return false;
5485 });
5486}
5487
5488void OpenMPOpt::registerAAs(bool IsModulePass) {
5489 if (SCC.empty())
5490 return;
5491
5492 if (IsModulePass) {
5493 // Ensure we create the AAKernelInfo AAs first and without triggering an
5494 // update. This will make sure we register all value simplification
5495 // callbacks before any other AA has the chance to create an AAValueSimplify
5496 // or similar.
5497 auto CreateKernelInfoCB = [&](Use &, Function &Kernel) {
5498 A.getOrCreateAAFor<AAKernelInfo>(
5499 IRPosition::function(Kernel), /* QueryingAA */ nullptr,
5500 DepClassTy::NONE, /* ForceUpdate */ false,
5501 /* UpdateAfterInit */ false);
5502 return false;
5503 };
5504 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
5505 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
5506 InitRFI.foreachUse(SCC, CreateKernelInfoCB);
5507
5508 registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
5509 registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
5510 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
5511 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
5512 }
5513
5514 // Create CallSite AA for all Getters.
5515 if (DeduceICVValues) {
5516 for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
5517 auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
5518
5519 auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
5520
5521 auto CreateAA = [&](Use &U, Function &Caller) {
5522 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
5523 if (!CI)
5524 return false;
5525
5526 auto &CB = cast<CallBase>(*CI);
5527
5529 A.getOrCreateAAFor<AAICVTracker>(CBPos);
5530 return false;
5531 };
5532
5533 GetterRFI.foreachUse(SCC, CreateAA);
5534 }
5535 }
5536
5537 // Create an ExecutionDomain AA for every function and a HeapToStack AA for
5538 // every function if there is a device kernel.
5539 if (!isOpenMPDevice(M))
5540 return;
5541
5542 for (auto *F : SCC) {
5543 if (F->isDeclaration())
5544 continue;
5545
5546 // We look at internal functions only on-demand but if any use is not a
5547 // direct call or outside the current set of analyzed functions, we have
5548 // to do it eagerly.
5549 if (F->hasLocalLinkage()) {
5550 if (llvm::all_of(F->uses(), [this](const Use &U) {
5551 const auto *CB = dyn_cast<CallBase>(U.getUser());
5552 return CB && CB->isCallee(&U) &&
5553 A.isRunOn(const_cast<Function *>(CB->getCaller()));
5554 }))
5555 continue;
5556 }
5557 registerAAsForFunction(A, *F);
5558 }
5559}
5560
5561void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) {
5563 A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
5564 A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(F));
5566 A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(F));
5567 if (F.hasFnAttribute(Attribute::Convergent))
5568 A.getOrCreateAAFor<AANonConvergent>(IRPosition::function(F));
5569
5570 for (auto &I : instructions(F)) {
5571 if (auto *LI = dyn_cast<LoadInst>(&I)) {
5572 bool UsedAssumedInformation = false;
5573 A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
5574 UsedAssumedInformation, AA::Interprocedural);
5575 A.getOrCreateAAFor<AAAddressSpace>(
5576 IRPosition::value(*LI->getPointerOperand()));
5577 continue;
5578 }
5579 if (auto *CI = dyn_cast<CallBase>(&I)) {
5580 if (CI->isIndirectCall())
5581 A.getOrCreateAAFor<AAIndirectCallInfo>(
5583 }
5584 if (auto *SI = dyn_cast<StoreInst>(&I)) {
5585 A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI));
5586 A.getOrCreateAAFor<AAAddressSpace>(
5587 IRPosition::value(*SI->getPointerOperand()));
5588 continue;
5589 }
5590 if (auto *FI = dyn_cast<FenceInst>(&I)) {
5591 A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*FI));
5592 continue;
5593 }
5594 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
5595 if (II->getIntrinsicID() == Intrinsic::assume) {
5596 A.getOrCreateAAFor<AAPotentialValues>(
5597 IRPosition::value(*II->getArgOperand(0)));
5598 continue;
5599 }
5600 }
5601 }
5602}
5603
5604const char AAICVTracker::ID = 0;
5605const char AAKernelInfo::ID = 0;
5606const char AAExecutionDomain::ID = 0;
5607const char AAHeapToShared::ID = 0;
5608const char AAFoldRuntimeCall::ID = 0;
5609
5610AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
5611 Attributor &A) {
5612 AAICVTracker *AA = nullptr;
5613 switch (IRP.getPositionKind()) {
5618 llvm_unreachable("ICVTracker can only be created for function position!");
5620 AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
5621 break;
5623 AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
5624 break;
5626 AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
5627 break;
5629 AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
5630 break;
5631 }
5632
5633 return *AA;
5634}
5635
5637 Attributor &A) {
5638 AAExecutionDomainFunction *AA = nullptr;
5639 switch (IRP.getPositionKind()) {
5648 "AAExecutionDomain can only be created for function position!");
5650 AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
5651 break;
5652 }
5653
5654 return *AA;
5655}
5656
5657AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
5658 Attributor &A) {
5659 AAHeapToSharedFunction *AA = nullptr;
5660 switch (IRP.getPositionKind()) {
5669 "AAHeapToShared can only be created for function position!");
5671 AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
5672 break;
5673 }
5674
5675 return *AA;
5676}
5677
5678AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
5679 Attributor &A) {
5680 AAKernelInfo *AA = nullptr;
5681 switch (IRP.getPositionKind()) {
5688 llvm_unreachable("KernelInfo can only be created for function position!");
5690 AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
5691 break;
5693 AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
5694 break;
5695 }
5696
5697 return *AA;
5698}
5699
5700AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
5701 Attributor &A) {
5702 AAFoldRuntimeCall *AA = nullptr;
5703 switch (IRP.getPositionKind()) {
5711 llvm_unreachable("KernelInfo can only be created for call site position!");
5713 AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
5714 break;
5715 }
5716
5717 return *AA;
5718}
5719
5721 if (!containsOpenMP(M))
5722 return PreservedAnalyses::all();
5724 return PreservedAnalyses::all();
5725
5728 KernelSet Kernels = getDeviceKernels(M);
5729
5731 LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt Module Pass:\n" << M);
5732
5733 auto IsCalled = [&](Function &F) {
5734 if (Kernels.contains(&F))
5735 return true;
5736 for (const User *U : F.users())
5737 if (!isa<BlockAddress>(U))
5738 return true;
5739 return false;
5740 };
5741
5742 auto EmitRemark = [&](Function &F) {
5744 ORE.emit([&]() {
5745 OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
5746 return ORA << "Could not internalize function. "
5747 << "Some optimizations may not be possible. [OMP140]";
5748 });
5749 };
5750
5751 bool Changed = false;
5752
5753 // Create internal copies of each function if this is a kernel Module. This
5754 // allows iterprocedural passes to see every call edge.
5755 DenseMap<Function *, Function *> InternalizedMap;
5756 if (isOpenMPDevice(M)) {
5757 SmallPtrSet<Function *, 16> InternalizeFns;
5758 for (Function &F : M)
5759 if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
5762 InternalizeFns.insert(&F);
5763 } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
5764 EmitRemark(F);
5765 }
5766 }
5767
5768 Changed |=
5769 Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
5770 }
5771
5772 // Look at every function in the Module unless it was internalized.
5773 SetVector<Function *> Functions;
5775 for (Function &F : M)
5776 if (!F.isDeclaration() && !InternalizedMap.lookup(&F)) {
5777 SCC.push_back(&F);
5778 Functions.insert(&F);
5779 }
5780
5781 if (SCC.empty())
5782 return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
5783
5784 AnalysisGetter AG(FAM);
5785
5786 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5788 };
5789
5791 CallGraphUpdater CGUpdater;
5792
5793 bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5795 OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, PostLink);
5796
5797 unsigned MaxFixpointIterations =
5799
5800 AttributorConfig AC(CGUpdater);
5802 AC.IsModulePass = true;
5803 AC.RewriteSignatures = false;
5804 AC.MaxFixpointIterations = MaxFixpointIterations;
5805 AC.OREGetter = OREGetter;
5806 AC.PassName = DEBUG_TYPE;
5807 AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5808 AC.IPOAmendableCB = [](const Function &F) {
5809 return F.hasFnAttribute("kernel");
5810 };
5811
5812 Attributor A(Functions, InfoCache, AC);
5813
5814 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5815 Changed |= OMPOpt.run(true);
5816
5817 // Optionally inline device functions for potentially better performance.
5819 for (Function &F : M)
5820 if (!F.isDeclaration() && !Kernels.contains(&F) &&
5821 !F.hasFnAttribute(Attribute::NoInline))
5822 F.addFnAttr(Attribute::AlwaysInline);
5823
5825 LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M);
5826
5827 if (Changed)
5828 return PreservedAnalyses::none();
5829
5830 return PreservedAnalyses::all();
5831}
5832
5835 LazyCallGraph &CG,
5836 CGSCCUpdateResult &UR) {
5837 if (!containsOpenMP(*C.begin()->getFunction().getParent()))
5838 return PreservedAnalyses::all();
5840 return PreservedAnalyses::all();
5841
5843 // If there are kernels in the module, we have to run on all SCC's.
5844 for (LazyCallGraph::Node &N : C) {
5845 Function *Fn = &N.getFunction();
5846 SCC.push_back(Fn);
5847 }
5848
5849 if (SCC.empty())
5850 return PreservedAnalyses::all();
5851
5852 Module &M = *C.begin()->getFunction().getParent();
5853
5855 LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt CGSCC Pass:\n" << M);
5856
5857 KernelSet Kernels = getDeviceKernels(M);
5858
5860 AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
5861
5862 AnalysisGetter AG(FAM);
5863
5864 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5866 };
5867
5869 CallGraphUpdater CGUpdater;
5870 CGUpdater.initialize(CG, C, AM, UR);
5871
5872 bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5874 SetVector<Function *> Functions(SCC.begin(), SCC.end());
5875 OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
5876 /*CGSCC*/ &Functions, PostLink);
5877
5878 unsigned MaxFixpointIterations =
5880
5881 AttributorConfig AC(CGUpdater);
5883 AC.IsModulePass = false;
5884 AC.RewriteSignatures = false;
5885 AC.MaxFixpointIterations = MaxFixpointIterations;
5886 AC.OREGetter = OREGetter;
5887 AC.PassName = DEBUG_TYPE;
5888 AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5889
5890 Attributor A(Functions, InfoCache, AC);
5891
5892 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5893 bool Changed = OMPOpt.run(false);
5894
5896 LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
5897
5898 if (Changed)
5899 return PreservedAnalyses::none();
5900
5901 return PreservedAnalyses::all();
5902}
5903
5905 return Fn.hasFnAttribute("kernel");
5906}
5907
5908static bool isKernelCC(Function &F) {
5909 switch (F.getCallingConv()) {
5910 default:
5911 return false;
5915 return true;
5916 }
5917}
5918
5920 // TODO: Create a more cross-platform way of determining device kernels.
5921 KernelSet Kernels;
5922
5923 DenseSet<const Function *> SeenKernels;
5924 auto ProcessKernel = [&](Function &KF) {
5925 if (SeenKernels.insert(&KF).second) {
5926 // We are only interested in OpenMP target regions. Others, such as
5927 // kernels generated by CUDA but linked together, are not interesting to
5928 // this pass.
5929 if (isOpenMPKernel(KF)) {
5930 ++NumOpenMPTargetRegionKernels;
5931 Kernels.insert(&KF);
5932 } else
5933 ++NumNonOpenMPTargetRegionKernels;
5934 }
5935 };
5936
5937 if (NamedMDNode *MD = M.getNamedMetadata("nvvm.annotations"))
5938 for (auto *Op : MD->operands()) {
5939 if (Op->getNumOperands() < 2)
5940 continue;
5941 MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
5942 if (!KindID || KindID->getString() != "kernel")
5943 continue;
5944
5945 if (auto *KernelFn =
5946 mdconst::dyn_extract_or_null<Function>(Op->getOperand(0)))
5947 ProcessKernel(*KernelFn);
5948 }
5949
5950 for (Function &F : M)
5951 if (isKernelCC(F))
5952 ProcessKernel(F);
5953
5954 return Kernels;
5955}
5956
5958 Metadata *MD = M.getModuleFlag("openmp");
5959 if (!MD)
5960 return false;
5961
5962 return true;
5963}
5964
5966 Metadata *MD = M.getModuleFlag("openmp-device");
5967 if (!MD)
5968 return false;
5969
5970 return true;
5971}
@ Generic
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
static cl::opt< unsigned > SetFixpointIterations("attributor-max-iterations", cl::Hidden, cl::desc("Maximal number of fixpoint iterations."), cl::init(32))
static const Function * getParent(const Value *V)
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
This file provides interfaces used to manipulate a call graph, regardless if it is a "old style" Call...
This file provides interfaces used to build and manipulate a call graph, which is a very useful tool ...
This file contains the declarations for the subclasses of Constant, which represent the different fla...
dxil pretty DXIL Metadata Pretty Printer
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(...)
Definition: Debug.h:106
This file defines the DenseSet and SmallDenseSet classes.
bool End
Definition: ELF_riscv.cpp:480
This file defines an array type that can be indexed using scoped enum values.
#define DEBUG_TYPE
static void emitRemark(const Function &F, OptimizationRemarkEmitter &ORE, bool Skip)
static bool lookup(const GsymReader &GR, DataExtractor &Data, uint64_t &Offset, uint64_t BaseAddr, uint64_t Addr, SourceLocations &SrcLocs, llvm::Error &Err)
A Lookup helper functions.
Definition: InlineInfo.cpp:108
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
This file provides utility analysis objects describing memory locations.
uint64_t IntrinsicInst * II
This file defines constans and helpers used when dealing with OpenMP.
This file defines constans that will be used by both host and device compilation.
static constexpr auto TAG
Definition: OpenMPOpt.cpp:186
static cl::opt< bool > HideMemoryTransferLatency("openmp-hide-memory-transfer-latency", cl::desc("[WIP] Tries to hide the latency of host to device memory" " transfers"), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptStateMachineRewrite("openmp-opt-disable-state-machine-rewrite", cl::desc("Disable OpenMP optimizations that replace the state machine."), cl::Hidden, cl::init(false))
static cl::opt< bool > EnableParallelRegionMerging("openmp-opt-enable-merging", cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintModuleAfterOptimizations("openmp-opt-print-module-after", cl::desc("Print the current module after OpenMP optimizations."), cl::Hidden, cl::init(false))
#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER)
Definition: OpenMPOpt.cpp:242
#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX)
Definition: OpenMPOpt.cpp:219
#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER)
Definition: OpenMPOpt.cpp:3678
static cl::opt< bool > PrintOpenMPKernels("openmp-print-gpu-kernels", cl::init(false), cl::Hidden)
static cl::opt< bool > DisableOpenMPOptFolding("openmp-opt-disable-folding", cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden, cl::init(false))
static bool isKernelCC(Function &F)
Definition: OpenMPOpt.cpp:5908
static cl::opt< bool > PrintModuleBeforeOptimizations("openmp-opt-print-module-before", cl::desc("Print the current module before OpenMP optimizations."), cl::Hidden, cl::init(false))
static cl::opt< unsigned > SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden, cl::desc("Maximal number of attributor iterations."), cl::init(256))
static cl::opt< bool > DisableInternalization("openmp-opt-disable-internalization", cl::desc("Disable function internalization."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintICVValues("openmp-print-icv-values", cl::init(false), cl::Hidden)
static cl::opt< bool > DisableOpenMPOptimizations("openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."), cl::Hidden, cl::init(false))
static cl::opt< unsigned > SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden, cl::desc("Maximum amount of shared memory to use."), cl::init(std::numeric_limits< unsigned >::max()))
static cl::opt< bool > EnableVerboseRemarks("openmp-opt-verbose-remarks", cl::desc("Enables more verbose remarks."), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptDeglobalization("openmp-opt-disable-deglobalization", cl::desc("Disable OpenMP optimizations involving deglobalization."), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptBarrierElimination("openmp-opt-disable-barrier-elimination", cl::desc("Disable OpenMP optimizations that eliminate barriers."), cl::Hidden, cl::init(false))
#define DEBUG_TYPE
Definition: OpenMPOpt.cpp:68
static cl::opt< bool > DeduceICVValues("openmp-deduce-icv-values", cl::init(false), cl::Hidden)
#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX)
Definition: OpenMPOpt.cpp:211
#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE)
Definition: OpenMPOpt.cpp:232
static cl::opt< bool > DisableOpenMPOptSPMDization("openmp-opt-disable-spmdization", cl::desc("Disable OpenMP optimizations involving SPMD-ization."), cl::Hidden, cl::init(false))
static cl::opt< bool > AlwaysInlineDeviceFunctions("openmp-opt-inline-device", cl::desc("Inline all applicable functions on the device."), cl::Hidden, cl::init(false))
#define P(N)
FunctionAnalysisManager FAM
if(PassOpts->AAPipeline)
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
static StringRef getName(Value *V)
R600 Clause Merge
Basic Register Allocator
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static bool contains(SmallPtrSetImpl< ConstantExpr * > &Cache, ConstantExpr *Expr, Constant *C)
Definition: Value.cpp:469
This file implements a set that has insertion order iteration characteristics.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:166
This file contains some functions that are useful when dealing with strings.
static const int BlockSize
Definition: TarWriter.cpp:33
static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, ArrayRef< StringLiteral > StandardNames)
Initialize the set of available library functions based on the specified target triple.
Value * RHS
static cl::opt< unsigned > MaxThreads("xcore-max-threads", cl::Optional, cl::desc("Maximum number of threads (for emulation thread-local storage)"), cl::Hidden, cl::value_desc("number"), cl::init(8))
AbstractCallSite.
This class represents a conversion between pointers from one address space to another.
an instruction to allocate memory on the stack
Definition: Instructions.h:63
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:410
This class represents an incoming formal argument to a Function.
Definition: Argument.h:31
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
size_t size() const
size - Get the array size.
Definition: ArrayRef.h:168
AttributeSet getParamAttrs(unsigned ArgNo) const
The attributes for the argument or parameter at the given index are returned.
LLVM Basic Block Representation.
Definition: BasicBlock.h:61
iterator end()
Definition: BasicBlock.h:464
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:451
const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
Definition: BasicBlock.cpp:426
reverse_iterator rbegin()
Definition: BasicBlock.h:467
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition: BasicBlock.h:213
BasicBlock * splitBasicBlock(iterator I, const Twine &BBName="", bool Before=false)
Split the basic block into two basic blocks at the specified instruction.
Definition: BasicBlock.cpp:589
const BasicBlock * getUniqueSuccessor() const
Return the successor of this block if it has a unique successor.
Definition: BasicBlock.cpp:509
InstListType::reverse_iterator reverse_iterator
Definition: BasicBlock.h:179
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:220
reverse_iterator rend()
Definition: BasicBlock.h:469
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:240
Conditional or Unconditional Branch instruction.
bool isConditional() const
static BranchInst * Create(BasicBlock *IfTrue, InsertPosition InsertBefore=nullptr)
BasicBlock * getSuccessor(unsigned i) const
Value * getCondition() const
Allocate memory in an ever growing pool, as if by bump-pointer.
Definition: Allocator.h:66
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
Definition: InstrTypes.h:1112
void setCallingConv(CallingConv::ID CC)
Definition: InstrTypes.h:1403
bool arg_empty() const
Definition: InstrTypes.h:1283
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1341
bool doesNotAccessMemory(unsigned OpNo) const
Definition: InstrTypes.h:1715
bool isIndirectCall() const
Return true if the callsite is an indirect call.
bool isCallee(Value::const_user_iterator UI) const
Determine whether the passed iterator points to the callee operand's Use.
Definition: InstrTypes.h:1352
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1286
void setArgOperand(unsigned i, Value *v)
Definition: InstrTypes.h:1291
iterator_range< User::op_iterator > args()
Iteration adapter for range-for loops.
Definition: InstrTypes.h:1277
unsigned getArgOperandNo(const Use *U) const
Given a use for a arg operand, get the arg operand number that corresponds to it.
Definition: InstrTypes.h:1317
unsigned arg_size() const
Definition: InstrTypes.h:1284
AttributeList getAttributes() const
Return the attributes for this call.
Definition: InstrTypes.h:1417
void addParamAttr(unsigned ArgNo, Attribute::AttrKind Kind)
Adds the attribute to the indicated argument.
Definition: InstrTypes.h:1494
bool isArgOperand(const Use *U) const
Definition: InstrTypes.h:1306
bool hasOperandBundles() const
Return true if this User has any operand bundles.
Definition: InstrTypes.h:1971
Function * getCaller()
Helper to get the caller (the parent function).
Wrapper to unify "old style" CallGraph and "new style" LazyCallGraph.
void initialize(LazyCallGraph &LCG, LazyCallGraph::SCC &SCC, CGSCCAnalysisManager &AM, CGSCCUpdateResult &UR)
Initializers for usage outside of a CGSCC pass, inside a CGSCC pass in the old and new pass manager (...
This class represents a function call, abstracting a target machine's calling convention.
static CallInst * Create(FunctionType *Ty, Value *F, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
@ ICMP_SLT
signed less than
Definition: InstrTypes.h:702
@ ICMP_EQ
equal
Definition: InstrTypes.h:694
@ ICMP_NE
not equal
Definition: InstrTypes.h:695
static Constant * getPointerCast(Constant *C, Type *Ty)
Create a BitCast, AddrSpaceCast, or a PtrToInt cast constant expression.
Definition: Constants.cpp:2253
static Constant * getPointerBitCastOrAddrSpaceCast(Constant *C, Type *Ty)
Create a BitCast or AddrSpaceCast for a pointer type depending on the address space.
Definition: Constants.cpp:2268
This is the shared class of boolean and integer constants.
Definition: Constants.h:83
IntegerType * getIntegerType() const
Variant of the getType() method to always return an IntegerType, which reduces the amount of casting ...
Definition: Constants.h:187
static ConstantInt * getTrue(LLVMContext &Context)
Definition: Constants.cpp:866
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition: Constants.h:208
int64_t getSExtValue() const
Return the constant as a 64-bit integer value after it has been sign extended as appropriate for the ...
Definition: Constants.h:163
This is an important base class in LLVM.
Definition: Constant.h:42
static Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
Definition: Constants.cpp:373
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:63
A debug info location.
Definition: DebugLoc.h:33
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition: DenseMap.h:194
iterator begin()
Definition: DenseMap.h:75
iterator end()
Definition: DenseMap.h:84
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:211
Implements a dense probed hash-table based set.
Definition: DenseSet.h:278
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:279
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:162
Instruction * findNearestCommonDominator(Instruction *I1, Instruction *I2) const
Find the nearest instruction I that dominates both I1 and I2, in the sense that a result produced bef...
Definition: Dominators.cpp:344
static ErrorSuccess success()
Create a success value.
Definition: Error.h:337
An instruction for ordering other memory operations.
Definition: Instructions.h:424
AtomicOrdering getOrdering() const
Returns the ordering constraint of this fence instruction.
Definition: Instructions.h:449
A proxy from a FunctionAnalysisManager to an SCC.
A handy container for a FunctionType+Callee-pointer pair, which can be passed around as a single enti...
Definition: DerivedTypes.h:170
const BasicBlock & getEntryBlock() const
Definition: Function.h:809
const BasicBlock & front() const
Definition: Function.h:860
LLVMContext & getContext() const
getContext - Return a reference to the LLVMContext associated with this function.
Definition: Function.cpp:369
Argument * getArg(unsigned i) const
Definition: Function.h:886
bool hasFnAttribute(Attribute::AttrKind Kind) const
Return true if the function has the attribute.
Definition: Function.cpp:731
bool isDeclaration() const
Return true if the primary definition of this global value is outside of the current translation unit...
Definition: Globals.cpp:296
bool hasLocalLinkage() const
Definition: GlobalValue.h:529
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:657
@ PrivateLinkage
Like Internal, but omit from symbol table.
Definition: GlobalValue.h:60
@ InternalLinkage
Rename collisions when linking (static functions).
Definition: GlobalValue.h:59
const Constant * getInitializer() const
getInitializer - Return the initializer for this global variable.
void setInitializer(Constant *InitVal)
setInitializer - Sets the initializer for this global variable, removing any existing initializer if ...
Definition: Globals.cpp:492
InsertPoint - A saved insertion point.
Definition: IRBuilder.h:276
BasicBlock * getBlock() const
Definition: IRBuilder.h:291
AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")
Definition: IRBuilder.h:1781
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition: IRBuilder.h:199
Value * CreateAddrSpaceCast(Value *V, Type *DestTy, const Twine &Name="")
Definition: IRBuilder.h:2157
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2705
An analysis over an "outer" IR unit that provides access to an analysis manager over an "inner" IR un...
Definition: PassManager.h:567
bool isLifetimeStartOrEnd() const LLVM_READONLY
Return true if the instruction is a llvm.lifetime.start or llvm.lifetime.end marker.
bool mayWriteToMemory() const LLVM_READONLY
Return true if this instruction may modify memory.
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
Definition: Instruction.h:492
const Module * getModule() const
Return the module owning the function this instruction belongs to or nullptr it the function does not...
Definition: Instruction.cpp:68
const Instruction * getPrevNonDebugInstruction(bool SkipPseudoOp=false) const
Return a pointer to the previous non-debug instruction in the same basic block as 'this',...
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:94
const Function * getFunction() const
Return the function this instruction belongs to.
Definition: Instruction.cpp:72
bool mayHaveSideEffects() const LLVM_READONLY
Return true if the instruction may have side effects.
bool mayReadFromMemory() const LLVM_READONLY
Return true if this instruction may read memory.
const Instruction * getNextNonDebugInstruction(bool SkipPseudoOp=false) const
Return a pointer to the next non-debug instruction in the same basic block as 'this',...
void setDebugLoc(DebugLoc Loc)
Set the debug location information for this instruction.
Definition: Instruction.h:489
void setSuccessor(unsigned Idx, BasicBlock *BB)
Update the specified successor to point at the provided block.
void moveBefore(Instruction *MovePos)
Unlink this instruction from its current basic block and insert it into the basic block that MovePos ...
A node in the call graph.
An SCC of the call graph.
A lazily constructed view of the call graph of a module.
An instruction for reading from memory.
Definition: Instructions.h:176
A single uniqued string.
Definition: Metadata.h:724
StringRef getString() const
Definition: Metadata.cpp:616
void eraseFromParent()
This method unlinks 'this' from the containing function and deletes it.
Root of the metadata hierarchy.
Definition: Metadata.h:62
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
const std::string & getTargetTriple() const
Get the target triple which is a string describing the target host.
Definition: Module.h:298
MutableArrayRef - Represent a mutable reference to an array (0 or more elements consecutively in memo...
Definition: ArrayRef.h:310
A tuple of MDNodes.
Definition: Metadata.h:1737
An interface to create LLVM-IR for OpenMP directives.
Definition: OMPIRBuilder.h:474
static std::pair< int32_t, int32_t > readThreadBoundsForKernel(const Triple &T, Function &Kernel)
}
void addAttributes(omp::RuntimeFunction FnID, Function &Fn)
Add attributes known for FnID to Fn.
IRBuilder<>::InsertPoint InsertPointTy
Type used throughout for insertion points.
Definition: OMPIRBuilder.h:520
static std::pair< int32_t, int32_t > readTeamBoundsForKernel(const Triple &T, Function &Kernel)
Read/write a bounds on teams for Kernel.
PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR)
Definition: OpenMPOpt.cpp:5833
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM)
Definition: OpenMPOpt.cpp:5720
Diagnostic information for optimization analysis remarks.
The optimization diagnostic interface.
Diagnostic information for missed-optimization remarks.
Diagnostic information for applied optimization remarks.
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1878
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Definition: Analysis.h:114
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:117
static ReturnInst * Create(LLVMContext &C, Value *retVal=nullptr, InsertPosition InsertBefore=nullptr)
A vector that has set insertion semantics.
Definition: SetVector.h:57
size_type size() const
Determine the number of elements in the SetVector.
Definition: SetVector.h:98
const value_type & back() const
Return the last element of the SetVector.
Definition: SetVector.h:149
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
Definition: SetVector.h:264
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:162
size_type size() const
Definition: SmallPtrSet.h:94
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:363
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
Definition: SmallPtrSet.h:452
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:384
iterator begin() const
Definition: SmallPtrSet.h:472
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:519
A SetVector that performs no allocations if smaller than a certain size.
Definition: SetVector.h:370
bool empty() const
Definition: SmallVector.h:81
size_t size() const
Definition: SmallVector.h:78
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:573
void assign(size_type NumElts, ValueParamT Elt)
Definition: SmallVector.h:704
reference emplace_back(ArgTypes &&... Args)
Definition: SmallVector.h:937
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:683
void push_back(const T &Elt)
Definition: SmallVector.h:413
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1196
An instruction for storing to memory.
Definition: Instructions.h:292
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:51
bool starts_with(StringRef Prefix) const
Check if this string starts with the given Prefix.
Definition: StringRef.h:265
Triple - Helper class for working with autoconf configuration names.
Definition: Triple.h:44
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
unsigned getPointerAddressSpace() const
Get the address space of this pointer or pointer vector type.
static Type * getVoidTy(LLVMContext &C)
static IntegerType * getInt16Ty(LLVMContext &C)
static IntegerType * getInt8Ty(LLVMContext &C)
static IntegerType * getInt32Ty(LLVMContext &C)
static UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
Definition: Constants.cpp:1859
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Definition: User.cpp:21
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
void setName(const Twine &Name)
Change the name of the value.
Definition: Value.cpp:377
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:534
iterator_range< user_iterator > users()
Definition: Value.h:421
User * user_back()
Definition: Value.h:407
const Value * stripPointerCasts() const
Strip off pointer casts, all-zero GEPs and address space casts.
Definition: Value.cpp:694
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
std::pair< iterator, bool > insert(const ValueT &V)
Definition: DenseSet.h:213
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
Definition: ilist_node.h:32
self_iterator getIterator()
Definition: ilist_node.h:132
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition: ilist_node.h:353
A raw_ostream that writes to an std::string.
Definition: raw_ostream.h:661
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
GlobalVariable * getKernelEnvironementGVFromKernelInitCB(CallBase *KernelInitCB)
Definition: OpenMPOpt.cpp:261
ConstantStruct * getKernelEnvironementFromKernelInitCB(CallBase *KernelInitCB)
Definition: OpenMPOpt.cpp:268
bool isValidAtPosition(const ValueAndContext &VAC, InformationCache &InfoCache)
Return true if the value of VAC is a valid at the position of VAC, that is a constant,...
Definition: Attributor.cpp:290
bool isPotentiallyAffectedByBarrier(Attributor &A, const Instruction &I, const AbstractAttribute &QueryingAA)
Return true if I is potentially affected by a barrier.
Definition: Attributor.cpp:889
@ Interprocedural
Definition: Attributor.h:182
bool isNoSyncInst(Attributor &A, const Instruction &I, const AbstractAttribute &QueryingAA)
Return true if I is a nosync instruction.
Definition: Attributor.cpp:205
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
E & operator^=(E &LHS, E RHS)
Definition: BitmaskEnum.h:195
@ Entry
Definition: COFF.h:844
@ AMDGPU_KERNEL
Used for AMDGPU code object kernels.
Definition: CallingConv.h:200
@ SPIR_KERNEL
Used for SPIR kernel functions.
Definition: CallingConv.h:144
@ PTX_Kernel
Call to a PTX kernel. Passes all arguments in parameter space.
Definition: CallingConv.h:125
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:443
PointerTypeMap run(const Module &M)
Compute the PointerTypeMap for the module M.
bool isOpenMPDevice(Module &M)
Helper to determine if M is a OpenMP target offloading device module.
Definition: OpenMPOpt.cpp:5965
bool containsOpenMP(Module &M)
Helper to determine if M contains OpenMP.
Definition: OpenMPOpt.cpp:5957
InternalControlVar
IDs for all Internal Control Variables (ICVs).
Definition: OMPConstants.h:26
RuntimeFunction
IDs for all omp runtime library (RTL) functions.
Definition: OMPConstants.h:45
KernelSet getDeviceKernels(Module &M)
Get OpenMP device kernels in M.
Definition: OpenMPOpt.cpp:5919
Function * Kernel
Summary of a kernel (=entry point for target offloading).
Definition: OpenMPOpt.h:21
@ OMP_TGT_EXEC_MODE_GENERIC_SPMD
bool isOpenMPKernel(Function &Fn)
Return true iff Fn is an OpenMP GPU kernel; Fn has the "kernel" attribute.
Definition: OpenMPOpt.cpp:5904
DiagnosticInfoOptimizationBase::Argument NV
const_iterator begin(StringRef path LLVM_LIFETIME_BOUND, Style style=Style::native)
Get begin iterator over path.
Definition: Path.cpp:226
const_iterator end(StringRef path LLVM_LIFETIME_BOUND)
Get end iterator over path.
Definition: Path.cpp:235
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition: STLExtras.h:329
@ Offset
Definition: DWP.cpp:480
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1739
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition: STLExtras.h:1697
bool succ_empty(const Instruction *I)
Definition: CFG.h:255
std::string to_string(const T &Value)
Definition: ScopedPrinter.h:85
bool operator!=(uint64_t V1, const APInt &V2)
Definition: APInt.h:2082
Value * GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, const DataLayout &DL, bool AllowNonInbounds=true)
Analyze the specified pointer to see if it can be expressed as a base pointer plus a constant offset.
const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)
This method strips off any GEP address adjustments, pointer casts or llvm.threadlocal....
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
@ CGSCC
Definition: Attributor.h:6486
@ FullLTOPostLink
Full LTO postlink (backend compile) phase.
@ ThinLTOPreLink
ThinLTO prelink (summary) phase.
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
void cantFail(Error Err, const char *Msg=nullptr)
Report a fatal error if Err is a failure value.
Definition: Error.h:756
bool operator&=(SparseBitVector< ElementSize > *LHS, const SparseBitVector< ElementSize > &RHS)
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition: STLExtras.h:1938
BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
auto predecessors(const MachineBasicBlock *BB)
ChangeStatus
{
Definition: Attributor.h:489
const char * toString(DWARFSectionKind Kind)
Constant * ConstantFoldInsertValueInstruction(Constant *Agg, Constant *Val, ArrayRef< unsigned > Idxs)
ConstantFoldInsertValueInstruction - Attempt to constant fold an insertvalue instruction with the spe...
Implement std::hash so that hash_code can be used in STL containers.
Definition: BitVector.h:858
#define N
An abstract interface for address space information.
Definition: Attributor.h:6294
An abstract attribute for getting assumption information.
Definition: Attributor.h:6220
An abstract state for querying live call edges.
Definition: Attributor.h:5495
virtual const SetVector< Function * > & getOptimisticEdges() const =0
Get the optimistic edges.
bool isExecutedByInitialThreadOnly(const Instruction &I) const
Check if an instruction is executed only by the initial thread.
Definition: Attributor.h:5655
static AAExecutionDomain & createForPosition(const IRPosition &IRP, Attributor &A)
Create an abstract attribute view for the position IRP.
Definition: OpenMPOpt.cpp:5636
virtual ExecutionDomainTy getFunctionExecutionDomain() const =0
virtual ExecutionDomainTy getExecutionDomain(const BasicBlock &) const =0
virtual bool isExecutedInAlignedRegion(Attributor &A, const Instruction &I) const =0
Check if the instruction I is executed in an aligned region, that is, the synchronizing effects befor...
virtual bool isNoOpFence(const FenceInst &FI) const =0
Helper function to determine if FI is a no-op given the information about its execution from ExecDoma...
static const char ID
Unique ID (due to the unique address)
Definition: Attributor.h:5686
An abstract interface for indirect call information interference.
Definition: Attributor.h:6413
An abstract interface for liveness abstract attribute.
Definition: Attributor.h:3985
An abstract interface for all memory location attributes (readnone/argmemonly/inaccessiblememonly/ina...
Definition: Attributor.h:4713
AccessKind
Simple enum to distinguish read/write/read-write accesses.
Definition: Attributor.h:4849
StateType::base_t MemoryLocationsKind
Definition: Attributor.h:4714
static bool isAlignedBarrier(const CallBase &CB, bool ExecutedAligned)
Helper function to determine if CB is an aligned (GPU) barrier.
An abstract Attribute for determining the necessity of the convergent attribute.
Definition: Attributor.h:5731
An abstract attribute for getting all assumption underlying objects.
Definition: Attributor.h:6252
Base struct for all "concrete attribute" deductions.
Definition: Attributor.h:3293
virtual ChangeStatus manifest(Attributor &A)
Hook for the Attributor to trigger the manifestation of the information represented by the abstract a...
Definition: Attributor.h:3408
virtual void initialize(Attributor &A)
Initialize the state with the information in the Attributor A.
Definition: Attributor.h:3357
virtual const std::string getAsStr(Attributor *A) const =0
This function should return the "summarized" assumed state as string.
virtual ChangeStatus updateImpl(Attributor &A)=0
The actual update/transfer function which has to be implemented by the derived classes.
virtual void trackStatistics() const =0
Hook to enable custom statistic tracking, called after manifest that resulted in a change if statisti...
virtual const char * getIdAddr() const =0
This function should return the address of the ID of the AbstractAttribute.
An interface to query the internal state of an abstract attribute.
Definition: Attributor.h:2613
virtual ChangeStatus indicatePessimisticFixpoint()=0
Indicate that the abstract state should converge to the pessimistic state.
virtual bool isAtFixpoint() const =0
Return if this abstract state is fixed, thus does not need to be updated if information changes as it...
virtual bool isValidState() const =0
Return if this abstract state is in a valid state.
virtual ChangeStatus indicateOptimisticFixpoint()=0
Indicate that the abstract state should converge to the optimistic state.
Wrapper for FunctionAnalysisManager.
Definition: Attributor.h:1127
Configuration for the Attributor.
Definition: Attributor.h:1431
std::function< void(Attributor &A, const Function &F)> InitializationCallback
Callback function to be invoked on internal functions marked live.
Definition: Attributor.h:1462
std::optional< unsigned > MaxFixpointIterations
Maximum number of iterations to run until fixpoint.
Definition: Attributor.h:1478
bool RewriteSignatures
Flag to determine if we rewrite function signatures.
Definition: Attributor.h:1448
const char * PassName
}
Definition: Attributor.h:1488
OptimizationRemarkGetter OREGetter
Definition: Attributor.h:1484
IPOAmendableCBTy IPOAmendableCB
Definition: Attributor.h:1491
bool IsModulePass
Is the user of the Attributor a module pass or not.
Definition: Attributor.h:1442
bool DefaultInitializeLiveInternals
Flag to determine if we want to initialize all default AAs for an internal function marked live.
Definition: Attributor.h:1452
The fixpoint analysis framework that orchestrates the attribute deduction.
Definition: Attributor.h:1525
static bool isInternalizable(Function &F)
Returns true if the function F can be internalized.
std::function< std::optional< Constant * >(const GlobalVariable &, const AbstractAttribute *, bool &)> GlobalVariableSimplifictionCallbackTy
Register CB as a simplification callback.
Definition: Attributor.h:2042
std::function< bool(Attributor &, const AbstractAttribute *)> VirtualUseCallbackTy
Definition: Attributor.h:2072
static bool internalizeFunctions(SmallPtrSetImpl< Function * > &FnSet, DenseMap< Function *, Function * > &FnMap)
Make copies of each function in the set FnSet such that the copied version has internal linkage after...
std::function< std::optional< Value * >(const IRPosition &, const AbstractAttribute *, bool &)> SimplifictionCallbackTy
Register CB as a simplification callback.
Definition: Attributor.h:2026
Simple wrapper for a single bit (boolean) state.
Definition: Attributor.h:2897
Support structure for SCC passes to communicate updates the call graph back to the CGSCC pass manager...
Helper to describe and deal with positions in the LLVM-IR.
Definition: Attributor.h:586
static const IRPosition callsite_returned(const CallBase &CB)
Create a position describing the returned value of CB.
Definition: Attributor.h:654
static const IRPosition returned(const Function &F, const CallBaseContext *CBContext=nullptr)
Create a position describing the returned value of F.
Definition: Attributor.h:636
static const IRPosition value(const Value &V, const CallBaseContext *CBContext=nullptr)
Create a position describing the value of V.
Definition: Attributor.h:610
static const IRPosition inst(const Instruction &I, const CallBaseContext *CBContext=nullptr)
Create a position describing the instruction I.
Definition: Attributor.h:622
@ IRP_ARGUMENT
An attribute for a function argument.
Definition: Attributor.h:600
@ IRP_RETURNED
An attribute for the function return value.
Definition: Attributor.h:596
@ IRP_CALL_SITE
An attribute for a call site (function scope).
Definition: Attributor.h:599
@ IRP_CALL_SITE_RETURNED
An attribute for a call site return value.
Definition: Attributor.h:597
@ IRP_FUNCTION
An attribute for a function (scope).
Definition: Attributor.h:598
@ IRP_FLOAT
A position that is not associated with a spot suitable for attributes.
Definition: Attributor.h:594
@ IRP_CALL_SITE_ARGUMENT
An attribute for a call site argument.
Definition: Attributor.h:601
@ IRP_INVALID
An invalid position.
Definition: Attributor.h:593
static const IRPosition function(const Function &F, const CallBaseContext *CBContext=nullptr)
Create a position describing the function scope of F.
Definition: Attributor.h:629
Kind getPositionKind() const
Return the associated position kind.
Definition: Attributor.h:882
static const IRPosition callsite_function(const CallBase &CB)
Create a position describing the function scope of CB.
Definition: Attributor.h:649
Function * getAnchorScope() const
Return the Function surrounding the anchor value.
Definition: Attributor.h:758
Data structure to hold cached (LLVM-IR) information.
Definition: Attributor.h:1203
bool isValidState() const override
See AbstractState::isValidState() NOTE: For now we simply pretend that the worst possible state is in...
Definition: Attributor.h:2672
ChangeStatus indicatePessimisticFixpoint() override
See AbstractState::indicatePessimisticFixpoint(...)
Definition: Attributor.h:2684
Direction
An enum for the direction of the loop.
Definition: LoopInfo.h:215
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Definition: Alignment.h:117
Description of a LLVM-IR insertion point (IP) and a debug/source location (filename,...
Definition: OMPIRBuilder.h:645
Helper to tie a abstract state implementation to an abstract attribute.
Definition: Attributor.h:3182
StateType & getState() override
See AbstractAttribute::getState(...).
Definition: Attributor.h:3190
Defines various target-specific GPU grid values that must be consistent between host RTL (plugin),...
Definition: OMPGridValues.h:57