LLVM 23.0.0git
OpenMPOpt.cpp
Go to the documentation of this file.
1//===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// OpenMP specific optimizations:
10//
11// - Deduplication of runtime calls, e.g., omp_get_thread_num.
12// - Replacing globalized device memory with stack memory.
13// - Replacing globalized device memory with shared memory.
14// - Parallel region merging.
15// - Transforming generic-mode device kernels to SPMD mode.
16// - Specializing the state machine for generic-mode device kernels.
17//
18//===----------------------------------------------------------------------===//
19
21
22#include "llvm/ADT/DenseSet.h"
25#include "llvm/ADT/SetVector.h"
28#include "llvm/ADT/Statistic.h"
30#include "llvm/ADT/StringRef.h"
39#include "llvm/IR/Assumptions.h"
40#include "llvm/IR/BasicBlock.h"
41#include "llvm/IR/Constants.h"
43#include "llvm/IR/Dominators.h"
44#include "llvm/IR/Function.h"
45#include "llvm/IR/GlobalValue.h"
47#include "llvm/IR/InstrTypes.h"
48#include "llvm/IR/Instruction.h"
51#include "llvm/IR/IntrinsicsAMDGPU.h"
52#include "llvm/IR/IntrinsicsNVPTX.h"
53#include "llvm/IR/LLVMContext.h"
56#include "llvm/Support/Debug.h"
60
61#include <algorithm>
62#include <optional>
63#include <string>
64
65using namespace llvm;
66using namespace omp;
67
68#define DEBUG_TYPE "openmp-opt"
69
71 "openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."),
72 cl::Hidden, cl::init(false));
73
75 "openmp-opt-enable-merging",
76 cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
77 cl::init(false));
78
79static cl::opt<bool>
80 DisableInternalization("openmp-opt-disable-internalization",
81 cl::desc("Disable function internalization."),
82 cl::Hidden, cl::init(false));
83
84static cl::opt<bool> DeduceICVValues("openmp-deduce-icv-values",
85 cl::init(false), cl::Hidden);
86static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
88static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
89 cl::init(false), cl::Hidden);
90
92 "openmp-hide-memory-transfer-latency",
93 cl::desc("[WIP] Tries to hide the latency of host to device memory"
94 " transfers"),
95 cl::Hidden, cl::init(false));
96
98 "openmp-opt-disable-deglobalization",
99 cl::desc("Disable OpenMP optimizations involving deglobalization."),
100 cl::Hidden, cl::init(false));
101
103 "openmp-opt-disable-spmdization",
104 cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
105 cl::Hidden, cl::init(false));
106
108 "openmp-opt-disable-folding",
109 cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
110 cl::init(false));
111
113 "openmp-opt-disable-state-machine-rewrite",
114 cl::desc("Disable OpenMP optimizations that replace the state machine."),
115 cl::Hidden, cl::init(false));
116
118 "openmp-opt-disable-barrier-elimination",
119 cl::desc("Disable OpenMP optimizations that eliminate barriers."),
120 cl::Hidden, cl::init(false));
121
123 "openmp-opt-print-module-after",
124 cl::desc("Print the current module after OpenMP optimizations."),
125 cl::Hidden, cl::init(false));
126
128 "openmp-opt-print-module-before",
129 cl::desc("Print the current module before OpenMP optimizations."),
130 cl::Hidden, cl::init(false));
131
133 "openmp-opt-inline-device",
134 cl::desc("Inline all applicable functions on the device."), cl::Hidden,
135 cl::init(false));
136
137static cl::opt<bool>
138 EnableVerboseRemarks("openmp-opt-verbose-remarks",
139 cl::desc("Enables more verbose remarks."), cl::Hidden,
140 cl::init(false));
141
143 SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden,
144 cl::desc("Maximal number of attributor iterations."),
145 cl::init(256));
146
148 SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden,
149 cl::desc("Maximum amount of shared memory to use."),
150 cl::init(std::numeric_limits<unsigned>::max()));
151
152STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
153 "Number of OpenMP runtime calls deduplicated");
154STATISTIC(NumOpenMPParallelRegionsDeleted,
155 "Number of OpenMP parallel regions deleted");
156STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
157 "Number of OpenMP runtime functions identified");
158STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
159 "Number of OpenMP runtime function uses identified");
160STATISTIC(NumOpenMPTargetRegionKernels,
161 "Number of OpenMP target region entry points (=kernels) identified");
162STATISTIC(NumNonOpenMPTargetRegionKernels,
163 "Number of non-OpenMP target region kernels identified");
164STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
165 "Number of OpenMP target region entry points (=kernels) executed in "
166 "SPMD-mode instead of generic-mode");
167STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
168 "Number of OpenMP target region entry points (=kernels) executed in "
169 "generic-mode without a state machines");
170STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
171 "Number of OpenMP target region entry points (=kernels) executed in "
172 "generic-mode with customized state machines with fallback");
173STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
174 "Number of OpenMP target region entry points (=kernels) executed in "
175 "generic-mode with customized state machines without fallback");
177 NumOpenMPParallelRegionsReplacedInGPUStateMachine,
178 "Number of OpenMP parallel regions replaced with ID in GPU state machines");
179STATISTIC(NumOpenMPParallelRegionsMerged,
180 "Number of OpenMP parallel regions merged");
181STATISTIC(NumBytesMovedToSharedMemory,
182 "Amount of memory pushed to shared memory");
183STATISTIC(NumBarriersEliminated, "Number of redundant barriers eliminated");
184
185#if !defined(NDEBUG)
186static constexpr auto TAG = "[" DEBUG_TYPE "]";
187#endif
188
189namespace KernelInfo {
190
191// struct ConfigurationEnvironmentTy {
192// uint8_t UseGenericStateMachine;
193// uint8_t MayUseNestedParallelism;
194// llvm::omp::OMPTgtExecModeFlags ExecMode;
195// int32_t MinThreads;
196// int32_t MaxThreads;
197// int32_t MinTeams;
198// int32_t MaxTeams;
199// };
200
201// struct DynamicEnvironmentTy {
202// uint16_t DebugIndentionLevel;
203// };
204
205// struct KernelEnvironmentTy {
206// ConfigurationEnvironmentTy Configuration;
207// IdentTy *Ident;
208// DynamicEnvironmentTy *DynamicEnv;
209// };
210
211#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX) \
212 constexpr unsigned MEMBER##Idx = IDX;
213
214KERNEL_ENVIRONMENT_IDX(Configuration, 0)
216
217#undef KERNEL_ENVIRONMENT_IDX
218
219#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX) \
220 constexpr unsigned MEMBER##Idx = IDX;
221
222KERNEL_ENVIRONMENT_CONFIGURATION_IDX(UseGenericStateMachine, 0)
223KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MayUseNestedParallelism, 1)
229
230#undef KERNEL_ENVIRONMENT_CONFIGURATION_IDX
231
232#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE) \
233 RETURNTYPE *get##MEMBER##FromKernelEnvironment(ConstantStruct *KernelEnvC) { \
234 return cast<RETURNTYPE>(KernelEnvC->getAggregateElement(MEMBER##Idx)); \
235 }
236
239
240#undef KERNEL_ENVIRONMENT_GETTER
241
242#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER) \
243 ConstantInt *get##MEMBER##FromKernelEnvironment( \
244 ConstantStruct *KernelEnvC) { \
245 ConstantStruct *ConfigC = \
246 getConfigurationFromKernelEnvironment(KernelEnvC); \
247 return dyn_cast<ConstantInt>(ConfigC->getAggregateElement(MEMBER##Idx)); \
248 }
249
250KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(UseGenericStateMachine)
251KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MayUseNestedParallelism)
257
258#undef KERNEL_ENVIRONMENT_CONFIGURATION_GETTER
259
262 constexpr int InitKernelEnvironmentArgNo = 0;
264 KernelInitCB->getArgOperand(InitKernelEnvironmentArgNo)
266}
267
273} // namespace KernelInfo
274
275namespace {
276
277struct AAHeapToShared;
278
279struct AAICVTracker;
280
281/// OpenMP specific information. For now, stores RFIs and ICVs also needed for
282/// Attributor runs.
283struct OMPInformationCache : public InformationCache {
284 OMPInformationCache(Module &M, AnalysisGetter &AG,
285 BumpPtrAllocator &Allocator, SetVector<Function *> *CGSCC,
286 bool OpenMPPostLink)
287 : InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),
288 OpenMPPostLink(OpenMPPostLink) {
289
290 OMPBuilder.Config.IsTargetDevice = isOpenMPDevice(OMPBuilder.M);
291 const Triple T(OMPBuilder.M.getTargetTriple());
292 switch (T.getArch()) {
296 assert(OMPBuilder.Config.IsTargetDevice &&
297 "OpenMP AMDGPU/NVPTX is only prepared to deal with device code.");
298 OMPBuilder.Config.IsGPU = true;
299 break;
300 default:
301 OMPBuilder.Config.IsGPU = false;
302 break;
303 }
304 OMPBuilder.initialize();
305 initializeRuntimeFunctions(M);
306 initializeInternalControlVars();
307 }
308
309 /// Generic information that describes an internal control variable.
310 struct InternalControlVarInfo {
311 /// The kind, as described by InternalControlVar enum.
313
314 /// The name of the ICV.
315 StringRef Name;
316
317 /// Environment variable associated with this ICV.
318 StringRef EnvVarName;
319
320 /// Initial value kind.
321 ICVInitValue InitKind;
322
323 /// Initial value.
324 ConstantInt *InitValue;
325
326 /// Setter RTL function associated with this ICV.
327 RuntimeFunction Setter;
328
329 /// Getter RTL function associated with this ICV.
330 RuntimeFunction Getter;
331
332 /// RTL Function corresponding to the override clause of this ICV
333 RuntimeFunction Clause;
334 };
335
336 /// Generic information that describes a runtime function
337 struct RuntimeFunctionInfo {
338
339 /// The kind, as described by the RuntimeFunction enum.
340 RuntimeFunction Kind;
341
342 /// The name of the function.
343 StringRef Name;
344
345 /// Flag to indicate a variadic function.
346 bool IsVarArg;
347
348 /// The return type of the function.
349 Type *ReturnType;
350
351 /// The argument types of the function.
352 SmallVector<Type *, 8> ArgumentTypes;
353
354 /// The declaration if available.
355 Function *Declaration = nullptr;
356
357 /// Uses of this runtime function per function containing the use.
358 using UseVector = SmallVector<Use *, 16>;
359
360 /// Clear UsesMap for runtime function.
361 void clearUsesMap() { UsesMap.clear(); }
362
363 /// Boolean conversion that is true if the runtime function was found.
364 operator bool() const { return Declaration; }
365
366 /// Return the vector of uses in function \p F.
367 UseVector &getOrCreateUseVector(Function *F) {
368 std::shared_ptr<UseVector> &UV = UsesMap[F];
369 if (!UV)
370 UV = std::make_shared<UseVector>();
371 return *UV;
372 }
373
374 /// Return the vector of uses in function \p F or `nullptr` if there are
375 /// none.
376 const UseVector *getUseVector(Function &F) const {
377 auto I = UsesMap.find(&F);
378 if (I != UsesMap.end())
379 return I->second.get();
380 return nullptr;
381 }
382
383 /// Return how many functions contain uses of this runtime function.
384 size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
385
386 /// Return the number of arguments (or the minimal number for variadic
387 /// functions).
388 size_t getNumArgs() const { return ArgumentTypes.size(); }
389
390 /// Run the callback \p CB on each use and forget the use if the result is
391 /// true. The callback will be fed the function in which the use was
392 /// encountered as second argument.
393 void foreachUse(SmallVectorImpl<Function *> &SCC,
394 function_ref<bool(Use &, Function &)> CB) {
395 for (Function *F : SCC)
396 foreachUse(CB, F);
397 }
398
399 /// Run the callback \p CB on each use within the function \p F and forget
400 /// the use if the result is true.
401 void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
402 SmallVector<unsigned, 8> ToBeDeleted;
403 ToBeDeleted.clear();
404
405 unsigned Idx = 0;
406 UseVector &UV = getOrCreateUseVector(F);
407
408 for (Use *U : UV) {
409 if (CB(*U, *F))
410 ToBeDeleted.push_back(Idx);
411 ++Idx;
412 }
413
414 // Remove the to-be-deleted indices in reverse order as prior
415 // modifications will not modify the smaller indices.
416 while (!ToBeDeleted.empty()) {
417 unsigned Idx = ToBeDeleted.pop_back_val();
418 UV[Idx] = UV.back();
419 UV.pop_back();
420 }
421 }
422
423 private:
424 /// Map from functions to all uses of this runtime function contained in
425 /// them.
426 DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap;
427
428 public:
429 /// Iterators for the uses of this runtime function.
430 decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
431 decltype(UsesMap)::iterator end() { return UsesMap.end(); }
432 };
433
434 /// An OpenMP-IR-Builder instance
435 OpenMPIRBuilder OMPBuilder;
436
437 /// Map from runtime function kind to the runtime function description.
438 EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
439 RuntimeFunction::OMPRTL___last>
440 RFIs;
441
442 /// Map from function declarations/definitions to their runtime enum type.
443 DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
444
445 /// Map from ICV kind to the ICV description.
446 EnumeratedArray<InternalControlVarInfo, InternalControlVar,
447 InternalControlVar::ICV___last>
448 ICVs;
449
450 /// Helper to initialize all internal control variable information for those
451 /// defined in OMPKinds.def.
452 void initializeInternalControlVars() {
453#define ICV_RT_SET(_Name, RTL) \
454 { \
455 auto &ICV = ICVs[_Name]; \
456 ICV.Setter = RTL; \
457 }
458#define ICV_RT_GET(Name, RTL) \
459 { \
460 auto &ICV = ICVs[Name]; \
461 ICV.Getter = RTL; \
462 }
463#define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \
464 { \
465 auto &ICV = ICVs[Enum]; \
466 ICV.Name = _Name; \
467 ICV.Kind = Enum; \
468 ICV.InitKind = Init; \
469 ICV.EnvVarName = _EnvVarName; \
470 switch (ICV.InitKind) { \
471 case ICV_IMPLEMENTATION_DEFINED: \
472 ICV.InitValue = nullptr; \
473 break; \
474 case ICV_ZERO: \
475 ICV.InitValue = ConstantInt::get( \
476 Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \
477 break; \
478 case ICV_FALSE: \
479 ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \
480 break; \
481 case ICV_LAST: \
482 break; \
483 } \
484 }
485#include "llvm/Frontend/OpenMP/OMPKinds.def"
486 }
487
488 /// Returns true if the function declaration \p F matches the runtime
489 /// function types, that is, return type \p RTFRetType, and argument types
490 /// \p RTFArgTypes.
491 static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
492 SmallVector<Type *, 8> &RTFArgTypes) {
493 // TODO: We should output information to the user (under debug output
494 // and via remarks).
495
496 if (!F)
497 return false;
498 if (F->getReturnType() != RTFRetType)
499 return false;
500 if (F->arg_size() != RTFArgTypes.size())
501 return false;
502
503 auto *RTFTyIt = RTFArgTypes.begin();
504 for (Argument &Arg : F->args()) {
505 if (Arg.getType() != *RTFTyIt)
506 return false;
507
508 ++RTFTyIt;
509 }
510
511 return true;
512 }
513
514 // Helper to collect all uses of the declaration in the UsesMap.
515 unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
516 unsigned NumUses = 0;
517 if (!RFI.Declaration)
518 return NumUses;
519 OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
520
521 if (CollectStats) {
522 NumOpenMPRuntimeFunctionsIdentified += 1;
523 NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
524 }
525
526 // TODO: We directly convert uses into proper calls and unknown uses.
527 for (Use &U : RFI.Declaration->uses()) {
528 if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
529 if (!CGSCC || CGSCC->empty() || CGSCC->contains(UserI->getFunction())) {
530 RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
531 ++NumUses;
532 }
533 } else {
534 RFI.getOrCreateUseVector(nullptr).push_back(&U);
535 ++NumUses;
536 }
537 }
538 return NumUses;
539 }
540
541 // Helper function to recollect uses of a runtime function.
542 void recollectUsesForFunction(RuntimeFunction RTF) {
543 auto &RFI = RFIs[RTF];
544 RFI.clearUsesMap();
545 collectUses(RFI, /*CollectStats*/ false);
546 }
547
548 // Helper function to recollect uses of all runtime functions.
549 void recollectUses() {
550 for (int Idx = 0; Idx < RFIs.size(); ++Idx)
551 recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
552 }
553
554 // Helper function to inherit the calling convention of the function callee.
555 void setCallingConvention(FunctionCallee Callee, CallInst *CI) {
556 if (Function *Fn = dyn_cast<Function>(Callee.getCallee()))
557 CI->setCallingConv(Fn->getCallingConv());
558 }
559
560 // Helper function to determine if it's legal to create a call to the runtime
561 // functions.
562 bool runtimeFnsAvailable(ArrayRef<RuntimeFunction> Fns) {
563 // We can always emit calls if we haven't yet linked in the runtime.
564 if (!OpenMPPostLink)
565 return true;
566
567 // Once the runtime has been already been linked in we cannot emit calls to
568 // any undefined functions.
569 for (RuntimeFunction Fn : Fns) {
570 RuntimeFunctionInfo &RFI = RFIs[Fn];
571
572 if (!RFI.Declaration || RFI.Declaration->isDeclaration())
573 return false;
574 }
575 return true;
576 }
577
578 /// Helper to initialize all runtime function information for those defined
579 /// in OpenMPKinds.def.
580 void initializeRuntimeFunctions(Module &M) {
581
582 // Helper macros for handling __VA_ARGS__ in OMP_RTL
583#define OMP_TYPE(VarName, ...) \
584 Type *VarName = OMPBuilder.VarName; \
585 (void)VarName;
586
587#define OMP_ARRAY_TYPE(VarName, ...) \
588 ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \
589 (void)VarName##Ty; \
590 PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \
591 (void)VarName##PtrTy;
592
593#define OMP_FUNCTION_TYPE(VarName, ...) \
594 FunctionType *VarName = OMPBuilder.VarName; \
595 (void)VarName; \
596 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
597 (void)VarName##Ptr;
598
599#define OMP_STRUCT_TYPE(VarName, ...) \
600 StructType *VarName = OMPBuilder.VarName; \
601 (void)VarName; \
602 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
603 (void)VarName##Ptr;
604
605#define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \
606 { \
607 SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \
608 Function *F = M.getFunction(_Name); \
609 RTLFunctions.insert(F); \
610 if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \
611 RuntimeFunctionIDMap[F] = _Enum; \
612 auto &RFI = RFIs[_Enum]; \
613 RFI.Kind = _Enum; \
614 RFI.Name = _Name; \
615 RFI.IsVarArg = _IsVarArg; \
616 RFI.ReturnType = OMPBuilder._ReturnType; \
617 RFI.ArgumentTypes = std::move(ArgsTypes); \
618 RFI.Declaration = F; \
619 unsigned NumUses = collectUses(RFI); \
620 (void)NumUses; \
621 LLVM_DEBUG({ \
622 dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \
623 << " found\n"; \
624 if (RFI.Declaration) \
625 dbgs() << TAG << "-> got " << NumUses << " uses in " \
626 << RFI.getNumFunctionsWithUses() \
627 << " different functions.\n"; \
628 }); \
629 } \
630 }
631#include "llvm/Frontend/OpenMP/OMPKinds.def"
632
633 // Remove the `noinline` attribute from `__kmpc`, `ompx::` and `omp_`
634 // functions, except if `optnone` is present.
635 if (isOpenMPDevice(M)) {
636 for (Function &F : M) {
637 for (StringRef Prefix : {"__kmpc", "_ZN4ompx", "omp_"})
638 if (F.hasFnAttribute(Attribute::NoInline) &&
639 F.getName().starts_with(Prefix) &&
640 !F.hasFnAttribute(Attribute::OptimizeNone))
641 F.removeFnAttr(Attribute::NoInline);
642 }
643 }
644
645 // TODO: We should attach the attributes defined in OMPKinds.def.
646 }
647
648 /// Collection of known OpenMP runtime functions..
649 DenseSet<const Function *> RTLFunctions;
650
651 /// Indicates if we have already linked in the OpenMP device library.
652 bool OpenMPPostLink = false;
653};
654
655template <typename Ty, bool InsertInvalidates = true>
656struct BooleanStateWithSetVector : public BooleanState {
657 bool contains(const Ty &Elem) const { return Set.contains(Elem); }
658 bool insert(const Ty &Elem) {
659 if (InsertInvalidates)
660 BooleanState::indicatePessimisticFixpoint();
661 return Set.insert(Elem);
662 }
663
664 const Ty &operator[](int Idx) const { return Set[Idx]; }
665 bool operator==(const BooleanStateWithSetVector &RHS) const {
666 return BooleanState::operator==(RHS) && Set == RHS.Set;
667 }
668 bool operator!=(const BooleanStateWithSetVector &RHS) const {
669 return !(*this == RHS);
670 }
671
672 bool empty() const { return Set.empty(); }
673 size_t size() const { return Set.size(); }
674
675 /// "Clamp" this state with \p RHS.
676 BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
677 BooleanState::operator^=(RHS);
678 Set.insert_range(RHS.Set);
679 return *this;
680 }
681
682private:
683 /// A set to keep track of elements.
684 SetVector<Ty> Set;
685
686public:
687 typename decltype(Set)::iterator begin() { return Set.begin(); }
688 typename decltype(Set)::iterator end() { return Set.end(); }
689 typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
690 typename decltype(Set)::const_iterator end() const { return Set.end(); }
691};
692
693template <typename Ty, bool InsertInvalidates = true>
694using BooleanStateWithPtrSetVector =
695 BooleanStateWithSetVector<Ty *, InsertInvalidates>;
696
697struct KernelInfoState : AbstractState {
698 /// Flag to track if we reached a fixpoint.
699 bool IsAtFixpoint = false;
700
701 /// The parallel regions (identified by the outlined parallel functions) that
702 /// can be reached from the associated function.
703 BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
704 ReachedKnownParallelRegions;
705
706 /// State to track what parallel region we might reach.
707 BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
708
709 /// State to track if we are in SPMD-mode, assumed or know, and why we decided
710 /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
711 /// false.
712 BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
713
714 /// The __kmpc_target_init call in this kernel, if any. If we find more than
715 /// one we abort as the kernel is malformed.
716 CallBase *KernelInitCB = nullptr;
717
718 /// The constant kernel environement as taken from and passed to
719 /// __kmpc_target_init.
720 ConstantStruct *KernelEnvC = nullptr;
721
722 /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
723 /// one we abort as the kernel is malformed.
724 CallBase *KernelDeinitCB = nullptr;
725
726 /// Flag to indicate if the associated function is a kernel entry.
727 bool IsKernelEntry = false;
728
729 /// State to track what kernel entries can reach the associated function.
730 BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
731
732 /// State to indicate if we can track parallel level of the associated
733 /// function. We will give up tracking if we encounter unknown caller or the
734 /// caller is __kmpc_parallel_60.
735 BooleanStateWithSetVector<uint8_t> ParallelLevels;
736
737 /// Flag that indicates if the kernel has nested Parallelism
738 bool NestedParallelism = false;
739
740 /// Abstract State interface
741 ///{
742
743 KernelInfoState() = default;
744 KernelInfoState(bool BestState) {
745 if (!BestState)
746 indicatePessimisticFixpoint();
747 }
748
749 /// See AbstractState::isValidState(...)
750 bool isValidState() const override { return true; }
751
752 /// See AbstractState::isAtFixpoint(...)
753 bool isAtFixpoint() const override { return IsAtFixpoint; }
754
755 /// See AbstractState::indicatePessimisticFixpoint(...)
756 ChangeStatus indicatePessimisticFixpoint() override {
757 IsAtFixpoint = true;
758 ParallelLevels.indicatePessimisticFixpoint();
759 ReachingKernelEntries.indicatePessimisticFixpoint();
760 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
761 ReachedKnownParallelRegions.indicatePessimisticFixpoint();
762 ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
763 NestedParallelism = true;
764 return ChangeStatus::CHANGED;
765 }
766
767 /// See AbstractState::indicateOptimisticFixpoint(...)
768 ChangeStatus indicateOptimisticFixpoint() override {
769 IsAtFixpoint = true;
770 ParallelLevels.indicateOptimisticFixpoint();
771 ReachingKernelEntries.indicateOptimisticFixpoint();
772 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
773 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
774 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
775 return ChangeStatus::UNCHANGED;
776 }
777
778 /// Return the assumed state
779 KernelInfoState &getAssumed() { return *this; }
780 const KernelInfoState &getAssumed() const { return *this; }
781
782 bool operator==(const KernelInfoState &RHS) const {
783 if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
784 return false;
785 if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
786 return false;
787 if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
788 return false;
789 if (ReachingKernelEntries != RHS.ReachingKernelEntries)
790 return false;
791 if (ParallelLevels != RHS.ParallelLevels)
792 return false;
793 if (NestedParallelism != RHS.NestedParallelism)
794 return false;
795 return true;
796 }
797
798 /// Returns true if this kernel contains any OpenMP parallel regions.
799 bool mayContainParallelRegion() {
800 return !ReachedKnownParallelRegions.empty() ||
801 !ReachedUnknownParallelRegions.empty();
802 }
803
804 /// Return empty set as the best state of potential values.
805 static KernelInfoState getBestState() { return KernelInfoState(true); }
806
807 static KernelInfoState getBestState(KernelInfoState &KIS) {
808 return getBestState();
809 }
810
811 /// Return full set as the worst state of potential values.
812 static KernelInfoState getWorstState() { return KernelInfoState(false); }
813
814 /// "Clamp" this state with \p KIS.
815 KernelInfoState operator^=(const KernelInfoState &KIS) {
816 // Do not merge two different _init and _deinit call sites.
817 if (KIS.KernelInitCB) {
818 if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
819 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
820 "assumptions.");
821 KernelInitCB = KIS.KernelInitCB;
822 }
823 if (KIS.KernelDeinitCB) {
824 if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
825 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
826 "assumptions.");
827 KernelDeinitCB = KIS.KernelDeinitCB;
828 }
829 if (KIS.KernelEnvC) {
830 if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
831 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
832 "assumptions.");
833 KernelEnvC = KIS.KernelEnvC;
834 }
835 SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
836 ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
837 ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
838 NestedParallelism |= KIS.NestedParallelism;
839 return *this;
840 }
841
842 KernelInfoState operator&=(const KernelInfoState &KIS) {
843 return (*this ^= KIS);
844 }
845
846 ///}
847};
848
849/// Used to map the values physically (in the IR) stored in an offload
850/// array, to a vector in memory.
851struct OffloadArray {
852 /// Physical array (in the IR).
853 AllocaInst *Array = nullptr;
854 /// Mapped values.
855 SmallVector<Value *, 8> StoredValues;
856 /// Last stores made in the offload array.
857 SmallVector<StoreInst *, 8> LastAccesses;
858
859 OffloadArray() = default;
860
861 /// Initializes the OffloadArray with the values stored in \p Array before
862 /// instruction \p Before is reached. Returns false if the initialization
863 /// fails.
864 /// This MUST be used immediately after the construction of the object.
865 bool initialize(AllocaInst &Array, Instruction &Before) {
866 if (!getValues(Array, Before))
867 return false;
868
869 this->Array = &Array;
870 return true;
871 }
872
873 static const unsigned DeviceIDArgNum = 1;
874 static const unsigned BasePtrsArgNum = 3;
875 static const unsigned PtrsArgNum = 4;
876 static const unsigned SizesArgNum = 5;
877
878private:
879 /// Traverses the BasicBlock where \p Array is, collecting the stores made to
880 /// \p Array, leaving StoredValues with the values stored before the
881 /// instruction \p Before is reached.
882 bool getValues(AllocaInst &Array, Instruction &Before) {
883 // Initialize containers.
884 const DataLayout &DL = Array.getDataLayout();
885 std::optional<TypeSize> ArraySize = Array.getAllocationSize(DL);
886 if (!ArraySize || !ArraySize->isFixed())
887 return false;
888 const unsigned int PointerSize = DL.getPointerSize();
889 const uint64_t NumValues = ArraySize->getFixedValue() / PointerSize;
890 StoredValues.assign(NumValues, nullptr);
891 LastAccesses.assign(NumValues, nullptr);
892
893 // TODO: This assumes the instruction \p Before is in the same
894 // BasicBlock as Array. Make it general, for any control flow graph.
895 BasicBlock *BB = Array.getParent();
896 if (BB != Before.getParent())
897 return false;
898
899 for (Instruction &I : *BB) {
900 if (&I == &Before)
901 break;
902
903 if (!isa<StoreInst>(&I))
904 continue;
905
906 auto *S = cast<StoreInst>(&I);
907 int64_t Offset = -1;
908 auto *Dst =
909 GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
910 if (Dst == &Array) {
911 int64_t Idx = Offset / PointerSize;
912 // Ignore updates that must be UB (probably in dead code at runtime)
913 if ((uint64_t)Idx < NumValues) {
914 StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
915 LastAccesses[Idx] = S;
916 }
917 }
918 }
919
920 return isFilled();
921 }
922
923 /// Returns true if all values in StoredValues and
924 /// LastAccesses are not nullptrs.
925 bool isFilled() {
926 const unsigned NumValues = StoredValues.size();
927 for (unsigned I = 0; I < NumValues; ++I) {
928 if (!StoredValues[I] || !LastAccesses[I])
929 return false;
930 }
931
932 return true;
933 }
934};
935
936struct OpenMPOpt {
937
938 using OptimizationRemarkGetter =
939 function_ref<OptimizationRemarkEmitter &(Function *)>;
940
941 OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
942 OptimizationRemarkGetter OREGetter,
943 OMPInformationCache &OMPInfoCache, Attributor &A)
944 : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
945 OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
946
947 /// Check if any remarks are enabled for openmp-opt
948 bool remarksEnabled() {
949 auto &Ctx = M.getContext();
950 return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
951 }
952
953 /// Run all OpenMP optimizations on the underlying SCC.
954 bool run(bool IsModulePass) {
955 if (SCC.empty())
956 return false;
957
958 bool Changed = false;
959
960 LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
961 << " functions\n");
962
963 if (IsModulePass) {
964 Changed |= runAttributor(IsModulePass);
965
966 // Recollect uses, in case Attributor deleted any.
967 OMPInfoCache.recollectUses();
968
969 // TODO: This should be folded into buildCustomStateMachine.
970 Changed |= rewriteDeviceCodeStateMachine();
971
972 if (remarksEnabled())
973 analysisGlobalization();
974 } else {
975 if (PrintICVValues)
976 printICVs();
978 printKernels();
979
980 Changed |= runAttributor(IsModulePass);
981
982 // Recollect uses, in case Attributor deleted any.
983 OMPInfoCache.recollectUses();
984
985 Changed |= deleteParallelRegions();
986
988 Changed |= hideMemTransfersLatency();
989 Changed |= deduplicateRuntimeCalls();
991 if (mergeParallelRegions()) {
992 deduplicateRuntimeCalls();
993 Changed = true;
994 }
995 }
996 }
997
998 if (OMPInfoCache.OpenMPPostLink)
999 Changed |= removeRuntimeSymbols();
1000
1001 return Changed;
1002 }
1003
1004 /// Print initial ICV values for testing.
1005 /// FIXME: This should be done from the Attributor once it is added.
1006 void printICVs() const {
1007 InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
1008 ICV_proc_bind};
1009
1010 for (Function *F : SCC) {
1011 for (auto ICV : ICVs) {
1012 auto ICVInfo = OMPInfoCache.ICVs[ICV];
1013 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1014 return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
1015 << " Value: "
1016 << (ICVInfo.InitValue
1017 ? toString(ICVInfo.InitValue->getValue(), 10, true)
1018 : "IMPLEMENTATION_DEFINED");
1019 };
1020
1021 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
1022 }
1023 }
1024 }
1025
1026 /// Print OpenMP GPU kernels for testing.
1027 void printKernels() const {
1028 for (Function *F : SCC) {
1029 if (!omp::isOpenMPKernel(*F))
1030 continue;
1031
1032 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1033 return ORA << "OpenMP GPU kernel "
1034 << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
1035 };
1036
1038 }
1039 }
1040
1041 /// Return the call if \p U is a callee use in a regular call. If \p RFI is
1042 /// given it has to be the callee or a nullptr is returned.
1043 static CallInst *getCallIfRegularCall(
1044 Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1045 CallInst *CI = dyn_cast<CallInst>(U.getUser());
1046 if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
1047 (!RFI ||
1048 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1049 return CI;
1050 return nullptr;
1051 }
1052
1053 /// Return the call if \p V is a regular call. If \p RFI is given it has to be
1054 /// the callee or a nullptr is returned.
1055 static CallInst *getCallIfRegularCall(
1056 Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1057 CallInst *CI = dyn_cast<CallInst>(&V);
1058 if (CI && !CI->hasOperandBundles() &&
1059 (!RFI ||
1060 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1061 return CI;
1062 return nullptr;
1063 }
1064
1065private:
1066 /// Merge parallel regions when it is safe.
1067 bool mergeParallelRegions() {
1068 const unsigned CallbackCalleeOperand = 2;
1069 const unsigned CallbackFirstArgOperand = 3;
1070 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
1071
1072 // Check if there are any __kmpc_fork_call calls to merge.
1073 OMPInformationCache::RuntimeFunctionInfo &RFI =
1074 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1075
1076 if (!RFI.Declaration)
1077 return false;
1078
1079 // Unmergable calls that prevent merging a parallel region.
1080 OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
1081 OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
1082 OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
1083 };
1084
1085 bool Changed = false;
1086 LoopInfo *LI = nullptr;
1087 DominatorTree *DT = nullptr;
1088
1089 SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap;
1090
1091 BasicBlock *StartBB = nullptr, *EndBB = nullptr;
1092 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
1093 ArrayRef<BasicBlock *> DeallocBlocks) {
1094 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1095 BasicBlock *CGEndBB =
1096 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1097 assert(StartBB != nullptr && "StartBB should not be null");
1098 CGStartBB->getTerminator()->setSuccessor(0, StartBB);
1099 assert(EndBB != nullptr && "EndBB should not be null");
1100 EndBB->getTerminator()->setSuccessor(0, CGEndBB);
1101 return Error::success();
1102 };
1103
1104 auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
1105 Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
1106 ReplacementValue = &Inner;
1107 return CodeGenIP;
1108 };
1109
1110 auto FiniCB = [&](InsertPointTy CodeGenIP) { return Error::success(); };
1111
1112 /// Create a sequential execution region within a merged parallel region,
1113 /// encapsulated in a master construct with a barrier for synchronization.
1114 auto CreateSequentialRegion = [&](Function *OuterFn,
1115 BasicBlock *OuterPredBB,
1116 Instruction *SeqStartI,
1117 Instruction *SeqEndI) {
1118 // Isolate the instructions of the sequential region to a separate
1119 // block.
1120 BasicBlock *ParentBB = SeqStartI->getParent();
1121 BasicBlock *SeqEndBB =
1122 SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
1123 BasicBlock *SeqAfterBB =
1124 SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
1125 BasicBlock *SeqStartBB =
1126 SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
1127
1128 assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
1129 "Expected a different CFG");
1130 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
1131 ParentBB->getTerminator()->eraseFromParent();
1132
1133 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
1134 ArrayRef<BasicBlock *> DeallocBlocks) {
1135 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1136 BasicBlock *CGEndBB =
1137 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1138 assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
1139 CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
1140 assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
1141 SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
1142 return Error::success();
1143 };
1144 auto FiniCB = [&](InsertPointTy CodeGenIP) { return Error::success(); };
1145
1146 // Find outputs from the sequential region to outside users and
1147 // broadcast their values to them.
1148 for (Instruction &I : *SeqStartBB) {
1149 SmallPtrSet<Instruction *, 4> OutsideUsers;
1150 for (User *Usr : I.users()) {
1151 Instruction &UsrI = *cast<Instruction>(Usr);
1152 // Ignore outputs to LT intrinsics, code extraction for the merged
1153 // parallel region will fix them.
1154 if (UsrI.isLifetimeStartOrEnd())
1155 continue;
1156
1157 if (UsrI.getParent() != SeqStartBB)
1158 OutsideUsers.insert(&UsrI);
1159 }
1160
1161 if (OutsideUsers.empty())
1162 continue;
1163
1164 // Emit an alloca in the outer region to store the broadcasted
1165 // value.
1166 const DataLayout &DL = M.getDataLayout();
1167 AllocaInst *AllocaI = new AllocaInst(
1168 I.getType(), DL.getAllocaAddrSpace(), nullptr,
1169 I.getName() + ".seq.output.alloc", OuterFn->front().begin());
1170
1171 // Emit a store instruction in the sequential BB to update the
1172 // value.
1173 new StoreInst(&I, AllocaI, SeqStartBB->getTerminator()->getIterator());
1174
1175 // Emit a load instruction and replace the use of the output value
1176 // with it.
1177 for (Instruction *UsrI : OutsideUsers) {
1178 LoadInst *LoadI = new LoadInst(I.getType(), AllocaI,
1179 I.getName() + ".seq.output.load",
1180 UsrI->getIterator());
1181 UsrI->replaceUsesOfWith(&I, LoadI);
1182 }
1183 }
1184
1185 OpenMPIRBuilder::LocationDescription Loc(
1186 InsertPointTy(ParentBB, ParentBB->end()), DL);
1188 OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB));
1189 cantFail(
1190 OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel));
1191
1192 UncondBrInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
1193
1194 LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
1195 << "\n");
1196 };
1197
1198 // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
1199 // contained in BB and only separated by instructions that can be
1200 // redundantly executed in parallel. The block BB is split before the first
1201 // call (in MergableCIs) and after the last so the entire region we merge
1202 // into a single parallel region is contained in a single basic block
1203 // without any other instructions. We use the OpenMPIRBuilder to outline
1204 // that block and call the resulting function via __kmpc_fork_call.
1205 auto Merge = [&](const SmallVectorImpl<CallInst *> &MergableCIs,
1206 BasicBlock *BB) {
1207 // TODO: Change the interface to allow single CIs expanded, e.g, to
1208 // include an outer loop.
1209 assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
1210
1211 auto Remark = [&](OptimizationRemark OR) {
1212 OR << "Parallel region merged with parallel region"
1213 << (MergableCIs.size() > 2 ? "s" : "") << " at ";
1214 for (auto *CI : llvm::drop_begin(MergableCIs)) {
1215 OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
1216 if (CI != MergableCIs.back())
1217 OR << ", ";
1218 }
1219 return OR << ".";
1220 };
1221
1222 emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
1223
1224 Function *OriginalFn = BB->getParent();
1225 LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
1226 << " parallel regions in " << OriginalFn->getName()
1227 << "\n");
1228
1229 // Isolate the calls to merge in a separate block.
1230 EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
1231 BasicBlock *AfterBB =
1232 SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1233 StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
1234 "omp.par.merged");
1235
1236 assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
1237 const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1238 BB->getTerminator()->eraseFromParent();
1239
1240 // Create sequential regions for sequential instructions that are
1241 // in-between mergable parallel regions.
1242 for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1243 It != End; ++It) {
1244 Instruction *ForkCI = *It;
1245 Instruction *NextForkCI = *(It + 1);
1246
1247 // Continue if there are not in-between instructions.
1248 if (ForkCI->getNextNode() == NextForkCI)
1249 continue;
1250
1251 CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1252 NextForkCI->getPrevNode());
1253 }
1254
1255 OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1256 DL);
1257 IRBuilder<>::InsertPoint AllocaIP(
1258 &OriginalFn->getEntryBlock(),
1259 OriginalFn->getEntryBlock().getFirstInsertionPt());
1260 // Create the merged parallel region with default proc binding, to
1261 // avoid overriding binding settings, and without explicit cancellation.
1263 cantFail(OMPInfoCache.OMPBuilder.createParallel(
1264 Loc, AllocaIP, /* DeallocBlocks */ {}, BodyGenCB, PrivCB, FiniCB,
1265 nullptr, nullptr, OMP_PROC_BIND_default,
1266 /* IsCancellable */ false));
1267 UncondBrInst::Create(AfterBB, AfterIP.getBlock());
1268
1269 // Perform the actual outlining.
1270 OMPInfoCache.OMPBuilder.finalize(OriginalFn);
1271
1272 Function *OutlinedFn = MergableCIs.front()->getCaller();
1273
1274 // Replace the __kmpc_fork_call calls with direct calls to the outlined
1275 // callbacks.
1276 SmallVector<Value *, 8> Args;
1277 for (auto *CI : MergableCIs) {
1278 Value *Callee = CI->getArgOperand(CallbackCalleeOperand);
1279 FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask;
1280 Args.clear();
1281 Args.push_back(OutlinedFn->getArg(0));
1282 Args.push_back(OutlinedFn->getArg(1));
1283 for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1284 ++U)
1285 Args.push_back(CI->getArgOperand(U));
1286
1287 CallInst *NewCI =
1288 CallInst::Create(FT, Callee, Args, "", CI->getIterator());
1289 if (CI->getDebugLoc())
1290 NewCI->setDebugLoc(CI->getDebugLoc());
1291
1292 // Forward parameter attributes from the callback to the callee.
1293 for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1294 ++U)
1295 for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
1296 NewCI->addParamAttr(
1297 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1298
1299 // Emit an explicit barrier to replace the implicit fork-join barrier.
1300 if (CI != MergableCIs.back()) {
1301 // TODO: Remove barrier if the merged parallel region includes the
1302 // 'nowait' clause.
1303 cantFail(OMPInfoCache.OMPBuilder.createBarrier(
1304 InsertPointTy(NewCI->getParent(),
1305 NewCI->getNextNode()->getIterator()),
1306 OMPD_parallel));
1307 }
1308
1309 CI->eraseFromParent();
1310 }
1311
1312 assert(OutlinedFn != OriginalFn && "Outlining failed");
1313 CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1314 CGUpdater.reanalyzeFunction(*OriginalFn);
1315
1316 NumOpenMPParallelRegionsMerged += MergableCIs.size();
1317
1318 return true;
1319 };
1320
1321 // Helper function that identifes sequences of
1322 // __kmpc_fork_call uses in a basic block.
1323 auto DetectPRsCB = [&](Use &U, Function &F) {
1324 CallInst *CI = getCallIfRegularCall(U, &RFI);
1325 BB2PRMap[CI->getParent()].insert(CI);
1326
1327 return false;
1328 };
1329
1330 BB2PRMap.clear();
1331 RFI.foreachUse(SCC, DetectPRsCB);
1332 SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1333 // Find mergable parallel regions within a basic block that are
1334 // safe to merge, that is any in-between instructions can safely
1335 // execute in parallel after merging.
1336 // TODO: support merging across basic-blocks.
1337 for (auto &It : BB2PRMap) {
1338 auto &CIs = It.getSecond();
1339 if (CIs.size() < 2)
1340 continue;
1341
1342 BasicBlock *BB = It.getFirst();
1343 SmallVector<CallInst *, 4> MergableCIs;
1344
1345 /// Returns true if the instruction is mergable, false otherwise.
1346 /// A terminator instruction is unmergable by definition since merging
1347 /// works within a BB. Instructions before the mergable region are
1348 /// mergable if they are not calls to OpenMP runtime functions that may
1349 /// set different execution parameters for subsequent parallel regions.
1350 /// Instructions in-between parallel regions are mergable if they are not
1351 /// calls to any non-intrinsic function since that may call a non-mergable
1352 /// OpenMP runtime function.
1353 auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1354 // We do not merge across BBs, hence return false (unmergable) if the
1355 // instruction is a terminator.
1356 if (I.isTerminator())
1357 return false;
1358
1359 if (!isa<CallInst>(&I))
1360 return true;
1361
1362 CallInst *CI = cast<CallInst>(&I);
1363 if (IsBeforeMergableRegion) {
1364 Function *CalledFunction = CI->getCalledFunction();
1365 if (!CalledFunction)
1366 return false;
1367 // Return false (unmergable) if the call before the parallel
1368 // region calls an explicit affinity (proc_bind) or number of
1369 // threads (num_threads) compiler-generated function. Those settings
1370 // may be incompatible with following parallel regions.
1371 // TODO: ICV tracking to detect compatibility.
1372 for (const auto &RFI : UnmergableCallsInfo) {
1373 if (CalledFunction == RFI.Declaration)
1374 return false;
1375 }
1376 } else {
1377 // Return false (unmergable) if there is a call instruction
1378 // in-between parallel regions when it is not an intrinsic. It
1379 // may call an unmergable OpenMP runtime function in its callpath.
1380 // TODO: Keep track of possible OpenMP calls in the callpath.
1381 if (!isa<IntrinsicInst>(CI))
1382 return false;
1383 }
1384
1385 return true;
1386 };
1387 // Find maximal number of parallel region CIs that are safe to merge.
1388 for (auto It = BB->begin(), End = BB->end(); It != End;) {
1389 Instruction &I = *It;
1390 ++It;
1391
1392 if (CIs.count(&I)) {
1393 MergableCIs.push_back(cast<CallInst>(&I));
1394 continue;
1395 }
1396
1397 // Continue expanding if the instruction is mergable.
1398 if (IsMergable(I, MergableCIs.empty()))
1399 continue;
1400
1401 // Forward the instruction iterator to skip the next parallel region
1402 // since there is an unmergable instruction which can affect it.
1403 for (; It != End; ++It) {
1404 Instruction &SkipI = *It;
1405 if (CIs.count(&SkipI)) {
1406 LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1407 << " due to " << I << "\n");
1408 ++It;
1409 break;
1410 }
1411 }
1412
1413 // Store mergable regions found.
1414 if (MergableCIs.size() > 1) {
1415 MergableCIsVector.push_back(MergableCIs);
1416 LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1417 << " parallel regions in block " << BB->getName()
1418 << " of function " << BB->getParent()->getName()
1419 << "\n";);
1420 }
1421
1422 MergableCIs.clear();
1423 }
1424
1425 if (!MergableCIsVector.empty()) {
1426 Changed = true;
1427
1428 for (auto &MergableCIs : MergableCIsVector)
1429 Merge(MergableCIs, BB);
1430 MergableCIsVector.clear();
1431 }
1432 }
1433
1434 if (Changed) {
1435 /// Re-collect use for fork calls, emitted barrier calls, and
1436 /// any emitted master/end_master calls.
1437 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1438 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1439 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1440 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1441 }
1442
1443 return Changed;
1444 }
1445
1446 /// Try to delete parallel regions if possible.
1447 bool deleteParallelRegions() {
1448 const unsigned CallbackCalleeOperand = 2;
1449
1450 OMPInformationCache::RuntimeFunctionInfo &RFI =
1451 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1452
1453 if (!RFI.Declaration)
1454 return false;
1455
1456 bool Changed = false;
1457 auto DeleteCallCB = [&](Use &U, Function &) {
1458 CallInst *CI = getCallIfRegularCall(U);
1459 if (!CI)
1460 return false;
1461 auto *Fn = dyn_cast<Function>(
1462 CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1463 if (!Fn)
1464 return false;
1465 if (!Fn->onlyReadsMemory())
1466 return false;
1467 if (!Fn->hasFnAttribute(Attribute::WillReturn))
1468 return false;
1469
1470 LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1471 << CI->getCaller()->getName() << "\n");
1472
1473 auto Remark = [&](OptimizationRemark OR) {
1474 return OR << "Removing parallel region with no side-effects.";
1475 };
1477
1478 CI->eraseFromParent();
1479 Changed = true;
1480 ++NumOpenMPParallelRegionsDeleted;
1481 return true;
1482 };
1483
1484 RFI.foreachUse(SCC, DeleteCallCB);
1485
1486 return Changed;
1487 }
1488
1489 /// Try to eliminate runtime calls by reusing existing ones.
1490 bool deduplicateRuntimeCalls() {
1491 bool Changed = false;
1492
1493 RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1494 OMPRTL_omp_get_num_threads,
1495 OMPRTL_omp_in_parallel,
1496 OMPRTL_omp_get_cancellation,
1497 OMPRTL_omp_get_supported_active_levels,
1498 OMPRTL_omp_get_level,
1499 OMPRTL_omp_get_ancestor_thread_num,
1500 OMPRTL_omp_get_team_size,
1501 OMPRTL_omp_get_active_level,
1502 OMPRTL_omp_in_final,
1503 OMPRTL_omp_get_proc_bind,
1504 OMPRTL_omp_get_num_places,
1505 OMPRTL_omp_get_num_procs,
1506 OMPRTL_omp_get_place_num,
1507 OMPRTL_omp_get_partition_num_places,
1508 OMPRTL_omp_get_partition_place_nums};
1509
1510 // Global-tid is handled separately.
1511 SmallSetVector<Value *, 16> GTIdArgs;
1512 collectGlobalThreadIdArguments(GTIdArgs);
1513 LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1514 << " global thread ID arguments\n");
1515
1516 for (Function *F : SCC) {
1517 for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1518 Changed |= deduplicateRuntimeCalls(
1519 *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1520
1521 // __kmpc_global_thread_num is special as we can replace it with an
1522 // argument in enough cases to make it worth trying.
1523 Value *GTIdArg = nullptr;
1524 for (Argument &Arg : F->args())
1525 if (GTIdArgs.count(&Arg)) {
1526 GTIdArg = &Arg;
1527 break;
1528 }
1529 Changed |= deduplicateRuntimeCalls(
1530 *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1531 }
1532
1533 return Changed;
1534 }
1535
1536 /// Tries to remove known runtime symbols that are optional from the module.
1537 bool removeRuntimeSymbols() {
1538 // The RPC client symbol is defined in `libc` and indicates that something
1539 // required an RPC server. If its users were all optimized out then we can
1540 // safely remove it.
1541 // TODO: This should be somewhere more common in the future.
1542 if (GlobalVariable *GV = M.getNamedGlobal("__llvm_rpc_client")) {
1543 if (GV->hasNUsesOrMore(1))
1544 return false;
1545
1546 GV->replaceAllUsesWith(PoisonValue::get(GV->getType()));
1547 GV->eraseFromParent();
1548 return true;
1549 }
1550 return false;
1551 }
1552
1553 /// Tries to hide the latency of runtime calls that involve host to
1554 /// device memory transfers by splitting them into their "issue" and "wait"
1555 /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1556 /// moved downards as much as possible. The "issue" issues the memory transfer
1557 /// asynchronously, returning a handle. The "wait" waits in the returned
1558 /// handle for the memory transfer to finish.
1559 bool hideMemTransfersLatency() {
1560 auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1561 bool Changed = false;
1562 auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1563 auto *RTCall = getCallIfRegularCall(U, &RFI);
1564 if (!RTCall)
1565 return false;
1566
1567 OffloadArray OffloadArrays[3];
1568 if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1569 return false;
1570
1571 LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1572
1573 // TODO: Check if can be moved upwards.
1574 bool WasSplit = false;
1575 Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1576 if (WaitMovementPoint)
1577 WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1578
1579 Changed |= WasSplit;
1580 return WasSplit;
1581 };
1582 if (OMPInfoCache.runtimeFnsAvailable(
1583 {OMPRTL___tgt_target_data_begin_mapper_issue,
1584 OMPRTL___tgt_target_data_begin_mapper_wait}))
1585 RFI.foreachUse(SCC, SplitMemTransfers);
1586
1587 return Changed;
1588 }
1589
1590 void analysisGlobalization() {
1591 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1592
1593 auto CheckGlobalization = [&](Use &U, Function &Decl) {
1594 if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1595 auto Remark = [&](OptimizationRemarkMissed ORM) {
1596 return ORM
1597 << "Found thread data sharing on the GPU. "
1598 << "Expect degraded performance due to data globalization.";
1599 };
1601 }
1602
1603 return false;
1604 };
1605
1606 RFI.foreachUse(SCC, CheckGlobalization);
1607 }
1608
1609 /// Maps the values stored in the offload arrays passed as arguments to
1610 /// \p RuntimeCall into the offload arrays in \p OAs.
1611 bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1613 assert(OAs.size() == 3 && "Need space for three offload arrays!");
1614
1615 // A runtime call that involves memory offloading looks something like:
1616 // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1617 // i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1618 // ...)
1619 // So, the idea is to access the allocas that allocate space for these
1620 // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1621 // Therefore:
1622 // i8** %offload_baseptrs.
1623 Value *BasePtrsArg =
1624 RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1625 // i8** %offload_ptrs.
1626 Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1627 // i8** %offload_sizes.
1628 Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1629
1630 // Get values stored in **offload_baseptrs.
1631 auto *V = getUnderlyingObject(BasePtrsArg);
1632 if (!isa<AllocaInst>(V))
1633 return false;
1634 auto *BasePtrsArray = cast<AllocaInst>(V);
1635 if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1636 return false;
1637
1638 // Get values stored in **offload_baseptrs.
1639 V = getUnderlyingObject(PtrsArg);
1640 if (!isa<AllocaInst>(V))
1641 return false;
1642 auto *PtrsArray = cast<AllocaInst>(V);
1643 if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1644 return false;
1645
1646 // Get values stored in **offload_sizes.
1647 V = getUnderlyingObject(SizesArg);
1648 // If it's a [constant] global array don't analyze it.
1649 if (isa<GlobalValue>(V))
1650 return isa<Constant>(V);
1651 if (!isa<AllocaInst>(V))
1652 return false;
1653
1654 auto *SizesArray = cast<AllocaInst>(V);
1655 if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1656 return false;
1657
1658 return true;
1659 }
1660
1661 /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1662 /// For now this is a way to test that the function getValuesInOffloadArrays
1663 /// is working properly.
1664 /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1665 void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1666 assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1667
1668 LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1669 std::string ValuesStr;
1670 raw_string_ostream Printer(ValuesStr);
1671 std::string Separator = " --- ";
1672
1673 for (auto *BP : OAs[0].StoredValues) {
1674 BP->print(Printer);
1675 Printer << Separator;
1676 }
1677 LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << ValuesStr << "\n");
1678 ValuesStr.clear();
1679
1680 for (auto *P : OAs[1].StoredValues) {
1681 P->print(Printer);
1682 Printer << Separator;
1683 }
1684 LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << ValuesStr << "\n");
1685 ValuesStr.clear();
1686
1687 for (auto *S : OAs[2].StoredValues) {
1688 S->print(Printer);
1689 Printer << Separator;
1690 }
1691 LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << ValuesStr << "\n");
1692 }
1693
1694 /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1695 /// moved. Returns nullptr if the movement is not possible, or not worth it.
1696 Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1697 // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1698 // Make it traverse the CFG.
1699
1700 Instruction *CurrentI = &RuntimeCall;
1701 bool IsWorthIt = false;
1702 while ((CurrentI = CurrentI->getNextNode())) {
1703
1704 // TODO: Once we detect the regions to be offloaded we should use the
1705 // alias analysis manager to check if CurrentI may modify one of
1706 // the offloaded regions.
1707 if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1708 if (IsWorthIt)
1709 return CurrentI;
1710
1711 return nullptr;
1712 }
1713
1714 // FIXME: For now if we move it over anything without side effect
1715 // is worth it.
1716 IsWorthIt = true;
1717 }
1718
1719 // Return end of BasicBlock.
1720 return RuntimeCall.getParent()->getTerminator();
1721 }
1722
1723 /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1724 bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1725 Instruction &WaitMovementPoint) {
1726 // Create stack allocated handle (__tgt_async_info) at the beginning of the
1727 // function. Used for storing information of the async transfer, allowing to
1728 // wait on it later.
1729 auto &IRBuilder = OMPInfoCache.OMPBuilder;
1730 Function *F = RuntimeCall.getCaller();
1731 BasicBlock &Entry = F->getEntryBlock();
1732 IRBuilder.Builder.SetInsertPoint(&Entry,
1733 Entry.getFirstNonPHIOrDbgOrAlloca());
1734 Value *Handle = IRBuilder.Builder.CreateAlloca(
1735 IRBuilder.AsyncInfo, /*ArraySize=*/nullptr, "handle");
1736 Handle =
1737 IRBuilder.Builder.CreateAddrSpaceCast(Handle, IRBuilder.AsyncInfoPtr);
1738
1739 // Add "issue" runtime call declaration:
1740 // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1741 // i8**, i8**, i64*, i64*)
1742 FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1743 M, OMPRTL___tgt_target_data_begin_mapper_issue);
1744
1745 // Change RuntimeCall call site for its asynchronous version.
1746 SmallVector<Value *, 16> Args;
1747 for (auto &Arg : RuntimeCall.args())
1748 Args.push_back(Arg.get());
1749 Args.push_back(Handle);
1750
1751 CallInst *IssueCallsite = CallInst::Create(IssueDecl, Args, /*NameStr=*/"",
1752 RuntimeCall.getIterator());
1753 OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);
1754 RuntimeCall.eraseFromParent();
1755
1756 // Add "wait" runtime call declaration:
1757 // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1758 FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1759 M, OMPRTL___tgt_target_data_begin_mapper_wait);
1760
1761 Value *WaitParams[2] = {
1762 IssueCallsite->getArgOperand(
1763 OffloadArray::DeviceIDArgNum), // device_id.
1764 Handle // handle to wait on.
1765 };
1766 CallInst *WaitCallsite = CallInst::Create(
1767 WaitDecl, WaitParams, /*NameStr=*/"", WaitMovementPoint.getIterator());
1768 OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);
1769
1770 return true;
1771 }
1772
1773 static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1774 bool GlobalOnly, bool &SingleChoice) {
1775 if (CurrentIdent == NextIdent)
1776 return CurrentIdent;
1777
1778 // TODO: Figure out how to actually combine multiple debug locations. For
1779 // now we just keep an existing one if there is a single choice.
1780 if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1781 SingleChoice = !CurrentIdent;
1782 return NextIdent;
1783 }
1784 return nullptr;
1785 }
1786
1787 /// Return an `struct ident_t*` value that represents the ones used in the
1788 /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1789 /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1790 /// return value we create one from scratch. We also do not yet combine
1791 /// information, e.g., the source locations, see combinedIdentStruct.
1792 Value *
1793 getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1794 Function &F, bool GlobalOnly) {
1795 bool SingleChoice = true;
1796 Value *Ident = nullptr;
1797 auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1798 CallInst *CI = getCallIfRegularCall(U, &RFI);
1799 if (!CI || &F != &Caller)
1800 return false;
1801 Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1802 /* GlobalOnly */ true, SingleChoice);
1803 return false;
1804 };
1805 RFI.foreachUse(SCC, CombineIdentStruct);
1806
1807 if (!Ident || !SingleChoice) {
1808 // The IRBuilder uses the insertion block to get to the module, this is
1809 // unfortunate but we work around it for now.
1810 if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1811 OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1812 &F.getEntryBlock(), F.getEntryBlock().begin()));
1813 // Create a fallback location if non was found.
1814 // TODO: Use the debug locations of the calls instead.
1815 uint32_t SrcLocStrSize;
1816 Constant *Loc =
1817 OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1818 Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize);
1819 }
1820 return Ident;
1821 }
1822
1823 /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1824 /// \p ReplVal if given.
1825 bool deduplicateRuntimeCalls(Function &F,
1826 OMPInformationCache::RuntimeFunctionInfo &RFI,
1827 Value *ReplVal = nullptr) {
1828 auto *UV = RFI.getUseVector(F);
1829 if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1830 return false;
1831
1832 LLVM_DEBUG(
1833 dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1834 << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1835
1836 assert((!ReplVal || (isa<Argument>(ReplVal) &&
1837 cast<Argument>(ReplVal)->getParent() == &F)) &&
1838 "Unexpected replacement value!");
1839
1840 // TODO: Use dominance to find a good position instead.
1841 auto CanBeMoved = [this](CallBase &CB) {
1842 unsigned NumArgs = CB.arg_size();
1843 if (NumArgs == 0)
1844 return true;
1845 if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1846 return false;
1847 for (unsigned U = 1; U < NumArgs; ++U)
1848 if (isa<Instruction>(CB.getArgOperand(U)))
1849 return false;
1850 return true;
1851 };
1852
1853 if (!ReplVal) {
1854 auto *DT =
1855 OMPInfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(F);
1856 if (!DT)
1857 return false;
1858 Instruction *IP = nullptr;
1859 for (Use *U : *UV) {
1860 if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1861 if (IP)
1862 IP = DT->findNearestCommonDominator(IP, CI);
1863 else
1864 IP = CI;
1865 if (!CanBeMoved(*CI))
1866 continue;
1867 if (!ReplVal)
1868 ReplVal = CI;
1869 }
1870 }
1871 if (!ReplVal)
1872 return false;
1873 assert(IP && "Expected insertion point!");
1874 cast<Instruction>(ReplVal)->moveBefore(IP->getIterator());
1875 }
1876
1877 // If we use a call as a replacement value we need to make sure the ident is
1878 // valid at the new location. For now we just pick a global one, either
1879 // existing and used by one of the calls, or created from scratch.
1880 if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1881 if (!CI->arg_empty() &&
1882 CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1883 Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1884 /* GlobalOnly */ true);
1885 CI->setArgOperand(0, Ident);
1886 }
1887 }
1888
1889 bool Changed = false;
1890 auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1891 CallInst *CI = getCallIfRegularCall(U, &RFI);
1892 if (!CI || CI == ReplVal || &F != &Caller)
1893 return false;
1894 assert(CI->getCaller() == &F && "Unexpected call!");
1895
1896 auto Remark = [&](OptimizationRemark OR) {
1897 return OR << "OpenMP runtime call "
1898 << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1899 };
1900 if (CI->getDebugLoc())
1902 else
1904
1905 CI->replaceAllUsesWith(ReplVal);
1906 CI->eraseFromParent();
1907 ++NumOpenMPRuntimeCallsDeduplicated;
1908 Changed = true;
1909 return true;
1910 };
1911 RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1912
1913 return Changed;
1914 }
1915
1916 /// Collect arguments that represent the global thread id in \p GTIdArgs.
1917 void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1918 // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1919 // initialization. We could define an AbstractAttribute instead and
1920 // run the Attributor here once it can be run as an SCC pass.
1921
1922 // Helper to check the argument \p ArgNo at all call sites of \p F for
1923 // a GTId.
1924 auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1925 if (!F.hasLocalLinkage())
1926 return false;
1927 for (Use &U : F.uses()) {
1928 if (CallInst *CI = getCallIfRegularCall(U)) {
1929 Value *ArgOp = CI->getArgOperand(ArgNo);
1930 if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1931 getCallIfRegularCall(
1932 *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1933 continue;
1934 }
1935 return false;
1936 }
1937 return true;
1938 };
1939
1940 // Helper to identify uses of a GTId as GTId arguments.
1941 auto AddUserArgs = [&](Value &GTId) {
1942 for (Use &U : GTId.uses())
1943 if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1944 if (CI->isArgOperand(&U))
1945 if (Function *Callee = CI->getCalledFunction())
1946 if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1947 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1948 };
1949
1950 // The argument users of __kmpc_global_thread_num calls are GTIds.
1951 OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1952 OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1953
1954 GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1955 if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1956 AddUserArgs(*CI);
1957 return false;
1958 });
1959
1960 // Transitively search for more arguments by looking at the users of the
1961 // ones we know already. During the search the GTIdArgs vector is extended
1962 // so we cannot cache the size nor can we use a range based for.
1963 for (unsigned U = 0; U < GTIdArgs.size(); ++U)
1964 AddUserArgs(*GTIdArgs[U]);
1965 }
1966
1967 /// Kernel (=GPU) optimizations and utility functions
1968 ///
1969 ///{{
1970
1971 /// Cache to remember the unique kernel for a function.
1972 DenseMap<Function *, std::optional<Kernel>> UniqueKernelMap;
1973
1974 /// Find the unique kernel that will execute \p F, if any.
1975 Kernel getUniqueKernelFor(Function &F);
1976
1977 /// Find the unique kernel that will execute \p I, if any.
1978 Kernel getUniqueKernelFor(Instruction &I) {
1979 return getUniqueKernelFor(*I.getFunction());
1980 }
1981
1982 /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1983 /// the cases we can avoid taking the address of a function.
1984 bool rewriteDeviceCodeStateMachine();
1985
1986 ///
1987 ///}}
1988
1989 /// Emit a remark generically
1990 ///
1991 /// This template function can be used to generically emit a remark. The
1992 /// RemarkKind should be one of the following:
1993 /// - OptimizationRemark to indicate a successful optimization attempt
1994 /// - OptimizationRemarkMissed to report a failed optimization attempt
1995 /// - OptimizationRemarkAnalysis to provide additional information about an
1996 /// optimization attempt
1997 ///
1998 /// The remark is built using a callback function provided by the caller that
1999 /// takes a RemarkKind as input and returns a RemarkKind.
2000 template <typename RemarkKind, typename RemarkCallBack>
2001 void emitRemark(Instruction *I, StringRef RemarkName,
2002 RemarkCallBack &&RemarkCB) const {
2003 Function *F = I->getParent()->getParent();
2004 auto &ORE = OREGetter(F);
2005
2006 if (RemarkName.starts_with("OMP"))
2007 ORE.emit([&]() {
2008 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
2009 << " [" << RemarkName << "]";
2010 });
2011 else
2012 ORE.emit(
2013 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
2014 }
2015
2016 /// Emit a remark on a function.
2017 template <typename RemarkKind, typename RemarkCallBack>
2018 void emitRemark(Function *F, StringRef RemarkName,
2019 RemarkCallBack &&RemarkCB) const {
2020 auto &ORE = OREGetter(F);
2021
2022 if (RemarkName.starts_with("OMP"))
2023 ORE.emit([&]() {
2024 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
2025 << " [" << RemarkName << "]";
2026 });
2027 else
2028 ORE.emit(
2029 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
2030 }
2031
2032 /// The underlying module.
2033 Module &M;
2034
2035 /// The SCC we are operating on.
2036 SmallVectorImpl<Function *> &SCC;
2037
2038 /// Callback to update the call graph, the first argument is a removed call,
2039 /// the second an optional replacement call.
2040 CallGraphUpdater &CGUpdater;
2041
2042 /// Callback to get an OptimizationRemarkEmitter from a Function *
2043 OptimizationRemarkGetter OREGetter;
2044
2045 /// OpenMP-specific information cache. Also Used for Attributor runs.
2046 OMPInformationCache &OMPInfoCache;
2047
2048 /// Attributor instance.
2049 Attributor &A;
2050
2051 /// Helper function to run Attributor on SCC.
2052 bool runAttributor(bool IsModulePass) {
2053 if (SCC.empty())
2054 return false;
2055
2056 registerAAs(IsModulePass);
2057
2058 ChangeStatus Changed = A.run();
2059
2060 LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
2061 << " functions, result: " << Changed << ".\n");
2062
2063 if (Changed == ChangeStatus::CHANGED)
2064 OMPInfoCache.invalidateAnalyses();
2065
2066 return Changed == ChangeStatus::CHANGED;
2067 }
2068
2069 void registerFoldRuntimeCall(RuntimeFunction RF);
2070
2071 /// Populate the Attributor with abstract attribute opportunities in the
2072 /// functions.
2073 void registerAAs(bool IsModulePass);
2074
2075public:
2076 /// Callback to register AAs for live functions, including internal functions
2077 /// marked live during the traversal.
2078 static void registerAAsForFunction(Attributor &A, const Function &F);
2079};
2080
2081Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
2082 if (OMPInfoCache.CGSCC && !OMPInfoCache.CGSCC->empty() &&
2083 !OMPInfoCache.CGSCC->contains(&F))
2084 return nullptr;
2085
2086 // Use a scope to keep the lifetime of the CachedKernel short.
2087 {
2088 std::optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
2089 if (CachedKernel)
2090 return *CachedKernel;
2091
2092 // TODO: We should use an AA to create an (optimistic and callback
2093 // call-aware) call graph. For now we stick to simple patterns that
2094 // are less powerful, basically the worst fixpoint.
2095 if (isOpenMPKernel(F)) {
2096 CachedKernel = Kernel(&F);
2097 return *CachedKernel;
2098 }
2099
2100 CachedKernel = nullptr;
2101 if (!F.hasLocalLinkage()) {
2102
2103 // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
2104 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2105 return ORA << "Potentially unknown OpenMP target region caller.";
2106 };
2108
2109 return nullptr;
2110 }
2111 }
2112
2113 auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
2114 if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
2115 // Allow use in equality comparisons.
2116 if (Cmp->isEquality())
2117 return getUniqueKernelFor(*Cmp);
2118 return nullptr;
2119 }
2120 if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
2121 // Allow direct calls.
2122 if (CB->isCallee(&U))
2123 return getUniqueKernelFor(*CB);
2124
2125 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2126 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_60];
2127 // Allow the use in __kmpc_parallel_60 calls.
2128 if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
2129 return getUniqueKernelFor(*CB);
2130 return nullptr;
2131 }
2132 // Disallow every other use.
2133 return nullptr;
2134 };
2135
2136 // TODO: In the future we want to track more than just a unique kernel.
2137 SmallPtrSet<Kernel, 2> PotentialKernels;
2138 OMPInformationCache::foreachUse(F, [&](const Use &U) {
2139 PotentialKernels.insert(GetUniqueKernelForUse(U));
2140 });
2141
2142 Kernel K = nullptr;
2143 if (PotentialKernels.size() == 1)
2144 K = *PotentialKernels.begin();
2145
2146 // Cache the result.
2147 UniqueKernelMap[&F] = K;
2148
2149 return K;
2150}
2151
2152bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
2153 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2154 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_60];
2155
2156 bool Changed = false;
2157 if (!KernelParallelRFI)
2158 return Changed;
2159
2160 // If we have disabled state machine changes, exit
2162 return Changed;
2163
2164 for (Function *F : SCC) {
2165
2166 // Check if the function is a use in a __kmpc_parallel_60 call at
2167 // all.
2168 bool UnknownUse = false;
2169 bool KernelParallelUse = false;
2170 unsigned NumDirectCalls = 0;
2171
2172 SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
2173 OMPInformationCache::foreachUse(*F, [&](Use &U) {
2174 if (auto *CB = dyn_cast<CallBase>(U.getUser()))
2175 if (CB->isCallee(&U)) {
2176 ++NumDirectCalls;
2177 return;
2178 }
2179
2180 if (isa<ICmpInst>(U.getUser())) {
2181 ToBeReplacedStateMachineUses.push_back(&U);
2182 return;
2183 }
2184
2185 // Find wrapper functions that represent parallel kernels.
2186 CallInst *CI =
2187 OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
2188 const unsigned int WrapperFunctionArgNo = 6;
2189 if (!KernelParallelUse && CI &&
2190 CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
2191 KernelParallelUse = true;
2192 ToBeReplacedStateMachineUses.push_back(&U);
2193 return;
2194 }
2195 UnknownUse = true;
2196 });
2197
2198 // Do not emit a remark if we haven't seen a __kmpc_parallel_60
2199 // use.
2200 if (!KernelParallelUse)
2201 continue;
2202
2203 // If this ever hits, we should investigate.
2204 // TODO: Checking the number of uses is not a necessary restriction and
2205 // should be lifted.
2206 if (UnknownUse || NumDirectCalls != 1 ||
2207 ToBeReplacedStateMachineUses.size() > 2) {
2208 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2209 return ORA << "Parallel region is used in "
2210 << (UnknownUse ? "unknown" : "unexpected")
2211 << " ways. Will not attempt to rewrite the state machine.";
2212 };
2214 continue;
2215 }
2216
2217 // Even if we have __kmpc_parallel_60 calls, we (for now) give
2218 // up if the function is not called from a unique kernel.
2219 Kernel K = getUniqueKernelFor(*F);
2220 if (!K) {
2221 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2222 return ORA << "Parallel region is not called from a unique kernel. "
2223 "Will not attempt to rewrite the state machine.";
2224 };
2226 continue;
2227 }
2228
2229 // We now know F is a parallel body function called only from the kernel K.
2230 // We also identified the state machine uses in which we replace the
2231 // function pointer by a new global symbol for identification purposes. This
2232 // ensures only direct calls to the function are left.
2233
2234 Module &M = *F->getParent();
2235 Type *Int8Ty = Type::getInt8Ty(M.getContext());
2236
2237 auto *ID = new GlobalVariable(
2238 M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2239 UndefValue::get(Int8Ty), F->getName() + ".ID");
2240
2241 for (Use *U : ToBeReplacedStateMachineUses)
2243 ID, U->get()->getType()));
2244
2245 ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2246
2247 Changed = true;
2248 }
2249
2250 return Changed;
2251}
2252
2253/// Abstract Attribute for tracking ICV values.
2254struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2255 using Base = StateWrapper<BooleanState, AbstractAttribute>;
2256 AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2257
2258 /// Returns true if value is assumed to be tracked.
2259 bool isAssumedTracked() const { return getAssumed(); }
2260
2261 /// Returns true if value is known to be tracked.
2262 bool isKnownTracked() const { return getAssumed(); }
2263
2264 /// Create an abstract attribute biew for the position \p IRP.
2265 static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2266
2267 /// Return the value with which \p I can be replaced for specific \p ICV.
2268 virtual std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2269 const Instruction *I,
2270 Attributor &A) const {
2271 return std::nullopt;
2272 }
2273
2274 /// Return an assumed unique ICV value if a single candidate is found. If
2275 /// there cannot be one, return a nullptr. If it is not clear yet, return
2276 /// std::nullopt.
2277 virtual std::optional<Value *>
2278 getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2279
2280 // Currently only nthreads is being tracked.
2281 // this array will only grow with time.
2282 InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2283
2284 /// See AbstractAttribute::getName()
2285 StringRef getName() const override { return "AAICVTracker"; }
2286
2287 /// See AbstractAttribute::getIdAddr()
2288 const char *getIdAddr() const override { return &ID; }
2289
2290 /// This function should return true if the type of the \p AA is AAICVTracker
2291 static bool classof(const AbstractAttribute *AA) {
2292 return (AA->getIdAddr() == &ID);
2293 }
2294
2295 static const char ID;
2296};
2297
2298struct AAICVTrackerFunction : public AAICVTracker {
2299 AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2300 : AAICVTracker(IRP, A) {}
2301
2302 // FIXME: come up with better string.
2303 const std::string getAsStr(Attributor *) const override {
2304 return "ICVTrackerFunction";
2305 }
2306
2307 // FIXME: come up with some stats.
2308 void trackStatistics() const override {}
2309
2310 /// We don't manifest anything for this AA.
2311 ChangeStatus manifest(Attributor &A) override {
2312 return ChangeStatus::UNCHANGED;
2313 }
2314
2315 // Map of ICV to their values at specific program point.
2316 EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar,
2317 InternalControlVar::ICV___last>
2318 ICVReplacementValuesMap;
2319
2320 ChangeStatus updateImpl(Attributor &A) override {
2321 ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
2322
2323 Function *F = getAnchorScope();
2324
2325 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2326
2327 for (InternalControlVar ICV : TrackableICVs) {
2328 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2329
2330 auto &ValuesMap = ICVReplacementValuesMap[ICV];
2331 auto TrackValues = [&](Use &U, Function &) {
2332 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2333 if (!CI)
2334 return false;
2335
2336 // FIXME: handle setters with more that 1 arguments.
2337 /// Track new value.
2338 if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2339 HasChanged = ChangeStatus::CHANGED;
2340
2341 return false;
2342 };
2343
2344 auto CallCheck = [&](Instruction &I) {
2345 std::optional<Value *> ReplVal = getValueForCall(A, I, ICV);
2346 if (ReplVal && ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2347 HasChanged = ChangeStatus::CHANGED;
2348
2349 return true;
2350 };
2351
2352 // Track all changes of an ICV.
2353 SetterRFI.foreachUse(TrackValues, F);
2354
2355 bool UsedAssumedInformation = false;
2356 A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2357 UsedAssumedInformation,
2358 /* CheckBBLivenessOnly */ true);
2359
2360 /// TODO: Figure out a way to avoid adding entry in
2361 /// ICVReplacementValuesMap
2362 Instruction *Entry = &F->getEntryBlock().front();
2363 if (HasChanged == ChangeStatus::CHANGED)
2364 ValuesMap.try_emplace(Entry);
2365 }
2366
2367 return HasChanged;
2368 }
2369
2370 /// Helper to check if \p I is a call and get the value for it if it is
2371 /// unique.
2372 std::optional<Value *> getValueForCall(Attributor &A, const Instruction &I,
2373 InternalControlVar &ICV) const {
2374
2375 const auto *CB = dyn_cast<CallBase>(&I);
2376 if (!CB || CB->hasFnAttr("no_openmp") ||
2377 CB->hasFnAttr("no_openmp_routines") ||
2378 CB->hasFnAttr("no_openmp_constructs"))
2379 return std::nullopt;
2380
2381 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2382 auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2383 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2384 Function *CalledFunction = CB->getCalledFunction();
2385
2386 // Indirect call, assume ICV changes.
2387 if (CalledFunction == nullptr)
2388 return nullptr;
2389 if (CalledFunction == GetterRFI.Declaration)
2390 return std::nullopt;
2391 if (CalledFunction == SetterRFI.Declaration) {
2392 if (ICVReplacementValuesMap[ICV].count(&I))
2393 return ICVReplacementValuesMap[ICV].lookup(&I);
2394
2395 return nullptr;
2396 }
2397
2398 // Since we don't know, assume it changes the ICV.
2399 if (CalledFunction->isDeclaration())
2400 return nullptr;
2401
2402 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2403 *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
2404
2405 if (ICVTrackingAA->isAssumedTracked()) {
2406 std::optional<Value *> URV =
2407 ICVTrackingAA->getUniqueReplacementValue(ICV);
2408 if (!URV || (*URV && AA::isValidAtPosition(AA::ValueAndContext(**URV, I),
2409 OMPInfoCache)))
2410 return URV;
2411 }
2412
2413 // If we don't know, assume it changes.
2414 return nullptr;
2415 }
2416
2417 // We don't check unique value for a function, so return std::nullopt.
2418 std::optional<Value *>
2419 getUniqueReplacementValue(InternalControlVar ICV) const override {
2420 return std::nullopt;
2421 }
2422
2423 /// Return the value with which \p I can be replaced for specific \p ICV.
2424 std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2425 const Instruction *I,
2426 Attributor &A) const override {
2427 const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2428 if (ValuesMap.count(I))
2429 return ValuesMap.lookup(I);
2430
2432 SmallPtrSet<const Instruction *, 16> Visited;
2433 Worklist.push_back(I);
2434
2435 std::optional<Value *> ReplVal;
2436
2437 while (!Worklist.empty()) {
2438 const Instruction *CurrInst = Worklist.pop_back_val();
2439 if (!Visited.insert(CurrInst).second)
2440 continue;
2441
2442 const BasicBlock *CurrBB = CurrInst->getParent();
2443
2444 // Go up and look for all potential setters/calls that might change the
2445 // ICV.
2446 while ((CurrInst = CurrInst->getPrevNode())) {
2447 if (ValuesMap.count(CurrInst)) {
2448 std::optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2449 // Unknown value, track new.
2450 if (!ReplVal) {
2451 ReplVal = NewReplVal;
2452 break;
2453 }
2454
2455 // If we found a new value, we can't know the icv value anymore.
2456 if (NewReplVal)
2457 if (ReplVal != NewReplVal)
2458 return nullptr;
2459
2460 break;
2461 }
2462
2463 std::optional<Value *> NewReplVal = getValueForCall(A, *CurrInst, ICV);
2464 if (!NewReplVal)
2465 continue;
2466
2467 // Unknown value, track new.
2468 if (!ReplVal) {
2469 ReplVal = NewReplVal;
2470 break;
2471 }
2472
2473 // if (NewReplVal.hasValue())
2474 // We found a new value, we can't know the icv value anymore.
2475 if (ReplVal != NewReplVal)
2476 return nullptr;
2477 }
2478
2479 // If we are in the same BB and we have a value, we are done.
2480 if (CurrBB == I->getParent() && ReplVal)
2481 return ReplVal;
2482
2483 // Go through all predecessors and add terminators for analysis.
2484 for (const BasicBlock *Pred : predecessors(CurrBB))
2485 if (const Instruction *Terminator = Pred->getTerminator())
2486 Worklist.push_back(Terminator);
2487 }
2488
2489 return ReplVal;
2490 }
2491};
2492
2493struct AAICVTrackerFunctionReturned : AAICVTracker {
2494 AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2495 : AAICVTracker(IRP, A) {}
2496
2497 // FIXME: come up with better string.
2498 const std::string getAsStr(Attributor *) const override {
2499 return "ICVTrackerFunctionReturned";
2500 }
2501
2502 // FIXME: come up with some stats.
2503 void trackStatistics() const override {}
2504
2505 /// We don't manifest anything for this AA.
2506 ChangeStatus manifest(Attributor &A) override {
2507 return ChangeStatus::UNCHANGED;
2508 }
2509
2510 // Map of ICV to their values at specific program point.
2511 EnumeratedArray<std::optional<Value *>, InternalControlVar,
2512 InternalControlVar::ICV___last>
2513 ICVReplacementValuesMap;
2514
2515 /// Return the value with which \p I can be replaced for specific \p ICV.
2516 std::optional<Value *>
2517 getUniqueReplacementValue(InternalControlVar ICV) const override {
2518 return ICVReplacementValuesMap[ICV];
2519 }
2520
2521 ChangeStatus updateImpl(Attributor &A) override {
2522 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2523 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2524 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2525
2526 if (!ICVTrackingAA->isAssumedTracked())
2527 return indicatePessimisticFixpoint();
2528
2529 for (InternalControlVar ICV : TrackableICVs) {
2530 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2531 std::optional<Value *> UniqueICVValue;
2532
2533 auto CheckReturnInst = [&](Instruction &I) {
2534 std::optional<Value *> NewReplVal =
2535 ICVTrackingAA->getReplacementValue(ICV, &I, A);
2536
2537 // If we found a second ICV value there is no unique returned value.
2538 if (UniqueICVValue && UniqueICVValue != NewReplVal)
2539 return false;
2540
2541 UniqueICVValue = NewReplVal;
2542
2543 return true;
2544 };
2545
2546 bool UsedAssumedInformation = false;
2547 if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2548 UsedAssumedInformation,
2549 /* CheckBBLivenessOnly */ true))
2550 UniqueICVValue = nullptr;
2551
2552 if (UniqueICVValue == ReplVal)
2553 continue;
2554
2555 ReplVal = UniqueICVValue;
2556 Changed = ChangeStatus::CHANGED;
2557 }
2558
2559 return Changed;
2560 }
2561};
2562
2563struct AAICVTrackerCallSite : AAICVTracker {
2564 AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2565 : AAICVTracker(IRP, A) {}
2566
2567 void initialize(Attributor &A) override {
2568 assert(getAnchorScope() && "Expected anchor function");
2569
2570 // We only initialize this AA for getters, so we need to know which ICV it
2571 // gets.
2572 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2573 for (InternalControlVar ICV : TrackableICVs) {
2574 auto ICVInfo = OMPInfoCache.ICVs[ICV];
2575 auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2576 if (Getter.Declaration == getAssociatedFunction()) {
2577 AssociatedICV = ICVInfo.Kind;
2578 return;
2579 }
2580 }
2581
2582 /// Unknown ICV.
2583 indicatePessimisticFixpoint();
2584 }
2585
2586 ChangeStatus manifest(Attributor &A) override {
2587 if (!ReplVal || !*ReplVal)
2588 return ChangeStatus::UNCHANGED;
2589
2590 A.changeAfterManifest(IRPosition::inst(*getCtxI()), **ReplVal);
2591 A.deleteAfterManifest(*getCtxI());
2592
2593 return ChangeStatus::CHANGED;
2594 }
2595
2596 // FIXME: come up with better string.
2597 const std::string getAsStr(Attributor *) const override {
2598 return "ICVTrackerCallSite";
2599 }
2600
2601 // FIXME: come up with some stats.
2602 void trackStatistics() const override {}
2603
2604 InternalControlVar AssociatedICV;
2605 std::optional<Value *> ReplVal;
2606
2607 ChangeStatus updateImpl(Attributor &A) override {
2608 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2609 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2610
2611 // We don't have any information, so we assume it changes the ICV.
2612 if (!ICVTrackingAA->isAssumedTracked())
2613 return indicatePessimisticFixpoint();
2614
2615 std::optional<Value *> NewReplVal =
2616 ICVTrackingAA->getReplacementValue(AssociatedICV, getCtxI(), A);
2617
2618 if (ReplVal == NewReplVal)
2619 return ChangeStatus::UNCHANGED;
2620
2621 ReplVal = NewReplVal;
2622 return ChangeStatus::CHANGED;
2623 }
2624
2625 // Return the value with which associated value can be replaced for specific
2626 // \p ICV.
2627 std::optional<Value *>
2628 getUniqueReplacementValue(InternalControlVar ICV) const override {
2629 return ReplVal;
2630 }
2631};
2632
2633struct AAICVTrackerCallSiteReturned : AAICVTracker {
2634 AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2635 : AAICVTracker(IRP, A) {}
2636
2637 // FIXME: come up with better string.
2638 const std::string getAsStr(Attributor *) const override {
2639 return "ICVTrackerCallSiteReturned";
2640 }
2641
2642 // FIXME: come up with some stats.
2643 void trackStatistics() const override {}
2644
2645 /// We don't manifest anything for this AA.
2646 ChangeStatus manifest(Attributor &A) override {
2647 return ChangeStatus::UNCHANGED;
2648 }
2649
2650 // Map of ICV to their values at specific program point.
2651 EnumeratedArray<std::optional<Value *>, InternalControlVar,
2652 InternalControlVar::ICV___last>
2653 ICVReplacementValuesMap;
2654
2655 /// Return the value with which associated value can be replaced for specific
2656 /// \p ICV.
2657 std::optional<Value *>
2658 getUniqueReplacementValue(InternalControlVar ICV) const override {
2659 return ICVReplacementValuesMap[ICV];
2660 }
2661
2662 ChangeStatus updateImpl(Attributor &A) override {
2663 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2664 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2665 *this, IRPosition::returned(*getAssociatedFunction()),
2666 DepClassTy::REQUIRED);
2667
2668 // We don't have any information, so we assume it changes the ICV.
2669 if (!ICVTrackingAA->isAssumedTracked())
2670 return indicatePessimisticFixpoint();
2671
2672 for (InternalControlVar ICV : TrackableICVs) {
2673 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2674 std::optional<Value *> NewReplVal =
2675 ICVTrackingAA->getUniqueReplacementValue(ICV);
2676
2677 if (ReplVal == NewReplVal)
2678 continue;
2679
2680 ReplVal = NewReplVal;
2681 Changed = ChangeStatus::CHANGED;
2682 }
2683 return Changed;
2684 }
2685};
2686
2687/// Determines if \p BB exits the function unconditionally itself or reaches a
2688/// block that does through only unique successors.
2689static bool hasFunctionEndAsUniqueSuccessor(const BasicBlock *BB) {
2690 if (succ_empty(BB))
2691 return true;
2692 const BasicBlock *const Successor = BB->getUniqueSuccessor();
2693 if (!Successor)
2694 return false;
2695 return hasFunctionEndAsUniqueSuccessor(Successor);
2696}
2697
2698struct AAExecutionDomainFunction : public AAExecutionDomain {
2699 AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2700 : AAExecutionDomain(IRP, A) {}
2701
2702 ~AAExecutionDomainFunction() override { delete RPOT; }
2703
2704 void initialize(Attributor &A) override {
2705 Function *F = getAnchorScope();
2706 assert(F && "Expected anchor function");
2707 RPOT = new ReversePostOrderTraversal<Function *>(F);
2708 }
2709
2710 const std::string getAsStr(Attributor *) const override {
2711 unsigned TotalBlocks = 0, InitialThreadBlocks = 0, AlignedBlocks = 0;
2712 for (auto &It : BEDMap) {
2713 if (!It.getFirst())
2714 continue;
2715 TotalBlocks++;
2716 InitialThreadBlocks += It.getSecond().IsExecutedByInitialThreadOnly;
2717 AlignedBlocks += It.getSecond().IsReachedFromAlignedBarrierOnly &&
2718 It.getSecond().IsReachingAlignedBarrierOnly;
2719 }
2720 return "[AAExecutionDomain] " + std::to_string(InitialThreadBlocks) + "/" +
2721 std::to_string(AlignedBlocks) + " of " +
2722 std::to_string(TotalBlocks) +
2723 " executed by initial thread / aligned";
2724 }
2725
2726 /// See AbstractAttribute::trackStatistics().
2727 void trackStatistics() const override {}
2728
2729 ChangeStatus manifest(Attributor &A) override {
2730 LLVM_DEBUG({
2731 for (const BasicBlock &BB : *getAnchorScope()) {
2732 if (!isExecutedByInitialThreadOnly(BB))
2733 continue;
2734 dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2735 << BB.getName() << " is executed by a single thread.\n";
2736 }
2737 });
2738
2739 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2740
2742 return Changed;
2743
2744 SmallPtrSet<CallBase *, 16> DeletedBarriers;
2745 auto HandleAlignedBarrier = [&](CallBase *CB) {
2746 const ExecutionDomainTy &ED = CB ? CEDMap[{CB, PRE}] : BEDMap[nullptr];
2747 if (!ED.IsReachedFromAlignedBarrierOnly ||
2748 ED.EncounteredNonLocalSideEffect)
2749 return;
2750 if (!ED.EncounteredAssumes.empty() && !A.isModulePass())
2751 return;
2752
2753 // We can remove this barrier, if it is one, or aligned barriers reaching
2754 // the kernel end (if CB is nullptr). Aligned barriers reaching the kernel
2755 // end should only be removed if the kernel end is their unique successor;
2756 // otherwise, they may have side-effects that aren't accounted for in the
2757 // kernel end in their other successors. If those barriers have other
2758 // barriers reaching them, those can be transitively removed as well as
2759 // long as the kernel end is also their unique successor.
2760 if (CB) {
2761 DeletedBarriers.insert(CB);
2762 A.deleteAfterManifest(*CB);
2763 ++NumBarriersEliminated;
2764 Changed = ChangeStatus::CHANGED;
2765 } else if (!ED.AlignedBarriers.empty()) {
2766 Changed = ChangeStatus::CHANGED;
2767 SmallVector<CallBase *> Worklist(ED.AlignedBarriers.begin(),
2768 ED.AlignedBarriers.end());
2769 SmallSetVector<CallBase *, 16> Visited;
2770 while (!Worklist.empty()) {
2771 CallBase *LastCB = Worklist.pop_back_val();
2772 if (!Visited.insert(LastCB))
2773 continue;
2774 if (LastCB->getFunction() != getAnchorScope())
2775 continue;
2776 if (!hasFunctionEndAsUniqueSuccessor(LastCB->getParent()))
2777 continue;
2778 if (!DeletedBarriers.count(LastCB)) {
2779 ++NumBarriersEliminated;
2780 A.deleteAfterManifest(*LastCB);
2781 continue;
2782 }
2783 // The final aligned barrier (LastCB) reaching the kernel end was
2784 // removed already. This means we can go one step further and remove
2785 // the barriers encoutered last before (LastCB).
2786 const ExecutionDomainTy &LastED = CEDMap[{LastCB, PRE}];
2787 Worklist.append(LastED.AlignedBarriers.begin(),
2788 LastED.AlignedBarriers.end());
2789 }
2790 }
2791
2792 // If we actually eliminated a barrier we need to eliminate the associated
2793 // llvm.assumes as well to avoid creating UB.
2794 if (!ED.EncounteredAssumes.empty() && (CB || !ED.AlignedBarriers.empty()))
2795 for (auto *AssumeCB : ED.EncounteredAssumes)
2796 A.deleteAfterManifest(*AssumeCB);
2797 };
2798
2799 for (auto *CB : AlignedBarriers)
2800 HandleAlignedBarrier(CB);
2801
2802 // Handle the "kernel end barrier" for kernels too.
2803 if (omp::isOpenMPKernel(*getAnchorScope()))
2804 HandleAlignedBarrier(nullptr);
2805
2806 return Changed;
2807 }
2808
2809 bool isNoOpFence(const FenceInst &FI) const override {
2810 return getState().isValidState() && !NonNoOpFences.count(&FI);
2811 }
2812
2813 /// Merge barrier and assumption information from \p PredED into the successor
2814 /// \p ED.
2815 void
2816 mergeInPredecessorBarriersAndAssumptions(Attributor &A, ExecutionDomainTy &ED,
2817 const ExecutionDomainTy &PredED);
2818
2819 /// Merge all information from \p PredED into the successor \p ED. If
2820 /// \p InitialEdgeOnly is set, only the initial edge will enter the block
2821 /// represented by \p ED from this predecessor.
2822 bool mergeInPredecessor(Attributor &A, ExecutionDomainTy &ED,
2823 const ExecutionDomainTy &PredED,
2824 bool InitialEdgeOnly = false);
2825
2826 /// Accumulate information for the entry block in \p EntryBBED.
2827 bool handleCallees(Attributor &A, ExecutionDomainTy &EntryBBED);
2828
2829 /// See AbstractAttribute::updateImpl.
2830 ChangeStatus updateImpl(Attributor &A) override;
2831
2832 /// Query interface, see AAExecutionDomain
2833 ///{
2834 bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2835 if (!isValidState())
2836 return false;
2837 assert(BB.getParent() == getAnchorScope() && "Block is out of scope!");
2838 return BEDMap.lookup(&BB).IsExecutedByInitialThreadOnly;
2839 }
2840
2841 bool isExecutedInAlignedRegion(Attributor &A,
2842 const Instruction &I) const override {
2843 assert(I.getFunction() == getAnchorScope() &&
2844 "Instruction is out of scope!");
2845 if (!isValidState())
2846 return false;
2847
2848 bool ForwardIsOk = true;
2849 const Instruction *CurI;
2850
2851 // Check forward until a call or the block end is reached.
2852 CurI = &I;
2853 do {
2854 auto *CB = dyn_cast<CallBase>(CurI);
2855 if (!CB)
2856 continue;
2857 if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2858 return true;
2859 const auto &It = CEDMap.find({CB, PRE});
2860 if (It == CEDMap.end())
2861 continue;
2862 if (!It->getSecond().IsReachingAlignedBarrierOnly)
2863 ForwardIsOk = false;
2864 break;
2865 } while ((CurI = CurI->getNextNode()));
2866
2867 if (!CurI && !BEDMap.lookup(I.getParent()).IsReachingAlignedBarrierOnly)
2868 ForwardIsOk = false;
2869
2870 // Check backward until a call or the block beginning is reached.
2871 CurI = &I;
2872 do {
2873 auto *CB = dyn_cast<CallBase>(CurI);
2874 if (!CB)
2875 continue;
2876 if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2877 return true;
2878 const auto &It = CEDMap.find({CB, POST});
2879 if (It == CEDMap.end())
2880 continue;
2881 if (It->getSecond().IsReachedFromAlignedBarrierOnly)
2882 break;
2883 return false;
2884 } while ((CurI = CurI->getPrevNode()));
2885
2886 // Delayed decision on the forward pass to allow aligned barrier detection
2887 // in the backwards traversal.
2888 if (!ForwardIsOk)
2889 return false;
2890
2891 if (!CurI) {
2892 const BasicBlock *BB = I.getParent();
2893 if (BB == &BB->getParent()->getEntryBlock())
2894 return BEDMap.lookup(nullptr).IsReachedFromAlignedBarrierOnly;
2895 if (!llvm::all_of(predecessors(BB), [&](const BasicBlock *PredBB) {
2896 return BEDMap.lookup(PredBB).IsReachedFromAlignedBarrierOnly;
2897 })) {
2898 return false;
2899 }
2900 }
2901
2902 // On neither traversal we found a anything but aligned barriers.
2903 return true;
2904 }
2905
2906 ExecutionDomainTy getExecutionDomain(const BasicBlock &BB) const override {
2907 assert(isValidState() &&
2908 "No request should be made against an invalid state!");
2909 return BEDMap.lookup(&BB);
2910 }
2911 std::pair<ExecutionDomainTy, ExecutionDomainTy>
2912 getExecutionDomain(const CallBase &CB) const override {
2913 assert(isValidState() &&
2914 "No request should be made against an invalid state!");
2915 return {CEDMap.lookup({&CB, PRE}), CEDMap.lookup({&CB, POST})};
2916 }
2917 ExecutionDomainTy getFunctionExecutionDomain() const override {
2918 assert(isValidState() &&
2919 "No request should be made against an invalid state!");
2920 return InterProceduralED;
2921 }
2922 ///}
2923
2924 // Check if the edge into the successor block contains a condition that only
2925 // lets the main thread execute it.
2926 static bool isInitialThreadOnlyEdge(Attributor &A, CondBrInst *Edge,
2927 BasicBlock &SuccessorBB) {
2928 if (!Edge)
2929 return false;
2930 if (Edge->getSuccessor(0) != &SuccessorBB)
2931 return false;
2932
2933 auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2934 if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2935 return false;
2936
2937 ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2938 if (!C)
2939 return false;
2940
2941 // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
2942 if (C->isAllOnesValue()) {
2943 auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2944 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2945 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2946 CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2947 if (!CB)
2948 return false;
2949 ConstantStruct *KernelEnvC =
2951 ConstantInt *ExecModeC =
2952 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
2953 return ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC;
2954 }
2955
2956 if (C->isZero()) {
2957 // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x()
2958 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2959 if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2960 return true;
2961
2962 // Match: 0 == llvm.amdgcn.workitem.id.x()
2963 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2964 if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2965 return true;
2966 }
2967
2968 return false;
2969 };
2970
2971 /// Mapping containing information about the function for other AAs.
2972 ExecutionDomainTy InterProceduralED;
2973
2974 enum Direction { PRE = 0, POST = 1 };
2975 /// Mapping containing information per block.
2976 DenseMap<const BasicBlock *, ExecutionDomainTy> BEDMap;
2977 DenseMap<PointerIntPair<const CallBase *, 1, Direction>, ExecutionDomainTy>
2978 CEDMap;
2979 SmallSetVector<CallBase *, 16> AlignedBarriers;
2980
2981 ReversePostOrderTraversal<Function *> *RPOT = nullptr;
2982
2983 /// Set \p R to \V and report true if that changed \p R.
2984 static bool setAndRecord(bool &R, bool V) {
2985 bool Eq = (R == V);
2986 R = V;
2987 return !Eq;
2988 }
2989
2990 /// Collection of fences known to be non-no-opt. All fences not in this set
2991 /// can be assumed no-opt.
2992 SmallPtrSet<const FenceInst *, 8> NonNoOpFences;
2993};
2994
2995void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions(
2996 Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED) {
2997 for (auto *EA : PredED.EncounteredAssumes)
2998 ED.addAssumeInst(A, *EA);
2999
3000 for (auto *AB : PredED.AlignedBarriers)
3001 ED.addAlignedBarrier(A, *AB);
3002}
3003
3004bool AAExecutionDomainFunction::mergeInPredecessor(
3005 Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED,
3006 bool InitialEdgeOnly) {
3007
3008 bool Changed = false;
3009 Changed |=
3010 setAndRecord(ED.IsExecutedByInitialThreadOnly,
3011 InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly &&
3012 ED.IsExecutedByInitialThreadOnly));
3013
3014 Changed |= setAndRecord(ED.IsReachedFromAlignedBarrierOnly,
3015 ED.IsReachedFromAlignedBarrierOnly &&
3016 PredED.IsReachedFromAlignedBarrierOnly);
3017 Changed |= setAndRecord(ED.EncounteredNonLocalSideEffect,
3018 ED.EncounteredNonLocalSideEffect |
3019 PredED.EncounteredNonLocalSideEffect);
3020 // Do not track assumptions and barriers as part of Changed.
3021 if (ED.IsReachedFromAlignedBarrierOnly)
3022 mergeInPredecessorBarriersAndAssumptions(A, ED, PredED);
3023 else
3024 ED.clearAssumeInstAndAlignedBarriers();
3025 return Changed;
3026}
3027
3028bool AAExecutionDomainFunction::handleCallees(Attributor &A,
3029 ExecutionDomainTy &EntryBBED) {
3031 auto PredForCallSite = [&](AbstractCallSite ACS) {
3032 const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3033 *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
3034 DepClassTy::OPTIONAL);
3035 if (!EDAA || !EDAA->getState().isValidState())
3036 return false;
3037 CallSiteEDs.emplace_back(
3038 EDAA->getExecutionDomain(*cast<CallBase>(ACS.getInstruction())));
3039 return true;
3040 };
3041
3042 ExecutionDomainTy ExitED;
3043 bool AllCallSitesKnown;
3044 if (A.checkForAllCallSites(PredForCallSite, *this,
3045 /* RequiresAllCallSites */ true,
3046 AllCallSitesKnown)) {
3047 for (const auto &[CSInED, CSOutED] : CallSiteEDs) {
3048 mergeInPredecessor(A, EntryBBED, CSInED);
3049 ExitED.IsReachingAlignedBarrierOnly &=
3050 CSOutED.IsReachingAlignedBarrierOnly;
3051 }
3052
3053 } else {
3054 // We could not find all predecessors, so this is either a kernel or a
3055 // function with external linkage (or with some other weird uses).
3056 if (omp::isOpenMPKernel(*getAnchorScope())) {
3057 EntryBBED.IsExecutedByInitialThreadOnly = false;
3058 EntryBBED.IsReachedFromAlignedBarrierOnly = true;
3059 EntryBBED.EncounteredNonLocalSideEffect = false;
3060 ExitED.IsReachingAlignedBarrierOnly = false;
3061 } else {
3062 EntryBBED.IsExecutedByInitialThreadOnly = false;
3063 EntryBBED.IsReachedFromAlignedBarrierOnly = false;
3064 EntryBBED.EncounteredNonLocalSideEffect = true;
3065 ExitED.IsReachingAlignedBarrierOnly = false;
3066 }
3067 }
3068
3069 bool Changed = false;
3070 auto &FnED = BEDMap[nullptr];
3071 Changed |= setAndRecord(FnED.IsReachedFromAlignedBarrierOnly,
3072 FnED.IsReachedFromAlignedBarrierOnly &
3073 EntryBBED.IsReachedFromAlignedBarrierOnly);
3074 Changed |= setAndRecord(FnED.IsReachingAlignedBarrierOnly,
3075 FnED.IsReachingAlignedBarrierOnly &
3076 ExitED.IsReachingAlignedBarrierOnly);
3077 Changed |= setAndRecord(FnED.IsExecutedByInitialThreadOnly,
3078 EntryBBED.IsExecutedByInitialThreadOnly);
3079 return Changed;
3080}
3081
3082ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
3083
3084 bool Changed = false;
3085
3086 // Helper to deal with an aligned barrier encountered during the forward
3087 // traversal. \p CB is the aligned barrier, \p ED is the execution domain when
3088 // it was encountered.
3089 auto HandleAlignedBarrier = [&](CallBase &CB, ExecutionDomainTy &ED) {
3090 Changed |= AlignedBarriers.insert(&CB);
3091 // First, update the barrier ED kept in the separate CEDMap.
3092 auto &CallInED = CEDMap[{&CB, PRE}];
3093 Changed |= mergeInPredecessor(A, CallInED, ED);
3094 CallInED.IsReachingAlignedBarrierOnly = true;
3095 // Next adjust the ED we use for the traversal.
3096 ED.EncounteredNonLocalSideEffect = false;
3097 ED.IsReachedFromAlignedBarrierOnly = true;
3098 // Aligned barrier collection has to come last.
3099 ED.clearAssumeInstAndAlignedBarriers();
3100 ED.addAlignedBarrier(A, CB);
3101 auto &CallOutED = CEDMap[{&CB, POST}];
3102 Changed |= mergeInPredecessor(A, CallOutED, ED);
3103 };
3104
3105 auto *LivenessAA =
3106 A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL);
3107
3108 Function *F = getAnchorScope();
3109 BasicBlock &EntryBB = F->getEntryBlock();
3110 bool IsKernel = omp::isOpenMPKernel(*F);
3111
3112 SmallVector<Instruction *> SyncInstWorklist;
3113 for (auto &RIt : *RPOT) {
3114 BasicBlock &BB = *RIt;
3115
3116 bool IsEntryBB = &BB == &EntryBB;
3117 // TODO: We use local reasoning since we don't have a divergence analysis
3118 // running as well. We could basically allow uniform branches here.
3119 bool AlignedBarrierLastInBlock = IsEntryBB && IsKernel;
3120 bool IsExplicitlyAligned = IsEntryBB && IsKernel;
3121 ExecutionDomainTy ED;
3122 // Propagate "incoming edges" into information about this block.
3123 if (IsEntryBB) {
3124 Changed |= handleCallees(A, ED);
3125 } else {
3126 // For live non-entry blocks we only propagate
3127 // information via live edges.
3128 if (LivenessAA && LivenessAA->isAssumedDead(&BB))
3129 continue;
3130
3131 for (auto *PredBB : predecessors(&BB)) {
3132 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, &BB))
3133 continue;
3134 bool InitialEdgeOnly = isInitialThreadOnlyEdge(
3135 A, dyn_cast<CondBrInst>(PredBB->getTerminator()), BB);
3136 mergeInPredecessor(A, ED, BEDMap[PredBB], InitialEdgeOnly);
3137 }
3138 }
3139
3140 // Now we traverse the block, accumulate effects in ED and attach
3141 // information to calls.
3142 for (Instruction &I : BB) {
3143 bool UsedAssumedInformation;
3144 if (A.isAssumedDead(I, *this, LivenessAA, UsedAssumedInformation,
3145 /* CheckBBLivenessOnly */ false, DepClassTy::OPTIONAL,
3146 /* CheckForDeadStore */ true))
3147 continue;
3148
3149 // Asummes and "assume-like" (dbg, lifetime, ...) are handled first, the
3150 // former is collected the latter is ignored.
3151 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
3152 if (auto *AI = dyn_cast_or_null<AssumeInst>(II)) {
3153 ED.addAssumeInst(A, *AI);
3154 continue;
3155 }
3156 // TODO: Should we also collect and delete lifetime markers?
3157 if (II->isAssumeLikeIntrinsic())
3158 continue;
3159 }
3160
3161 if (auto *FI = dyn_cast<FenceInst>(&I)) {
3162 if (!ED.EncounteredNonLocalSideEffect) {
3163 // An aligned fence without non-local side-effects is a no-op.
3164 if (ED.IsReachedFromAlignedBarrierOnly)
3165 continue;
3166 // A non-aligned fence without non-local side-effects is a no-op
3167 // if the ordering only publishes non-local side-effects (or less).
3168 switch (FI->getOrdering()) {
3169 case AtomicOrdering::NotAtomic:
3170 continue;
3171 case AtomicOrdering::Unordered:
3172 continue;
3173 case AtomicOrdering::Monotonic:
3174 continue;
3175 case AtomicOrdering::Acquire:
3176 break;
3177 case AtomicOrdering::Release:
3178 continue;
3179 case AtomicOrdering::AcquireRelease:
3180 break;
3181 case AtomicOrdering::SequentiallyConsistent:
3182 break;
3183 };
3184 }
3185 NonNoOpFences.insert(FI);
3186 }
3187
3188 auto *CB = dyn_cast<CallBase>(&I);
3189 bool IsNoSync = AA::isNoSyncInst(A, I, *this);
3190 bool IsAlignedBarrier =
3191 !IsNoSync && CB &&
3192 AANoSync::isAlignedBarrier(*CB, AlignedBarrierLastInBlock);
3193
3194 AlignedBarrierLastInBlock &= IsNoSync;
3195 IsExplicitlyAligned &= IsNoSync;
3196
3197 // Next we check for calls. Aligned barriers are handled
3198 // explicitly, everything else is kept for the backward traversal and will
3199 // also affect our state.
3200 if (CB) {
3201 if (IsAlignedBarrier) {
3202 HandleAlignedBarrier(*CB, ED);
3203 AlignedBarrierLastInBlock = true;
3204 IsExplicitlyAligned = true;
3205 continue;
3206 }
3207
3208 // Check the pointer(s) of a memory intrinsic explicitly.
3209 if (isa<MemIntrinsic>(&I)) {
3210 if (!ED.EncounteredNonLocalSideEffect &&
3212 ED.EncounteredNonLocalSideEffect = true;
3213 if (!IsNoSync) {
3214 ED.IsReachedFromAlignedBarrierOnly = false;
3215 SyncInstWorklist.push_back(&I);
3216 }
3217 continue;
3218 }
3219
3220 // Record how we entered the call, then accumulate the effect of the
3221 // call in ED for potential use by the callee.
3222 auto &CallInED = CEDMap[{CB, PRE}];
3223 Changed |= mergeInPredecessor(A, CallInED, ED);
3224
3225 // If we have a sync-definition we can check if it starts/ends in an
3226 // aligned barrier. If we are unsure we assume any sync breaks
3227 // alignment.
3229 if (!IsNoSync && Callee && !Callee->isDeclaration()) {
3230 const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3231 *this, IRPosition::function(*Callee), DepClassTy::OPTIONAL);
3232 if (EDAA && EDAA->getState().isValidState()) {
3233 const auto &CalleeED = EDAA->getFunctionExecutionDomain();
3234 ED.IsReachedFromAlignedBarrierOnly =
3235 CalleeED.IsReachedFromAlignedBarrierOnly;
3236 AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly;
3237 if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly)
3238 ED.EncounteredNonLocalSideEffect |=
3239 CalleeED.EncounteredNonLocalSideEffect;
3240 else
3241 ED.EncounteredNonLocalSideEffect =
3242 CalleeED.EncounteredNonLocalSideEffect;
3243 if (!CalleeED.IsReachingAlignedBarrierOnly) {
3244 Changed |=
3245 setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3246 SyncInstWorklist.push_back(&I);
3247 }
3248 if (CalleeED.IsReachedFromAlignedBarrierOnly)
3249 mergeInPredecessorBarriersAndAssumptions(A, ED, CalleeED);
3250 auto &CallOutED = CEDMap[{CB, POST}];
3251 Changed |= mergeInPredecessor(A, CallOutED, ED);
3252 continue;
3253 }
3254 }
3255 if (!IsNoSync) {
3256 ED.IsReachedFromAlignedBarrierOnly = false;
3257 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3258 SyncInstWorklist.push_back(&I);
3259 }
3260 AlignedBarrierLastInBlock &= ED.IsReachedFromAlignedBarrierOnly;
3261 ED.EncounteredNonLocalSideEffect |= !CB->doesNotAccessMemory();
3262 auto &CallOutED = CEDMap[{CB, POST}];
3263 Changed |= mergeInPredecessor(A, CallOutED, ED);
3264 }
3265
3266 if (!I.mayHaveSideEffects() && !I.mayReadFromMemory())
3267 continue;
3268
3269 // If we have a callee we try to use fine-grained information to
3270 // determine local side-effects.
3271 if (CB) {
3272 const auto *MemAA = A.getAAFor<AAMemoryLocation>(
3273 *this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL);
3274
3275 auto AccessPred = [&](const Instruction *I, const Value *Ptr,
3278 return !AA::isPotentiallyAffectedByBarrier(A, {Ptr}, *this, I);
3279 };
3280 if (MemAA && MemAA->getState().isValidState() &&
3281 MemAA->checkForAllAccessesToMemoryKind(
3283 continue;
3284 }
3285
3286 auto &InfoCache = A.getInfoCache();
3287 if (!I.mayHaveSideEffects() && InfoCache.isOnlyUsedByAssume(I))
3288 continue;
3289
3290 if (auto *LI = dyn_cast<LoadInst>(&I))
3291 if (LI->hasMetadata(LLVMContext::MD_invariant_load))
3292 continue;
3293
3294 if (!ED.EncounteredNonLocalSideEffect &&
3296 ED.EncounteredNonLocalSideEffect = true;
3297 }
3298
3299 bool IsEndAndNotReachingAlignedBarriersOnly = false;
3300 if (!isa<UnreachableInst>(BB.getTerminator()) &&
3301 !BB.getTerminator()->getNumSuccessors()) {
3302
3303 Changed |= mergeInPredecessor(A, InterProceduralED, ED);
3304
3305 auto &FnED = BEDMap[nullptr];
3306 if (IsKernel && !IsExplicitlyAligned)
3307 FnED.IsReachingAlignedBarrierOnly = false;
3308 Changed |= mergeInPredecessor(A, FnED, ED);
3309
3310 if (!FnED.IsReachingAlignedBarrierOnly) {
3311 IsEndAndNotReachingAlignedBarriersOnly = true;
3312 SyncInstWorklist.push_back(BB.getTerminator());
3313 auto &BBED = BEDMap[&BB];
3314 Changed |= setAndRecord(BBED.IsReachingAlignedBarrierOnly, false);
3315 }
3316 }
3317
3318 ExecutionDomainTy &StoredED = BEDMap[&BB];
3319 ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly &
3320 !IsEndAndNotReachingAlignedBarriersOnly;
3321
3322 // Check if we computed anything different as part of the forward
3323 // traversal. We do not take assumptions and aligned barriers into account
3324 // as they do not influence the state we iterate. Backward traversal values
3325 // are handled later on.
3326 if (ED.IsExecutedByInitialThreadOnly !=
3327 StoredED.IsExecutedByInitialThreadOnly ||
3328 ED.IsReachedFromAlignedBarrierOnly !=
3329 StoredED.IsReachedFromAlignedBarrierOnly ||
3330 ED.EncounteredNonLocalSideEffect !=
3331 StoredED.EncounteredNonLocalSideEffect)
3332 Changed = true;
3333
3334 // Update the state with the new value.
3335 StoredED = std::move(ED);
3336 }
3337
3338 // Propagate (non-aligned) sync instruction effects backwards until the
3339 // entry is hit or an aligned barrier.
3340 SmallSetVector<BasicBlock *, 16> Visited;
3341 while (!SyncInstWorklist.empty()) {
3342 Instruction *SyncInst = SyncInstWorklist.pop_back_val();
3343 Instruction *CurInst = SyncInst;
3344 bool HitAlignedBarrierOrKnownEnd = false;
3345 while ((CurInst = CurInst->getPrevNode())) {
3346 auto *CB = dyn_cast<CallBase>(CurInst);
3347 if (!CB)
3348 continue;
3349 auto &CallOutED = CEDMap[{CB, POST}];
3350 Changed |= setAndRecord(CallOutED.IsReachingAlignedBarrierOnly, false);
3351 auto &CallInED = CEDMap[{CB, PRE}];
3352 HitAlignedBarrierOrKnownEnd =
3353 AlignedBarriers.count(CB) || !CallInED.IsReachingAlignedBarrierOnly;
3354 if (HitAlignedBarrierOrKnownEnd)
3355 break;
3356 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3357 }
3358 if (HitAlignedBarrierOrKnownEnd)
3359 continue;
3360 BasicBlock *SyncBB = SyncInst->getParent();
3361 for (auto *PredBB : predecessors(SyncBB)) {
3362 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, SyncBB))
3363 continue;
3364 if (!Visited.insert(PredBB))
3365 continue;
3366 auto &PredED = BEDMap[PredBB];
3367 if (setAndRecord(PredED.IsReachingAlignedBarrierOnly, false)) {
3368 Changed = true;
3369 SyncInstWorklist.push_back(PredBB->getTerminator());
3370 }
3371 }
3372 if (SyncBB != &EntryBB)
3373 continue;
3374 Changed |=
3375 setAndRecord(InterProceduralED.IsReachingAlignedBarrierOnly, false);
3376 }
3377
3378 return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
3379}
3380
3381/// Try to replace memory allocation calls called by a single thread with a
3382/// static buffer of shared memory.
3383struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
3384 using Base = StateWrapper<BooleanState, AbstractAttribute>;
3385 AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3386
3387 /// Create an abstract attribute view for the position \p IRP.
3388 static AAHeapToShared &createForPosition(const IRPosition &IRP,
3389 Attributor &A);
3390
3391 /// Returns true if HeapToShared conversion is assumed to be possible.
3392 virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
3393
3394 /// Returns true if HeapToShared conversion is assumed and the CB is a
3395 /// callsite to a free operation to be removed.
3396 virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
3397
3398 /// See AbstractAttribute::getName().
3399 StringRef getName() const override { return "AAHeapToShared"; }
3400
3401 /// See AbstractAttribute::getIdAddr().
3402 const char *getIdAddr() const override { return &ID; }
3403
3404 /// This function should return true if the type of the \p AA is
3405 /// AAHeapToShared.
3406 static bool classof(const AbstractAttribute *AA) {
3407 return (AA->getIdAddr() == &ID);
3408 }
3409
3410 /// Unique ID (due to the unique address)
3411 static const char ID;
3412};
3413
3414struct AAHeapToSharedFunction : public AAHeapToShared {
3415 AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
3416 : AAHeapToShared(IRP, A) {}
3417
3418 const std::string getAsStr(Attributor *) const override {
3419 return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
3420 " malloc calls eligible.";
3421 }
3422
3423 /// See AbstractAttribute::trackStatistics().
3424 void trackStatistics() const override {}
3425
3426 /// This functions finds free calls that will be removed by the
3427 /// HeapToShared transformation.
3428 void findPotentialRemovedFreeCalls(Attributor &A) {
3429 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3430 auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3431
3432 PotentialRemovedFreeCalls.clear();
3433 // Update free call users of found malloc calls.
3434 for (CallBase *CB : MallocCalls) {
3436 for (auto *U : CB->users()) {
3437 CallBase *C = dyn_cast<CallBase>(U);
3438 if (C && C->getCalledFunction() == FreeRFI.Declaration)
3439 FreeCalls.push_back(C);
3440 }
3441
3442 if (FreeCalls.size() != 1)
3443 continue;
3444
3445 PotentialRemovedFreeCalls.insert(FreeCalls.front());
3446 }
3447 }
3448
3449 void initialize(Attributor &A) override {
3451 indicatePessimisticFixpoint();
3452 return;
3453 }
3454
3455 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3456 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3457 if (!RFI.Declaration)
3458 return;
3459
3461 [](const IRPosition &, const AbstractAttribute *,
3462 bool &) -> std::optional<Value *> { return nullptr; };
3463
3464 Function *F = getAnchorScope();
3465 const OMPInformationCache::RuntimeFunctionInfo::UseVector *Uses =
3466 RFI.getUseVector(*F);
3467 if (!Uses)
3468 return;
3469
3470 for (Use *U : *Uses)
3471 if (CallBase *CB = dyn_cast<CallBase>(U->getUser())) {
3472 MallocCalls.insert(CB);
3473 A.registerSimplificationCallback(IRPosition::callsite_returned(*CB),
3474 SCB);
3475 }
3476
3477 findPotentialRemovedFreeCalls(A);
3478 }
3479
3480 bool isAssumedHeapToShared(CallBase &CB) const override {
3481 return isValidState() && MallocCalls.count(&CB);
3482 }
3483
3484 bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
3485 return isValidState() && PotentialRemovedFreeCalls.count(&CB);
3486 }
3487
3488 ChangeStatus manifest(Attributor &A) override {
3489 if (MallocCalls.empty())
3490 return ChangeStatus::UNCHANGED;
3491
3492 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3493 auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3494
3495 Function *F = getAnchorScope();
3496 auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
3497 DepClassTy::OPTIONAL);
3498
3499 ChangeStatus Changed = ChangeStatus::UNCHANGED;
3500 for (CallBase *CB : MallocCalls) {
3501 // Skip replacing this if HeapToStack has already claimed it.
3502 if (HS && HS->isAssumedHeapToStack(*CB))
3503 continue;
3504
3505 // Find the unique free call to remove it.
3507 for (auto *U : CB->users()) {
3508 CallBase *C = dyn_cast<CallBase>(U);
3509 if (C && C->getCalledFunction() == FreeCall.Declaration)
3510 FreeCalls.push_back(C);
3511 }
3512 if (FreeCalls.size() != 1)
3513 continue;
3514
3515 auto *AllocSize = cast<ConstantInt>(CB->getArgOperand(0));
3516
3517 if (AllocSize->getZExtValue() + SharedMemoryUsed > SharedMemoryLimit) {
3518 LLVM_DEBUG(dbgs() << TAG << "Cannot replace call " << *CB
3519 << " with shared memory."
3520 << " Shared memory usage is limited to "
3521 << SharedMemoryLimit << " bytes\n");
3522 continue;
3523 }
3524
3525 LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
3526 << " with " << AllocSize->getZExtValue()
3527 << " bytes of shared memory\n");
3528
3529 // Create a new shared memory buffer of the same size as the allocation
3530 // and replace all the uses of the original allocation with it.
3531 Module *M = CB->getModule();
3532 Type *Int8Ty = Type::getInt8Ty(M->getContext());
3533 Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
3534 auto *SharedMem = new GlobalVariable(
3535 *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
3536 PoisonValue::get(Int8ArrTy), CB->getName() + "_shared", nullptr,
3538 static_cast<unsigned>(AddressSpace::Shared));
3539 auto *NewBuffer = ConstantExpr::getPointerCast(
3540 SharedMem, PointerType::getUnqual(M->getContext()));
3541
3542 auto Remark = [&](OptimizationRemark OR) {
3543 return OR << "Replaced globalized variable with "
3544 << ore::NV("SharedMemory", AllocSize->getZExtValue())
3545 << (AllocSize->isOne() ? " byte " : " bytes ")
3546 << "of shared memory.";
3547 };
3548 A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
3549
3550 MaybeAlign Alignment = CB->getRetAlign();
3551 assert(Alignment &&
3552 "HeapToShared on allocation without alignment attribute");
3553 SharedMem->setAlignment(*Alignment);
3554
3555 A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewBuffer);
3556 A.deleteAfterManifest(*CB);
3557 A.deleteAfterManifest(*FreeCalls.front());
3558
3559 SharedMemoryUsed += AllocSize->getZExtValue();
3560 NumBytesMovedToSharedMemory = SharedMemoryUsed;
3561 Changed = ChangeStatus::CHANGED;
3562 }
3563
3564 return Changed;
3565 }
3566
3567 ChangeStatus updateImpl(Attributor &A) override {
3568 if (MallocCalls.empty())
3569 return indicatePessimisticFixpoint();
3570 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3571 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3572 if (!RFI.Declaration)
3573 return ChangeStatus::UNCHANGED;
3574
3575 Function *F = getAnchorScope();
3576
3577 auto NumMallocCalls = MallocCalls.size();
3578
3579 // Only consider malloc calls executed by a single thread with a constant.
3580 for (User *U : RFI.Declaration->users()) {
3581 if (CallBase *CB = dyn_cast<CallBase>(U)) {
3582 if (CB->getCaller() != F)
3583 continue;
3584 if (!MallocCalls.count(CB))
3585 continue;
3586 if (!isa<ConstantInt>(CB->getArgOperand(0))) {
3587 MallocCalls.remove(CB);
3588 continue;
3589 }
3590 const auto *ED = A.getAAFor<AAExecutionDomain>(
3591 *this, IRPosition::function(*F), DepClassTy::REQUIRED);
3592 if (!ED || !ED->isExecutedByInitialThreadOnly(*CB))
3593 MallocCalls.remove(CB);
3594 }
3595 }
3596
3597 findPotentialRemovedFreeCalls(A);
3598
3599 if (NumMallocCalls != MallocCalls.size())
3600 return ChangeStatus::CHANGED;
3601
3602 return ChangeStatus::UNCHANGED;
3603 }
3604
3605 /// Collection of all malloc calls in a function.
3606 SmallSetVector<CallBase *, 4> MallocCalls;
3607 /// Collection of potentially removed free calls in a function.
3608 SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
3609 /// The total amount of shared memory that has been used for HeapToShared.
3610 unsigned SharedMemoryUsed = 0;
3611};
3612
3613struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
3614 using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
3615 AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3616
3617 /// The callee value is tracked beyond a simple stripPointerCasts, so we allow
3618 /// unknown callees.
3619 static bool requiresCalleeForCallBase() { return false; }
3620
3621 /// Statistics are tracked as part of manifest for now.
3622 void trackStatistics() const override {}
3623
3624 /// See AbstractAttribute::getAsStr()
3625 const std::string getAsStr(Attributor *) const override {
3626 if (!isValidState())
3627 return "<invalid>";
3628 return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
3629 : "generic") +
3630 std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
3631 : "") +
3632 std::string(" #PRs: ") +
3633 (ReachedKnownParallelRegions.isValidState()
3634 ? std::to_string(ReachedKnownParallelRegions.size())
3635 : "<invalid>") +
3636 ", #Unknown PRs: " +
3637 (ReachedUnknownParallelRegions.isValidState()
3638 ? std::to_string(ReachedUnknownParallelRegions.size())
3639 : "<invalid>") +
3640 ", #Reaching Kernels: " +
3641 (ReachingKernelEntries.isValidState()
3642 ? std::to_string(ReachingKernelEntries.size())
3643 : "<invalid>") +
3644 ", #ParLevels: " +
3645 (ParallelLevels.isValidState()
3646 ? std::to_string(ParallelLevels.size())
3647 : "<invalid>") +
3648 ", NestedPar: " + (NestedParallelism ? "yes" : "no");
3649 }
3650
3651 /// Create an abstract attribute biew for the position \p IRP.
3652 static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
3653
3654 /// See AbstractAttribute::getName()
3655 StringRef getName() const override { return "AAKernelInfo"; }
3656
3657 /// See AbstractAttribute::getIdAddr()
3658 const char *getIdAddr() const override { return &ID; }
3659
3660 /// This function should return true if the type of the \p AA is AAKernelInfo
3661 static bool classof(const AbstractAttribute *AA) {
3662 return (AA->getIdAddr() == &ID);
3663 }
3664
3665 static const char ID;
3666};
3667
3668/// The function kernel info abstract attribute, basically, what can we say
3669/// about a function with regards to the KernelInfoState.
3670struct AAKernelInfoFunction : AAKernelInfo {
3671 AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
3672 : AAKernelInfo(IRP, A) {}
3673
3674 SmallPtrSet<Instruction *, 4> GuardedInstructions;
3675
3676 SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
3677 return GuardedInstructions;
3678 }
3679
3680 void setConfigurationOfKernelEnvironment(ConstantStruct *ConfigC) {
3682 KernelEnvC, ConfigC, {KernelInfo::ConfigurationIdx});
3683 assert(NewKernelEnvC && "Failed to create new kernel environment");
3684 KernelEnvC = cast<ConstantStruct>(NewKernelEnvC);
3685 }
3686
3687#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER) \
3688 void set##MEMBER##OfKernelEnvironment(ConstantInt *NewVal) { \
3689 ConstantStruct *ConfigC = \
3690 KernelInfo::getConfigurationFromKernelEnvironment(KernelEnvC); \
3691 Constant *NewConfigC = ConstantFoldInsertValueInstruction( \
3692 ConfigC, NewVal, {KernelInfo::MEMBER##Idx}); \
3693 assert(NewConfigC && "Failed to create new configuration environment"); \
3694 setConfigurationOfKernelEnvironment(cast<ConstantStruct>(NewConfigC)); \
3695 }
3696
3697 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(UseGenericStateMachine)
3698 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MayUseNestedParallelism)
3704
3705#undef KERNEL_ENVIRONMENT_CONFIGURATION_SETTER
3706
3707 /// See AbstractAttribute::initialize(...).
3708 void initialize(Attributor &A) override {
3709 // This is a high-level transform that might change the constant arguments
3710 // of the init and dinit calls. We need to tell the Attributor about this
3711 // to avoid other parts using the current constant value for simpliication.
3712 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3713
3714 Function *Fn = getAnchorScope();
3715
3716 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
3717 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
3718 OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
3719 OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
3720
3721 // For kernels we perform more initialization work, first we find the init
3722 // and deinit calls.
3723 auto StoreCallBase = [](Use &U,
3724 OMPInformationCache::RuntimeFunctionInfo &RFI,
3725 CallBase *&Storage) {
3726 CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
3727 assert(CB &&
3728 "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
3729 assert(!Storage &&
3730 "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
3731 Storage = CB;
3732 return false;
3733 };
3734 InitRFI.foreachUse(
3735 [&](Use &U, Function &) {
3736 StoreCallBase(U, InitRFI, KernelInitCB);
3737 return false;
3738 },
3739 Fn);
3740 DeinitRFI.foreachUse(
3741 [&](Use &U, Function &) {
3742 StoreCallBase(U, DeinitRFI, KernelDeinitCB);
3743 return false;
3744 },
3745 Fn);
3746
3747 // Ignore kernels without initializers such as global constructors.
3748 if (!KernelInitCB || !KernelDeinitCB)
3749 return;
3750
3751 // Add itself to the reaching kernel and set IsKernelEntry.
3752 ReachingKernelEntries.insert(Fn);
3753 IsKernelEntry = true;
3754
3755 KernelEnvC =
3757 GlobalVariable *KernelEnvGV =
3759
3761 KernelConfigurationSimplifyCB =
3762 [&](const GlobalVariable &GV, const AbstractAttribute *AA,
3763 bool &UsedAssumedInformation) -> std::optional<Constant *> {
3764 if (!isAtFixpoint()) {
3765 if (!AA)
3766 return nullptr;
3767 UsedAssumedInformation = true;
3768 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3769 }
3770 return KernelEnvC;
3771 };
3772
3773 A.registerGlobalVariableSimplificationCallback(
3774 *KernelEnvGV, KernelConfigurationSimplifyCB);
3775
3776 // We cannot change to SPMD mode if the runtime functions aren't availible.
3777 bool CanChangeToSPMD = OMPInfoCache.runtimeFnsAvailable(
3778 {OMPRTL___kmpc_get_hardware_thread_id_in_block,
3779 OMPRTL___kmpc_barrier_simple_spmd});
3780
3781 // Check if we know we are in SPMD-mode already.
3782 ConstantInt *ExecModeC =
3783 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
3784 ConstantInt *AssumedExecModeC = ConstantInt::get(
3785 ExecModeC->getIntegerType(),
3787 if (ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)
3788 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3789 else if (DisableOpenMPOptSPMDization || !CanChangeToSPMD)
3790 // This is a generic region but SPMDization is disabled so stop
3791 // tracking.
3792 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3793 else
3794 setExecModeOfKernelEnvironment(AssumedExecModeC);
3795
3796 const Triple T(Fn->getParent()->getTargetTriple());
3797 auto *Int32Ty = Type::getInt32Ty(Fn->getContext());
3798 auto [MinThreads, MaxThreads] =
3800 if (MinThreads)
3801 setMinThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinThreads));
3802 if (MaxThreads)
3803 setMaxThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxThreads));
3804 auto [MinTeams, MaxTeams] =
3806 if (MinTeams)
3807 setMinTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinTeams));
3808 if (MaxTeams)
3809 setMaxTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxTeams));
3810
3811 ConstantInt *MayUseNestedParallelismC =
3812 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(KernelEnvC);
3813 ConstantInt *AssumedMayUseNestedParallelismC = ConstantInt::get(
3814 MayUseNestedParallelismC->getIntegerType(), NestedParallelism);
3815 setMayUseNestedParallelismOfKernelEnvironment(
3816 AssumedMayUseNestedParallelismC);
3817
3819 ConstantInt *UseGenericStateMachineC =
3820 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3821 KernelEnvC);
3822 ConstantInt *AssumedUseGenericStateMachineC =
3823 ConstantInt::get(UseGenericStateMachineC->getIntegerType(), false);
3824 setUseGenericStateMachineOfKernelEnvironment(
3825 AssumedUseGenericStateMachineC);
3826 }
3827
3828 // Register virtual uses of functions we might need to preserve.
3829 auto RegisterVirtualUse = [&](RuntimeFunction RFKind,
3831 if (!OMPInfoCache.RFIs[RFKind].Declaration)
3832 return;
3833 A.registerVirtualUseCallback(*OMPInfoCache.RFIs[RFKind].Declaration, CB);
3834 };
3835
3836 // Add a dependence to ensure updates if the state changes.
3837 auto AddDependence = [](Attributor &A, const AAKernelInfo *KI,
3838 const AbstractAttribute *QueryingAA) {
3839 if (QueryingAA) {
3840 A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);
3841 }
3842 return true;
3843 };
3844
3845 Attributor::VirtualUseCallbackTy CustomStateMachineUseCB =
3846 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3847 // Whenever we create a custom state machine we will insert calls to
3848 // __kmpc_get_hardware_num_threads_in_block,
3849 // __kmpc_get_warp_size,
3850 // __kmpc_barrier_simple_generic,
3851 // __kmpc_kernel_parallel, and
3852 // __kmpc_kernel_end_parallel.
3853 // Not needed if we are on track for SPMDzation.
3854 if (SPMDCompatibilityTracker.isValidState())
3855 return AddDependence(A, this, QueryingAA);
3856 // Not needed if we can't rewrite due to an invalid state.
3857 if (!ReachedKnownParallelRegions.isValidState())
3858 return AddDependence(A, this, QueryingAA);
3859 return false;
3860 };
3861
3862 // Not needed if we are pre-runtime merge.
3863 if (!KernelInitCB->getCalledFunction()->isDeclaration()) {
3864 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_num_threads_in_block,
3865 CustomStateMachineUseCB);
3866 RegisterVirtualUse(OMPRTL___kmpc_get_warp_size, CustomStateMachineUseCB);
3867 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_generic,
3868 CustomStateMachineUseCB);
3869 RegisterVirtualUse(OMPRTL___kmpc_kernel_parallel,
3870 CustomStateMachineUseCB);
3871 RegisterVirtualUse(OMPRTL___kmpc_kernel_end_parallel,
3872 CustomStateMachineUseCB);
3873 }
3874
3875 // If we do not perform SPMDzation we do not need the virtual uses below.
3876 if (SPMDCompatibilityTracker.isAtFixpoint())
3877 return;
3878
3879 Attributor::VirtualUseCallbackTy HWThreadIdUseCB =
3880 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3881 // Whenever we perform SPMDzation we will insert
3882 // __kmpc_get_hardware_thread_id_in_block calls.
3883 if (!SPMDCompatibilityTracker.isValidState())
3884 return AddDependence(A, this, QueryingAA);
3885 return false;
3886 };
3887 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_thread_id_in_block,
3888 HWThreadIdUseCB);
3889
3890 Attributor::VirtualUseCallbackTy SPMDBarrierUseCB =
3891 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3892 // Whenever we perform SPMDzation with guarding we will insert
3893 // __kmpc_simple_barrier_spmd calls. If SPMDzation failed, there is
3894 // nothing to guard, or there are no parallel regions, we don't need
3895 // the calls.
3896 if (!SPMDCompatibilityTracker.isValidState())
3897 return AddDependence(A, this, QueryingAA);
3898 if (SPMDCompatibilityTracker.empty())
3899 return AddDependence(A, this, QueryingAA);
3900 if (!mayContainParallelRegion())
3901 return AddDependence(A, this, QueryingAA);
3902 return false;
3903 };
3904 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_spmd, SPMDBarrierUseCB);
3905 }
3906
3907 /// Sanitize the string \p S such that it is a suitable global symbol name.
3908 static std::string sanitizeForGlobalName(std::string S) {
3909 std::replace_if(
3910 S.begin(), S.end(),
3911 [](const char C) {
3912 return !((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') ||
3913 (C >= '0' && C <= '9') || C == '_');
3914 },
3915 '.');
3916 return S;
3917 }
3918
3919 /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
3920 /// finished now.
3921 ChangeStatus manifest(Attributor &A) override {
3922 // If we are not looking at a kernel with __kmpc_target_init and
3923 // __kmpc_target_deinit call we cannot actually manifest the information.
3924 if (!KernelInitCB || !KernelDeinitCB)
3925 return ChangeStatus::UNCHANGED;
3926
3927 ChangeStatus Changed = ChangeStatus::UNCHANGED;
3928
3929 bool HasBuiltStateMachine = true;
3930 if (!changeToSPMDMode(A, Changed)) {
3931 if (!KernelInitCB->getCalledFunction()->isDeclaration())
3932 HasBuiltStateMachine = buildCustomStateMachine(A, Changed);
3933 else
3934 HasBuiltStateMachine = false;
3935 }
3936
3937 // We need to reset KernelEnvC if specific rewriting is not done.
3938 ConstantStruct *ExistingKernelEnvC =
3940 ConstantInt *OldUseGenericStateMachineVal =
3941 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3942 ExistingKernelEnvC);
3943 if (!HasBuiltStateMachine)
3944 setUseGenericStateMachineOfKernelEnvironment(
3945 OldUseGenericStateMachineVal);
3946
3947 // At last, update the KernelEnvc
3948 GlobalVariable *KernelEnvGV =
3950 if (KernelEnvGV->getInitializer() != KernelEnvC) {
3951 KernelEnvGV->setInitializer(KernelEnvC);
3952 Changed = ChangeStatus::CHANGED;
3953 }
3954
3955 return Changed;
3956 }
3957
3958 void insertInstructionGuardsHelper(Attributor &A) {
3959 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3960
3961 auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3962 Instruction *RegionEndI) {
3963 LoopInfo *LI = nullptr;
3964 DominatorTree *DT = nullptr;
3965 MemorySSAUpdater *MSU = nullptr;
3966 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3967
3968 BasicBlock *ParentBB = RegionStartI->getParent();
3969 Function *Fn = ParentBB->getParent();
3970 Module &M = *Fn->getParent();
3971
3972 // Create all the blocks and logic.
3973 // ParentBB:
3974 // goto RegionCheckTidBB
3975 // RegionCheckTidBB:
3976 // Tid = __kmpc_hardware_thread_id()
3977 // if (Tid != 0)
3978 // goto RegionBarrierBB
3979 // RegionStartBB:
3980 // <execute instructions guarded>
3981 // goto RegionEndBB
3982 // RegionEndBB:
3983 // <store escaping values to shared mem>
3984 // goto RegionBarrierBB
3985 // RegionBarrierBB:
3986 // __kmpc_simple_barrier_spmd()
3987 // // second barrier is omitted if lacking escaping values.
3988 // <load escaping values from shared mem>
3989 // __kmpc_simple_barrier_spmd()
3990 // goto RegionExitBB
3991 // RegionExitBB:
3992 // <execute rest of instructions>
3993
3994 BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3995 DT, LI, MSU, "region.guarded.end");
3996 BasicBlock *RegionBarrierBB =
3997 SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3998 MSU, "region.barrier");
3999 BasicBlock *RegionExitBB =
4000 SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
4001 DT, LI, MSU, "region.exit");
4002 BasicBlock *RegionStartBB =
4003 SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
4004
4005 assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
4006 "Expected a different CFG");
4007
4008 BasicBlock *RegionCheckTidBB = SplitBlock(
4009 ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
4010
4011 // Register basic blocks with the Attributor.
4012 A.registerManifestAddedBasicBlock(*RegionEndBB);
4013 A.registerManifestAddedBasicBlock(*RegionBarrierBB);
4014 A.registerManifestAddedBasicBlock(*RegionExitBB);
4015 A.registerManifestAddedBasicBlock(*RegionStartBB);
4016 A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
4017
4018 bool HasBroadcastValues = false;
4019 // Find escaping outputs from the guarded region to outside users and
4020 // broadcast their values to them.
4021 for (Instruction &I : *RegionStartBB) {
4022 SmallVector<Use *, 4> OutsideUses;
4023 for (Use &U : I.uses()) {
4024 Instruction &UsrI = *cast<Instruction>(U.getUser());
4025 if (UsrI.getParent() != RegionStartBB)
4026 OutsideUses.push_back(&U);
4027 }
4028
4029 if (OutsideUses.empty())
4030 continue;
4031
4032 HasBroadcastValues = true;
4033
4034 // Emit a global variable in shared memory to store the broadcasted
4035 // value.
4036 auto *SharedMem = new GlobalVariable(
4037 M, I.getType(), /* IsConstant */ false,
4039 sanitizeForGlobalName(
4040 (I.getName() + ".guarded.output.alloc").str()),
4042 static_cast<unsigned>(AddressSpace::Shared));
4043
4044 // Emit a store instruction to update the value.
4045 new StoreInst(&I, SharedMem,
4046 RegionEndBB->getTerminator()->getIterator());
4047
4048 LoadInst *LoadI = new LoadInst(
4049 I.getType(), SharedMem, I.getName() + ".guarded.output.load",
4050 RegionBarrierBB->getTerminator()->getIterator());
4051
4052 // Emit a load instruction and replace uses of the output value.
4053 for (Use *U : OutsideUses)
4054 A.changeUseAfterManifest(*U, *LoadI);
4055 }
4056
4057 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4058
4059 // Go to tid check BB in ParentBB.
4060 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
4061 ParentBB->getTerminator()->eraseFromParent();
4062 OpenMPIRBuilder::LocationDescription Loc(
4063 InsertPointTy(ParentBB, ParentBB->end()), DL);
4064 OMPInfoCache.OMPBuilder.updateToLocation(Loc);
4065 uint32_t SrcLocStrSize;
4066 auto *SrcLocStr =
4067 OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4068 Value *Ident =
4069 OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4070 UncondBrInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
4071
4072 // Add check for Tid in RegionCheckTidBB
4073 RegionCheckTidBB->getTerminator()->eraseFromParent();
4074 OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
4075 InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
4076 OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
4077 FunctionCallee HardwareTidFn =
4078 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4079 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4080 CallInst *Tid =
4081 OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
4082 Tid->setDebugLoc(DL);
4083 OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);
4084 Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
4085 OMPInfoCache.OMPBuilder.Builder
4086 .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
4087 ->setDebugLoc(DL);
4088
4089 // First barrier for synchronization, ensures main thread has updated
4090 // values.
4091 FunctionCallee BarrierFn =
4092 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4093 M, OMPRTL___kmpc_barrier_simple_spmd);
4094 OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
4095 RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
4096 CallInst *Barrier =
4097 OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});
4098 Barrier->setDebugLoc(DL);
4099 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4100
4101 // Second barrier ensures workers have read broadcast values.
4102 if (HasBroadcastValues) {
4103 CallInst *Barrier =
4104 CallInst::Create(BarrierFn, {Ident, Tid}, "",
4105 RegionBarrierBB->getTerminator()->getIterator());
4106 Barrier->setDebugLoc(DL);
4107 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4108 }
4109 };
4110
4111 auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4112 SmallPtrSet<BasicBlock *, 8> Visited;
4113 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4114 BasicBlock *BB = GuardedI->getParent();
4115 if (!Visited.insert(BB).second)
4116 continue;
4117
4119 Instruction *LastEffect = nullptr;
4120 BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();
4121 while (++IP != IPEnd) {
4122 if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
4123 continue;
4124 Instruction *I = &*IP;
4125 if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI))
4126 continue;
4127 if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) {
4128 LastEffect = nullptr;
4129 continue;
4130 }
4131 if (LastEffect)
4132 Reorders.push_back({I, LastEffect});
4133 LastEffect = &*IP;
4134 }
4135 for (auto &Reorder : Reorders)
4136 Reorder.first->moveBefore(Reorder.second->getIterator());
4137 }
4138
4140
4141 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4142 BasicBlock *BB = GuardedI->getParent();
4143 auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
4144 IRPosition::function(*GuardedI->getFunction()), nullptr,
4145 DepClassTy::NONE);
4146 assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
4147 auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
4148 // Continue if instruction is already guarded.
4149 if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
4150 continue;
4151
4152 Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
4153 for (Instruction &I : *BB) {
4154 // If instruction I needs to be guarded update the guarded region
4155 // bounds.
4156 if (SPMDCompatibilityTracker.contains(&I)) {
4157 CalleeAAFunction.getGuardedInstructions().insert(&I);
4158 if (GuardedRegionStart)
4159 GuardedRegionEnd = &I;
4160 else
4161 GuardedRegionStart = GuardedRegionEnd = &I;
4162
4163 continue;
4164 }
4165
4166 // Instruction I does not need guarding, store
4167 // any region found and reset bounds.
4168 if (GuardedRegionStart) {
4169 GuardedRegions.push_back(
4170 std::make_pair(GuardedRegionStart, GuardedRegionEnd));
4171 GuardedRegionStart = nullptr;
4172 GuardedRegionEnd = nullptr;
4173 }
4174 }
4175 }
4176
4177 for (auto &GR : GuardedRegions)
4178 CreateGuardedRegion(GR.first, GR.second);
4179 }
4180
4181 void forceSingleThreadPerWorkgroupHelper(Attributor &A) {
4182 // Only allow 1 thread per workgroup to continue executing the user code.
4183 //
4184 // InitCB = __kmpc_target_init(...)
4185 // ThreadIdInBlock = __kmpc_get_hardware_thread_id_in_block();
4186 // if (ThreadIdInBlock != 0) return;
4187 // UserCode:
4188 // // user code
4189 //
4190 auto &Ctx = getAnchorValue().getContext();
4191 Function *Kernel = getAssociatedFunction();
4192 assert(Kernel && "Expected an associated function!");
4193
4194 // Create block for user code to branch to from initial block.
4195 BasicBlock *InitBB = KernelInitCB->getParent();
4196 BasicBlock *UserCodeBB = InitBB->splitBasicBlock(
4197 KernelInitCB->getNextNode(), "main.thread.user_code");
4198 BasicBlock *ReturnBB =
4199 BasicBlock::Create(Ctx, "exit.threads", Kernel, UserCodeBB);
4200
4201 // Register blocks with attributor:
4202 A.registerManifestAddedBasicBlock(*InitBB);
4203 A.registerManifestAddedBasicBlock(*UserCodeBB);
4204 A.registerManifestAddedBasicBlock(*ReturnBB);
4205
4206 // Debug location:
4207 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4208 ReturnInst::Create(Ctx, ReturnBB)->setDebugLoc(DLoc);
4209 InitBB->getTerminator()->eraseFromParent();
4210
4211 // Prepare call to OMPRTL___kmpc_get_hardware_thread_id_in_block.
4212 Module &M = *Kernel->getParent();
4213 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4214 FunctionCallee ThreadIdInBlockFn =
4215 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4216 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4217
4218 // Get thread ID in block.
4219 CallInst *ThreadIdInBlock =
4220 CallInst::Create(ThreadIdInBlockFn, "thread_id.in.block", InitBB);
4221 OMPInfoCache.setCallingConvention(ThreadIdInBlockFn, ThreadIdInBlock);
4222 ThreadIdInBlock->setDebugLoc(DLoc);
4223
4224 // Eliminate all threads in the block with ID not equal to 0:
4225 Instruction *IsMainThread =
4226 ICmpInst::Create(ICmpInst::ICmp, CmpInst::ICMP_NE, ThreadIdInBlock,
4227 ConstantInt::get(ThreadIdInBlock->getType(), 0),
4228 "thread.is_main", InitBB);
4229 IsMainThread->setDebugLoc(DLoc);
4230 CondBrInst::Create(IsMainThread, ReturnBB, UserCodeBB, InitBB);
4231 }
4232
4233 bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {
4234 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4235
4236 if (!SPMDCompatibilityTracker.isAssumed()) {
4237 for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
4238 if (!NonCompatibleI)
4239 continue;
4240
4241 // Skip diagnostics on calls to known OpenMP runtime functions for now.
4242 if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
4243 if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
4244 continue;
4245
4246 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4247 ORA << "Value has potential side effects preventing SPMD-mode "
4248 "execution";
4249 if (isa<CallBase>(NonCompatibleI)) {
4250 ORA << ". Add `[[omp::assume(\"ompx_spmd_amenable\")]]` to "
4251 "the called function to override";
4252 }
4253 return ORA << ".";
4254 };
4255 A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
4256 Remark);
4257
4258 LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
4259 << *NonCompatibleI << "\n");
4260 }
4261
4262 return false;
4263 }
4264
4265 // Get the actual kernel, could be the caller of the anchor scope if we have
4266 // a debug wrapper.
4267 Function *Kernel = getAnchorScope();
4268 if (Kernel->hasLocalLinkage()) {
4269 assert(Kernel->hasOneUse() && "Unexpected use of debug kernel wrapper.");
4270 auto *CB = cast<CallBase>(Kernel->user_back());
4271 Kernel = CB->getCaller();
4272 }
4273 assert(omp::isOpenMPKernel(*Kernel) && "Expected kernel function!");
4274
4275 // Check if the kernel is already in SPMD mode, if so, return success.
4276 ConstantStruct *ExistingKernelEnvC =
4278 auto *ExecModeC =
4279 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4280 const int8_t ExecModeVal = ExecModeC->getSExtValue();
4281 if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC)
4282 return true;
4283
4284 // We will now unconditionally modify the IR, indicate a change.
4285 Changed = ChangeStatus::CHANGED;
4286
4287 // Do not use instruction guards when no parallel is present inside
4288 // the target region.
4289 if (mayContainParallelRegion())
4290 insertInstructionGuardsHelper(A);
4291 else
4292 forceSingleThreadPerWorkgroupHelper(A);
4293
4294 // Adjust the global exec mode flag that tells the runtime what mode this
4295 // kernel is executed in.
4296 assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
4297 "Initially non-SPMD kernel has SPMD exec mode!");
4298 setExecModeOfKernelEnvironment(
4299 ConstantInt::get(ExecModeC->getIntegerType(),
4300 ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
4301
4302 ++NumOpenMPTargetRegionKernelsSPMD;
4303
4304 auto Remark = [&](OptimizationRemark OR) {
4305 return OR << "Transformed generic-mode kernel to SPMD-mode.";
4306 };
4307 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
4308 return true;
4309 };
4310
4311 bool buildCustomStateMachine(Attributor &A, ChangeStatus &Changed) {
4312 // If we have disabled state machine rewrites, don't make a custom one
4314 return false;
4315
4316 // Don't rewrite the state machine if we are not in a valid state.
4317 if (!ReachedKnownParallelRegions.isValidState())
4318 return false;
4319
4320 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4321 if (!OMPInfoCache.runtimeFnsAvailable(
4322 {OMPRTL___kmpc_get_hardware_num_threads_in_block,
4323 OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
4324 OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
4325 return false;
4326
4327 ConstantStruct *ExistingKernelEnvC =
4329
4330 // Check if the current configuration is non-SPMD and generic state machine.
4331 // If we already have SPMD mode or a custom state machine we do not need to
4332 // go any further. If it is anything but a constant something is weird and
4333 // we give up.
4334 ConstantInt *UseStateMachineC =
4335 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4336 ExistingKernelEnvC);
4337 ConstantInt *ModeC =
4338 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4339
4340 // If we are stuck with generic mode, try to create a custom device (=GPU)
4341 // state machine which is specialized for the parallel regions that are
4342 // reachable by the kernel.
4343 if (UseStateMachineC->isZero() ||
4345 return false;
4346
4347 Changed = ChangeStatus::CHANGED;
4348
4349 // If not SPMD mode, indicate we use a custom state machine now.
4350 setUseGenericStateMachineOfKernelEnvironment(
4351 ConstantInt::get(UseStateMachineC->getIntegerType(), false));
4352
4353 // If we don't actually need a state machine we are done here. This can
4354 // happen if there simply are no parallel regions. In the resulting kernel
4355 // all worker threads will simply exit right away, leaving the main thread
4356 // to do the work alone.
4357 if (!mayContainParallelRegion()) {
4358 ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
4359
4360 auto Remark = [&](OptimizationRemark OR) {
4361 return OR << "Removing unused state machine from generic-mode kernel.";
4362 };
4363 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
4364
4365 return true;
4366 }
4367
4368 // Keep track in the statistics of our new shiny custom state machine.
4369 if (ReachedUnknownParallelRegions.empty()) {
4370 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
4371
4372 auto Remark = [&](OptimizationRemark OR) {
4373 return OR << "Rewriting generic-mode kernel with a customized state "
4374 "machine.";
4375 };
4376 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
4377 } else {
4378 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
4379
4380 auto Remark = [&](OptimizationRemarkAnalysis OR) {
4381 return OR << "Generic-mode kernel is executed with a customized state "
4382 "machine that requires a fallback.";
4383 };
4384 A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
4385
4386 // Tell the user why we ended up with a fallback.
4387 for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
4388 if (!UnknownParallelRegionCB)
4389 continue;
4390 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4391 return ORA << "Call may contain unknown parallel regions. Use "
4392 << "`[[omp::assume(\"omp_no_parallelism\")]]` to "
4393 "override.";
4394 };
4395 A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
4396 "OMP133", Remark);
4397 }
4398 }
4399
4400 // Create all the blocks:
4401 //
4402 // InitCB = __kmpc_target_init(...)
4403 // BlockHwSize =
4404 // __kmpc_get_hardware_num_threads_in_block();
4405 // WarpSize = __kmpc_get_warp_size();
4406 // BlockSize = BlockHwSize - WarpSize;
4407 // IsWorkerCheckBB: bool IsWorker = InitCB != -1;
4408 // if (IsWorker) {
4409 // if (InitCB >= BlockSize) return;
4410 // SMBeginBB: __kmpc_barrier_simple_generic(...);
4411 // void *WorkFn;
4412 // bool Active = __kmpc_kernel_parallel(&WorkFn);
4413 // if (!WorkFn) return;
4414 // SMIsActiveCheckBB: if (Active) {
4415 // SMIfCascadeCurrentBB: if (WorkFn == <ParFn0>)
4416 // ParFn0(...);
4417 // SMIfCascadeCurrentBB: else if (WorkFn == <ParFn1>)
4418 // ParFn1(...);
4419 // ...
4420 // SMIfCascadeCurrentBB: else
4421 // ((WorkFnTy*)WorkFn)(...);
4422 // SMEndParallelBB: __kmpc_kernel_end_parallel(...);
4423 // }
4424 // SMDoneBB: __kmpc_barrier_simple_generic(...);
4425 // goto SMBeginBB;
4426 // }
4427 // UserCodeEntryBB: // user code
4428 // __kmpc_target_deinit(...)
4429 //
4430 auto &Ctx = getAnchorValue().getContext();
4431 Function *Kernel = getAssociatedFunction();
4432 assert(Kernel && "Expected an associated function!");
4433
4434 BasicBlock *InitBB = KernelInitCB->getParent();
4435 BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
4436 KernelInitCB->getNextNode(), "thread.user_code.check");
4437 BasicBlock *IsWorkerCheckBB =
4438 BasicBlock::Create(Ctx, "is_worker_check", Kernel, UserCodeEntryBB);
4439 BasicBlock *StateMachineBeginBB = BasicBlock::Create(
4440 Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
4441 BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
4442 Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
4443 BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
4444 Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
4445 BasicBlock *StateMachineIfCascadeCurrentBB =
4446 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4447 Kernel, UserCodeEntryBB);
4448 BasicBlock *StateMachineEndParallelBB =
4449 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
4450 Kernel, UserCodeEntryBB);
4451 BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
4452 Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
4453 A.registerManifestAddedBasicBlock(*InitBB);
4454 A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
4455 A.registerManifestAddedBasicBlock(*IsWorkerCheckBB);
4456 A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
4457 A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
4458 A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
4459 A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
4460 A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
4461 A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
4462
4463 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4464 ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
4465 InitBB->getTerminator()->eraseFromParent();
4466
4467 Instruction *IsWorker =
4468 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
4469 ConstantInt::getAllOnesValue(KernelInitCB->getType()),
4470 "thread.is_worker", InitBB);
4471 IsWorker->setDebugLoc(DLoc);
4472 CondBrInst::Create(IsWorker, IsWorkerCheckBB, UserCodeEntryBB, InitBB);
4473
4474 Module &M = *Kernel->getParent();
4475 FunctionCallee BlockHwSizeFn =
4476 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4477 M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
4478 FunctionCallee WarpSizeFn =
4479 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4480 M, OMPRTL___kmpc_get_warp_size);
4481 CallInst *BlockHwSize =
4482 CallInst::Create(BlockHwSizeFn, "block.hw_size", IsWorkerCheckBB);
4483 OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);
4484 BlockHwSize->setDebugLoc(DLoc);
4485 CallInst *WarpSize =
4486 CallInst::Create(WarpSizeFn, "warp.size", IsWorkerCheckBB);
4487 OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);
4488 WarpSize->setDebugLoc(DLoc);
4489 Instruction *BlockSize = BinaryOperator::CreateSub(
4490 BlockHwSize, WarpSize, "block.size", IsWorkerCheckBB);
4491 BlockSize->setDebugLoc(DLoc);
4492 Instruction *IsMainOrWorker = ICmpInst::Create(
4493 ICmpInst::ICmp, llvm::CmpInst::ICMP_SLT, KernelInitCB, BlockSize,
4494 "thread.is_main_or_worker", IsWorkerCheckBB);
4495 IsMainOrWorker->setDebugLoc(DLoc);
4496 CondBrInst::Create(IsMainOrWorker, StateMachineBeginBB,
4497 StateMachineFinishedBB, IsWorkerCheckBB);
4498
4499 // Create local storage for the work function pointer.
4500 const DataLayout &DL = M.getDataLayout();
4501 Type *VoidPtrTy = PointerType::getUnqual(Ctx);
4502 Instruction *WorkFnAI =
4503 new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
4504 "worker.work_fn.addr", Kernel->getEntryBlock().begin());
4505 WorkFnAI->setDebugLoc(DLoc);
4506
4507 OMPInfoCache.OMPBuilder.updateToLocation(
4508 OpenMPIRBuilder::LocationDescription(
4509 IRBuilder<>::InsertPoint(StateMachineBeginBB,
4510 StateMachineBeginBB->end()),
4511 DLoc));
4512
4513 Value *Ident = KernelInfo::getIdentFromKernelEnvironment(KernelEnvC);
4514 Value *GTid = KernelInitCB;
4515
4516 FunctionCallee BarrierFn =
4517 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4518 M, OMPRTL___kmpc_barrier_simple_generic);
4519 CallInst *Barrier =
4520 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB);
4521 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4522 Barrier->setDebugLoc(DLoc);
4523
4524 if (WorkFnAI->getType()->getPointerAddressSpace() !=
4525 (unsigned int)AddressSpace::Generic) {
4526 WorkFnAI = new AddrSpaceCastInst(
4527 WorkFnAI, PointerType::get(Ctx, (unsigned int)AddressSpace::Generic),
4528 WorkFnAI->getName() + ".generic", StateMachineBeginBB);
4529 WorkFnAI->setDebugLoc(DLoc);
4530 }
4531
4532 FunctionCallee KernelParallelFn =
4533 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4534 M, OMPRTL___kmpc_kernel_parallel);
4535 CallInst *IsActiveWorker = CallInst::Create(
4536 KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
4537 OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);
4538 IsActiveWorker->setDebugLoc(DLoc);
4539 Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
4540 StateMachineBeginBB);
4541 WorkFn->setDebugLoc(DLoc);
4542
4543 FunctionType *ParallelRegionFnTy = FunctionType::get(
4544 Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)},
4545 false);
4546
4547 Instruction *IsDone =
4548 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
4549 Constant::getNullValue(VoidPtrTy), "worker.is_done",
4550 StateMachineBeginBB);
4551 IsDone->setDebugLoc(DLoc);
4552 CondBrInst::Create(IsDone, StateMachineFinishedBB,
4553 StateMachineIsActiveCheckBB, StateMachineBeginBB)
4554 ->setDebugLoc(DLoc);
4555
4556 CondBrInst::Create(IsActiveWorker, StateMachineIfCascadeCurrentBB,
4557 StateMachineDoneBarrierBB, StateMachineIsActiveCheckBB)
4558 ->setDebugLoc(DLoc);
4559
4560 Value *ZeroArg =
4561 Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
4562
4563 const unsigned int WrapperFunctionArgNo = 6;
4564
4565 // Now that we have most of the CFG skeleton it is time for the if-cascade
4566 // that checks the function pointer we got from the runtime against the
4567 // parallel regions we expect, if there are any.
4568 for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) {
4569 auto *CB = ReachedKnownParallelRegions[I];
4570 auto *ParallelRegion = dyn_cast<Function>(
4571 CB->getArgOperand(WrapperFunctionArgNo)->stripPointerCasts());
4572 BasicBlock *PRExecuteBB = BasicBlock::Create(
4573 Ctx, "worker_state_machine.parallel_region.execute", Kernel,
4574 StateMachineEndParallelBB);
4575 CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
4576 ->setDebugLoc(DLoc);
4577 UncondBrInst::Create(StateMachineEndParallelBB, PRExecuteBB)
4578 ->setDebugLoc(DLoc);
4579
4580 BasicBlock *PRNextBB =
4581 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4582 Kernel, StateMachineEndParallelBB);
4583 A.registerManifestAddedBasicBlock(*PRExecuteBB);
4584 A.registerManifestAddedBasicBlock(*PRNextBB);
4585
4586 // Check if we need to compare the pointer at all or if we can just
4587 // call the parallel region function.
4588 Value *IsPR;
4589 if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) {
4590 Instruction *CmpI = ICmpInst::Create(
4591 ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn, ParallelRegion,
4592 "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
4593 CmpI->setDebugLoc(DLoc);
4594 IsPR = CmpI;
4595 } else {
4596 IsPR = ConstantInt::getTrue(Ctx);
4597 }
4598
4599 CondBrInst::Create(IsPR, PRExecuteBB, PRNextBB,
4600 StateMachineIfCascadeCurrentBB)
4601 ->setDebugLoc(DLoc);
4602 StateMachineIfCascadeCurrentBB = PRNextBB;
4603 }
4604
4605 // At the end of the if-cascade we place the indirect function pointer call
4606 // in case we might need it, that is if there can be parallel regions we
4607 // have not handled in the if-cascade above.
4608 if (!ReachedUnknownParallelRegions.empty()) {
4609 StateMachineIfCascadeCurrentBB->setName(
4610 "worker_state_machine.parallel_region.fallback.execute");
4611 CallInst::Create(ParallelRegionFnTy, WorkFn, {ZeroArg, GTid}, "",
4612 StateMachineIfCascadeCurrentBB)
4613 ->setDebugLoc(DLoc);
4614 }
4615 UncondBrInst::Create(StateMachineEndParallelBB,
4616 StateMachineIfCascadeCurrentBB)
4617 ->setDebugLoc(DLoc);
4618
4619 FunctionCallee EndParallelFn =
4620 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4621 M, OMPRTL___kmpc_kernel_end_parallel);
4622 CallInst *EndParallel =
4623 CallInst::Create(EndParallelFn, {}, "", StateMachineEndParallelBB);
4624 OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);
4625 EndParallel->setDebugLoc(DLoc);
4626 UncondBrInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
4627 ->setDebugLoc(DLoc);
4628
4629 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
4630 ->setDebugLoc(DLoc);
4631 UncondBrInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
4632 ->setDebugLoc(DLoc);
4633
4634 return true;
4635 }
4636
4637 /// Fixpoint iteration update function. Will be called every time a dependence
4638 /// changed its state (and in the beginning).
4639 ChangeStatus updateImpl(Attributor &A) override {
4640 KernelInfoState StateBefore = getState();
4641
4642 // When we leave this function this RAII will make sure the member
4643 // KernelEnvC is updated properly depending on the state. That member is
4644 // used for simplification of values and needs to be up to date at all
4645 // times.
4646 struct UpdateKernelEnvCRAII {
4647 AAKernelInfoFunction &AA;
4648
4649 UpdateKernelEnvCRAII(AAKernelInfoFunction &AA) : AA(AA) {}
4650
4651 ~UpdateKernelEnvCRAII() {
4652 if (!AA.KernelEnvC)
4653 return;
4654
4655 ConstantStruct *ExistingKernelEnvC =
4657
4658 if (!AA.isValidState()) {
4659 AA.KernelEnvC = ExistingKernelEnvC;
4660 return;
4661 }
4662
4663 if (!AA.ReachedKnownParallelRegions.isValidState())
4664 AA.setUseGenericStateMachineOfKernelEnvironment(
4665 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4666 ExistingKernelEnvC));
4667
4668 if (!AA.SPMDCompatibilityTracker.isValidState())
4669 AA.setExecModeOfKernelEnvironment(
4670 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC));
4671
4672 ConstantInt *MayUseNestedParallelismC =
4673 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(
4674 AA.KernelEnvC);
4675 ConstantInt *NewMayUseNestedParallelismC = ConstantInt::get(
4676 MayUseNestedParallelismC->getIntegerType(), AA.NestedParallelism);
4677 AA.setMayUseNestedParallelismOfKernelEnvironment(
4678 NewMayUseNestedParallelismC);
4679 }
4680 } RAII(*this);
4681
4682 // Callback to check a read/write instruction.
4683 auto CheckRWInst = [&](Instruction &I) {
4684 // We handle calls later.
4685 if (isa<CallBase>(I))
4686 return true;
4687 // We only care about write effects.
4688 if (!I.mayWriteToMemory())
4689 return true;
4690 if (auto *SI = dyn_cast<StoreInst>(&I)) {
4691 const auto *UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>(
4692 *this, IRPosition::value(*SI->getPointerOperand()),
4693 DepClassTy::OPTIONAL);
4694 auto *HS = A.getAAFor<AAHeapToStack>(
4695 *this, IRPosition::function(*I.getFunction()),
4696 DepClassTy::OPTIONAL);
4697 if (UnderlyingObjsAA &&
4698 UnderlyingObjsAA->forallUnderlyingObjects([&](Value &Obj) {
4699 if (AA::isAssumedThreadLocalObject(A, Obj, *this))
4700 return true;
4701 // Check for AAHeapToStack moved objects which must not be
4702 // guarded.
4703 auto *CB = dyn_cast<CallBase>(&Obj);
4704 return CB && HS && HS->isAssumedHeapToStack(*CB);
4705 }))
4706 return true;
4707 }
4708
4709 // Insert instruction that needs guarding.
4710 SPMDCompatibilityTracker.insert(&I);
4711 return true;
4712 };
4713
4714 bool UsedAssumedInformationInCheckRWInst = false;
4715 if (!SPMDCompatibilityTracker.isAtFixpoint())
4716 if (!A.checkForAllReadWriteInstructions(
4717 CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
4718 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4719
4720 bool UsedAssumedInformationFromReachingKernels = false;
4721 if (!IsKernelEntry) {
4722 updateParallelLevels(A);
4723
4724 bool AllReachingKernelsKnown = true;
4725 updateReachingKernelEntries(A, AllReachingKernelsKnown);
4726 UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
4727
4728 if (!SPMDCompatibilityTracker.empty()) {
4729 if (!ParallelLevels.isValidState())
4730 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4731 else if (!ReachingKernelEntries.isValidState())
4732 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4733 else {
4734 // Check if all reaching kernels agree on the mode as we can otherwise
4735 // not guard instructions. We might not be sure about the mode so we
4736 // we cannot fix the internal spmd-zation state either.
4737 int SPMD = 0, Generic = 0;
4738 for (auto *Kernel : ReachingKernelEntries) {
4739 auto *CBAA = A.getAAFor<AAKernelInfo>(
4740 *this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL);
4741 if (CBAA && CBAA->SPMDCompatibilityTracker.isValidState() &&
4742 CBAA->SPMDCompatibilityTracker.isAssumed())
4743 ++SPMD;
4744 else
4745 ++Generic;
4746 if (!CBAA || !CBAA->SPMDCompatibilityTracker.isAtFixpoint())
4747 UsedAssumedInformationFromReachingKernels = true;
4748 }
4749 if (SPMD != 0 && Generic != 0)
4750 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4751 }
4752 }
4753 }
4754
4755 // Callback to check a call instruction.
4756 bool AllParallelRegionStatesWereFixed = true;
4757 bool AllSPMDStatesWereFixed = true;
4758 auto CheckCallInst = [&](Instruction &I) {
4759 auto &CB = cast<CallBase>(I);
4760 auto *CBAA = A.getAAFor<AAKernelInfo>(
4761 *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4762 if (!CBAA)
4763 return false;
4764 getState() ^= CBAA->getState();
4765 AllSPMDStatesWereFixed &= CBAA->SPMDCompatibilityTracker.isAtFixpoint();
4766 AllParallelRegionStatesWereFixed &=
4767 CBAA->ReachedKnownParallelRegions.isAtFixpoint();
4768 AllParallelRegionStatesWereFixed &=
4769 CBAA->ReachedUnknownParallelRegions.isAtFixpoint();
4770 return true;
4771 };
4772
4773 bool UsedAssumedInformationInCheckCallInst = false;
4774 if (!A.checkForAllCallLikeInstructions(
4775 CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) {
4776 LLVM_DEBUG(dbgs() << TAG
4777 << "Failed to visit all call-like instructions!\n";);
4778 return indicatePessimisticFixpoint();
4779 }
4780
4781 // If we haven't used any assumed information for the reached parallel
4782 // region states we can fix it.
4783 if (!UsedAssumedInformationInCheckCallInst &&
4784 AllParallelRegionStatesWereFixed) {
4785 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
4786 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
4787 }
4788
4789 // If we haven't used any assumed information for the SPMD state we can fix
4790 // it.
4791 if (!UsedAssumedInformationInCheckRWInst &&
4792 !UsedAssumedInformationInCheckCallInst &&
4793 !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
4794 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4795
4796 return StateBefore == getState() ? ChangeStatus::UNCHANGED
4797 : ChangeStatus::CHANGED;
4798 }
4799
4800private:
4801 /// Update info regarding reaching kernels.
4802 void updateReachingKernelEntries(Attributor &A,
4803 bool &AllReachingKernelsKnown) {
4804 auto PredCallSite = [&](AbstractCallSite ACS) {
4805 Function *Caller = ACS.getInstruction()->getFunction();
4806
4807 assert(Caller && "Caller is nullptr");
4808
4809 auto *CAA = A.getOrCreateAAFor<AAKernelInfo>(
4810 IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
4811 if (CAA && CAA->ReachingKernelEntries.isValidState()) {
4812 ReachingKernelEntries ^= CAA->ReachingKernelEntries;
4813 return true;
4814 }
4815
4816 // We lost track of the caller of the associated function, any kernel
4817 // could reach now.
4818 ReachingKernelEntries.indicatePessimisticFixpoint();
4819
4820 return true;
4821 };
4822
4823 if (!A.checkForAllCallSites(PredCallSite, *this,
4824 true /* RequireAllCallSites */,
4825 AllReachingKernelsKnown))
4826 ReachingKernelEntries.indicatePessimisticFixpoint();
4827 }
4828
4829 /// Update info regarding parallel levels.
4830 void updateParallelLevels(Attributor &A) {
4831 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4832 OMPInformationCache::RuntimeFunctionInfo &Parallel60RFI =
4833 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_60];
4834
4835 auto PredCallSite = [&](AbstractCallSite ACS) {
4836 Function *Caller = ACS.getInstruction()->getFunction();
4837
4838 assert(Caller && "Caller is nullptr");
4839
4840 auto *CAA =
4841 A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
4842 if (CAA && CAA->ParallelLevels.isValidState()) {
4843 // Any function that is called by `__kmpc_parallel_60` will not be
4844 // folded as the parallel level in the function is updated. In order to
4845 // get it right, all the analysis would depend on the implentation. That
4846 // said, if in the future any change to the implementation, the analysis
4847 // could be wrong. As a consequence, we are just conservative here.
4848 if (Caller == Parallel60RFI.Declaration) {
4849 ParallelLevels.indicatePessimisticFixpoint();
4850 return true;
4851 }
4852
4853 ParallelLevels ^= CAA->ParallelLevels;
4854
4855 return true;
4856 }
4857
4858 // We lost track of the caller of the associated function, any kernel
4859 // could reach now.
4860 ParallelLevels.indicatePessimisticFixpoint();
4861
4862 return true;
4863 };
4864
4865 bool AllCallSitesKnown = true;
4866 if (!A.checkForAllCallSites(PredCallSite, *this,
4867 true /* RequireAllCallSites */,
4868 AllCallSitesKnown))
4869 ParallelLevels.indicatePessimisticFixpoint();
4870 }
4871};
4872
4873/// The call site kernel info abstract attribute, basically, what can we say
4874/// about a call site with regards to the KernelInfoState. For now this simply
4875/// forwards the information from the callee.
4876struct AAKernelInfoCallSite : AAKernelInfo {
4877 AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
4878 : AAKernelInfo(IRP, A) {}
4879
4880 /// See AbstractAttribute::initialize(...).
4881 void initialize(Attributor &A) override {
4882 AAKernelInfo::initialize(A);
4883
4884 CallBase &CB = cast<CallBase>(getAssociatedValue());
4885 auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>(
4886 *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4887
4888 // Check for SPMD-mode assumptions.
4889 if (AssumptionAA && AssumptionAA->hasAssumption("ompx_spmd_amenable")) {
4890 indicateOptimisticFixpoint();
4891 return;
4892 }
4893
4894 // First weed out calls we do not care about, that is readonly/readnone
4895 // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
4896 // parallel region or anything else we are looking for.
4897 if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
4898 indicateOptimisticFixpoint();
4899 return;
4900 }
4901
4902 // Next we check if we know the callee. If it is a known OpenMP function
4903 // we will handle them explicitly in the switch below. If it is not, we
4904 // will use an AAKernelInfo object on the callee to gather information and
4905 // merge that into the current state. The latter happens in the updateImpl.
4906 auto CheckCallee = [&](Function *Callee, unsigned NumCallees) {
4907 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4908 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4909 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4910 // Unknown caller or declarations are not analyzable, we give up.
4911 if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
4912
4913 // Unknown callees might contain parallel regions, except if they have
4914 // an appropriate assumption attached.
4915 if (!AssumptionAA ||
4916 !(AssumptionAA->hasAssumption("omp_no_openmp") ||
4917 AssumptionAA->hasAssumption("omp_no_parallelism")))
4918 ReachedUnknownParallelRegions.insert(&CB);
4919
4920 // If SPMDCompatibilityTracker is not fixed, we need to give up on the
4921 // idea we can run something unknown in SPMD-mode.
4922 if (!SPMDCompatibilityTracker.isAtFixpoint()) {
4923 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4924 SPMDCompatibilityTracker.insert(&CB);
4925 }
4926
4927 // We have updated the state for this unknown call properly, there
4928 // won't be any change so we indicate a fixpoint.
4929 indicateOptimisticFixpoint();
4930 }
4931 // If the callee is known and can be used in IPO, we will update the
4932 // state based on the callee state in updateImpl.
4933 return;
4934 }
4935 if (NumCallees > 1) {
4936 indicatePessimisticFixpoint();
4937 return;
4938 }
4939
4940 RuntimeFunction RF = It->getSecond();
4941 switch (RF) {
4942 // All the functions we know are compatible with SPMD mode.
4943 case OMPRTL___kmpc_is_spmd_exec_mode:
4944 case OMPRTL___kmpc_distribute_static_fini:
4945 case OMPRTL___kmpc_for_static_fini:
4946 case OMPRTL___kmpc_global_thread_num:
4947 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4948 case OMPRTL___kmpc_get_hardware_num_blocks:
4949 case OMPRTL___kmpc_single:
4950 case OMPRTL___kmpc_end_single:
4951 case OMPRTL___kmpc_master:
4952 case OMPRTL___kmpc_end_master:
4953 case OMPRTL___kmpc_barrier:
4954 case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
4955 case OMPRTL___kmpc_gpu_xteam_reduce_nowait:
4956 case OMPRTL___kmpc_error:
4957 case OMPRTL___kmpc_flush:
4958 case OMPRTL___kmpc_get_hardware_thread_id_in_block:
4959 case OMPRTL___kmpc_get_warp_size:
4960 case OMPRTL_omp_get_thread_num:
4961 case OMPRTL_omp_get_num_threads:
4962 case OMPRTL_omp_get_max_threads:
4963 case OMPRTL_omp_in_parallel:
4964 case OMPRTL_omp_get_dynamic:
4965 case OMPRTL_omp_get_cancellation:
4966 case OMPRTL_omp_get_nested:
4967 case OMPRTL_omp_get_schedule:
4968 case OMPRTL_omp_get_thread_limit:
4969 case OMPRTL_omp_get_supported_active_levels:
4970 case OMPRTL_omp_get_max_active_levels:
4971 case OMPRTL_omp_get_level:
4972 case OMPRTL_omp_get_ancestor_thread_num:
4973 case OMPRTL_omp_get_team_size:
4974 case OMPRTL_omp_get_active_level:
4975 case OMPRTL_omp_in_final:
4976 case OMPRTL_omp_get_proc_bind:
4977 case OMPRTL_omp_get_num_places:
4978 case OMPRTL_omp_get_num_procs:
4979 case OMPRTL_omp_get_place_proc_ids:
4980 case OMPRTL_omp_get_place_num:
4981 case OMPRTL_omp_get_partition_num_places:
4982 case OMPRTL_omp_get_partition_place_nums:
4983 case OMPRTL_omp_get_wtime:
4984 break;
4985 case OMPRTL___kmpc_distribute_static_init_4:
4986 case OMPRTL___kmpc_distribute_static_init_4u:
4987 case OMPRTL___kmpc_distribute_static_init_8:
4988 case OMPRTL___kmpc_distribute_static_init_8u:
4989 case OMPRTL___kmpc_for_static_init_4:
4990 case OMPRTL___kmpc_for_static_init_4u:
4991 case OMPRTL___kmpc_for_static_init_8:
4992 case OMPRTL___kmpc_for_static_init_8u: {
4993 // Check the schedule and allow static schedule in SPMD mode.
4994 unsigned ScheduleArgOpNo = 2;
4995 auto *ScheduleTypeCI =
4996 dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
4997 unsigned ScheduleTypeVal =
4998 ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
4999 switch (OMPScheduleType(ScheduleTypeVal)) {
5000 case OMPScheduleType::UnorderedStatic:
5001 case OMPScheduleType::UnorderedStaticChunked:
5002 case OMPScheduleType::OrderedDistribute:
5003 case OMPScheduleType::OrderedDistributeChunked:
5004 break;
5005 default:
5006 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5007 SPMDCompatibilityTracker.insert(&CB);
5008 break;
5009 };
5010 } break;
5011 case OMPRTL___kmpc_target_init:
5012 KernelInitCB = &CB;
5013 break;
5014 case OMPRTL___kmpc_target_deinit:
5015 KernelDeinitCB = &CB;
5016 break;
5017 case OMPRTL___kmpc_parallel_60:
5018 if (!handleParallel60(A, CB))
5019 indicatePessimisticFixpoint();
5020 return;
5021 case OMPRTL___kmpc_omp_task:
5022 // We do not look into tasks right now, just give up.
5023 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5024 SPMDCompatibilityTracker.insert(&CB);
5025 ReachedUnknownParallelRegions.insert(&CB);
5026 break;
5027 case OMPRTL___kmpc_alloc_shared:
5028 case OMPRTL___kmpc_free_shared:
5029 // Return without setting a fixpoint, to be resolved in updateImpl.
5030 return;
5031 case OMPRTL___kmpc_distribute_static_loop_4:
5032 case OMPRTL___kmpc_distribute_static_loop_4u:
5033 case OMPRTL___kmpc_distribute_static_loop_8:
5034 case OMPRTL___kmpc_distribute_static_loop_8u:
5035 case OMPRTL___kmpc_distribute_for_static_loop_4:
5036 case OMPRTL___kmpc_distribute_for_static_loop_4u:
5037 case OMPRTL___kmpc_distribute_for_static_loop_8:
5038 case OMPRTL___kmpc_distribute_for_static_loop_8u:
5039 case OMPRTL___kmpc_for_static_loop_4:
5040 case OMPRTL___kmpc_for_static_loop_4u:
5041 case OMPRTL___kmpc_for_static_loop_8:
5042 case OMPRTL___kmpc_for_static_loop_8u:
5043 // Parallel regions might be reached by these calls, as they take a
5044 // callback argument potentially containing arbitrary user-provided
5045 // code.
5046 ReachedUnknownParallelRegions.insert(&CB);
5047 // TODO: The presence of these calls on their own does not prevent a
5048 // kernel from being SPMD-izable. We mark it as such because we need
5049 // further changes in order to also consider the contents of the
5050 // callbacks passed to them.
5051 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5052 SPMDCompatibilityTracker.insert(&CB);
5053 break;
5054 default:
5055 // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
5056 // generally. However, they do not hide parallel regions.
5057 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5058 SPMDCompatibilityTracker.insert(&CB);
5059 break;
5060 }
5061 // All other OpenMP runtime calls will not reach parallel regions so they
5062 // can be safely ignored for now. Since it is a known OpenMP runtime call
5063 // we have now modeled all effects and there is no need for any update.
5064 indicateOptimisticFixpoint();
5065 };
5066
5067 const auto *AACE =
5068 A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5069 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5070 CheckCallee(getAssociatedFunction(), 1);
5071 return;
5072 }
5073 const auto &OptimisticEdges = AACE->getOptimisticEdges();
5074 for (auto *Callee : OptimisticEdges) {
5075 CheckCallee(Callee, OptimisticEdges.size());
5076 if (isAtFixpoint())
5077 break;
5078 }
5079 }
5080
5081 ChangeStatus updateImpl(Attributor &A) override {
5082 // TODO: Once we have call site specific value information we can provide
5083 // call site specific liveness information and then it makes
5084 // sense to specialize attributes for call sites arguments instead of
5085 // redirecting requests to the callee argument.
5086 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5087 KernelInfoState StateBefore = getState();
5088
5089 auto CheckCallee = [&](Function *F, int NumCallees) {
5090 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
5091
5092 // If F is not a runtime function, propagate the AAKernelInfo of the
5093 // callee.
5094 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
5095 const IRPosition &FnPos = IRPosition::function(*F);
5096 auto *FnAA =
5097 A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
5098 if (!FnAA)
5099 return indicatePessimisticFixpoint();
5100 if (getState() == FnAA->getState())
5101 return ChangeStatus::UNCHANGED;
5102 getState() = FnAA->getState();
5103 return ChangeStatus::CHANGED;
5104 }
5105 if (NumCallees > 1)
5106 return indicatePessimisticFixpoint();
5107
5108 CallBase &CB = cast<CallBase>(getAssociatedValue());
5109 if (It->getSecond() == OMPRTL___kmpc_parallel_60) {
5110 if (!handleParallel60(A, CB))
5111 return indicatePessimisticFixpoint();
5112 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5113 : ChangeStatus::CHANGED;
5114 }
5115
5116 // F is a runtime function that allocates or frees memory, check
5117 // AAHeapToStack and AAHeapToShared.
5118 assert(
5119 (It->getSecond() == OMPRTL___kmpc_alloc_shared ||
5120 It->getSecond() == OMPRTL___kmpc_free_shared) &&
5121 "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
5122
5123 auto *HeapToStackAA = A.getAAFor<AAHeapToStack>(
5124 *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5125 auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>(
5126 *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5127
5128 RuntimeFunction RF = It->getSecond();
5129
5130 switch (RF) {
5131 // If neither HeapToStack nor HeapToShared assume the call is removed,
5132 // assume SPMD incompatibility.
5133 case OMPRTL___kmpc_alloc_shared:
5134 if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&
5135 (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))
5136 SPMDCompatibilityTracker.insert(&CB);
5137 break;
5138 case OMPRTL___kmpc_free_shared:
5139 if ((!HeapToStackAA ||
5140 !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&
5141 (!HeapToSharedAA ||
5142 !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))
5143 SPMDCompatibilityTracker.insert(&CB);
5144 break;
5145 default:
5146 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5147 SPMDCompatibilityTracker.insert(&CB);
5148 }
5149 return ChangeStatus::CHANGED;
5150 };
5151
5152 const auto *AACE =
5153 A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5154 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5155 if (Function *F = getAssociatedFunction())
5156 CheckCallee(F, /*NumCallees=*/1);
5157 } else {
5158 const auto &OptimisticEdges = AACE->getOptimisticEdges();
5159 for (auto *Callee : OptimisticEdges) {
5160 CheckCallee(Callee, OptimisticEdges.size());
5161 if (isAtFixpoint())
5162 break;
5163 }
5164 }
5165
5166 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5167 : ChangeStatus::CHANGED;
5168 }
5169
5170 /// Deal with a __kmpc_parallel_60 call (\p CB). Returns true if the call was
5171 /// handled, if a problem occurred, false is returned.
5172 bool handleParallel60(Attributor &A, CallBase &CB) {
5173 const unsigned int NonWrapperFunctionArgNo = 5;
5174 const unsigned int WrapperFunctionArgNo = 6;
5175 auto ParallelRegionOpArgNo = SPMDCompatibilityTracker.isAssumed()
5176 ? NonWrapperFunctionArgNo
5177 : WrapperFunctionArgNo;
5178
5179 auto *ParallelRegion = dyn_cast<Function>(
5180 CB.getArgOperand(ParallelRegionOpArgNo)->stripPointerCasts());
5181 if (!ParallelRegion)
5182 return false;
5183
5184 ReachedKnownParallelRegions.insert(&CB);
5185 /// Check nested parallelism
5186 auto *FnAA = A.getAAFor<AAKernelInfo>(
5187 *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL);
5188 NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
5189 !FnAA->ReachedKnownParallelRegions.empty() ||
5190 !FnAA->ReachedKnownParallelRegions.isValidState() ||
5191 !FnAA->ReachedUnknownParallelRegions.isValidState() ||
5192 !FnAA->ReachedUnknownParallelRegions.empty();
5193 return true;
5194 }
5195};
5196
5197struct AAFoldRuntimeCall
5198 : public StateWrapper<BooleanState, AbstractAttribute> {
5199 using Base = StateWrapper<BooleanState, AbstractAttribute>;
5200
5201 AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
5202
5203 /// Statistics are tracked as part of manifest for now.
5204 void trackStatistics() const override {}
5205
5206 /// Create an abstract attribute biew for the position \p IRP.
5207 static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
5208 Attributor &A);
5209
5210 /// See AbstractAttribute::getName()
5211 StringRef getName() const override { return "AAFoldRuntimeCall"; }
5212
5213 /// See AbstractAttribute::getIdAddr()
5214 const char *getIdAddr() const override { return &ID; }
5215
5216 /// This function should return true if the type of the \p AA is
5217 /// AAFoldRuntimeCall
5218 static bool classof(const AbstractAttribute *AA) {
5219 return (AA->getIdAddr() == &ID);
5220 }
5221
5222 static const char ID;
5223};
5224
5225struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
5226 AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
5227 : AAFoldRuntimeCall(IRP, A) {}
5228
5229 /// See AbstractAttribute::getAsStr()
5230 const std::string getAsStr(Attributor *) const override {
5231 if (!isValidState())
5232 return "<invalid>";
5233
5234 std::string Str("simplified value: ");
5235
5236 if (!SimplifiedValue)
5237 return Str + std::string("none");
5238
5239 if (!*SimplifiedValue)
5240 return Str + std::string("nullptr");
5241
5242 if (ConstantInt *CI = dyn_cast<ConstantInt>(*SimplifiedValue))
5243 return Str + std::to_string(CI->getSExtValue());
5244
5245 return Str + std::string("unknown");
5246 }
5247
5248 void initialize(Attributor &A) override {
5250 indicatePessimisticFixpoint();
5251
5252 Function *Callee = getAssociatedFunction();
5253
5254 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5255 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
5256 assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
5257 "Expected a known OpenMP runtime function");
5258
5259 RFKind = It->getSecond();
5260
5261 CallBase &CB = cast<CallBase>(getAssociatedValue());
5262 A.registerSimplificationCallback(
5264 [&](const IRPosition &IRP, const AbstractAttribute *AA,
5265 bool &UsedAssumedInformation) -> std::optional<Value *> {
5266 assert((isValidState() || SimplifiedValue == nullptr) &&
5267 "Unexpected invalid state!");
5268
5269 if (!isAtFixpoint()) {
5270 UsedAssumedInformation = true;
5271 if (AA)
5272 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
5273 }
5274 return SimplifiedValue;
5275 });
5276 }
5277
5278 ChangeStatus updateImpl(Attributor &A) override {
5279 ChangeStatus Changed = ChangeStatus::UNCHANGED;
5280 switch (RFKind) {
5281 case OMPRTL___kmpc_is_spmd_exec_mode:
5282 Changed |= foldIsSPMDExecMode(A);
5283 break;
5284 case OMPRTL___kmpc_parallel_level:
5285 Changed |= foldParallelLevel(A);
5286 break;
5287 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
5288 Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
5289 break;
5290 case OMPRTL___kmpc_get_hardware_num_blocks:
5291 Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
5292 break;
5293 default:
5294 llvm_unreachable("Unhandled OpenMP runtime function!");
5295 }
5296
5297 return Changed;
5298 }
5299
5300 ChangeStatus manifest(Attributor &A) override {
5301 ChangeStatus Changed = ChangeStatus::UNCHANGED;
5302
5303 if (SimplifiedValue && *SimplifiedValue) {
5304 Instruction &I = *getCtxI();
5305 A.changeAfterManifest(IRPosition::inst(I), **SimplifiedValue);
5306 A.deleteAfterManifest(I);
5307
5308 CallBase *CB = dyn_cast<CallBase>(&I);
5309 auto Remark = [&](OptimizationRemark OR) {
5310 if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue))
5311 return OR << "Replacing OpenMP runtime call "
5312 << CB->getCalledFunction()->getName() << " with "
5313 << ore::NV("FoldedValue", C->getZExtValue()) << ".";
5314 return OR << "Replacing OpenMP runtime call "
5315 << CB->getCalledFunction()->getName() << ".";
5316 };
5317
5318 if (CB && EnableVerboseRemarks)
5319 A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark);
5320
5321 LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "
5322 << **SimplifiedValue << "\n");
5323
5324 Changed = ChangeStatus::CHANGED;
5325 }
5326
5327 return Changed;
5328 }
5329
5330 ChangeStatus indicatePessimisticFixpoint() override {
5331 SimplifiedValue = nullptr;
5332 return AAFoldRuntimeCall::indicatePessimisticFixpoint();
5333 }
5334
5335private:
5336 /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
5337 ChangeStatus foldIsSPMDExecMode(Attributor &A) {
5338 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5339
5340 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5341 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5342 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5343 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5344
5345 if (!CallerKernelInfoAA ||
5346 !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5347 return indicatePessimisticFixpoint();
5348
5349 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5350 auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5351 DepClassTy::REQUIRED);
5352
5353 if (!AA || !AA->isValidState()) {
5354 SimplifiedValue = nullptr;
5355 return indicatePessimisticFixpoint();
5356 }
5357
5358 if (AA->SPMDCompatibilityTracker.isAssumed()) {
5359 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5360 ++KnownSPMDCount;
5361 else
5362 ++AssumedSPMDCount;
5363 } else {
5364 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5365 ++KnownNonSPMDCount;
5366 else
5367 ++AssumedNonSPMDCount;
5368 }
5369 }
5370
5371 if ((AssumedSPMDCount + KnownSPMDCount) &&
5372 (AssumedNonSPMDCount + KnownNonSPMDCount))
5373 return indicatePessimisticFixpoint();
5374
5375 auto &Ctx = getAnchorValue().getContext();
5376 if (KnownSPMDCount || AssumedSPMDCount) {
5377 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5378 "Expected only SPMD kernels!");
5379 // All reaching kernels are in SPMD mode. Update all function calls to
5380 // __kmpc_is_spmd_exec_mode to 1.
5381 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
5382 } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
5383 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5384 "Expected only non-SPMD kernels!");
5385 // All reaching kernels are in non-SPMD mode. Update all function
5386 // calls to __kmpc_is_spmd_exec_mode to 0.
5387 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
5388 } else {
5389 // We have empty reaching kernels, therefore we cannot tell if the
5390 // associated call site can be folded. At this moment, SimplifiedValue
5391 // must be none.
5392 assert(!SimplifiedValue && "SimplifiedValue should be none");
5393 }
5394
5395 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5396 : ChangeStatus::CHANGED;
5397 }
5398
5399 /// Fold __kmpc_parallel_level into a constant if possible.
5400 ChangeStatus foldParallelLevel(Attributor &A) {
5401 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5402
5403 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5404 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5405
5406 if (!CallerKernelInfoAA ||
5407 !CallerKernelInfoAA->ParallelLevels.isValidState())
5408 return indicatePessimisticFixpoint();
5409
5410 if (!CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5411 return indicatePessimisticFixpoint();
5412
5413 if (CallerKernelInfoAA->ReachingKernelEntries.empty()) {
5414 assert(!SimplifiedValue &&
5415 "SimplifiedValue should keep none at this point");
5416 return ChangeStatus::UNCHANGED;
5417 }
5418
5419 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5420 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5421 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5422 auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5423 DepClassTy::REQUIRED);
5424 if (!AA || !AA->SPMDCompatibilityTracker.isValidState())
5425 return indicatePessimisticFixpoint();
5426
5427 if (AA->SPMDCompatibilityTracker.isAssumed()) {
5428 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5429 ++KnownSPMDCount;
5430 else
5431 ++AssumedSPMDCount;
5432 } else {
5433 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5434 ++KnownNonSPMDCount;
5435 else
5436 ++AssumedNonSPMDCount;
5437 }
5438 }
5439
5440 if ((AssumedSPMDCount + KnownSPMDCount) &&
5441 (AssumedNonSPMDCount + KnownNonSPMDCount))
5442 return indicatePessimisticFixpoint();
5443
5444 auto &Ctx = getAnchorValue().getContext();
5445 // If the caller can only be reached by SPMD kernel entries, the parallel
5446 // level is 1. Similarly, if the caller can only be reached by non-SPMD
5447 // kernel entries, it is 0.
5448 if (AssumedSPMDCount || KnownSPMDCount) {
5449 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5450 "Expected only SPMD kernels!");
5451 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
5452 } else {
5453 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5454 "Expected only non-SPMD kernels!");
5455 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
5456 }
5457 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5458 : ChangeStatus::CHANGED;
5459 }
5460
5461 ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
5462 // Specialize only if all the calls agree with the attribute constant value
5463 int32_t CurrentAttrValue = -1;
5464 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5465
5466 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5467 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5468
5469 if (!CallerKernelInfoAA ||
5470 !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5471 return indicatePessimisticFixpoint();
5472
5473 // Iterate over the kernels that reach this function
5474 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5475 int32_t NextAttrVal = K->getFnAttributeAsParsedInteger(Attr, -1);
5476
5477 if (NextAttrVal == -1 ||
5478 (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
5479 return indicatePessimisticFixpoint();
5480 CurrentAttrValue = NextAttrVal;
5481 }
5482
5483 if (CurrentAttrValue != -1) {
5484 auto &Ctx = getAnchorValue().getContext();
5485 SimplifiedValue =
5486 ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
5487 }
5488 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5489 : ChangeStatus::CHANGED;
5490 }
5491
5492 /// An optional value the associated value is assumed to fold to. That is, we
5493 /// assume the associated value (which is a call) can be replaced by this
5494 /// simplified value.
5495 std::optional<Value *> SimplifiedValue;
5496
5497 /// The runtime function kind of the callee of the associated call site.
5498 RuntimeFunction RFKind;
5499};
5500
5501} // namespace
5502
5503/// Register folding callsite
5504void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
5505 auto &RFI = OMPInfoCache.RFIs[RF];
5506 RFI.foreachUse(SCC, [&](Use &U, Function &F) {
5507 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
5508 if (!CI)
5509 return false;
5510 A.getOrCreateAAFor<AAFoldRuntimeCall>(
5511 IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
5512 DepClassTy::NONE, /* ForceUpdate */ false,
5513 /* UpdateAfterInit */ false);
5514 return false;
5515 });
5516}
5517
5518void OpenMPOpt::registerAAs(bool IsModulePass) {
5519 if (SCC.empty())
5520 return;
5521
5522 if (IsModulePass) {
5523 // Ensure we create the AAKernelInfo AAs first and without triggering an
5524 // update. This will make sure we register all value simplification
5525 // callbacks before any other AA has the chance to create an AAValueSimplify
5526 // or similar.
5527 auto CreateKernelInfoCB = [&](Use &, Function &Kernel) {
5528 A.getOrCreateAAFor<AAKernelInfo>(
5529 IRPosition::function(Kernel), /* QueryingAA */ nullptr,
5530 DepClassTy::NONE, /* ForceUpdate */ false,
5531 /* UpdateAfterInit */ false);
5532 return false;
5533 };
5534 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
5535 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
5536 InitRFI.foreachUse(SCC, CreateKernelInfoCB);
5537
5538 registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
5539 registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
5540 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
5541 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
5542 }
5543
5544 // Create CallSite AA for all Getters.
5545 if (DeduceICVValues) {
5546 for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
5547 auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
5548
5549 auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
5550
5551 auto CreateAA = [&](Use &U, Function &Caller) {
5552 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
5553 if (!CI)
5554 return false;
5555
5556 auto &CB = cast<CallBase>(*CI);
5557
5558 IRPosition CBPos = IRPosition::callsite_function(CB);
5559 A.getOrCreateAAFor<AAICVTracker>(CBPos);
5560 return false;
5561 };
5562
5563 GetterRFI.foreachUse(SCC, CreateAA);
5564 }
5565 }
5566
5567 // Create an ExecutionDomain AA for every function and a HeapToStack AA for
5568 // every function if there is a device kernel.
5569 if (!isOpenMPDevice(M))
5570 return;
5571
5572 for (auto *F : SCC) {
5573 if (F->isDeclaration())
5574 continue;
5575
5576 // We look at internal functions only on-demand but if any use is not a
5577 // direct call or outside the current set of analyzed functions, we have
5578 // to do it eagerly.
5579 if (F->hasLocalLinkage()) {
5580 if (llvm::all_of(F->uses(), [this](const Use &U) {
5581 const auto *CB = dyn_cast<CallBase>(U.getUser());
5582 return CB && CB->isCallee(&U) &&
5583 A.isRunOn(const_cast<Function *>(CB->getCaller()));
5584 }))
5585 continue;
5586 }
5587 registerAAsForFunction(A, *F);
5588 }
5589}
5590
5591void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) {
5592 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5593
5594 IRPosition FPos = IRPosition::function(F);
5595 A.getOrCreateAAFor<AAExecutionDomain>(FPos);
5596 if (F.hasFnAttribute(Attribute::Convergent))
5597 A.getOrCreateAAFor<AANonConvergent>(FPos);
5598
5599 bool FunctionUsesSharedAlloc = false;
5601 const OMPInformationCache::RuntimeFunctionInfo::UseVector *SharedAllocUses =
5602 OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared].getUseVector(
5603 const_cast<Function &>(F));
5604 FunctionUsesSharedAlloc = SharedAllocUses && !SharedAllocUses->empty();
5605 }
5606 bool HasHeapToStackCandidate = false;
5607 const TargetLibraryInfo *TLI = nullptr;
5608
5609 for (auto &I : instructions(F)) {
5610 if (auto *LI = dyn_cast<LoadInst>(&I)) {
5611 bool UsedAssumedInformation = false;
5612 A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
5613 UsedAssumedInformation, AA::Interprocedural);
5614 A.getOrCreateAAFor<AAAddressSpace>(
5615 IRPosition::value(*LI->getPointerOperand()));
5616 continue;
5617 }
5618 if (auto *CI = dyn_cast<CallBase>(&I)) {
5619 if (!DisableOpenMPOptDeglobalization && !HasHeapToStackCandidate) {
5620 if (!TLI)
5621 TLI = A.getInfoCache().getTargetLibraryInfoForFunction(F);
5622 HasHeapToStackCandidate =
5623 isRemovableAlloc(CI, TLI) || getFreedOperand(CI, TLI);
5624 }
5625 if (CI->isIndirectCall())
5626 A.getOrCreateAAFor<AAIndirectCallInfo>(
5628 }
5629 if (auto *SI = dyn_cast<StoreInst>(&I)) {
5630 A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI));
5631 A.getOrCreateAAFor<AAAddressSpace>(
5632 IRPosition::value(*SI->getPointerOperand()));
5633 continue;
5634 }
5635 if (auto *FI = dyn_cast<FenceInst>(&I)) {
5636 A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*FI));
5637 continue;
5638 }
5639 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
5640 if (II->getIntrinsicID() == Intrinsic::assume) {
5641 A.getOrCreateAAFor<AAPotentialValues>(
5642 IRPosition::value(*II->getArgOperand(0)));
5643 continue;
5644 }
5645 }
5646 }
5647
5648 if (FunctionUsesSharedAlloc)
5649 A.getOrCreateAAFor<AAHeapToShared>(FPos);
5650 if (HasHeapToStackCandidate)
5651 A.getOrCreateAAFor<AAHeapToStack>(FPos);
5652}
5653
5654const char AAICVTracker::ID = 0;
5655const char AAKernelInfo::ID = 0;
5656const char AAExecutionDomain::ID = 0;
5657const char AAHeapToShared::ID = 0;
5658const char AAFoldRuntimeCall::ID = 0;
5659
5660AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
5661 Attributor &A) {
5662 AAICVTracker *AA = nullptr;
5663 switch (IRP.getPositionKind()) {
5668 llvm_unreachable("ICVTracker can only be created for function position!");
5670 AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
5671 break;
5673 AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
5674 break;
5676 AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
5677 break;
5679 AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
5680 break;
5681 }
5682
5683 return *AA;
5684}
5685
5687 Attributor &A) {
5688 AAExecutionDomainFunction *AA = nullptr;
5689 switch (IRP.getPositionKind()) {
5698 "AAExecutionDomain can only be created for function position!");
5700 AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
5701 break;
5702 }
5703
5704 return *AA;
5705}
5706
5707AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
5708 Attributor &A) {
5709 AAHeapToSharedFunction *AA = nullptr;
5710 switch (IRP.getPositionKind()) {
5719 "AAHeapToShared can only be created for function position!");
5721 AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
5722 break;
5723 }
5724
5725 return *AA;
5726}
5727
5728AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
5729 Attributor &A) {
5730 AAKernelInfo *AA = nullptr;
5731 switch (IRP.getPositionKind()) {
5738 llvm_unreachable("KernelInfo can only be created for function position!");
5740 AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
5741 break;
5743 AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
5744 break;
5745 }
5746
5747 return *AA;
5748}
5749
5750AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
5751 Attributor &A) {
5752 AAFoldRuntimeCall *AA = nullptr;
5753 switch (IRP.getPositionKind()) {
5761 llvm_unreachable("KernelInfo can only be created for call site position!");
5763 AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
5764 break;
5765 }
5766
5767 return *AA;
5768}
5769
5771 if (!containsOpenMP(M))
5772 return PreservedAnalyses::all();
5774 return PreservedAnalyses::all();
5775
5778 KernelSet Kernels = getDeviceKernels(M);
5779
5781 LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt Module Pass:\n" << M);
5782
5783 auto IsCalled = [&](Function &F) {
5784 if (Kernels.contains(&F))
5785 return true;
5786 return !F.use_empty();
5787 };
5788
5789 auto EmitRemark = [&](Function &F) {
5790 auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
5791 ORE.emit([&]() {
5792 OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
5793 return ORA << "Could not internalize function. "
5794 << "Some optimizations may not be possible. [OMP140]";
5795 });
5796 };
5797
5798 bool Changed = false;
5799
5800 // Create internal copies of each function if this is a kernel Module. This
5801 // allows iterprocedural passes to see every call edge.
5802 DenseMap<Function *, Function *> InternalizedMap;
5803 if (isOpenMPDevice(M)) {
5804 SmallPtrSet<Function *, 16> InternalizeFns;
5805 for (Function &F : M)
5806 if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
5809 InternalizeFns.insert(&F);
5810 } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
5811 EmitRemark(F);
5812 }
5813 }
5814
5815 Changed |=
5816 Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
5817 }
5818
5819 // Look at every function in the Module unless it was internalized.
5820 SetVector<Function *> Functions;
5822 for (Function &F : M)
5823 if (!F.isDeclaration() && !InternalizedMap.lookup(&F)) {
5824 SCC.push_back(&F);
5825 Functions.insert(&F);
5826 }
5827
5828 if (SCC.empty())
5830
5831 AnalysisGetter AG(FAM);
5832
5833 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5834 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
5835 };
5836
5837 BumpPtrAllocator Allocator;
5838 CallGraphUpdater CGUpdater;
5839
5840 bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5843 OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, PostLink);
5844
5845 unsigned MaxFixpointIterations =
5847
5848 AttributorConfig AC(CGUpdater);
5850 AC.IsModulePass = true;
5851 AC.RewriteSignatures = false;
5852 AC.MaxFixpointIterations = MaxFixpointIterations;
5853 AC.OREGetter = OREGetter;
5854 AC.PassName = DEBUG_TYPE;
5855 AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5856 AC.IPOAmendableCB = [](const Function &F) {
5857 return F.hasFnAttribute("kernel");
5858 };
5859
5860 Attributor A(Functions, InfoCache, AC);
5861
5862 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5863 Changed |= OMPOpt.run(true);
5864
5865 // Optionally inline device functions for potentially better performance.
5867 for (Function &F : M)
5868 if (!F.isDeclaration() && !Kernels.contains(&F) &&
5869 !F.hasFnAttribute(Attribute::NoInline))
5870 F.addFnAttr(Attribute::AlwaysInline);
5871
5873 LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M);
5874
5875 if (Changed)
5876 return PreservedAnalyses::none();
5877
5878 return PreservedAnalyses::all();
5879}
5880
5883 LazyCallGraph &CG,
5884 CGSCCUpdateResult &UR) {
5885 if (!containsOpenMP(*C.begin()->getFunction().getParent()))
5886 return PreservedAnalyses::all();
5888 return PreservedAnalyses::all();
5889
5891 // If there are kernels in the module, we have to run on all SCC's.
5892 for (LazyCallGraph::Node &N : C) {
5893 Function *Fn = &N.getFunction();
5894 SCC.push_back(Fn);
5895 }
5896
5897 if (SCC.empty())
5898 return PreservedAnalyses::all();
5899
5900 Module &M = *C.begin()->getFunction().getParent();
5901
5903 LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt CGSCC Pass:\n" << M);
5904
5906 AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
5907
5908 AnalysisGetter AG(FAM);
5909
5910 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5911 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
5912 };
5913
5914 BumpPtrAllocator Allocator;
5915 CallGraphUpdater CGUpdater;
5916 CGUpdater.initialize(CG, C, AM, UR);
5917
5918 bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5922 OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
5923 /*CGSCC*/ &Functions, PostLink);
5924
5925 unsigned MaxFixpointIterations =
5927
5928 AttributorConfig AC(CGUpdater);
5930 AC.IsModulePass = false;
5931 AC.RewriteSignatures = false;
5932 AC.MaxFixpointIterations = MaxFixpointIterations;
5933 AC.OREGetter = OREGetter;
5934 AC.PassName = DEBUG_TYPE;
5935 AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5936
5937 Attributor A(Functions, InfoCache, AC);
5938
5939 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5940 bool Changed = OMPOpt.run(false);
5941
5943 LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
5944
5945 if (Changed)
5946 return PreservedAnalyses::none();
5947
5948 return PreservedAnalyses::all();
5949}
5950
5952 return Fn.hasFnAttribute("kernel");
5953}
5954
5956 KernelSet Kernels;
5957
5958 for (Function &F : M)
5959 if (F.hasKernelCallingConv()) {
5960 // We are only interested in OpenMP target regions. Others, such as
5961 // kernels generated by CUDA but linked together, are not interesting to
5962 // this pass.
5963 if (isOpenMPKernel(F)) {
5964 ++NumOpenMPTargetRegionKernels;
5965 Kernels.insert(&F);
5966 } else
5967 ++NumNonOpenMPTargetRegionKernels;
5968 }
5969
5970 return Kernels;
5971}
5972
5974 Metadata *MD = M.getModuleFlag("openmp");
5975 if (!MD)
5976 return false;
5977
5978 return true;
5979}
5980
5982 Metadata *MD = M.getModuleFlag("openmp-device");
5983 if (!MD)
5984 return false;
5985
5986 return true;
5987}
@ Generic
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
amdgpu next use AMDGPU Next Use Analysis Printer
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
static cl::opt< unsigned > SetFixpointIterations("attributor-max-iterations", cl::Hidden, cl::desc("Maximal number of fixpoint iterations."), cl::init(32))
static const Function * getParent(const Value *V)
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
This file provides interfaces used to manipulate a call graph, regardless if it is a "old style" Call...
This file provides interfaces used to build and manipulate a call graph, which is a very useful tool ...
This file contains the declarations for the subclasses of Constant, which represent the different fla...
This file defines the DenseSet and SmallDenseSet classes.
This file defines an array type that can be indexed using scoped enum values.
#define DEBUG_TYPE
static void emitRemark(const Function &F, OptimizationRemarkEmitter &ORE, bool Skip)
Loop::LoopBounds::Direction Direction
Definition LoopInfo.cpp:253
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
Machine Check Debug Module
This file provides utility analysis objects describing memory locations.
#define T
uint64_t IntrinsicInst * II
This file defines constans and helpers used when dealing with OpenMP.
This file defines constans that will be used by both host and device compilation.
static constexpr auto TAG
static cl::opt< bool > HideMemoryTransferLatency("openmp-hide-memory-transfer-latency", cl::desc("[WIP] Tries to hide the latency of host to device memory" " transfers"), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptStateMachineRewrite("openmp-opt-disable-state-machine-rewrite", cl::desc("Disable OpenMP optimizations that replace the state machine."), cl::Hidden, cl::init(false))
static cl::opt< bool > EnableParallelRegionMerging("openmp-opt-enable-merging", cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintModuleAfterOptimizations("openmp-opt-print-module-after", cl::desc("Print the current module after OpenMP optimizations."), cl::Hidden, cl::init(false))
#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER)
#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX)
#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER)
static cl::opt< bool > PrintOpenMPKernels("openmp-print-gpu-kernels", cl::init(false), cl::Hidden)
static cl::opt< bool > DisableOpenMPOptFolding("openmp-opt-disable-folding", cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintModuleBeforeOptimizations("openmp-opt-print-module-before", cl::desc("Print the current module before OpenMP optimizations."), cl::Hidden, cl::init(false))
static cl::opt< unsigned > SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden, cl::desc("Maximal number of attributor iterations."), cl::init(256))
static cl::opt< bool > DisableInternalization("openmp-opt-disable-internalization", cl::desc("Disable function internalization."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintICVValues("openmp-print-icv-values", cl::init(false), cl::Hidden)
static cl::opt< bool > DisableOpenMPOptimizations("openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."), cl::Hidden, cl::init(false))
static cl::opt< unsigned > SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden, cl::desc("Maximum amount of shared memory to use."), cl::init(std::numeric_limits< unsigned >::max()))
static cl::opt< bool > EnableVerboseRemarks("openmp-opt-verbose-remarks", cl::desc("Enables more verbose remarks."), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptDeglobalization("openmp-opt-disable-deglobalization", cl::desc("Disable OpenMP optimizations involving deglobalization."), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptBarrierElimination("openmp-opt-disable-barrier-elimination", cl::desc("Disable OpenMP optimizations that eliminate barriers."), cl::Hidden, cl::init(false))
#define DEBUG_TYPE
Definition OpenMPOpt.cpp:68
static cl::opt< bool > DeduceICVValues("openmp-deduce-icv-values", cl::init(false), cl::Hidden)
#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX)
#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE)
static cl::opt< bool > DisableOpenMPOptSPMDization("openmp-opt-disable-spmdization", cl::desc("Disable OpenMP optimizations involving SPMD-ization."), cl::Hidden, cl::init(false))
static cl::opt< bool > AlwaysInlineDeviceFunctions("openmp-opt-inline-device", cl::desc("Inline all applicable functions on the device."), cl::Hidden, cl::init(false))
#define P(N)
FunctionAnalysisManager FAM
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
static StringRef getName(Value *V)
R600 Clause Merge
Basic Register Allocator
Remove Loads Into Fake Uses
std::pair< BasicBlock *, BasicBlock * > Edge
static bool contains(SmallPtrSetImpl< ConstantExpr * > &Cache, ConstantExpr *Expr, Constant *C)
Definition Value.cpp:484
This file implements a set that has insertion order iteration characteristics.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition Debug.h:119
static const int BlockSize
Definition TarWriter.cpp:33
static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, const llvm::StringTable &StandardNames, VectorLibrary VecLib)
Initialize the set of available library functions based on the specified target triple.
Value * RHS
static cl::opt< unsigned > MaxThreads("xcore-max-threads", cl::Optional, cl::desc("Maximum number of threads (for emulation thread-local storage)"), cl::Hidden, cl::value_desc("number"), cl::init(8))
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
size_t size() const
Get the array size.
Definition ArrayRef.h:141
iterator end()
Definition BasicBlock.h:474
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:461
LLVM_ABI const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
LLVM_ABI BasicBlock * splitBasicBlock(iterator I, const Twine &BBName="")
Split the basic block into two basic blocks at the specified instruction.
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
reverse_iterator rbegin()
Definition BasicBlock.h:477
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition BasicBlock.h:206
LLVM_ABI const BasicBlock * getUniqueSuccessor() const
Return the successor of this block if it has a unique successor.
InstListType::reverse_iterator reverse_iterator
Definition BasicBlock.h:172
reverse_iterator rend()
Definition BasicBlock.h:479
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction; assumes that the block is well-formed.
Definition BasicBlock.h:237
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
void setCallingConv(CallingConv::ID CC)
bool arg_empty() const
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
bool doesNotAccessMemory(unsigned OpNo) const
LLVM_ABI bool isIndirectCall() const
Return true if the callsite is an indirect call.
bool isCallee(Value::const_user_iterator UI) const
Determine whether the passed iterator points to the callee operand's Use.
Value * getArgOperand(unsigned i) const
void setArgOperand(unsigned i, Value *v)
iterator_range< User::op_iterator > args()
Iteration adapter for range-for loops.
unsigned getArgOperandNo(const Use *U) const
Given a use for a arg operand, get the arg operand number that corresponds to it.
unsigned arg_size() const
AttributeList getAttributes() const
Return the attributes for this call.
void addParamAttr(unsigned ArgNo, Attribute::AttrKind Kind)
Adds the attribute to the indicated argument.
bool isArgOperand(const Use *U) const
bool hasOperandBundles() const
Return true if this User has any operand bundles.
LLVM_ABI Function * getCaller()
Helper to get the caller (the parent function).
Wrapper to unify "old style" CallGraph and "new style" LazyCallGraph.
void initialize(LazyCallGraph &LCG, LazyCallGraph::SCC &SCC, CGSCCAnalysisManager &AM, CGSCCUpdateResult &UR)
Initializers for usage outside of a CGSCC pass, inside a CGSCC pass in the old and new pass manager (...
static CallInst * Create(FunctionType *Ty, Value *F, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
@ ICMP_SLT
signed less than
Definition InstrTypes.h:769
@ ICMP_NE
not equal
Definition InstrTypes.h:762
static CondBrInst * Create(Value *Cond, BasicBlock *IfTrue, BasicBlock *IfFalse, InsertPosition InsertBefore=nullptr)
static LLVM_ABI Constant * getPointerCast(Constant *C, Type *Ty)
Create a BitCast, AddrSpaceCast, or a PtrToInt cast constant expression.
static LLVM_ABI Constant * getPointerBitCastOrAddrSpaceCast(Constant *C, Type *Ty)
Create a BitCast or AddrSpaceCast for a pointer type depending on the address space.
This is the shared class of boolean and integer constants.
Definition Constants.h:87
IntegerType * getIntegerType() const
Variant of the getType() method to always return an IntegerType, which reduces the amount of casting ...
Definition Constants.h:198
static LLVM_ABI ConstantInt * getTrue(LLVMContext &Context)
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition Constants.h:219
int64_t getSExtValue() const
Return the constant as a 64-bit integer value after it has been sign extended as appropriate for the ...
Definition Constants.h:174
This is an important base class in LLVM.
Definition Constant.h:43
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
ValueT lookup(const_arg_type_t< KeyT > Val) const
Return the entry for the specified key, or a default constructed value if no such entry exists.
Definition DenseMap.h:252
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:286
LLVM_ABI Instruction * findNearestCommonDominator(Instruction *I1, Instruction *I2) const
Find the nearest instruction I that dominates both I1 and I2, in the sense that a result produced bef...
static ErrorSuccess success()
Create a success value.
Definition Error.h:336
AtomicOrdering getOrdering() const
Returns the ordering constraint of this fence instruction.
A proxy from a FunctionAnalysisManager to an SCC.
const BasicBlock & getEntryBlock() const
Definition Function.h:783
const BasicBlock & front() const
Definition Function.h:834
LLVMContext & getContext() const
getContext - Return a reference to the LLVMContext associated with this function.
Definition Function.cpp:353
Argument * getArg(unsigned i) const
Definition Function.h:860
bool hasFnAttribute(Attribute::AttrKind Kind) const
Return true if the function has the attribute.
Definition Function.cpp:723
LLVM_ABI bool isDeclaration() const
Return true if the primary definition of this global value is outside of the current translation unit...
Definition Globals.cpp:346
bool hasLocalLinkage() const
Module * getParent()
Get the module that this global value is contained inside of...
@ PrivateLinkage
Like Internal, but omit from symbol table.
Definition GlobalValue.h:61
@ InternalLinkage
Rename collisions when linking (static functions).
Definition GlobalValue.h:60
const Constant * getInitializer() const
getInitializer - Return the initializer for this global variable.
LLVM_ABI void setInitializer(Constant *InitVal)
setInitializer - Sets the initializer for this global variable, removing any existing initializer if ...
Definition Globals.cpp:551
BasicBlock * getBlock() const
Definition IRBuilder.h:261
CondBrInst * CreateCondBr(Value *Cond, BasicBlock *True, BasicBlock *False, MDNode *BranchWeights=nullptr, MDNode *Unpredictable=nullptr)
Create a conditional 'br Cond, TrueDest, FalseDest' instruction.
Definition IRBuilder.h:1216
CallInst * CreateCall(FunctionType *FTy, Value *Callee, ArrayRef< Value * > Args={}, const Twine &Name="", MDNode *FPMathTag=nullptr)
Definition IRBuilder.h:2543
Value * CreateIsNull(Value *Arg, const Twine &Name="")
Return a boolean value testing if Arg == 0.
Definition IRBuilder.h:2692
LLVM_ABI bool isLifetimeStartOrEnd() const LLVM_READONLY
Return true if the instruction is a llvm.lifetime.start or llvm.lifetime.end marker.
LLVM_ABI bool mayWriteToMemory() const LLVM_READONLY
Return true if this instruction may modify memory.
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
LLVM_ABI const Module * getModule() const
Return the module owning the function this instruction belongs to or nullptr it the function does not...
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
LLVM_ABI const Function * getFunction() const
Return the function this instruction belongs to.
LLVM_ABI bool mayHaveSideEffects() const LLVM_READONLY
Return true if the instruction may have side effects.
LLVM_ABI bool mayReadFromMemory() const LLVM_READONLY
Return true if this instruction may read memory.
void setDebugLoc(DebugLoc Loc)
Set the debug location information for this instruction.
LLVM_ABI void setSuccessor(unsigned Idx, BasicBlock *BB)
Update the specified successor to point at the provided block.
A node in the call graph.
An SCC of the call graph.
A lazily constructed view of the call graph of a module.
LLVM_ABI void eraseFromParent()
This method unlinks 'this' from the containing function and deletes it.
LLVM_ABI StringRef getName() const
Return the name of the corresponding LLVM basic block, or an empty string.
Root of the metadata hierarchy.
Definition Metadata.h:64
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
const Triple & getTargetTriple() const
Get the target triple which is a string describing the target host.
Definition Module.h:283
LLVM_ABI Constant * getOrCreateIdent(Constant *SrcLocStr, uint32_t SrcLocStrSize, omp::IdentFlag Flags=omp::IdentFlag(0), unsigned Reserve2Flags=0)
Return an ident_t* encoding the source location SrcLocStr and Flags.
LLVM_ABI FunctionCallee getOrCreateRuntimeFunction(Module &M, omp::RuntimeFunction FnID)
Return the function declaration for the runtime function with FnID.
static LLVM_ABI std::pair< int32_t, int32_t > readThreadBoundsForKernel(const Triple &T, Function &Kernel)
}
LLVM_ABI Constant * getOrCreateSrcLocStr(StringRef LocStr, uint32_t &SrcLocStrSize)
Return the (LLVM-IR) string describing the source location LocStr.
IRBuilder<>::InsertPoint InsertPointTy
Type used throughout for insertion points.
IRBuilder Builder
The LLVM-IR Builder used to create IR.
static LLVM_ABI std::pair< int32_t, int32_t > readTeamBoundsForKernel(const Triple &T, Function &Kernel)
Read/write a bounds on teams for Kernel.
bool updateToLocation(const LocationDescription &Loc)
Update the internal location to Loc.
LLVM_ABI PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR)
LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM)
Diagnostic information for optimization analysis remarks.
The optimization diagnostic interface.
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Definition Analysis.h:115
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
static ReturnInst * Create(LLVMContext &C, Value *retVal=nullptr, InsertPosition InsertBefore=nullptr)
A vector that has set insertion semantics.
Definition SetVector.h:57
size_type size() const
Determine the number of elements in the SetVector.
Definition SetVector.h:103
size_type count(const_arg_type key) const
Count the number of elements of a given key in the SetVector.
Definition SetVector.h:262
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition SetVector.h:151
size_type size() const
Definition SmallPtrSet.h:99
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
iterator begin() const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
reference emplace_back(ArgTypes &&... Args)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
bool starts_with(StringRef Prefix) const
Check if this string starts with the given Prefix.
Definition StringRef.h:258
Triple - Helper class for working with autoconf configuration names.
Definition Triple.h:47
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
Definition Type.cpp:309
LLVM_ABI unsigned getPointerAddressSpace() const
Get the address space of this pointer or pointer vector type.
static UncondBrInst * Create(BasicBlock *Target, InsertPosition InsertBefore=nullptr)
static LLVM_ABI UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
LLVM_ABI bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Definition User.cpp:25
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:255
LLVM_ABI void setName(const Twine &Name)
Change the name of the value.
Definition Value.cpp:394
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition Value.h:439
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition Value.cpp:553
iterator_range< user_iterator > users()
Definition Value.h:426
User * user_back()
Definition Value.h:412
LLVM_ABI const Value * stripPointerCasts() const
Strip off pointer casts, all-zero GEPs and address space casts.
Definition Value.cpp:713
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:319
const ParentTy * getParent() const
Definition ilist_node.h:34
self_iterator getIterator()
Definition ilist_node.h:123
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition ilist_node.h:348
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
GlobalVariable * getKernelEnvironementGVFromKernelInitCB(CallBase *KernelInitCB)
ConstantStruct * getKernelEnvironementFromKernelInitCB(CallBase *KernelInitCB)
Abstract Attribute helper functions.
Definition Attributor.h:165
LLVM_ABI bool isValidAtPosition(const ValueAndContext &VAC, InformationCache &InfoCache)
Return true if the value of VAC is a valid at the position of VAC, that is a constant,...
LLVM_ABI bool isPotentiallyAffectedByBarrier(Attributor &A, const Instruction &I, const AbstractAttribute &QueryingAA)
Return true if I is potentially affected by a barrier.
@ Interprocedural
Definition Attributor.h:196
LLVM_ABI bool isNoSyncInst(Attributor &A, const Instruction &I, const AbstractAttribute &QueryingAA)
Return true if I is a nosync instruction.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
E & operator^=(E &LHS, E RHS)
@ Entry
Definition COFF.h:862
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
initializer< Ty > init(const Ty &Val)
DXILDebugInfoMap run(Module &M)
LLVM_ABI bool isOpenMPDevice(Module &M)
Helper to determine if M is a OpenMP target offloading device module.
LLVM_ABI bool containsOpenMP(Module &M)
Helper to determine if M contains OpenMP.
InternalControlVar
IDs for all Internal Control Variables (ICVs).
RuntimeFunction
IDs for all omp runtime library (RTL) functions.
LLVM_ABI KernelSet getDeviceKernels(Module &M)
Get OpenMP device kernels in M.
@ OMP_TGT_EXEC_MODE_GENERIC_SPMD
SetVector< Kernel > KernelSet
Set of kernels in the module.
Definition OpenMPOpt.h:24
Function * Kernel
Summary of a kernel (=entry point for target offloading).
Definition OpenMPOpt.h:21
LLVM_ABI bool isOpenMPKernel(Function &Fn)
Return true iff Fn is an OpenMP GPU kernel; Fn has the "kernel" attribute.
DiagnosticInfoOptimizationBase::Argument NV
NodeAddr< UseNode * > Use
Definition RDFGraph.h:385
bool empty() const
Definition BasicBlock.h:101
iterator end() const
Definition BasicBlock.h:89
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
LLVM_ABI iterator begin() const
This is an optimization pass for GlobalISel generic memory operations.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:315
@ Offset
Definition DWP.cpp:573
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1739
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition STLExtras.h:1669
bool succ_empty(const Instruction *I)
Definition CFG.h:141
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
LLVM_ABI bool isRemovableAlloc(const CallBase *V, const TargetLibraryInfo *TLI)
Return true if this is a call to an allocation function that does not have side effects that we are r...
bool operator!=(uint64_t V1, const APInt &V2)
Definition APInt.h:2142
constexpr from_range_t from_range
Value * GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, const DataLayout &DL, bool AllowNonInbounds=true)
Analyze the specified pointer to see if it can be expressed as a base pointer plus a constant offset.
InnerAnalysisManagerProxy< FunctionAnalysisManager, Module > FunctionAnalysisManagerModuleProxy
Provide the FunctionAnalysisManager to Module proxy.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
RelativeUniformCounterPtr ValuesPtrExpr VTableAddr Value
Definition InstrProf.h:143
AnalysisManager< LazyCallGraph::SCC, LazyCallGraph & > CGSCCAnalysisManager
The CGSCC analysis manager.
@ ThinLTOPostLink
ThinLTO postlink (backend compile) phase.
Definition Pass.h:83
@ FullLTOPostLink
Full LTO postlink (backend compile) phase.
Definition Pass.h:87
@ ThinLTOPreLink
ThinLTO prelink (summary) phase.
Definition Pass.h:81
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:753
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:209
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
MutableArrayRef(T &OneElt) -> MutableArrayRef< T >
void cantFail(Error Err, const char *Msg=nullptr)
Report a fatal error if Err is a failure value.
Definition Error.h:769
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
bool operator&=(SparseBitVector< ElementSize > *LHS, const SparseBitVector< ElementSize > &RHS)
LLVM_ABI BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="")
Split the specified block at the specified instruction.
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition STLExtras.h:2012
ArrayRef(const T &OneElt) -> ArrayRef< T >
LLVM_ABI Value * getFreedOperand(const CallBase *CB, const TargetLibraryInfo *TLI)
If this if a call to a free function, return the freed operand.
std::string toString(const APInt &I, unsigned Radix, bool Signed, bool formatAsCLiteral=false, bool UpperCase=true, bool InsertSeparators=false)
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
auto predecessors(const MachineBasicBlock *BB)
ChangeStatus
{
Definition Attributor.h:485
LLVM_ABI Constant * ConstantFoldInsertValueInstruction(Constant *Agg, Constant *Val, ArrayRef< unsigned > Idxs)
Attempt to constant fold an insertvalue instruction with the specified operands and indices.
@ OPTIONAL
The target may be valid if the source is not.
Definition Attributor.h:497
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
BumpPtrAllocatorImpl<> BumpPtrAllocator
The standard BumpPtrAllocator which just uses the default template parameters.
Definition Allocator.h:390
LLVM_ABI const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=MaxLookupSearchDepth)
This method strips off any GEP address adjustments, pointer casts or llvm.threadlocal....
AnalysisManager< Module > ModuleAnalysisManager
Convenience typedef for the Module analysis manager.
Definition MIRParser.h:39
#define N
static LLVM_ABI AAExecutionDomain & createForPosition(const IRPosition &IRP, Attributor &A)
Create an abstract attribute view for the position IRP.
AAExecutionDomain(const IRPosition &IRP, Attributor &A)
static LLVM_ABI const char ID
Unique ID (due to the unique address)
AccessKind
Simple enum to distinguish read/write/read-write accesses.
StateType::base_t MemoryLocationsKind
static LLVM_ABI bool isAlignedBarrier(const CallBase &CB, bool ExecutedAligned)
Helper function to determine if CB is an aligned (GPU) barrier.
Base struct for all "concrete attribute" deductions.
virtual const char * getIdAddr() const =0
This function should return the address of the ID of the AbstractAttribute.
An interface to query the internal state of an abstract attribute.
Wrapper for FunctionAnalysisManager.
Configuration for the Attributor.
std::function< void(Attributor &A, const Function &F)> InitializationCallback
Callback function to be invoked on internal functions marked live.
std::optional< unsigned > MaxFixpointIterations
Maximum number of iterations to run until fixpoint.
bool RewriteSignatures
Flag to determine if we rewrite function signatures.
const char * PassName
}
OptimizationRemarkGetter OREGetter
IPOAmendableCBTy IPOAmendableCB
bool IsModulePass
Is the user of the Attributor a module pass or not.
bool DefaultInitializeLiveInternals
Flag to determine if we want to initialize all default AAs for an internal function marked live.
The fixpoint analysis framework that orchestrates the attribute deduction.
static LLVM_ABI bool isInternalizable(Function &F)
Returns true if the function F can be internalized.
std::function< std::optional< Value * >( const IRPosition &, const AbstractAttribute *, bool &)> SimplifictionCallbackTy
Register CB as a simplification callback.
std::function< std::optional< Constant * >( const GlobalVariable &, const AbstractAttribute *, bool &)> GlobalVariableSimplifictionCallbackTy
Register CB as a simplification callback.
std::function< bool(Attributor &, const AbstractAttribute *)> VirtualUseCallbackTy
static LLVM_ABI bool internalizeFunctions(SmallPtrSetImpl< Function * > &FnSet, DenseMap< Function *, Function * > &FnMap)
Make copies of each function in the set FnSet such that the copied version has internal linkage after...
Simple wrapper for a single bit (boolean) state.
Support structure for SCC passes to communicate updates the call graph back to the CGSCC pass manager...
Helper to describe and deal with positions in the LLVM-IR.
Definition Attributor.h:582
static const IRPosition callsite_returned(const CallBase &CB)
Create a position describing the returned value of CB.
Definition Attributor.h:650
static const IRPosition returned(const Function &F, const CallBaseContext *CBContext=nullptr)
Create a position describing the returned value of F.
Definition Attributor.h:632
static const IRPosition value(const Value &V, const CallBaseContext *CBContext=nullptr)
Create a position describing the value of V.
Definition Attributor.h:606
static const IRPosition inst(const Instruction &I, const CallBaseContext *CBContext=nullptr)
Create a position describing the instruction I.
Definition Attributor.h:618
@ IRP_ARGUMENT
An attribute for a function argument.
Definition Attributor.h:596
@ IRP_RETURNED
An attribute for the function return value.
Definition Attributor.h:592
@ IRP_CALL_SITE
An attribute for a call site (function scope).
Definition Attributor.h:595
@ IRP_CALL_SITE_RETURNED
An attribute for a call site return value.
Definition Attributor.h:593
@ IRP_FUNCTION
An attribute for a function (scope).
Definition Attributor.h:594
@ IRP_FLOAT
A position that is not associated with a spot suitable for attributes.
Definition Attributor.h:590
@ IRP_CALL_SITE_ARGUMENT
An attribute for a call site argument.
Definition Attributor.h:597
@ IRP_INVALID
An invalid position.
Definition Attributor.h:589
static const IRPosition function(const Function &F, const CallBaseContext *CBContext=nullptr)
Create a position describing the function scope of F.
Definition Attributor.h:625
Kind getPositionKind() const
Return the associated position kind.
Definition Attributor.h:878
static const IRPosition callsite_function(const CallBase &CB)
Create a position describing the function scope of CB.
Definition Attributor.h:645
Data structure to hold cached (LLVM-IR) information.
Defines various target-specific GPU grid values that must be consistent between host RTL (plugin),...