LLVM  16.0.0git
OpenMPOpt.cpp
Go to the documentation of this file.
1 //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpenMP specific optimizations:
10 //
11 // - Deduplication of runtime calls, e.g., omp_get_thread_num.
12 // - Replacing globalized device memory with stack memory.
13 // - Replacing globalized device memory with shared memory.
14 // - Parallel region merging.
15 // - Transforming generic-mode device kernels to SPMD mode.
16 // - Specializing the state machine for generic-mode device kernels.
17 //
18 //===----------------------------------------------------------------------===//
19 
21 
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/Statistic.h"
26 #include "llvm/ADT/StringRef.h"
34 #include "llvm/IR/Assumptions.h"
35 #include "llvm/IR/Constants.h"
36 #include "llvm/IR/DiagnosticInfo.h"
37 #include "llvm/IR/GlobalValue.h"
38 #include "llvm/IR/GlobalVariable.h"
39 #include "llvm/IR/Instruction.h"
40 #include "llvm/IR/Instructions.h"
41 #include "llvm/IR/IntrinsicInst.h"
42 #include "llvm/IR/IntrinsicsAMDGPU.h"
43 #include "llvm/IR/IntrinsicsNVPTX.h"
44 #include "llvm/IR/LLVMContext.h"
45 #include "llvm/InitializePasses.h"
47 #include "llvm/Support/Debug.h"
48 #include "llvm/Transforms/IPO.h"
52 
53 #include <algorithm>
54 
55 using namespace llvm;
56 using namespace omp;
57 
58 #define DEBUG_TYPE "openmp-opt"
59 
61  "openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."),
62  cl::Hidden, cl::init(false));
63 
65  "openmp-opt-enable-merging",
66  cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
67  cl::init(false));
68 
69 static cl::opt<bool>
70  DisableInternalization("openmp-opt-disable-internalization",
71  cl::desc("Disable function internalization."),
72  cl::Hidden, cl::init(false));
73 
74 static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
75  cl::Hidden);
76 static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
77  cl::init(false), cl::Hidden);
78 
80  "openmp-hide-memory-transfer-latency",
81  cl::desc("[WIP] Tries to hide the latency of host to device memory"
82  " transfers"),
83  cl::Hidden, cl::init(false));
84 
86  "openmp-opt-disable-deglobalization",
87  cl::desc("Disable OpenMP optimizations involving deglobalization."),
88  cl::Hidden, cl::init(false));
89 
91  "openmp-opt-disable-spmdization",
92  cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
93  cl::Hidden, cl::init(false));
94 
96  "openmp-opt-disable-folding",
97  cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
98  cl::init(false));
99 
101  "openmp-opt-disable-state-machine-rewrite",
102  cl::desc("Disable OpenMP optimizations that replace the state machine."),
103  cl::Hidden, cl::init(false));
104 
106  "openmp-opt-disable-barrier-elimination",
107  cl::desc("Disable OpenMP optimizations that eliminate barriers."),
108  cl::Hidden, cl::init(false));
109 
111  "openmp-opt-print-module-after",
112  cl::desc("Print the current module after OpenMP optimizations."),
113  cl::Hidden, cl::init(false));
114 
116  "openmp-opt-print-module-before",
117  cl::desc("Print the current module before OpenMP optimizations."),
118  cl::Hidden, cl::init(false));
119 
121  "openmp-opt-inline-device",
122  cl::desc("Inline all applicible functions on the device."), cl::Hidden,
123  cl::init(false));
124 
125 static cl::opt<bool>
126  EnableVerboseRemarks("openmp-opt-verbose-remarks",
127  cl::desc("Enables more verbose remarks."), cl::Hidden,
128  cl::init(false));
129 
130 static cl::opt<unsigned>
131  SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden,
132  cl::desc("Maximal number of attributor iterations."),
133  cl::init(256));
134 
135 static cl::opt<unsigned>
136  SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden,
137  cl::desc("Maximum amount of shared memory to use."),
139 
140 STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
141  "Number of OpenMP runtime calls deduplicated");
142 STATISTIC(NumOpenMPParallelRegionsDeleted,
143  "Number of OpenMP parallel regions deleted");
144 STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
145  "Number of OpenMP runtime functions identified");
146 STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
147  "Number of OpenMP runtime function uses identified");
148 STATISTIC(NumOpenMPTargetRegionKernels,
149  "Number of OpenMP target region entry points (=kernels) identified");
150 STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
151  "Number of OpenMP target region entry points (=kernels) executed in "
152  "SPMD-mode instead of generic-mode");
153 STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
154  "Number of OpenMP target region entry points (=kernels) executed in "
155  "generic-mode without a state machines");
156 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
157  "Number of OpenMP target region entry points (=kernels) executed in "
158  "generic-mode with customized state machines with fallback");
159 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
160  "Number of OpenMP target region entry points (=kernels) executed in "
161  "generic-mode with customized state machines without fallback");
162 STATISTIC(
163  NumOpenMPParallelRegionsReplacedInGPUStateMachine,
164  "Number of OpenMP parallel regions replaced with ID in GPU state machines");
165 STATISTIC(NumOpenMPParallelRegionsMerged,
166  "Number of OpenMP parallel regions merged");
167 STATISTIC(NumBytesMovedToSharedMemory,
168  "Amount of memory pushed to shared memory");
169 STATISTIC(NumBarriersEliminated, "Number of redundant barriers eliminated");
170 
171 #if !defined(NDEBUG)
172 static constexpr auto TAG = "[" DEBUG_TYPE "]";
173 #endif
174 
175 namespace {
176 
177 struct AAHeapToShared;
178 
179 struct AAICVTracker;
180 
181 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for
182 /// Attributor runs.
183 struct OMPInformationCache : public InformationCache {
184  OMPInformationCache(Module &M, AnalysisGetter &AG,
187  : InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),
188  Kernels(Kernels) {
189 
190  OMPBuilder.initialize();
191  initializeRuntimeFunctions(M);
192  initializeInternalControlVars();
193  }
194 
195  /// Generic information that describes an internal control variable.
196  struct InternalControlVarInfo {
197  /// The kind, as described by InternalControlVar enum.
198  InternalControlVar Kind;
199 
200  /// The name of the ICV.
201  StringRef Name;
202 
203  /// Environment variable associated with this ICV.
204  StringRef EnvVarName;
205 
206  /// Initial value kind.
207  ICVInitValue InitKind;
208 
209  /// Initial value.
210  ConstantInt *InitValue;
211 
212  /// Setter RTL function associated with this ICV.
213  RuntimeFunction Setter;
214 
215  /// Getter RTL function associated with this ICV.
216  RuntimeFunction Getter;
217 
218  /// RTL Function corresponding to the override clause of this ICV
220  };
221 
222  /// Generic information that describes a runtime function
223  struct RuntimeFunctionInfo {
224 
225  /// The kind, as described by the RuntimeFunction enum.
226  RuntimeFunction Kind;
227 
228  /// The name of the function.
229  StringRef Name;
230 
231  /// Flag to indicate a variadic function.
232  bool IsVarArg;
233 
234  /// The return type of the function.
235  Type *ReturnType;
236 
237  /// The argument types of the function.
238  SmallVector<Type *, 8> ArgumentTypes;
239 
240  /// The declaration if available.
241  Function *Declaration = nullptr;
242 
243  /// Uses of this runtime function per function containing the use.
244  using UseVector = SmallVector<Use *, 16>;
245 
246  /// Clear UsesMap for runtime function.
247  void clearUsesMap() { UsesMap.clear(); }
248 
249  /// Boolean conversion that is true if the runtime function was found.
250  operator bool() const { return Declaration; }
251 
252  /// Return the vector of uses in function \p F.
253  UseVector &getOrCreateUseVector(Function *F) {
254  std::shared_ptr<UseVector> &UV = UsesMap[F];
255  if (!UV)
256  UV = std::make_shared<UseVector>();
257  return *UV;
258  }
259 
260  /// Return the vector of uses in function \p F or `nullptr` if there are
261  /// none.
262  const UseVector *getUseVector(Function &F) const {
263  auto I = UsesMap.find(&F);
264  if (I != UsesMap.end())
265  return I->second.get();
266  return nullptr;
267  }
268 
269  /// Return how many functions contain uses of this runtime function.
270  size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
271 
272  /// Return the number of arguments (or the minimal number for variadic
273  /// functions).
274  size_t getNumArgs() const { return ArgumentTypes.size(); }
275 
276  /// Run the callback \p CB on each use and forget the use if the result is
277  /// true. The callback will be fed the function in which the use was
278  /// encountered as second argument.
279  void foreachUse(SmallVectorImpl<Function *> &SCC,
280  function_ref<bool(Use &, Function &)> CB) {
281  for (Function *F : SCC)
282  foreachUse(CB, F);
283  }
284 
285  /// Run the callback \p CB on each use within the function \p F and forget
286  /// the use if the result is true.
287  void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
288  SmallVector<unsigned, 8> ToBeDeleted;
289  ToBeDeleted.clear();
290 
291  unsigned Idx = 0;
292  UseVector &UV = getOrCreateUseVector(F);
293 
294  for (Use *U : UV) {
295  if (CB(*U, *F))
296  ToBeDeleted.push_back(Idx);
297  ++Idx;
298  }
299 
300  // Remove the to-be-deleted indices in reverse order as prior
301  // modifications will not modify the smaller indices.
302  while (!ToBeDeleted.empty()) {
303  unsigned Idx = ToBeDeleted.pop_back_val();
304  UV[Idx] = UV.back();
305  UV.pop_back();
306  }
307  }
308 
309  private:
310  /// Map from functions to all uses of this runtime function contained in
311  /// them.
313 
314  public:
315  /// Iterators for the uses of this runtime function.
316  decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
317  decltype(UsesMap)::iterator end() { return UsesMap.end(); }
318  };
319 
320  /// An OpenMP-IR-Builder instance
321  OpenMPIRBuilder OMPBuilder;
322 
323  /// Map from runtime function kind to the runtime function description.
324  EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
325  RuntimeFunction::OMPRTL___last>
326  RFIs;
327 
328  /// Map from function declarations/definitions to their runtime enum type.
329  DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
330 
331  /// Map from ICV kind to the ICV description.
332  EnumeratedArray<InternalControlVarInfo, InternalControlVar,
333  InternalControlVar::ICV___last>
334  ICVs;
335 
336  /// Helper to initialize all internal control variable information for those
337  /// defined in OMPKinds.def.
338  void initializeInternalControlVars() {
339 #define ICV_RT_SET(_Name, RTL) \
340  { \
341  auto &ICV = ICVs[_Name]; \
342  ICV.Setter = RTL; \
343  }
344 #define ICV_RT_GET(Name, RTL) \
345  { \
346  auto &ICV = ICVs[Name]; \
347  ICV.Getter = RTL; \
348  }
349 #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \
350  { \
351  auto &ICV = ICVs[Enum]; \
352  ICV.Name = _Name; \
353  ICV.Kind = Enum; \
354  ICV.InitKind = Init; \
355  ICV.EnvVarName = _EnvVarName; \
356  switch (ICV.InitKind) { \
357  case ICV_IMPLEMENTATION_DEFINED: \
358  ICV.InitValue = nullptr; \
359  break; \
360  case ICV_ZERO: \
361  ICV.InitValue = ConstantInt::get( \
362  Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \
363  break; \
364  case ICV_FALSE: \
365  ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \
366  break; \
367  case ICV_LAST: \
368  break; \
369  } \
370  }
371 #include "llvm/Frontend/OpenMP/OMPKinds.def"
372  }
373 
374  /// Returns true if the function declaration \p F matches the runtime
375  /// function types, that is, return type \p RTFRetType, and argument types
376  /// \p RTFArgTypes.
377  static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
378  SmallVector<Type *, 8> &RTFArgTypes) {
379  // TODO: We should output information to the user (under debug output
380  // and via remarks).
381 
382  if (!F)
383  return false;
384  if (F->getReturnType() != RTFRetType)
385  return false;
386  if (F->arg_size() != RTFArgTypes.size())
387  return false;
388 
389  auto *RTFTyIt = RTFArgTypes.begin();
390  for (Argument &Arg : F->args()) {
391  if (Arg.getType() != *RTFTyIt)
392  return false;
393 
394  ++RTFTyIt;
395  }
396 
397  return true;
398  }
399 
400  // Helper to collect all uses of the declaration in the UsesMap.
401  unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
402  unsigned NumUses = 0;
403  if (!RFI.Declaration)
404  return NumUses;
405  OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
406 
407  if (CollectStats) {
408  NumOpenMPRuntimeFunctionsIdentified += 1;
409  NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
410  }
411 
412  // TODO: We directly convert uses into proper calls and unknown uses.
413  for (Use &U : RFI.Declaration->uses()) {
414  if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
415  if (ModuleSlice.empty() || ModuleSlice.count(UserI->getFunction())) {
416  RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
417  ++NumUses;
418  }
419  } else {
420  RFI.getOrCreateUseVector(nullptr).push_back(&U);
421  ++NumUses;
422  }
423  }
424  return NumUses;
425  }
426 
427  // Helper function to recollect uses of a runtime function.
428  void recollectUsesForFunction(RuntimeFunction RTF) {
429  auto &RFI = RFIs[RTF];
430  RFI.clearUsesMap();
431  collectUses(RFI, /*CollectStats*/ false);
432  }
433 
434  // Helper function to recollect uses of all runtime functions.
435  void recollectUses() {
436  for (int Idx = 0; Idx < RFIs.size(); ++Idx)
437  recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
438  }
439 
440  // Helper function to inherit the calling convention of the function callee.
441  void setCallingConvention(FunctionCallee Callee, CallInst *CI) {
442  if (Function *Fn = dyn_cast<Function>(Callee.getCallee()))
443  CI->setCallingConv(Fn->getCallingConv());
444  }
445 
446  /// Helper to initialize all runtime function information for those defined
447  /// in OpenMPKinds.def.
448  void initializeRuntimeFunctions(Module &M) {
449 
450  // Helper macros for handling __VA_ARGS__ in OMP_RTL
451 #define OMP_TYPE(VarName, ...) \
452  Type *VarName = OMPBuilder.VarName; \
453  (void)VarName;
454 
455 #define OMP_ARRAY_TYPE(VarName, ...) \
456  ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \
457  (void)VarName##Ty; \
458  PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \
459  (void)VarName##PtrTy;
460 
461 #define OMP_FUNCTION_TYPE(VarName, ...) \
462  FunctionType *VarName = OMPBuilder.VarName; \
463  (void)VarName; \
464  PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
465  (void)VarName##Ptr;
466 
467 #define OMP_STRUCT_TYPE(VarName, ...) \
468  StructType *VarName = OMPBuilder.VarName; \
469  (void)VarName; \
470  PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
471  (void)VarName##Ptr;
472 
473 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \
474  { \
475  SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \
476  Function *F = M.getFunction(_Name); \
477  RTLFunctions.insert(F); \
478  if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \
479  RuntimeFunctionIDMap[F] = _Enum; \
480  auto &RFI = RFIs[_Enum]; \
481  RFI.Kind = _Enum; \
482  RFI.Name = _Name; \
483  RFI.IsVarArg = _IsVarArg; \
484  RFI.ReturnType = OMPBuilder._ReturnType; \
485  RFI.ArgumentTypes = std::move(ArgsTypes); \
486  RFI.Declaration = F; \
487  unsigned NumUses = collectUses(RFI); \
488  (void)NumUses; \
489  LLVM_DEBUG({ \
490  dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \
491  << " found\n"; \
492  if (RFI.Declaration) \
493  dbgs() << TAG << "-> got " << NumUses << " uses in " \
494  << RFI.getNumFunctionsWithUses() \
495  << " different functions.\n"; \
496  }); \
497  } \
498  }
499 #include "llvm/Frontend/OpenMP/OMPKinds.def"
500 
501  // Remove the `noinline` attribute from `__kmpc`, `_OMP::` and `omp_`
502  // functions, except if `optnone` is present.
503  if (isOpenMPDevice(M)) {
504  for (Function &F : M) {
505  for (StringRef Prefix : {"__kmpc", "_ZN4_OMP", "omp_"})
506  if (F.hasFnAttribute(Attribute::NoInline) &&
507  F.getName().startswith(Prefix) &&
508  !F.hasFnAttribute(Attribute::OptimizeNone))
509  F.removeFnAttr(Attribute::NoInline);
510  }
511  }
512 
513  // TODO: We should attach the attributes defined in OMPKinds.def.
514  }
515 
516  /// Collection of known kernels (\see Kernel) in the module.
518 
519  /// Collection of known OpenMP runtime functions..
520  DenseSet<const Function *> RTLFunctions;
521 };
522 
523 template <typename Ty, bool InsertInvalidates = true>
524 struct BooleanStateWithSetVector : public BooleanState {
525  bool contains(const Ty &Elem) const { return Set.contains(Elem); }
526  bool insert(const Ty &Elem) {
527  if (InsertInvalidates)
529  return Set.insert(Elem);
530  }
531 
532  const Ty &operator[](int Idx) const { return Set[Idx]; }
533  bool operator==(const BooleanStateWithSetVector &RHS) const {
534  return BooleanState::operator==(RHS) && Set == RHS.Set;
535  }
536  bool operator!=(const BooleanStateWithSetVector &RHS) const {
537  return !(*this == RHS);
538  }
539 
540  bool empty() const { return Set.empty(); }
541  size_t size() const { return Set.size(); }
542 
543  /// "Clamp" this state with \p RHS.
544  BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
546  Set.insert(RHS.Set.begin(), RHS.Set.end());
547  return *this;
548  }
549 
550 private:
551  /// A set to keep track of elements.
552  SetVector<Ty> Set;
553 
554 public:
555  typename decltype(Set)::iterator begin() { return Set.begin(); }
556  typename decltype(Set)::iterator end() { return Set.end(); }
557  typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
558  typename decltype(Set)::const_iterator end() const { return Set.end(); }
559 };
560 
561 template <typename Ty, bool InsertInvalidates = true>
562 using BooleanStateWithPtrSetVector =
563  BooleanStateWithSetVector<Ty *, InsertInvalidates>;
564 
565 struct KernelInfoState : AbstractState {
566  /// Flag to track if we reached a fixpoint.
567  bool IsAtFixpoint = false;
568 
569  /// The parallel regions (identified by the outlined parallel functions) that
570  /// can be reached from the associated function.
571  BooleanStateWithPtrSetVector<Function, /* InsertInvalidates */ false>
572  ReachedKnownParallelRegions;
573 
574  /// State to track what parallel region we might reach.
575  BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
576 
577  /// State to track if we are in SPMD-mode, assumed or know, and why we decided
578  /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
579  /// false.
580  BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
581 
582  /// The __kmpc_target_init call in this kernel, if any. If we find more than
583  /// one we abort as the kernel is malformed.
584  CallBase *KernelInitCB = nullptr;
585 
586  /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
587  /// one we abort as the kernel is malformed.
588  CallBase *KernelDeinitCB = nullptr;
589 
590  /// Flag to indicate if the associated function is a kernel entry.
591  bool IsKernelEntry = false;
592 
593  /// State to track what kernel entries can reach the associated function.
594  BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
595 
596  /// State to indicate if we can track parallel level of the associated
597  /// function. We will give up tracking if we encounter unknown caller or the
598  /// caller is __kmpc_parallel_51.
599  BooleanStateWithSetVector<uint8_t> ParallelLevels;
600 
601  /// Abstract State interface
602  ///{
603 
604  KernelInfoState() = default;
605  KernelInfoState(bool BestState) {
606  if (!BestState)
607  indicatePessimisticFixpoint();
608  }
609 
610  /// See AbstractState::isValidState(...)
611  bool isValidState() const override { return true; }
612 
613  /// See AbstractState::isAtFixpoint(...)
614  bool isAtFixpoint() const override { return IsAtFixpoint; }
615 
616  /// See AbstractState::indicatePessimisticFixpoint(...)
617  ChangeStatus indicatePessimisticFixpoint() override {
618  IsAtFixpoint = true;
619  ReachingKernelEntries.indicatePessimisticFixpoint();
620  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
621  ReachedKnownParallelRegions.indicatePessimisticFixpoint();
622  ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
623  return ChangeStatus::CHANGED;
624  }
625 
626  /// See AbstractState::indicateOptimisticFixpoint(...)
627  ChangeStatus indicateOptimisticFixpoint() override {
628  IsAtFixpoint = true;
629  ReachingKernelEntries.indicateOptimisticFixpoint();
630  SPMDCompatibilityTracker.indicateOptimisticFixpoint();
631  ReachedKnownParallelRegions.indicateOptimisticFixpoint();
632  ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
634  }
635 
636  /// Return the assumed state
637  KernelInfoState &getAssumed() { return *this; }
638  const KernelInfoState &getAssumed() const { return *this; }
639 
640  bool operator==(const KernelInfoState &RHS) const {
641  if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
642  return false;
643  if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
644  return false;
645  if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
646  return false;
647  if (ReachingKernelEntries != RHS.ReachingKernelEntries)
648  return false;
649  return true;
650  }
651 
652  /// Returns true if this kernel contains any OpenMP parallel regions.
653  bool mayContainParallelRegion() {
654  return !ReachedKnownParallelRegions.empty() ||
655  !ReachedUnknownParallelRegions.empty();
656  }
657 
658  /// Return empty set as the best state of potential values.
659  static KernelInfoState getBestState() { return KernelInfoState(true); }
660 
661  static KernelInfoState getBestState(KernelInfoState &KIS) {
662  return getBestState();
663  }
664 
665  /// Return full set as the worst state of potential values.
666  static KernelInfoState getWorstState() { return KernelInfoState(false); }
667 
668  /// "Clamp" this state with \p KIS.
669  KernelInfoState operator^=(const KernelInfoState &KIS) {
670  // Do not merge two different _init and _deinit call sites.
671  if (KIS.KernelInitCB) {
672  if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
673  llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
674  "assumptions.");
675  KernelInitCB = KIS.KernelInitCB;
676  }
677  if (KIS.KernelDeinitCB) {
678  if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
679  llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
680  "assumptions.");
681  KernelDeinitCB = KIS.KernelDeinitCB;
682  }
683  SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
684  ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
685  ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
686  return *this;
687  }
688 
689  KernelInfoState operator&=(const KernelInfoState &KIS) {
690  return (*this ^= KIS);
691  }
692 
693  ///}
694 };
695 
696 /// Used to map the values physically (in the IR) stored in an offload
697 /// array, to a vector in memory.
698 struct OffloadArray {
699  /// Physical array (in the IR).
700  AllocaInst *Array = nullptr;
701  /// Mapped values.
702  SmallVector<Value *, 8> StoredValues;
703  /// Last stores made in the offload array.
704  SmallVector<StoreInst *, 8> LastAccesses;
705 
706  OffloadArray() = default;
707 
708  /// Initializes the OffloadArray with the values stored in \p Array before
709  /// instruction \p Before is reached. Returns false if the initialization
710  /// fails.
711  /// This MUST be used immediately after the construction of the object.
712  bool initialize(AllocaInst &Array, Instruction &Before) {
713  if (!Array.getAllocatedType()->isArrayTy())
714  return false;
715 
716  if (!getValues(Array, Before))
717  return false;
718 
719  this->Array = &Array;
720  return true;
721  }
722 
723  static const unsigned DeviceIDArgNum = 1;
724  static const unsigned BasePtrsArgNum = 3;
725  static const unsigned PtrsArgNum = 4;
726  static const unsigned SizesArgNum = 5;
727 
728 private:
729  /// Traverses the BasicBlock where \p Array is, collecting the stores made to
730  /// \p Array, leaving StoredValues with the values stored before the
731  /// instruction \p Before is reached.
732  bool getValues(AllocaInst &Array, Instruction &Before) {
733  // Initialize container.
734  const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
735  StoredValues.assign(NumValues, nullptr);
736  LastAccesses.assign(NumValues, nullptr);
737 
738  // TODO: This assumes the instruction \p Before is in the same
739  // BasicBlock as Array. Make it general, for any control flow graph.
740  BasicBlock *BB = Array.getParent();
741  if (BB != Before.getParent())
742  return false;
743 
744  const DataLayout &DL = Array.getModule()->getDataLayout();
745  const unsigned int PointerSize = DL.getPointerSize();
746 
747  for (Instruction &I : *BB) {
748  if (&I == &Before)
749  break;
750 
751  if (!isa<StoreInst>(&I))
752  continue;
753 
754  auto *S = cast<StoreInst>(&I);
755  int64_t Offset = -1;
756  auto *Dst =
757  GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
758  if (Dst == &Array) {
759  int64_t Idx = Offset / PointerSize;
760  StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
761  LastAccesses[Idx] = S;
762  }
763  }
764 
765  return isFilled();
766  }
767 
768  /// Returns true if all values in StoredValues and
769  /// LastAccesses are not nullptrs.
770  bool isFilled() {
771  const unsigned NumValues = StoredValues.size();
772  for (unsigned I = 0; I < NumValues; ++I) {
773  if (!StoredValues[I] || !LastAccesses[I])
774  return false;
775  }
776 
777  return true;
778  }
779 };
780 
781 struct OpenMPOpt {
782 
783  using OptimizationRemarkGetter =
785 
786  OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
787  OptimizationRemarkGetter OREGetter,
788  OMPInformationCache &OMPInfoCache, Attributor &A)
789  : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
790  OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
791 
792  /// Check if any remarks are enabled for openmp-opt
793  bool remarksEnabled() {
794  auto &Ctx = M.getContext();
795  return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
796  }
797 
798  /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice.
799  bool run(bool IsModulePass) {
800  if (SCC.empty())
801  return false;
802 
803  bool Changed = false;
804 
805  LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
806  << " functions in a slice with "
807  << OMPInfoCache.ModuleSlice.size() << " functions\n");
808 
809  if (IsModulePass) {
810  Changed |= runAttributor(IsModulePass);
811 
812  // Recollect uses, in case Attributor deleted any.
813  OMPInfoCache.recollectUses();
814 
815  // TODO: This should be folded into buildCustomStateMachine.
816  Changed |= rewriteDeviceCodeStateMachine();
817 
818  if (remarksEnabled())
819  analysisGlobalization();
820 
821  Changed |= eliminateBarriers();
822  } else {
823  if (PrintICVValues)
824  printICVs();
825  if (PrintOpenMPKernels)
826  printKernels();
827 
828  Changed |= runAttributor(IsModulePass);
829 
830  // Recollect uses, in case Attributor deleted any.
831  OMPInfoCache.recollectUses();
832 
833  Changed |= deleteParallelRegions();
834 
836  Changed |= hideMemTransfersLatency();
837  Changed |= deduplicateRuntimeCalls();
839  if (mergeParallelRegions()) {
840  deduplicateRuntimeCalls();
841  Changed = true;
842  }
843  }
844 
845  Changed |= eliminateBarriers();
846  }
847 
848  return Changed;
849  }
850 
851  /// Print initial ICV values for testing.
852  /// FIXME: This should be done from the Attributor once it is added.
853  void printICVs() const {
854  InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
855  ICV_proc_bind};
856 
857  for (Function *F : SCC) {
858  for (auto ICV : ICVs) {
859  auto ICVInfo = OMPInfoCache.ICVs[ICV];
860  auto Remark = [&](OptimizationRemarkAnalysis ORA) {
861  return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
862  << " Value: "
863  << (ICVInfo.InitValue
864  ? toString(ICVInfo.InitValue->getValue(), 10, true)
865  : "IMPLEMENTATION_DEFINED");
866  };
867 
868  emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
869  }
870  }
871  }
872 
873  /// Print OpenMP GPU kernels for testing.
874  void printKernels() const {
875  for (Function *F : SCC) {
876  if (!OMPInfoCache.Kernels.count(F))
877  continue;
878 
879  auto Remark = [&](OptimizationRemarkAnalysis ORA) {
880  return ORA << "OpenMP GPU kernel "
881  << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
882  };
883 
884  emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
885  }
886  }
887 
888  /// Return the call if \p U is a callee use in a regular call. If \p RFI is
889  /// given it has to be the callee or a nullptr is returned.
890  static CallInst *getCallIfRegularCall(
891  Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
892  CallInst *CI = dyn_cast<CallInst>(U.getUser());
893  if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
894  (!RFI ||
895  (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
896  return CI;
897  return nullptr;
898  }
899 
900  /// Return the call if \p V is a regular call. If \p RFI is given it has to be
901  /// the callee or a nullptr is returned.
902  static CallInst *getCallIfRegularCall(
903  Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
904  CallInst *CI = dyn_cast<CallInst>(&V);
905  if (CI && !CI->hasOperandBundles() &&
906  (!RFI ||
907  (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
908  return CI;
909  return nullptr;
910  }
911 
912 private:
913  /// Merge parallel regions when it is safe.
914  bool mergeParallelRegions() {
915  const unsigned CallbackCalleeOperand = 2;
916  const unsigned CallbackFirstArgOperand = 3;
917  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
918 
919  // Check if there are any __kmpc_fork_call calls to merge.
920  OMPInformationCache::RuntimeFunctionInfo &RFI =
921  OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
922 
923  if (!RFI.Declaration)
924  return false;
925 
926  // Unmergable calls that prevent merging a parallel region.
927  OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
928  OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
929  OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
930  };
931 
932  bool Changed = false;
933  LoopInfo *LI = nullptr;
934  DominatorTree *DT = nullptr;
935 
937 
938  BasicBlock *StartBB = nullptr, *EndBB = nullptr;
939  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
940  BasicBlock *CGStartBB = CodeGenIP.getBlock();
941  BasicBlock *CGEndBB =
942  SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
943  assert(StartBB != nullptr && "StartBB should not be null");
944  CGStartBB->getTerminator()->setSuccessor(0, StartBB);
945  assert(EndBB != nullptr && "EndBB should not be null");
946  EndBB->getTerminator()->setSuccessor(0, CGEndBB);
947  };
948 
949  auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
950  Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
951  ReplacementValue = &Inner;
952  return CodeGenIP;
953  };
954 
955  auto FiniCB = [&](InsertPointTy CodeGenIP) {};
956 
957  /// Create a sequential execution region within a merged parallel region,
958  /// encapsulated in a master construct with a barrier for synchronization.
959  auto CreateSequentialRegion = [&](Function *OuterFn,
960  BasicBlock *OuterPredBB,
961  Instruction *SeqStartI,
962  Instruction *SeqEndI) {
963  // Isolate the instructions of the sequential region to a separate
964  // block.
965  BasicBlock *ParentBB = SeqStartI->getParent();
966  BasicBlock *SeqEndBB =
967  SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
968  BasicBlock *SeqAfterBB =
969  SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
970  BasicBlock *SeqStartBB =
971  SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
972 
973  assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
974  "Expected a different CFG");
975  const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
976  ParentBB->getTerminator()->eraseFromParent();
977 
978  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
979  BasicBlock *CGStartBB = CodeGenIP.getBlock();
980  BasicBlock *CGEndBB =
981  SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
982  assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
983  CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
984  assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
985  SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
986  };
987  auto FiniCB = [&](InsertPointTy CodeGenIP) {};
988 
989  // Find outputs from the sequential region to outside users and
990  // broadcast their values to them.
991  for (Instruction &I : *SeqStartBB) {
992  SmallPtrSet<Instruction *, 4> OutsideUsers;
993  for (User *Usr : I.users()) {
994  Instruction &UsrI = *cast<Instruction>(Usr);
995  // Ignore outputs to LT intrinsics, code extraction for the merged
996  // parallel region will fix them.
997  if (UsrI.isLifetimeStartOrEnd())
998  continue;
999 
1000  if (UsrI.getParent() != SeqStartBB)
1001  OutsideUsers.insert(&UsrI);
1002  }
1003 
1004  if (OutsideUsers.empty())
1005  continue;
1006 
1007  // Emit an alloca in the outer region to store the broadcasted
1008  // value.
1009  const DataLayout &DL = M.getDataLayout();
1010  AllocaInst *AllocaI = new AllocaInst(
1011  I.getType(), DL.getAllocaAddrSpace(), nullptr,
1012  I.getName() + ".seq.output.alloc", &OuterFn->front().front());
1013 
1014  // Emit a store instruction in the sequential BB to update the
1015  // value.
1016  new StoreInst(&I, AllocaI, SeqStartBB->getTerminator());
1017 
1018  // Emit a load instruction and replace the use of the output value
1019  // with it.
1020  for (Instruction *UsrI : OutsideUsers) {
1021  LoadInst *LoadI = new LoadInst(
1022  I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI);
1023  UsrI->replaceUsesOfWith(&I, LoadI);
1024  }
1025  }
1026 
1028  InsertPointTy(ParentBB, ParentBB->end()), DL);
1029  InsertPointTy SeqAfterIP =
1030  OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
1031 
1032  OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
1033 
1034  BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
1035 
1036  LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
1037  << "\n");
1038  };
1039 
1040  // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
1041  // contained in BB and only separated by instructions that can be
1042  // redundantly executed in parallel. The block BB is split before the first
1043  // call (in MergableCIs) and after the last so the entire region we merge
1044  // into a single parallel region is contained in a single basic block
1045  // without any other instructions. We use the OpenMPIRBuilder to outline
1046  // that block and call the resulting function via __kmpc_fork_call.
1047  auto Merge = [&](const SmallVectorImpl<CallInst *> &MergableCIs,
1048  BasicBlock *BB) {
1049  // TODO: Change the interface to allow single CIs expanded, e.g, to
1050  // include an outer loop.
1051  assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
1052 
1053  auto Remark = [&](OptimizationRemark OR) {
1054  OR << "Parallel region merged with parallel region"
1055  << (MergableCIs.size() > 2 ? "s" : "") << " at ";
1056  for (auto *CI : llvm::drop_begin(MergableCIs)) {
1057  OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
1058  if (CI != MergableCIs.back())
1059  OR << ", ";
1060  }
1061  return OR << ".";
1062  };
1063 
1064  emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
1065 
1066  Function *OriginalFn = BB->getParent();
1067  LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
1068  << " parallel regions in " << OriginalFn->getName()
1069  << "\n");
1070 
1071  // Isolate the calls to merge in a separate block.
1072  EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
1073  BasicBlock *AfterBB =
1074  SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1075  StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
1076  "omp.par.merged");
1077 
1078  assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
1079  const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1080  BB->getTerminator()->eraseFromParent();
1081 
1082  // Create sequential regions for sequential instructions that are
1083  // in-between mergable parallel regions.
1084  for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1085  It != End; ++It) {
1086  Instruction *ForkCI = *It;
1087  Instruction *NextForkCI = *(It + 1);
1088 
1089  // Continue if there are not in-between instructions.
1090  if (ForkCI->getNextNode() == NextForkCI)
1091  continue;
1092 
1093  CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1094  NextForkCI->getPrevNode());
1095  }
1096 
1097  OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1098  DL);
1099  IRBuilder<>::InsertPoint AllocaIP(
1100  &OriginalFn->getEntryBlock(),
1101  OriginalFn->getEntryBlock().getFirstInsertionPt());
1102  // Create the merged parallel region with default proc binding, to
1103  // avoid overriding binding settings, and without explicit cancellation.
1104  InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
1105  Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
1106  OMP_PROC_BIND_default, /* IsCancellable */ false);
1107  BranchInst::Create(AfterBB, AfterIP.getBlock());
1108 
1109  // Perform the actual outlining.
1110  OMPInfoCache.OMPBuilder.finalize(OriginalFn);
1111 
1112  Function *OutlinedFn = MergableCIs.front()->getCaller();
1113 
1114  // Replace the __kmpc_fork_call calls with direct calls to the outlined
1115  // callbacks.
1117  for (auto *CI : MergableCIs) {
1118  Value *Callee = CI->getArgOperand(CallbackCalleeOperand);
1119  FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask;
1120  Args.clear();
1121  Args.push_back(OutlinedFn->getArg(0));
1122  Args.push_back(OutlinedFn->getArg(1));
1123  for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1124  ++U)
1125  Args.push_back(CI->getArgOperand(U));
1126 
1127  CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI);
1128  if (CI->getDebugLoc())
1129  NewCI->setDebugLoc(CI->getDebugLoc());
1130 
1131  // Forward parameter attributes from the callback to the callee.
1132  for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1133  ++U)
1134  for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
1135  NewCI->addParamAttr(
1136  U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1137 
1138  // Emit an explicit barrier to replace the implicit fork-join barrier.
1139  if (CI != MergableCIs.back()) {
1140  // TODO: Remove barrier if the merged parallel region includes the
1141  // 'nowait' clause.
1142  OMPInfoCache.OMPBuilder.createBarrier(
1143  InsertPointTy(NewCI->getParent(),
1144  NewCI->getNextNode()->getIterator()),
1145  OMPD_parallel);
1146  }
1147 
1148  CI->eraseFromParent();
1149  }
1150 
1151  assert(OutlinedFn != OriginalFn && "Outlining failed");
1152  CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1153  CGUpdater.reanalyzeFunction(*OriginalFn);
1154 
1155  NumOpenMPParallelRegionsMerged += MergableCIs.size();
1156 
1157  return true;
1158  };
1159 
1160  // Helper function that identifes sequences of
1161  // __kmpc_fork_call uses in a basic block.
1162  auto DetectPRsCB = [&](Use &U, Function &F) {
1163  CallInst *CI = getCallIfRegularCall(U, &RFI);
1164  BB2PRMap[CI->getParent()].insert(CI);
1165 
1166  return false;
1167  };
1168 
1169  BB2PRMap.clear();
1170  RFI.foreachUse(SCC, DetectPRsCB);
1171  SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1172  // Find mergable parallel regions within a basic block that are
1173  // safe to merge, that is any in-between instructions can safely
1174  // execute in parallel after merging.
1175  // TODO: support merging across basic-blocks.
1176  for (auto &It : BB2PRMap) {
1177  auto &CIs = It.getSecond();
1178  if (CIs.size() < 2)
1179  continue;
1180 
1181  BasicBlock *BB = It.getFirst();
1182  SmallVector<CallInst *, 4> MergableCIs;
1183 
1184  /// Returns true if the instruction is mergable, false otherwise.
1185  /// A terminator instruction is unmergable by definition since merging
1186  /// works within a BB. Instructions before the mergable region are
1187  /// mergable if they are not calls to OpenMP runtime functions that may
1188  /// set different execution parameters for subsequent parallel regions.
1189  /// Instructions in-between parallel regions are mergable if they are not
1190  /// calls to any non-intrinsic function since that may call a non-mergable
1191  /// OpenMP runtime function.
1192  auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1193  // We do not merge across BBs, hence return false (unmergable) if the
1194  // instruction is a terminator.
1195  if (I.isTerminator())
1196  return false;
1197 
1198  if (!isa<CallInst>(&I))
1199  return true;
1200 
1201  CallInst *CI = cast<CallInst>(&I);
1202  if (IsBeforeMergableRegion) {
1203  Function *CalledFunction = CI->getCalledFunction();
1204  if (!CalledFunction)
1205  return false;
1206  // Return false (unmergable) if the call before the parallel
1207  // region calls an explicit affinity (proc_bind) or number of
1208  // threads (num_threads) compiler-generated function. Those settings
1209  // may be incompatible with following parallel regions.
1210  // TODO: ICV tracking to detect compatibility.
1211  for (const auto &RFI : UnmergableCallsInfo) {
1212  if (CalledFunction == RFI.Declaration)
1213  return false;
1214  }
1215  } else {
1216  // Return false (unmergable) if there is a call instruction
1217  // in-between parallel regions when it is not an intrinsic. It
1218  // may call an unmergable OpenMP runtime function in its callpath.
1219  // TODO: Keep track of possible OpenMP calls in the callpath.
1220  if (!isa<IntrinsicInst>(CI))
1221  return false;
1222  }
1223 
1224  return true;
1225  };
1226  // Find maximal number of parallel region CIs that are safe to merge.
1227  for (auto It = BB->begin(), End = BB->end(); It != End;) {
1228  Instruction &I = *It;
1229  ++It;
1230 
1231  if (CIs.count(&I)) {
1232  MergableCIs.push_back(cast<CallInst>(&I));
1233  continue;
1234  }
1235 
1236  // Continue expanding if the instruction is mergable.
1237  if (IsMergable(I, MergableCIs.empty()))
1238  continue;
1239 
1240  // Forward the instruction iterator to skip the next parallel region
1241  // since there is an unmergable instruction which can affect it.
1242  for (; It != End; ++It) {
1243  Instruction &SkipI = *It;
1244  if (CIs.count(&SkipI)) {
1245  LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1246  << " due to " << I << "\n");
1247  ++It;
1248  break;
1249  }
1250  }
1251 
1252  // Store mergable regions found.
1253  if (MergableCIs.size() > 1) {
1254  MergableCIsVector.push_back(MergableCIs);
1255  LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1256  << " parallel regions in block " << BB->getName()
1257  << " of function " << BB->getParent()->getName()
1258  << "\n";);
1259  }
1260 
1261  MergableCIs.clear();
1262  }
1263 
1264  if (!MergableCIsVector.empty()) {
1265  Changed = true;
1266 
1267  for (auto &MergableCIs : MergableCIsVector)
1268  Merge(MergableCIs, BB);
1269  MergableCIsVector.clear();
1270  }
1271  }
1272 
1273  if (Changed) {
1274  /// Re-collect use for fork calls, emitted barrier calls, and
1275  /// any emitted master/end_master calls.
1276  OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1277  OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1278  OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1279  OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1280  }
1281 
1282  return Changed;
1283  }
1284 
1285  /// Try to delete parallel regions if possible.
1286  bool deleteParallelRegions() {
1287  const unsigned CallbackCalleeOperand = 2;
1288 
1289  OMPInformationCache::RuntimeFunctionInfo &RFI =
1290  OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1291 
1292  if (!RFI.Declaration)
1293  return false;
1294 
1295  bool Changed = false;
1296  auto DeleteCallCB = [&](Use &U, Function &) {
1297  CallInst *CI = getCallIfRegularCall(U);
1298  if (!CI)
1299  return false;
1300  auto *Fn = dyn_cast<Function>(
1301  CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1302  if (!Fn)
1303  return false;
1304  if (!Fn->onlyReadsMemory())
1305  return false;
1306  if (!Fn->hasFnAttribute(Attribute::WillReturn))
1307  return false;
1308 
1309  LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1310  << CI->getCaller()->getName() << "\n");
1311 
1312  auto Remark = [&](OptimizationRemark OR) {
1313  return OR << "Removing parallel region with no side-effects.";
1314  };
1315  emitRemark<OptimizationRemark>(CI, "OMP160", Remark);
1316 
1317  CGUpdater.removeCallSite(*CI);
1318  CI->eraseFromParent();
1319  Changed = true;
1320  ++NumOpenMPParallelRegionsDeleted;
1321  return true;
1322  };
1323 
1324  RFI.foreachUse(SCC, DeleteCallCB);
1325 
1326  return Changed;
1327  }
1328 
1329  /// Try to eliminate runtime calls by reusing existing ones.
1330  bool deduplicateRuntimeCalls() {
1331  bool Changed = false;
1332 
1333  RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1334  OMPRTL_omp_get_num_threads,
1335  OMPRTL_omp_in_parallel,
1336  OMPRTL_omp_get_cancellation,
1337  OMPRTL_omp_get_thread_limit,
1338  OMPRTL_omp_get_supported_active_levels,
1339  OMPRTL_omp_get_level,
1340  OMPRTL_omp_get_ancestor_thread_num,
1341  OMPRTL_omp_get_team_size,
1342  OMPRTL_omp_get_active_level,
1343  OMPRTL_omp_in_final,
1344  OMPRTL_omp_get_proc_bind,
1345  OMPRTL_omp_get_num_places,
1346  OMPRTL_omp_get_num_procs,
1347  OMPRTL_omp_get_place_num,
1348  OMPRTL_omp_get_partition_num_places,
1349  OMPRTL_omp_get_partition_place_nums};
1350 
1351  // Global-tid is handled separately.
1352  SmallSetVector<Value *, 16> GTIdArgs;
1353  collectGlobalThreadIdArguments(GTIdArgs);
1354  LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1355  << " global thread ID arguments\n");
1356 
1357  for (Function *F : SCC) {
1358  for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1359  Changed |= deduplicateRuntimeCalls(
1360  *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1361 
1362  // __kmpc_global_thread_num is special as we can replace it with an
1363  // argument in enough cases to make it worth trying.
1364  Value *GTIdArg = nullptr;
1365  for (Argument &Arg : F->args())
1366  if (GTIdArgs.count(&Arg)) {
1367  GTIdArg = &Arg;
1368  break;
1369  }
1370  Changed |= deduplicateRuntimeCalls(
1371  *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1372  }
1373 
1374  return Changed;
1375  }
1376 
1377  /// Tries to hide the latency of runtime calls that involve host to
1378  /// device memory transfers by splitting them into their "issue" and "wait"
1379  /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1380  /// moved downards as much as possible. The "issue" issues the memory transfer
1381  /// asynchronously, returning a handle. The "wait" waits in the returned
1382  /// handle for the memory transfer to finish.
1383  bool hideMemTransfersLatency() {
1384  auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1385  bool Changed = false;
1386  auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1387  auto *RTCall = getCallIfRegularCall(U, &RFI);
1388  if (!RTCall)
1389  return false;
1390 
1391  OffloadArray OffloadArrays[3];
1392  if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1393  return false;
1394 
1395  LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1396 
1397  // TODO: Check if can be moved upwards.
1398  bool WasSplit = false;
1399  Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1400  if (WaitMovementPoint)
1401  WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1402 
1403  Changed |= WasSplit;
1404  return WasSplit;
1405  };
1406  RFI.foreachUse(SCC, SplitMemTransfers);
1407 
1408  return Changed;
1409  }
1410 
1411  /// Eliminates redundant, aligned barriers in OpenMP offloaded kernels.
1412  /// TODO: Make this an AA and expand it to work across blocks and functions.
1413  bool eliminateBarriers() {
1414  bool Changed = false;
1415 
1417  return /*Changed=*/false;
1418 
1419  if (OMPInfoCache.Kernels.empty())
1420  return /*Changed=*/false;
1421 
1422  enum ImplicitBarrierType { IBT_ENTRY, IBT_EXIT };
1423 
1424  class BarrierInfo {
1425  Instruction *I;
1426  enum ImplicitBarrierType Type;
1427 
1428  public:
1429  BarrierInfo(enum ImplicitBarrierType Type) : I(nullptr), Type(Type) {}
1430  BarrierInfo(Instruction &I) : I(&I) {}
1431 
1432  bool isImplicit() { return !I; }
1433 
1434  bool isImplicitEntry() { return isImplicit() && Type == IBT_ENTRY; }
1435 
1436  bool isImplicitExit() { return isImplicit() && Type == IBT_EXIT; }
1437 
1438  Instruction *getInstruction() { return I; }
1439  };
1440 
1441  for (Function *Kernel : OMPInfoCache.Kernels) {
1442  for (BasicBlock &BB : *Kernel) {
1443  SmallVector<BarrierInfo, 8> BarriersInBlock;
1444  SmallPtrSet<Instruction *, 8> BarriersToBeDeleted;
1445 
1446  // Add the kernel entry implicit barrier.
1447  if (&Kernel->getEntryBlock() == &BB)
1448  BarriersInBlock.push_back(IBT_ENTRY);
1449 
1450  // Find implicit and explicit aligned barriers in the same basic block.
1451  for (Instruction &I : BB) {
1452  if (isa<ReturnInst>(I)) {
1453  // Add the implicit barrier when exiting the kernel.
1454  BarriersInBlock.push_back(IBT_EXIT);
1455  continue;
1456  }
1457  CallBase *CB = dyn_cast<CallBase>(&I);
1458  if (!CB)
1459  continue;
1460 
1461  auto IsAlignBarrierCB = [&](CallBase &CB) {
1462  switch (CB.getIntrinsicID()) {
1463  case Intrinsic::nvvm_barrier0:
1464  case Intrinsic::nvvm_barrier0_and:
1465  case Intrinsic::nvvm_barrier0_or:
1466  case Intrinsic::nvvm_barrier0_popc:
1467  return true;
1468  default:
1469  break;
1470  }
1471  return hasAssumption(CB,
1472  KnownAssumptionString("ompx_aligned_barrier"));
1473  };
1474 
1475  if (IsAlignBarrierCB(*CB)) {
1476  // Add an explicit aligned barrier.
1477  BarriersInBlock.push_back(I);
1478  }
1479  }
1480 
1481  if (BarriersInBlock.size() <= 1)
1482  continue;
1483 
1484  // A barrier in a barrier pair is removeable if all instructions
1485  // between the barriers in the pair are side-effect free modulo the
1486  // barrier operation.
1487  auto IsBarrierRemoveable = [&Kernel](BarrierInfo *StartBI,
1488  BarrierInfo *EndBI) {
1489  assert(
1490  !StartBI->isImplicitExit() &&
1491  "Expected start barrier to be other than a kernel exit barrier");
1492  assert(
1493  !EndBI->isImplicitEntry() &&
1494  "Expected end barrier to be other than a kernel entry barrier");
1495  // If StarBI instructions is null then this the implicit
1496  // kernel entry barrier, so iterate from the first instruction in the
1497  // entry block.
1498  Instruction *I = (StartBI->isImplicitEntry())
1499  ? &Kernel->getEntryBlock().front()
1500  : StartBI->getInstruction()->getNextNode();
1501  assert(I && "Expected non-null start instruction");
1502  Instruction *E = (EndBI->isImplicitExit())
1503  ? I->getParent()->getTerminator()
1504  : EndBI->getInstruction();
1505  assert(E && "Expected non-null end instruction");
1506 
1507  for (; I != E; I = I->getNextNode()) {
1508  if (!I->mayHaveSideEffects() && !I->mayReadFromMemory())
1509  continue;
1510 
1511  auto IsPotentiallyAffectedByBarrier =
1512  [](Optional<MemoryLocation> Loc) {
1513  const Value *Obj = (Loc && Loc->Ptr)
1514  ? getUnderlyingObject(Loc->Ptr)
1515  : nullptr;
1516  if (!Obj) {
1517  LLVM_DEBUG(
1518  dbgs()
1519  << "Access to unknown location requires barriers\n");
1520  return true;
1521  }
1522  if (isa<UndefValue>(Obj))
1523  return false;
1524  if (isa<AllocaInst>(Obj))
1525  return false;
1526  if (auto *GV = dyn_cast<GlobalVariable>(Obj)) {
1527  if (GV->isConstant())
1528  return false;
1529  if (GV->isThreadLocal())
1530  return false;
1531  if (GV->getAddressSpace() == (int)AddressSpace::Local)
1532  return false;
1533  if (GV->getAddressSpace() == (int)AddressSpace::Constant)
1534  return false;
1535  }
1536  LLVM_DEBUG(dbgs() << "Access to '" << *Obj
1537  << "' requires barriers\n");
1538  return true;
1539  };
1540 
1541  if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(I)) {
1543  if (IsPotentiallyAffectedByBarrier(Loc))
1544  return false;
1545  if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(I)) {
1548  if (IsPotentiallyAffectedByBarrier(Loc))
1549  return false;
1550  }
1551  continue;
1552  }
1553 
1554  if (auto *LI = dyn_cast<LoadInst>(I))
1555  if (LI->hasMetadata(LLVMContext::MD_invariant_load))
1556  continue;
1557 
1559  if (IsPotentiallyAffectedByBarrier(Loc))
1560  return false;
1561  }
1562 
1563  return true;
1564  };
1565 
1566  // Iterate barrier pairs and remove an explicit barrier if analysis
1567  // deems it removeable.
1568  for (auto *It = BarriersInBlock.begin(),
1569  *End = BarriersInBlock.end() - 1;
1570  It != End; ++It) {
1571 
1572  BarrierInfo *StartBI = It;
1573  BarrierInfo *EndBI = (It + 1);
1574 
1575  // Cannot remove when both are implicit barriers, continue.
1576  if (StartBI->isImplicit() && EndBI->isImplicit())
1577  continue;
1578 
1579  if (!IsBarrierRemoveable(StartBI, EndBI))
1580  continue;
1581 
1582  assert(!(StartBI->isImplicit() && EndBI->isImplicit()) &&
1583  "Expected at least one explicit barrier to remove.");
1584 
1585  // Remove an explicit barrier, check first, then second.
1586  if (!StartBI->isImplicit()) {
1587  LLVM_DEBUG(dbgs() << "Remove start barrier "
1588  << *StartBI->getInstruction() << "\n");
1589  BarriersToBeDeleted.insert(StartBI->getInstruction());
1590  } else {
1591  LLVM_DEBUG(dbgs() << "Remove end barrier "
1592  << *EndBI->getInstruction() << "\n");
1593  BarriersToBeDeleted.insert(EndBI->getInstruction());
1594  }
1595  }
1596 
1597  if (BarriersToBeDeleted.empty())
1598  continue;
1599 
1600  Changed = true;
1601  for (Instruction *I : BarriersToBeDeleted) {
1602  ++NumBarriersEliminated;
1603  auto Remark = [&](OptimizationRemark OR) {
1604  return OR << "Redundant barrier eliminated.";
1605  };
1606 
1608  emitRemark<OptimizationRemark>(I, "OMP190", Remark);
1609  I->eraseFromParent();
1610  }
1611  }
1612  }
1613 
1614  return Changed;
1615  }
1616 
1617  void analysisGlobalization() {
1618  auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1619 
1620  auto CheckGlobalization = [&](Use &U, Function &Decl) {
1621  if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1622  auto Remark = [&](OptimizationRemarkMissed ORM) {
1623  return ORM
1624  << "Found thread data sharing on the GPU. "
1625  << "Expect degraded performance due to data globalization.";
1626  };
1627  emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark);
1628  }
1629 
1630  return false;
1631  };
1632 
1633  RFI.foreachUse(SCC, CheckGlobalization);
1634  }
1635 
1636  /// Maps the values stored in the offload arrays passed as arguments to
1637  /// \p RuntimeCall into the offload arrays in \p OAs.
1638  bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1640  assert(OAs.size() == 3 && "Need space for three offload arrays!");
1641 
1642  // A runtime call that involves memory offloading looks something like:
1643  // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1644  // i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1645  // ...)
1646  // So, the idea is to access the allocas that allocate space for these
1647  // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1648  // Therefore:
1649  // i8** %offload_baseptrs.
1650  Value *BasePtrsArg =
1651  RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1652  // i8** %offload_ptrs.
1653  Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1654  // i8** %offload_sizes.
1655  Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1656 
1657  // Get values stored in **offload_baseptrs.
1658  auto *V = getUnderlyingObject(BasePtrsArg);
1659  if (!isa<AllocaInst>(V))
1660  return false;
1661  auto *BasePtrsArray = cast<AllocaInst>(V);
1662  if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1663  return false;
1664 
1665  // Get values stored in **offload_baseptrs.
1666  V = getUnderlyingObject(PtrsArg);
1667  if (!isa<AllocaInst>(V))
1668  return false;
1669  auto *PtrsArray = cast<AllocaInst>(V);
1670  if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1671  return false;
1672 
1673  // Get values stored in **offload_sizes.
1674  V = getUnderlyingObject(SizesArg);
1675  // If it's a [constant] global array don't analyze it.
1676  if (isa<GlobalValue>(V))
1677  return isa<Constant>(V);
1678  if (!isa<AllocaInst>(V))
1679  return false;
1680 
1681  auto *SizesArray = cast<AllocaInst>(V);
1682  if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1683  return false;
1684 
1685  return true;
1686  }
1687 
1688  /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1689  /// For now this is a way to test that the function getValuesInOffloadArrays
1690  /// is working properly.
1691  /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1692  void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1693  assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1694 
1695  LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1696  std::string ValuesStr;
1697  raw_string_ostream Printer(ValuesStr);
1698  std::string Separator = " --- ";
1699 
1700  for (auto *BP : OAs[0].StoredValues) {
1701  BP->print(Printer);
1702  Printer << Separator;
1703  }
1704  LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n");
1705  ValuesStr.clear();
1706 
1707  for (auto *P : OAs[1].StoredValues) {
1708  P->print(Printer);
1709  Printer << Separator;
1710  }
1711  LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n");
1712  ValuesStr.clear();
1713 
1714  for (auto *S : OAs[2].StoredValues) {
1715  S->print(Printer);
1716  Printer << Separator;
1717  }
1718  LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n");
1719  }
1720 
1721  /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1722  /// moved. Returns nullptr if the movement is not possible, or not worth it.
1723  Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1724  // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1725  // Make it traverse the CFG.
1726 
1727  Instruction *CurrentI = &RuntimeCall;
1728  bool IsWorthIt = false;
1729  while ((CurrentI = CurrentI->getNextNode())) {
1730 
1731  // TODO: Once we detect the regions to be offloaded we should use the
1732  // alias analysis manager to check if CurrentI may modify one of
1733  // the offloaded regions.
1734  if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1735  if (IsWorthIt)
1736  return CurrentI;
1737 
1738  return nullptr;
1739  }
1740 
1741  // FIXME: For now if we move it over anything without side effect
1742  // is worth it.
1743  IsWorthIt = true;
1744  }
1745 
1746  // Return end of BasicBlock.
1747  return RuntimeCall.getParent()->getTerminator();
1748  }
1749 
1750  /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1751  bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1752  Instruction &WaitMovementPoint) {
1753  // Create stack allocated handle (__tgt_async_info) at the beginning of the
1754  // function. Used for storing information of the async transfer, allowing to
1755  // wait on it later.
1756  auto &IRBuilder = OMPInfoCache.OMPBuilder;
1757  auto *F = RuntimeCall.getCaller();
1758  Instruction *FirstInst = &(F->getEntryBlock().front());
1759  AllocaInst *Handle = new AllocaInst(
1760  IRBuilder.AsyncInfo, F->getAddressSpace(), "handle", FirstInst);
1761 
1762  // Add "issue" runtime call declaration:
1763  // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1764  // i8**, i8**, i64*, i64*)
1765  FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1766  M, OMPRTL___tgt_target_data_begin_mapper_issue);
1767 
1768  // Change RuntimeCall call site for its asynchronous version.
1770  for (auto &Arg : RuntimeCall.args())
1771  Args.push_back(Arg.get());
1772  Args.push_back(Handle);
1773 
1774  CallInst *IssueCallsite =
1775  CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall);
1776  OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);
1777  RuntimeCall.eraseFromParent();
1778 
1779  // Add "wait" runtime call declaration:
1780  // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1781  FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1782  M, OMPRTL___tgt_target_data_begin_mapper_wait);
1783 
1784  Value *WaitParams[2] = {
1785  IssueCallsite->getArgOperand(
1786  OffloadArray::DeviceIDArgNum), // device_id.
1787  Handle // handle to wait on.
1788  };
1789  CallInst *WaitCallsite = CallInst::Create(
1790  WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint);
1791  OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);
1792 
1793  return true;
1794  }
1795 
1796  static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1797  bool GlobalOnly, bool &SingleChoice) {
1798  if (CurrentIdent == NextIdent)
1799  return CurrentIdent;
1800 
1801  // TODO: Figure out how to actually combine multiple debug locations. For
1802  // now we just keep an existing one if there is a single choice.
1803  if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1804  SingleChoice = !CurrentIdent;
1805  return NextIdent;
1806  }
1807  return nullptr;
1808  }
1809 
1810  /// Return an `struct ident_t*` value that represents the ones used in the
1811  /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1812  /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1813  /// return value we create one from scratch. We also do not yet combine
1814  /// information, e.g., the source locations, see combinedIdentStruct.
1815  Value *
1816  getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1817  Function &F, bool GlobalOnly) {
1818  bool SingleChoice = true;
1819  Value *Ident = nullptr;
1820  auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1821  CallInst *CI = getCallIfRegularCall(U, &RFI);
1822  if (!CI || &F != &Caller)
1823  return false;
1824  Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1825  /* GlobalOnly */ true, SingleChoice);
1826  return false;
1827  };
1828  RFI.foreachUse(SCC, CombineIdentStruct);
1829 
1830  if (!Ident || !SingleChoice) {
1831  // The IRBuilder uses the insertion block to get to the module, this is
1832  // unfortunate but we work around it for now.
1833  if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1834  OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1835  &F.getEntryBlock(), F.getEntryBlock().begin()));
1836  // Create a fallback location if non was found.
1837  // TODO: Use the debug locations of the calls instead.
1838  uint32_t SrcLocStrSize;
1839  Constant *Loc =
1840  OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1841  Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize);
1842  }
1843  return Ident;
1844  }
1845 
1846  /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1847  /// \p ReplVal if given.
1848  bool deduplicateRuntimeCalls(Function &F,
1849  OMPInformationCache::RuntimeFunctionInfo &RFI,
1850  Value *ReplVal = nullptr) {
1851  auto *UV = RFI.getUseVector(F);
1852  if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1853  return false;
1854 
1855  LLVM_DEBUG(
1856  dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1857  << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1858 
1859  assert((!ReplVal || (isa<Argument>(ReplVal) &&
1860  cast<Argument>(ReplVal)->getParent() == &F)) &&
1861  "Unexpected replacement value!");
1862 
1863  // TODO: Use dominance to find a good position instead.
1864  auto CanBeMoved = [this](CallBase &CB) {
1865  unsigned NumArgs = CB.arg_size();
1866  if (NumArgs == 0)
1867  return true;
1868  if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1869  return false;
1870  for (unsigned U = 1; U < NumArgs; ++U)
1871  if (isa<Instruction>(CB.getArgOperand(U)))
1872  return false;
1873  return true;
1874  };
1875 
1876  if (!ReplVal) {
1877  for (Use *U : *UV)
1878  if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1879  if (!CanBeMoved(*CI))
1880  continue;
1881 
1882  // If the function is a kernel, dedup will move
1883  // the runtime call right after the kernel init callsite. Otherwise,
1884  // it will move it to the beginning of the caller function.
1885  if (isKernel(F)) {
1886  auto &KernelInitRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
1887  auto *KernelInitUV = KernelInitRFI.getUseVector(F);
1888 
1889  if (KernelInitUV->empty())
1890  continue;
1891 
1892  assert(KernelInitUV->size() == 1 &&
1893  "Expected a single __kmpc_target_init in kernel\n");
1894 
1895  CallInst *KernelInitCI =
1896  getCallIfRegularCall(*KernelInitUV->front(), &KernelInitRFI);
1897  assert(KernelInitCI &&
1898  "Expected a call to __kmpc_target_init in kernel\n");
1899 
1900  CI->moveAfter(KernelInitCI);
1901  } else
1902  CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt());
1903  ReplVal = CI;
1904  break;
1905  }
1906  if (!ReplVal)
1907  return false;
1908  }
1909 
1910  // If we use a call as a replacement value we need to make sure the ident is
1911  // valid at the new location. For now we just pick a global one, either
1912  // existing and used by one of the calls, or created from scratch.
1913  if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1914  if (!CI->arg_empty() &&
1915  CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1916  Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1917  /* GlobalOnly */ true);
1918  CI->setArgOperand(0, Ident);
1919  }
1920  }
1921 
1922  bool Changed = false;
1923  auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1924  CallInst *CI = getCallIfRegularCall(U, &RFI);
1925  if (!CI || CI == ReplVal || &F != &Caller)
1926  return false;
1927  assert(CI->getCaller() == &F && "Unexpected call!");
1928 
1929  auto Remark = [&](OptimizationRemark OR) {
1930  return OR << "OpenMP runtime call "
1931  << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1932  };
1933  if (CI->getDebugLoc())
1934  emitRemark<OptimizationRemark>(CI, "OMP170", Remark);
1935  else
1936  emitRemark<OptimizationRemark>(&F, "OMP170", Remark);
1937 
1938  CGUpdater.removeCallSite(*CI);
1939  CI->replaceAllUsesWith(ReplVal);
1940  CI->eraseFromParent();
1941  ++NumOpenMPRuntimeCallsDeduplicated;
1942  Changed = true;
1943  return true;
1944  };
1945  RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1946 
1947  return Changed;
1948  }
1949 
1950  /// Collect arguments that represent the global thread id in \p GTIdArgs.
1951  void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1952  // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1953  // initialization. We could define an AbstractAttribute instead and
1954  // run the Attributor here once it can be run as an SCC pass.
1955 
1956  // Helper to check the argument \p ArgNo at all call sites of \p F for
1957  // a GTId.
1958  auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1959  if (!F.hasLocalLinkage())
1960  return false;
1961  for (Use &U : F.uses()) {
1962  if (CallInst *CI = getCallIfRegularCall(U)) {
1963  Value *ArgOp = CI->getArgOperand(ArgNo);
1964  if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1965  getCallIfRegularCall(
1966  *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1967  continue;
1968  }
1969  return false;
1970  }
1971  return true;
1972  };
1973 
1974  // Helper to identify uses of a GTId as GTId arguments.
1975  auto AddUserArgs = [&](Value &GTId) {
1976  for (Use &U : GTId.uses())
1977  if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1978  if (CI->isArgOperand(&U))
1979  if (Function *Callee = CI->getCalledFunction())
1980  if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1981  GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1982  };
1983 
1984  // The argument users of __kmpc_global_thread_num calls are GTIds.
1985  OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1986  OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1987 
1988  GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1989  if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1990  AddUserArgs(*CI);
1991  return false;
1992  });
1993 
1994  // Transitively search for more arguments by looking at the users of the
1995  // ones we know already. During the search the GTIdArgs vector is extended
1996  // so we cannot cache the size nor can we use a range based for.
1997  for (unsigned U = 0; U < GTIdArgs.size(); ++U)
1998  AddUserArgs(*GTIdArgs[U]);
1999  }
2000 
2001  /// Kernel (=GPU) optimizations and utility functions
2002  ///
2003  ///{{
2004 
2005  /// Check if \p F is a kernel, hence entry point for target offloading.
2006  bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); }
2007 
2008  /// Cache to remember the unique kernel for a function.
2009  DenseMap<Function *, Optional<Kernel>> UniqueKernelMap;
2010 
2011  /// Find the unique kernel that will execute \p F, if any.
2012  Kernel getUniqueKernelFor(Function &F);
2013 
2014  /// Find the unique kernel that will execute \p I, if any.
2015  Kernel getUniqueKernelFor(Instruction &I) {
2016  return getUniqueKernelFor(*I.getFunction());
2017  }
2018 
2019  /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
2020  /// the cases we can avoid taking the address of a function.
2021  bool rewriteDeviceCodeStateMachine();
2022 
2023  ///
2024  ///}}
2025 
2026  /// Emit a remark generically
2027  ///
2028  /// This template function can be used to generically emit a remark. The
2029  /// RemarkKind should be one of the following:
2030  /// - OptimizationRemark to indicate a successful optimization attempt
2031  /// - OptimizationRemarkMissed to report a failed optimization attempt
2032  /// - OptimizationRemarkAnalysis to provide additional information about an
2033  /// optimization attempt
2034  ///
2035  /// The remark is built using a callback function provided by the caller that
2036  /// takes a RemarkKind as input and returns a RemarkKind.
2037  template <typename RemarkKind, typename RemarkCallBack>
2038  void emitRemark(Instruction *I, StringRef RemarkName,
2039  RemarkCallBack &&RemarkCB) const {
2040  Function *F = I->getParent()->getParent();
2041  auto &ORE = OREGetter(F);
2042 
2043  if (RemarkName.startswith("OMP"))
2044  ORE.emit([&]() {
2045  return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
2046  << " [" << RemarkName << "]";
2047  });
2048  else
2049  ORE.emit(
2050  [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
2051  }
2052 
2053  /// Emit a remark on a function.
2054  template <typename RemarkKind, typename RemarkCallBack>
2055  void emitRemark(Function *F, StringRef RemarkName,
2056  RemarkCallBack &&RemarkCB) const {
2057  auto &ORE = OREGetter(F);
2058 
2059  if (RemarkName.startswith("OMP"))
2060  ORE.emit([&]() {
2061  return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
2062  << " [" << RemarkName << "]";
2063  });
2064  else
2065  ORE.emit(
2066  [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
2067  }
2068 
2069  /// RAII struct to temporarily change an RTL function's linkage to external.
2070  /// This prevents it from being mistakenly removed by other optimizations.
2071  struct ExternalizationRAII {
2072  ExternalizationRAII(OMPInformationCache &OMPInfoCache,
2073  RuntimeFunction RFKind)
2074  : Declaration(OMPInfoCache.RFIs[RFKind].Declaration) {
2075  if (!Declaration)
2076  return;
2077 
2078  LinkageType = Declaration->getLinkage();
2079  Declaration->setLinkage(GlobalValue::ExternalLinkage);
2080  }
2081 
2082  ~ExternalizationRAII() {
2083  if (!Declaration)
2084  return;
2085 
2086  Declaration->setLinkage(LinkageType);
2087  }
2088 
2089  Function *Declaration;
2090  GlobalValue::LinkageTypes LinkageType;
2091  };
2092 
2093  /// The underlying module.
2094  Module &M;
2095 
2096  /// The SCC we are operating on.
2098 
2099  /// Callback to update the call graph, the first argument is a removed call,
2100  /// the second an optional replacement call.
2101  CallGraphUpdater &CGUpdater;
2102 
2103  /// Callback to get an OptimizationRemarkEmitter from a Function *
2104  OptimizationRemarkGetter OREGetter;
2105 
2106  /// OpenMP-specific information cache. Also Used for Attributor runs.
2107  OMPInformationCache &OMPInfoCache;
2108 
2109  /// Attributor instance.
2110  Attributor &A;
2111 
2112  /// Helper function to run Attributor on SCC.
2113  bool runAttributor(bool IsModulePass) {
2114  if (SCC.empty())
2115  return false;
2116 
2117  // Temporarily make these function have external linkage so the Attributor
2118  // doesn't remove them when we try to look them up later.
2119  ExternalizationRAII Parallel(OMPInfoCache, OMPRTL___kmpc_kernel_parallel);
2120  ExternalizationRAII EndParallel(OMPInfoCache,
2121  OMPRTL___kmpc_kernel_end_parallel);
2122  ExternalizationRAII BarrierSPMD(OMPInfoCache,
2123  OMPRTL___kmpc_barrier_simple_spmd);
2124  ExternalizationRAII BarrierGeneric(OMPInfoCache,
2125  OMPRTL___kmpc_barrier_simple_generic);
2126  ExternalizationRAII ThreadId(OMPInfoCache,
2127  OMPRTL___kmpc_get_hardware_thread_id_in_block);
2128  ExternalizationRAII NumThreads(
2129  OMPInfoCache, OMPRTL___kmpc_get_hardware_num_threads_in_block);
2130  ExternalizationRAII WarpSize(OMPInfoCache, OMPRTL___kmpc_get_warp_size);
2131 
2132  registerAAs(IsModulePass);
2133 
2134  ChangeStatus Changed = A.run();
2135 
2136  LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
2137  << " functions, result: " << Changed << ".\n");
2138 
2139  return Changed == ChangeStatus::CHANGED;
2140  }
2141 
2142  void registerFoldRuntimeCall(RuntimeFunction RF);
2143 
2144  /// Populate the Attributor with abstract attribute opportunities in the
2145  /// function.
2146  void registerAAs(bool IsModulePass);
2147 };
2148 
2149 Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
2150  if (!OMPInfoCache.ModuleSlice.empty() && !OMPInfoCache.ModuleSlice.count(&F))
2151  return nullptr;
2152 
2153  // Use a scope to keep the lifetime of the CachedKernel short.
2154  {
2155  Optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
2156  if (CachedKernel)
2157  return *CachedKernel;
2158 
2159  // TODO: We should use an AA to create an (optimistic and callback
2160  // call-aware) call graph. For now we stick to simple patterns that
2161  // are less powerful, basically the worst fixpoint.
2162  if (isKernel(F)) {
2163  CachedKernel = Kernel(&F);
2164  return *CachedKernel;
2165  }
2166 
2167  CachedKernel = nullptr;
2168  if (!F.hasLocalLinkage()) {
2169 
2170  // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
2171  auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2172  return ORA << "Potentially unknown OpenMP target region caller.";
2173  };
2174  emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
2175 
2176  return nullptr;
2177  }
2178  }
2179 
2180  auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
2181  if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
2182  // Allow use in equality comparisons.
2183  if (Cmp->isEquality())
2184  return getUniqueKernelFor(*Cmp);
2185  return nullptr;
2186  }
2187  if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
2188  // Allow direct calls.
2189  if (CB->isCallee(&U))
2190  return getUniqueKernelFor(*CB);
2191 
2192  OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2193  OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2194  // Allow the use in __kmpc_parallel_51 calls.
2195  if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
2196  return getUniqueKernelFor(*CB);
2197  return nullptr;
2198  }
2199  // Disallow every other use.
2200  return nullptr;
2201  };
2202 
2203  // TODO: In the future we want to track more than just a unique kernel.
2204  SmallPtrSet<Kernel, 2> PotentialKernels;
2205  OMPInformationCache::foreachUse(F, [&](const Use &U) {
2206  PotentialKernels.insert(GetUniqueKernelForUse(U));
2207  });
2208 
2209  Kernel K = nullptr;
2210  if (PotentialKernels.size() == 1)
2211  K = *PotentialKernels.begin();
2212 
2213  // Cache the result.
2214  UniqueKernelMap[&F] = K;
2215 
2216  return K;
2217 }
2218 
2219 bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
2220  OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2221  OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2222 
2223  bool Changed = false;
2224  if (!KernelParallelRFI)
2225  return Changed;
2226 
2227  // If we have disabled state machine changes, exit
2229  return Changed;
2230 
2231  for (Function *F : SCC) {
2232 
2233  // Check if the function is a use in a __kmpc_parallel_51 call at
2234  // all.
2235  bool UnknownUse = false;
2236  bool KernelParallelUse = false;
2237  unsigned NumDirectCalls = 0;
2238 
2239  SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
2240  OMPInformationCache::foreachUse(*F, [&](Use &U) {
2241  if (auto *CB = dyn_cast<CallBase>(U.getUser()))
2242  if (CB->isCallee(&U)) {
2243  ++NumDirectCalls;
2244  return;
2245  }
2246 
2247  if (isa<ICmpInst>(U.getUser())) {
2248  ToBeReplacedStateMachineUses.push_back(&U);
2249  return;
2250  }
2251 
2252  // Find wrapper functions that represent parallel kernels.
2253  CallInst *CI =
2254  OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
2255  const unsigned int WrapperFunctionArgNo = 6;
2256  if (!KernelParallelUse && CI &&
2257  CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
2258  KernelParallelUse = true;
2259  ToBeReplacedStateMachineUses.push_back(&U);
2260  return;
2261  }
2262  UnknownUse = true;
2263  });
2264 
2265  // Do not emit a remark if we haven't seen a __kmpc_parallel_51
2266  // use.
2267  if (!KernelParallelUse)
2268  continue;
2269 
2270  // If this ever hits, we should investigate.
2271  // TODO: Checking the number of uses is not a necessary restriction and
2272  // should be lifted.
2273  if (UnknownUse || NumDirectCalls != 1 ||
2274  ToBeReplacedStateMachineUses.size() > 2) {
2275  auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2276  return ORA << "Parallel region is used in "
2277  << (UnknownUse ? "unknown" : "unexpected")
2278  << " ways. Will not attempt to rewrite the state machine.";
2279  };
2280  emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark);
2281  continue;
2282  }
2283 
2284  // Even if we have __kmpc_parallel_51 calls, we (for now) give
2285  // up if the function is not called from a unique kernel.
2286  Kernel K = getUniqueKernelFor(*F);
2287  if (!K) {
2288  auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2289  return ORA << "Parallel region is not called from a unique kernel. "
2290  "Will not attempt to rewrite the state machine.";
2291  };
2292  emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark);
2293  continue;
2294  }
2295 
2296  // We now know F is a parallel body function called only from the kernel K.
2297  // We also identified the state machine uses in which we replace the
2298  // function pointer by a new global symbol for identification purposes. This
2299  // ensures only direct calls to the function are left.
2300 
2301  Module &M = *F->getParent();
2302  Type *Int8Ty = Type::getInt8Ty(M.getContext());
2303 
2304  auto *ID = new GlobalVariable(
2305  M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2306  UndefValue::get(Int8Ty), F->getName() + ".ID");
2307 
2308  for (Use *U : ToBeReplacedStateMachineUses)
2310  ID, U->get()->getType()));
2311 
2312  ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2313 
2314  Changed = true;
2315  }
2316 
2317  return Changed;
2318 }
2319 
2320 /// Abstract Attribute for tracking ICV values.
2321 struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2323  AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2324 
2325  void initialize(Attributor &A) override {
2326  Function *F = getAnchorScope();
2327  if (!F || !A.isFunctionIPOAmendable(*F))
2328  indicatePessimisticFixpoint();
2329  }
2330 
2331  /// Returns true if value is assumed to be tracked.
2332  bool isAssumedTracked() const { return getAssumed(); }
2333 
2334  /// Returns true if value is known to be tracked.
2335  bool isKnownTracked() const { return getAssumed(); }
2336 
2337  /// Create an abstract attribute biew for the position \p IRP.
2338  static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2339 
2340  /// Return the value with which \p I can be replaced for specific \p ICV.
2341  virtual Optional<Value *> getReplacementValue(InternalControlVar ICV,
2342  const Instruction *I,
2343  Attributor &A) const {
2344  return None;
2345  }
2346 
2347  /// Return an assumed unique ICV value if a single candidate is found. If
2348  /// there cannot be one, return a nullptr. If it is not clear yet, return the
2349  /// Optional::NoneType.
2350  virtual Optional<Value *>
2351  getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2352 
2353  // Currently only nthreads is being tracked.
2354  // this array will only grow with time.
2355  InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2356 
2357  /// See AbstractAttribute::getName()
2358  const std::string getName() const override { return "AAICVTracker"; }
2359 
2360  /// See AbstractAttribute::getIdAddr()
2361  const char *getIdAddr() const override { return &ID; }
2362 
2363  /// This function should return true if the type of the \p AA is AAICVTracker
2364  static bool classof(const AbstractAttribute *AA) {
2365  return (AA->getIdAddr() == &ID);
2366  }
2367 
2368  static const char ID;
2369 };
2370 
2371 struct AAICVTrackerFunction : public AAICVTracker {
2372  AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2373  : AAICVTracker(IRP, A) {}
2374 
2375  // FIXME: come up with better string.
2376  const std::string getAsStr() const override { return "ICVTrackerFunction"; }
2377 
2378  // FIXME: come up with some stats.
2379  void trackStatistics() const override {}
2380 
2381  /// We don't manifest anything for this AA.
2382  ChangeStatus manifest(Attributor &A) override {
2383  return ChangeStatus::UNCHANGED;
2384  }
2385 
2386  // Map of ICV to their values at specific program point.
2388  InternalControlVar::ICV___last>
2389  ICVReplacementValuesMap;
2390 
2391  ChangeStatus updateImpl(Attributor &A) override {
2393 
2394  Function *F = getAnchorScope();
2395 
2396  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2397 
2398  for (InternalControlVar ICV : TrackableICVs) {
2399  auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2400 
2401  auto &ValuesMap = ICVReplacementValuesMap[ICV];
2402  auto TrackValues = [&](Use &U, Function &) {
2403  CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2404  if (!CI)
2405  return false;
2406 
2407  // FIXME: handle setters with more that 1 arguments.
2408  /// Track new value.
2409  if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2410  HasChanged = ChangeStatus::CHANGED;
2411 
2412  return false;
2413  };
2414 
2415  auto CallCheck = [&](Instruction &I) {
2416  Optional<Value *> ReplVal = getValueForCall(A, I, ICV);
2417  if (ReplVal && ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2418  HasChanged = ChangeStatus::CHANGED;
2419 
2420  return true;
2421  };
2422 
2423  // Track all changes of an ICV.
2424  SetterRFI.foreachUse(TrackValues, F);
2425 
2426  bool UsedAssumedInformation = false;
2427  A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2428  UsedAssumedInformation,
2429  /* CheckBBLivenessOnly */ true);
2430 
2431  /// TODO: Figure out a way to avoid adding entry in
2432  /// ICVReplacementValuesMap
2433  Instruction *Entry = &F->getEntryBlock().front();
2434  if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
2435  ValuesMap.insert(std::make_pair(Entry, nullptr));
2436  }
2437 
2438  return HasChanged;
2439  }
2440 
2441  /// Helper to check if \p I is a call and get the value for it if it is
2442  /// unique.
2443  Optional<Value *> getValueForCall(Attributor &A, const Instruction &I,
2444  InternalControlVar &ICV) const {
2445 
2446  const auto *CB = dyn_cast<CallBase>(&I);
2447  if (!CB || CB->hasFnAttr("no_openmp") ||
2448  CB->hasFnAttr("no_openmp_routines"))
2449  return None;
2450 
2451  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2452  auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2453  auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2454  Function *CalledFunction = CB->getCalledFunction();
2455 
2456  // Indirect call, assume ICV changes.
2457  if (CalledFunction == nullptr)
2458  return nullptr;
2459  if (CalledFunction == GetterRFI.Declaration)
2460  return None;
2461  if (CalledFunction == SetterRFI.Declaration) {
2462  if (ICVReplacementValuesMap[ICV].count(&I))
2463  return ICVReplacementValuesMap[ICV].lookup(&I);
2464 
2465  return nullptr;
2466  }
2467 
2468  // Since we don't know, assume it changes the ICV.
2469  if (CalledFunction->isDeclaration())
2470  return nullptr;
2471 
2472  const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2474 
2475  if (ICVTrackingAA.isAssumedTracked()) {
2476  Optional<Value *> URV = ICVTrackingAA.getUniqueReplacementValue(ICV);
2477  if (!URV || (*URV && AA::isValidAtPosition(AA::ValueAndContext(**URV, I),
2478  OMPInfoCache)))
2479  return URV;
2480  }
2481 
2482  // If we don't know, assume it changes.
2483  return nullptr;
2484  }
2485 
2486  // We don't check unique value for a function, so return None.
2488  getUniqueReplacementValue(InternalControlVar ICV) const override {
2489  return None;
2490  }
2491 
2492  /// Return the value with which \p I can be replaced for specific \p ICV.
2493  Optional<Value *> getReplacementValue(InternalControlVar ICV,
2494  const Instruction *I,
2495  Attributor &A) const override {
2496  const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2497  if (ValuesMap.count(I))
2498  return ValuesMap.lookup(I);
2499 
2502  Worklist.push_back(I);
2503 
2504  Optional<Value *> ReplVal;
2505 
2506  while (!Worklist.empty()) {
2507  const Instruction *CurrInst = Worklist.pop_back_val();
2508  if (!Visited.insert(CurrInst).second)
2509  continue;
2510 
2511  const BasicBlock *CurrBB = CurrInst->getParent();
2512 
2513  // Go up and look for all potential setters/calls that might change the
2514  // ICV.
2515  while ((CurrInst = CurrInst->getPrevNode())) {
2516  if (ValuesMap.count(CurrInst)) {
2517  Optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2518  // Unknown value, track new.
2519  if (!ReplVal) {
2520  ReplVal = NewReplVal;
2521  break;
2522  }
2523 
2524  // If we found a new value, we can't know the icv value anymore.
2525  if (NewReplVal)
2526  if (ReplVal != NewReplVal)
2527  return nullptr;
2528 
2529  break;
2530  }
2531 
2532  Optional<Value *> NewReplVal = getValueForCall(A, *CurrInst, ICV);
2533  if (!NewReplVal)
2534  continue;
2535 
2536  // Unknown value, track new.
2537  if (!ReplVal) {
2538  ReplVal = NewReplVal;
2539  break;
2540  }
2541 
2542  // if (NewReplVal.hasValue())
2543  // We found a new value, we can't know the icv value anymore.
2544  if (ReplVal != NewReplVal)
2545  return nullptr;
2546  }
2547 
2548  // If we are in the same BB and we have a value, we are done.
2549  if (CurrBB == I->getParent() && ReplVal)
2550  return ReplVal;
2551 
2552  // Go through all predecessors and add terminators for analysis.
2553  for (const BasicBlock *Pred : predecessors(CurrBB))
2554  if (const Instruction *Terminator = Pred->getTerminator())
2555  Worklist.push_back(Terminator);
2556  }
2557 
2558  return ReplVal;
2559  }
2560 };
2561 
2562 struct AAICVTrackerFunctionReturned : AAICVTracker {
2563  AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2564  : AAICVTracker(IRP, A) {}
2565 
2566  // FIXME: come up with better string.
2567  const std::string getAsStr() const override {
2568  return "ICVTrackerFunctionReturned";
2569  }
2570 
2571  // FIXME: come up with some stats.
2572  void trackStatistics() const override {}
2573 
2574  /// We don't manifest anything for this AA.
2575  ChangeStatus manifest(Attributor &A) override {
2576  return ChangeStatus::UNCHANGED;
2577  }
2578 
2579  // Map of ICV to their values at specific program point.
2581  InternalControlVar::ICV___last>
2582  ICVReplacementValuesMap;
2583 
2584  /// Return the value with which \p I can be replaced for specific \p ICV.
2586  getUniqueReplacementValue(InternalControlVar ICV) const override {
2587  return ICVReplacementValuesMap[ICV];
2588  }
2589 
2590  ChangeStatus updateImpl(Attributor &A) override {
2592  const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2593  *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2594 
2595  if (!ICVTrackingAA.isAssumedTracked())
2596  return indicatePessimisticFixpoint();
2597 
2598  for (InternalControlVar ICV : TrackableICVs) {
2599  Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2600  Optional<Value *> UniqueICVValue;
2601 
2602  auto CheckReturnInst = [&](Instruction &I) {
2603  Optional<Value *> NewReplVal =
2604  ICVTrackingAA.getReplacementValue(ICV, &I, A);
2605 
2606  // If we found a second ICV value there is no unique returned value.
2607  if (UniqueICVValue && UniqueICVValue != NewReplVal)
2608  return false;
2609 
2610  UniqueICVValue = NewReplVal;
2611 
2612  return true;
2613  };
2614 
2615  bool UsedAssumedInformation = false;
2616  if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2617  UsedAssumedInformation,
2618  /* CheckBBLivenessOnly */ true))
2619  UniqueICVValue = nullptr;
2620 
2621  if (UniqueICVValue == ReplVal)
2622  continue;
2623 
2624  ReplVal = UniqueICVValue;
2625  Changed = ChangeStatus::CHANGED;
2626  }
2627 
2628  return Changed;
2629  }
2630 };
2631 
2632 struct AAICVTrackerCallSite : AAICVTracker {
2633  AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2634  : AAICVTracker(IRP, A) {}
2635 
2636  void initialize(Attributor &A) override {
2637  Function *F = getAnchorScope();
2638  if (!F || !A.isFunctionIPOAmendable(*F))
2639  indicatePessimisticFixpoint();
2640 
2641  // We only initialize this AA for getters, so we need to know which ICV it
2642  // gets.
2643  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2644  for (InternalControlVar ICV : TrackableICVs) {
2645  auto ICVInfo = OMPInfoCache.ICVs[ICV];
2646  auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2647  if (Getter.Declaration == getAssociatedFunction()) {
2648  AssociatedICV = ICVInfo.Kind;
2649  return;
2650  }
2651  }
2652 
2653  /// Unknown ICV.
2654  indicatePessimisticFixpoint();
2655  }
2656 
2657  ChangeStatus manifest(Attributor &A) override {
2658  if (!ReplVal || !*ReplVal)
2659  return ChangeStatus::UNCHANGED;
2660 
2661  A.changeAfterManifest(IRPosition::inst(*getCtxI()), **ReplVal);
2662  A.deleteAfterManifest(*getCtxI());
2663 
2664  return ChangeStatus::CHANGED;
2665  }
2666 
2667  // FIXME: come up with better string.
2668  const std::string getAsStr() const override { return "ICVTrackerCallSite"; }
2669 
2670  // FIXME: come up with some stats.
2671  void trackStatistics() const override {}
2672 
2673  InternalControlVar AssociatedICV;
2674  Optional<Value *> ReplVal;
2675 
2676  ChangeStatus updateImpl(Attributor &A) override {
2677  const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2678  *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2679 
2680  // We don't have any information, so we assume it changes the ICV.
2681  if (!ICVTrackingAA.isAssumedTracked())
2682  return indicatePessimisticFixpoint();
2683 
2684  Optional<Value *> NewReplVal =
2685  ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A);
2686 
2687  if (ReplVal == NewReplVal)
2688  return ChangeStatus::UNCHANGED;
2689 
2690  ReplVal = NewReplVal;
2691  return ChangeStatus::CHANGED;
2692  }
2693 
2694  // Return the value with which associated value can be replaced for specific
2695  // \p ICV.
2697  getUniqueReplacementValue(InternalControlVar ICV) const override {
2698  return ReplVal;
2699  }
2700 };
2701 
2702 struct AAICVTrackerCallSiteReturned : AAICVTracker {
2703  AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2704  : AAICVTracker(IRP, A) {}
2705 
2706  // FIXME: come up with better string.
2707  const std::string getAsStr() const override {
2708  return "ICVTrackerCallSiteReturned";
2709  }
2710 
2711  // FIXME: come up with some stats.
2712  void trackStatistics() const override {}
2713 
2714  /// We don't manifest anything for this AA.
2715  ChangeStatus manifest(Attributor &A) override {
2716  return ChangeStatus::UNCHANGED;
2717  }
2718 
2719  // Map of ICV to their values at specific program point.
2721  InternalControlVar::ICV___last>
2722  ICVReplacementValuesMap;
2723 
2724  /// Return the value with which associated value can be replaced for specific
2725  /// \p ICV.
2727  getUniqueReplacementValue(InternalControlVar ICV) const override {
2728  return ICVReplacementValuesMap[ICV];
2729  }
2730 
2731  ChangeStatus updateImpl(Attributor &A) override {
2733  const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2734  *this, IRPosition::returned(*getAssociatedFunction()),
2736 
2737  // We don't have any information, so we assume it changes the ICV.
2738  if (!ICVTrackingAA.isAssumedTracked())
2739  return indicatePessimisticFixpoint();
2740 
2741  for (InternalControlVar ICV : TrackableICVs) {
2742  Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2743  Optional<Value *> NewReplVal =
2744  ICVTrackingAA.getUniqueReplacementValue(ICV);
2745 
2746  if (ReplVal == NewReplVal)
2747  continue;
2748 
2749  ReplVal = NewReplVal;
2750  Changed = ChangeStatus::CHANGED;
2751  }
2752  return Changed;
2753  }
2754 };
2755 
2756 struct AAExecutionDomainFunction : public AAExecutionDomain {
2757  AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2758  : AAExecutionDomain(IRP, A) {}
2759 
2760  const std::string getAsStr() const override {
2761  return "[AAExecutionDomain] " + std::to_string(SingleThreadedBBs.size()) +
2762  "/" + std::to_string(NumBBs) + " BBs thread 0 only.";
2763  }
2764 
2765  /// See AbstractAttribute::trackStatistics().
2766  void trackStatistics() const override {}
2767 
2768  void initialize(Attributor &A) override {
2769  Function *F = getAnchorScope();
2770  for (const auto &BB : *F)
2771  SingleThreadedBBs.insert(&BB);
2772  NumBBs = SingleThreadedBBs.size();
2773  }
2774 
2775  ChangeStatus manifest(Attributor &A) override {
2776  LLVM_DEBUG({
2777  for (const BasicBlock *BB : SingleThreadedBBs)
2778  dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2779  << BB->getName() << " is executed by a single thread.\n";
2780  });
2781  return ChangeStatus::UNCHANGED;
2782  }
2783 
2784  ChangeStatus updateImpl(Attributor &A) override;
2785 
2786  /// Check if an instruction is executed by a single thread.
2787  bool isExecutedByInitialThreadOnly(const Instruction &I) const override {
2788  return isExecutedByInitialThreadOnly(*I.getParent());
2789  }
2790 
2791  bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2792  return isValidState() && SingleThreadedBBs.contains(&BB);
2793  }
2794 
2795  /// Set of basic blocks that are executed by a single thread.
2796  SmallSetVector<const BasicBlock *, 16> SingleThreadedBBs;
2797 
2798  /// Total number of basic blocks in this function.
2799  long unsigned NumBBs = 0;
2800 };
2801 
2802 ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
2803  Function *F = getAnchorScope();
2805  auto NumSingleThreadedBBs = SingleThreadedBBs.size();
2806 
2807  bool AllCallSitesKnown;
2808  auto PredForCallSite = [&](AbstractCallSite ACS) {
2809  const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
2810  *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
2812  return ACS.isDirectCall() &&
2813  ExecutionDomainAA.isExecutedByInitialThreadOnly(
2814  *ACS.getInstruction());
2815  };
2816 
2817  if (!A.checkForAllCallSites(PredForCallSite, *this,
2818  /* RequiresAllCallSites */ true,
2819  AllCallSitesKnown))
2820  SingleThreadedBBs.remove(&F->getEntryBlock());
2821 
2822  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2823  auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2824 
2825  // Check if the edge into the successor block contains a condition that only
2826  // lets the main thread execute it.
2827  auto IsInitialThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) {
2828  if (!Edge || !Edge->isConditional())
2829  return false;
2830  if (Edge->getSuccessor(0) != SuccessorBB)
2831  return false;
2832 
2833  auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2834  if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2835  return false;
2836 
2837  ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2838  if (!C)
2839  return false;
2840 
2841  // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
2842  if (C->isAllOnesValue()) {
2843  auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2844  CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2845  if (!CB)
2846  return false;
2847  const int InitModeArgNo = 1;
2848  auto *ModeCI = dyn_cast<ConstantInt>(CB->getOperand(InitModeArgNo));
2849  return ModeCI && (ModeCI->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC);
2850  }
2851 
2852  if (C->isZero()) {
2853  // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x()
2854  if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2855  if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2856  return true;
2857 
2858  // Match: 0 == llvm.amdgcn.workitem.id.x()
2859  if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2860  if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2861  return true;
2862  }
2863 
2864  return false;
2865  };
2866 
2867  // Merge all the predecessor states into the current basic block. A basic
2868  // block is executed by a single thread if all of its predecessors are.
2869  auto MergePredecessorStates = [&](BasicBlock *BB) {
2870  if (pred_empty(BB))
2871  return SingleThreadedBBs.contains(BB);
2872 
2873  bool IsInitialThread = true;
2874  for (BasicBlock *PredBB : predecessors(BB)) {
2875  if (!IsInitialThreadOnly(dyn_cast<BranchInst>(PredBB->getTerminator()),
2876  BB))
2877  IsInitialThread &= SingleThreadedBBs.contains(PredBB);
2878  }
2879 
2880  return IsInitialThread;
2881  };
2882 
2883  for (auto *BB : RPOT) {
2884  if (!MergePredecessorStates(BB))
2885  SingleThreadedBBs.remove(BB);
2886  }
2887 
2888  return (NumSingleThreadedBBs == SingleThreadedBBs.size())
2891 }
2892 
2893 /// Try to replace memory allocation calls called by a single thread with a
2894 /// static buffer of shared memory.
2895 struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
2897  AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2898 
2899  /// Create an abstract attribute view for the position \p IRP.
2900  static AAHeapToShared &createForPosition(const IRPosition &IRP,
2901  Attributor &A);
2902 
2903  /// Returns true if HeapToShared conversion is assumed to be possible.
2904  virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
2905 
2906  /// Returns true if HeapToShared conversion is assumed and the CB is a
2907  /// callsite to a free operation to be removed.
2908  virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
2909 
2910  /// See AbstractAttribute::getName().
2911  const std::string getName() const override { return "AAHeapToShared"; }
2912 
2913  /// See AbstractAttribute::getIdAddr().
2914  const char *getIdAddr() const override { return &ID; }
2915 
2916  /// This function should return true if the type of the \p AA is
2917  /// AAHeapToShared.
2918  static bool classof(const AbstractAttribute *AA) {
2919  return (AA->getIdAddr() == &ID);
2920  }
2921 
2922  /// Unique ID (due to the unique address)
2923  static const char ID;
2924 };
2925 
2926 struct AAHeapToSharedFunction : public AAHeapToShared {
2927  AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
2928  : AAHeapToShared(IRP, A) {}
2929 
2930  const std::string getAsStr() const override {
2931  return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
2932  " malloc calls eligible.";
2933  }
2934 
2935  /// See AbstractAttribute::trackStatistics().
2936  void trackStatistics() const override {}
2937 
2938  /// This functions finds free calls that will be removed by the
2939  /// HeapToShared transformation.
2940  void findPotentialRemovedFreeCalls(Attributor &A) {
2941  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2942  auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2943 
2944  PotentialRemovedFreeCalls.clear();
2945  // Update free call users of found malloc calls.
2946  for (CallBase *CB : MallocCalls) {
2947  SmallVector<CallBase *, 4> FreeCalls;
2948  for (auto *U : CB->users()) {
2949  CallBase *C = dyn_cast<CallBase>(U);
2950  if (C && C->getCalledFunction() == FreeRFI.Declaration)
2951  FreeCalls.push_back(C);
2952  }
2953 
2954  if (FreeCalls.size() != 1)
2955  continue;
2956 
2957  PotentialRemovedFreeCalls.insert(FreeCalls.front());
2958  }
2959  }
2960 
2961  void initialize(Attributor &A) override {
2963  indicatePessimisticFixpoint();
2964  return;
2965  }
2966 
2967  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2968  auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
2969 
2971  [](const IRPosition &, const AbstractAttribute *,
2972  bool &) -> Optional<Value *> { return nullptr; };
2973  for (User *U : RFI.Declaration->users())
2974  if (CallBase *CB = dyn_cast<CallBase>(U)) {
2975  MallocCalls.insert(CB);
2976  A.registerSimplificationCallback(IRPosition::callsite_returned(*CB),
2977  SCB);
2978  }
2979 
2980  findPotentialRemovedFreeCalls(A);
2981  }
2982 
2983  bool isAssumedHeapToShared(CallBase &CB) const override {
2984  return isValidState() && MallocCalls.count(&CB);
2985  }
2986 
2987  bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
2988  return isValidState() && PotentialRemovedFreeCalls.count(&CB);
2989  }
2990 
2991  ChangeStatus manifest(Attributor &A) override {
2992  if (MallocCalls.empty())
2993  return ChangeStatus::UNCHANGED;
2994 
2995  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2996  auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2997 
2998  Function *F = getAnchorScope();
2999  auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
3001 
3003  for (CallBase *CB : MallocCalls) {
3004  // Skip replacing this if HeapToStack has already claimed it.
3005  if (HS && HS->isAssumedHeapToStack(*CB))
3006  continue;
3007 
3008  // Find the unique free call to remove it.
3009  SmallVector<CallBase *, 4> FreeCalls;
3010  for (auto *U : CB->users()) {
3011  CallBase *C = dyn_cast<CallBase>(U);
3012  if (C && C->getCalledFunction() == FreeCall.Declaration)
3013  FreeCalls.push_back(C);
3014  }
3015  if (FreeCalls.size() != 1)
3016  continue;
3017 
3018  auto *AllocSize = cast<ConstantInt>(CB->getArgOperand(0));
3019 
3020  if (AllocSize->getZExtValue() + SharedMemoryUsed > SharedMemoryLimit) {
3021  LLVM_DEBUG(dbgs() << TAG << "Cannot replace call " << *CB
3022  << " with shared memory."
3023  << " Shared memory usage is limited to "
3024  << SharedMemoryLimit << " bytes\n");
3025  continue;
3026  }
3027 
3028  LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
3029  << " with " << AllocSize->getZExtValue()
3030  << " bytes of shared memory\n");
3031 
3032  // Create a new shared memory buffer of the same size as the allocation
3033  // and replace all the uses of the original allocation with it.
3034  Module *M = CB->getModule();
3035  Type *Int8Ty = Type::getInt8Ty(M->getContext());
3036  Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
3037  auto *SharedMem = new GlobalVariable(
3038  *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
3039  UndefValue::get(Int8ArrTy), CB->getName() + "_shared", nullptr,
3041  static_cast<unsigned>(AddressSpace::Shared));
3042  auto *NewBuffer =
3043  ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo());
3044 
3045  auto Remark = [&](OptimizationRemark OR) {
3046  return OR << "Replaced globalized variable with "
3047  << ore::NV("SharedMemory", AllocSize->getZExtValue())
3048  << ((AllocSize->getZExtValue() != 1) ? " bytes " : " byte ")
3049  << "of shared memory.";
3050  };
3051  A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
3052 
3053  MaybeAlign Alignment = CB->getRetAlign();
3054  assert(Alignment &&
3055  "HeapToShared on allocation without alignment attribute");
3056  SharedMem->setAlignment(MaybeAlign(Alignment));
3057 
3058  A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewBuffer);
3059  A.deleteAfterManifest(*CB);
3060  A.deleteAfterManifest(*FreeCalls.front());
3061 
3062  SharedMemoryUsed += AllocSize->getZExtValue();
3063  NumBytesMovedToSharedMemory = SharedMemoryUsed;
3064  Changed = ChangeStatus::CHANGED;
3065  }
3066 
3067  return Changed;
3068  }
3069 
3070  ChangeStatus updateImpl(Attributor &A) override {
3071  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3072  auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3073  Function *F = getAnchorScope();
3074 
3075  auto NumMallocCalls = MallocCalls.size();
3076 
3077  // Only consider malloc calls executed by a single thread with a constant.
3078  for (User *U : RFI.Declaration->users()) {
3079  const auto &ED = A.getAAFor<AAExecutionDomain>(
3081  if (CallBase *CB = dyn_cast<CallBase>(U))
3082  if (!isa<ConstantInt>(CB->getArgOperand(0)) ||
3083  !ED.isExecutedByInitialThreadOnly(*CB))
3084  MallocCalls.remove(CB);
3085  }
3086 
3087  findPotentialRemovedFreeCalls(A);
3088 
3089  if (NumMallocCalls != MallocCalls.size())
3090  return ChangeStatus::CHANGED;
3091 
3092  return ChangeStatus::UNCHANGED;
3093  }
3094 
3095  /// Collection of all malloc calls in a function.
3096  SmallSetVector<CallBase *, 4> MallocCalls;
3097  /// Collection of potentially removed free calls in a function.
3098  SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
3099  /// The total amount of shared memory that has been used for HeapToShared.
3100  unsigned SharedMemoryUsed = 0;
3101 };
3102 
3103 struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
3105  AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3106 
3107  /// Statistics are tracked as part of manifest for now.
3108  void trackStatistics() const override {}
3109 
3110  /// See AbstractAttribute::getAsStr()
3111  const std::string getAsStr() const override {
3112  if (!isValidState())
3113  return "<invalid>";
3114  return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
3115  : "generic") +
3116  std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
3117  : "") +
3118  std::string(" #PRs: ") +
3119  (ReachedKnownParallelRegions.isValidState()
3120  ? std::to_string(ReachedKnownParallelRegions.size())
3121  : "<invalid>") +
3122  ", #Unknown PRs: " +
3123  (ReachedUnknownParallelRegions.isValidState()
3124  ? std::to_string(ReachedUnknownParallelRegions.size())
3125  : "<invalid>") +
3126  ", #Reaching Kernels: " +
3127  (ReachingKernelEntries.isValidState()
3128  ? std::to_string(ReachingKernelEntries.size())
3129  : "<invalid>");
3130  }
3131 
3132  /// Create an abstract attribute biew for the position \p IRP.
3133  static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
3134 
3135  /// See AbstractAttribute::getName()
3136  const std::string getName() const override { return "AAKernelInfo"; }
3137 
3138  /// See AbstractAttribute::getIdAddr()
3139  const char *getIdAddr() const override { return &ID; }
3140 
3141  /// This function should return true if the type of the \p AA is AAKernelInfo
3142  static bool classof(const AbstractAttribute *AA) {
3143  return (AA->getIdAddr() == &ID);
3144  }
3145 
3146  static const char ID;
3147 };
3148 
3149 /// The function kernel info abstract attribute, basically, what can we say
3150 /// about a function with regards to the KernelInfoState.
3151 struct AAKernelInfoFunction : AAKernelInfo {
3152  AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
3153  : AAKernelInfo(IRP, A) {}
3154 
3155  SmallPtrSet<Instruction *, 4> GuardedInstructions;
3156 
3157  SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
3158  return GuardedInstructions;
3159  }
3160 
3161  /// See AbstractAttribute::initialize(...).
3162  void initialize(Attributor &A) override {
3163  // This is a high-level transform that might change the constant arguments
3164  // of the init and dinit calls. We need to tell the Attributor about this
3165  // to avoid other parts using the current constant value for simpliication.
3166  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3167 
3168  Function *Fn = getAnchorScope();
3169 
3170  OMPInformationCache::RuntimeFunctionInfo &InitRFI =
3171  OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
3172  OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
3173  OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
3174 
3175  // For kernels we perform more initialization work, first we find the init
3176  // and deinit calls.
3177  auto StoreCallBase = [](Use &U,
3178  OMPInformationCache::RuntimeFunctionInfo &RFI,
3179  CallBase *&Storage) {
3180  CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
3181  assert(CB &&
3182  "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
3183  assert(!Storage &&
3184  "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
3185  Storage = CB;
3186  return false;
3187  };
3188  InitRFI.foreachUse(
3189  [&](Use &U, Function &) {
3190  StoreCallBase(U, InitRFI, KernelInitCB);
3191  return false;
3192  },
3193  Fn);
3194  DeinitRFI.foreachUse(
3195  [&](Use &U, Function &) {
3196  StoreCallBase(U, DeinitRFI, KernelDeinitCB);
3197  return false;
3198  },
3199  Fn);
3200 
3201  // Ignore kernels without initializers such as global constructors.
3202  if (!KernelInitCB || !KernelDeinitCB)
3203  return;
3204 
3205  // Add itself to the reaching kernel and set IsKernelEntry.
3206  ReachingKernelEntries.insert(Fn);
3207  IsKernelEntry = true;
3208 
3209  // For kernels we might need to initialize/finalize the IsSPMD state and
3210  // we need to register a simplification callback so that the Attributor
3211  // knows the constant arguments to __kmpc_target_init and
3212  // __kmpc_target_deinit might actually change.
3213 
3214  Attributor::SimplifictionCallbackTy StateMachineSimplifyCB =
3215  [&](const IRPosition &IRP, const AbstractAttribute *AA,
3216  bool &UsedAssumedInformation) -> Optional<Value *> {
3217  // IRP represents the "use generic state machine" argument of an
3218  // __kmpc_target_init call. We will answer this one with the internal
3219  // state. As long as we are not in an invalid state, we will create a
3220  // custom state machine so the value should be a `i1 false`. If we are
3221  // in an invalid state, we won't change the value that is in the IR.
3222  if (!ReachedKnownParallelRegions.isValidState())
3223  return nullptr;
3224  // If we have disabled state machine rewrites, don't make a custom one.
3226  return nullptr;
3227  if (AA)
3228  A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3229  UsedAssumedInformation = !isAtFixpoint();
3230  auto *FalseVal =
3232  return FalseVal;
3233  };
3234 
3235  Attributor::SimplifictionCallbackTy ModeSimplifyCB =
3236  [&](const IRPosition &IRP, const AbstractAttribute *AA,
3237  bool &UsedAssumedInformation) -> Optional<Value *> {
3238  // IRP represents the "SPMDCompatibilityTracker" argument of an
3239  // __kmpc_target_init or
3240  // __kmpc_target_deinit call. We will answer this one with the internal
3241  // state.
3242  if (!SPMDCompatibilityTracker.isValidState())
3243  return nullptr;
3244  if (!SPMDCompatibilityTracker.isAtFixpoint()) {
3245  if (AA)
3246  A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3247  UsedAssumedInformation = true;
3248  } else {
3249  UsedAssumedInformation = false;
3250  }
3251  auto *Val = ConstantInt::getSigned(
3253  SPMDCompatibilityTracker.isAssumed() ? OMP_TGT_EXEC_MODE_SPMD
3255  return Val;
3256  };
3257 
3258  Attributor::SimplifictionCallbackTy IsGenericModeSimplifyCB =
3259  [&](const IRPosition &IRP, const AbstractAttribute *AA,
3260  bool &UsedAssumedInformation) -> Optional<Value *> {
3261  // IRP represents the "RequiresFullRuntime" argument of an
3262  // __kmpc_target_init or __kmpc_target_deinit call. We will answer this
3263  // one with the internal state of the SPMDCompatibilityTracker, so if
3264  // generic then true, if SPMD then false.
3265  if (!SPMDCompatibilityTracker.isValidState())
3266  return nullptr;
3267  if (!SPMDCompatibilityTracker.isAtFixpoint()) {
3268  if (AA)
3269  A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3270  UsedAssumedInformation = true;
3271  } else {
3272  UsedAssumedInformation = false;
3273  }
3274  auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(),
3275  !SPMDCompatibilityTracker.isAssumed());
3276  return Val;
3277  };
3278 
3279  constexpr const int InitModeArgNo = 1;
3280  constexpr const int DeinitModeArgNo = 1;
3281  constexpr const int InitUseStateMachineArgNo = 2;
3282  constexpr const int InitRequiresFullRuntimeArgNo = 3;
3283  constexpr const int DeinitRequiresFullRuntimeArgNo = 2;
3284  A.registerSimplificationCallback(
3285  IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo),
3286  StateMachineSimplifyCB);
3287  A.registerSimplificationCallback(
3288  IRPosition::callsite_argument(*KernelInitCB, InitModeArgNo),
3289  ModeSimplifyCB);
3290  A.registerSimplificationCallback(
3291  IRPosition::callsite_argument(*KernelDeinitCB, DeinitModeArgNo),
3292  ModeSimplifyCB);
3293  A.registerSimplificationCallback(
3294  IRPosition::callsite_argument(*KernelInitCB,
3295  InitRequiresFullRuntimeArgNo),
3296  IsGenericModeSimplifyCB);
3297  A.registerSimplificationCallback(
3298  IRPosition::callsite_argument(*KernelDeinitCB,
3299  DeinitRequiresFullRuntimeArgNo),
3300  IsGenericModeSimplifyCB);
3301 
3302  // Check if we know we are in SPMD-mode already.
3303  ConstantInt *ModeArg =
3304  dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo));
3305  if (ModeArg && (ModeArg->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD))
3306  SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3307  // This is a generic region but SPMDization is disabled so stop tracking.
3308  else if (DisableOpenMPOptSPMDization)
3309  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3310  }
3311 
3312  /// Sanitize the string \p S such that it is a suitable global symbol name.
3313  static std::string sanitizeForGlobalName(std::string S) {
3314  std::replace_if(
3315  S.begin(), S.end(),
3316  [](const char C) {
3317  return !((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') ||
3318  (C >= '0' && C <= '9') || C == '_');
3319  },
3320  '.');
3321  return S;
3322  }
3323 
3324  /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
3325  /// finished now.
3326  ChangeStatus manifest(Attributor &A) override {
3327  // If we are not looking at a kernel with __kmpc_target_init and
3328  // __kmpc_target_deinit call we cannot actually manifest the information.
3329  if (!KernelInitCB || !KernelDeinitCB)
3330  return ChangeStatus::UNCHANGED;
3331 
3332  // If we can we change the execution mode to SPMD-mode otherwise we build a
3333  // custom state machine.
3335  if (!changeToSPMDMode(A, Changed)) {
3336  if (!KernelInitCB->getCalledFunction()->isDeclaration())
3337  return buildCustomStateMachine(A);
3338  }
3339 
3340  return Changed;
3341  }
3342 
3343  void insertInstructionGuardsHelper(Attributor &A) {
3344  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3345 
3346  auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3347  Instruction *RegionEndI) {
3348  LoopInfo *LI = nullptr;
3349  DominatorTree *DT = nullptr;
3350  MemorySSAUpdater *MSU = nullptr;
3351  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3352 
3353  BasicBlock *ParentBB = RegionStartI->getParent();
3354  Function *Fn = ParentBB->getParent();
3355  Module &M = *Fn->getParent();
3356 
3357  // Create all the blocks and logic.
3358  // ParentBB:
3359  // goto RegionCheckTidBB
3360  // RegionCheckTidBB:
3361  // Tid = __kmpc_hardware_thread_id()
3362  // if (Tid != 0)
3363  // goto RegionBarrierBB
3364  // RegionStartBB:
3365  // <execute instructions guarded>
3366  // goto RegionEndBB
3367  // RegionEndBB:
3368  // <store escaping values to shared mem>
3369  // goto RegionBarrierBB
3370  // RegionBarrierBB:
3371  // __kmpc_simple_barrier_spmd()
3372  // // second barrier is omitted if lacking escaping values.
3373  // <load escaping values from shared mem>
3374  // __kmpc_simple_barrier_spmd()
3375  // goto RegionExitBB
3376  // RegionExitBB:
3377  // <execute rest of instructions>
3378 
3379  BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3380  DT, LI, MSU, "region.guarded.end");
3381  BasicBlock *RegionBarrierBB =
3382  SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3383  MSU, "region.barrier");
3384  BasicBlock *RegionExitBB =
3385  SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3386  DT, LI, MSU, "region.exit");
3387  BasicBlock *RegionStartBB =
3388  SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
3389 
3390  assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
3391  "Expected a different CFG");
3392 
3393  BasicBlock *RegionCheckTidBB = SplitBlock(
3394  ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
3395 
3396  // Register basic blocks with the Attributor.
3397  A.registerManifestAddedBasicBlock(*RegionEndBB);
3398  A.registerManifestAddedBasicBlock(*RegionBarrierBB);
3399  A.registerManifestAddedBasicBlock(*RegionExitBB);
3400  A.registerManifestAddedBasicBlock(*RegionStartBB);
3401  A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
3402 
3403  bool HasBroadcastValues = false;
3404  // Find escaping outputs from the guarded region to outside users and
3405  // broadcast their values to them.
3406  for (Instruction &I : *RegionStartBB) {
3407  SmallPtrSet<Instruction *, 4> OutsideUsers;
3408  for (User *Usr : I.users()) {
3409  Instruction &UsrI = *cast<Instruction>(Usr);
3410  if (UsrI.getParent() != RegionStartBB)
3411  OutsideUsers.insert(&UsrI);
3412  }
3413 
3414  if (OutsideUsers.empty())
3415  continue;
3416 
3417  HasBroadcastValues = true;
3418 
3419  // Emit a global variable in shared memory to store the broadcasted
3420  // value.
3421  auto *SharedMem = new GlobalVariable(
3422  M, I.getType(), /* IsConstant */ false,
3424  sanitizeForGlobalName(
3425  (I.getName() + ".guarded.output.alloc").str()),
3426  nullptr, GlobalValue::NotThreadLocal,
3427  static_cast<unsigned>(AddressSpace::Shared));
3428 
3429  // Emit a store instruction to update the value.
3430  new StoreInst(&I, SharedMem, RegionEndBB->getTerminator());
3431 
3432  LoadInst *LoadI = new LoadInst(I.getType(), SharedMem,
3433  I.getName() + ".guarded.output.load",
3434  RegionBarrierBB->getTerminator());
3435 
3436  // Emit a load instruction and replace uses of the output value.
3437  for (Instruction *UsrI : OutsideUsers)
3438  UsrI->replaceUsesOfWith(&I, LoadI);
3439  }
3440 
3441  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3442 
3443  // Go to tid check BB in ParentBB.
3444  const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
3445  ParentBB->getTerminator()->eraseFromParent();
3447  InsertPointTy(ParentBB, ParentBB->end()), DL);
3448  OMPInfoCache.OMPBuilder.updateToLocation(Loc);
3449  uint32_t SrcLocStrSize;
3450  auto *SrcLocStr =
3451  OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3452  Value *Ident =
3453  OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3454  BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
3455 
3456  // Add check for Tid in RegionCheckTidBB
3457  RegionCheckTidBB->getTerminator()->eraseFromParent();
3458  OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
3459  InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
3460  OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
3461  FunctionCallee HardwareTidFn =
3462  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3463  M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
3464  CallInst *Tid =
3465  OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
3466  Tid->setDebugLoc(DL);
3467  OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);
3468  Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
3469  OMPInfoCache.OMPBuilder.Builder
3470  .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
3471  ->setDebugLoc(DL);
3472 
3473  // First barrier for synchronization, ensures main thread has updated
3474  // values.
3475  FunctionCallee BarrierFn =
3476  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3477  M, OMPRTL___kmpc_barrier_simple_spmd);
3478  OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
3479  RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
3480  CallInst *Barrier =
3481  OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});
3482  Barrier->setDebugLoc(DL);
3483  OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
3484 
3485  // Second barrier ensures workers have read broadcast values.
3486  if (HasBroadcastValues) {
3487  CallInst *Barrier = CallInst::Create(BarrierFn, {Ident, Tid}, "",
3488  RegionBarrierBB->getTerminator());
3489  Barrier->setDebugLoc(DL);
3490  OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
3491  }
3492  };
3493 
3494  auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3496  for (Instruction *GuardedI : SPMDCompatibilityTracker) {
3497  BasicBlock *BB = GuardedI->getParent();
3498  if (!Visited.insert(BB).second)
3499  continue;
3500 
3502  Instruction *LastEffect = nullptr;
3503  BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();
3504  while (++IP != IPEnd) {
3505  if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
3506  continue;
3507  Instruction *I = &*IP;
3508  if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI))
3509  continue;
3510  if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) {
3511  LastEffect = nullptr;
3512  continue;
3513  }
3514  if (LastEffect)
3515  Reorders.push_back({I, LastEffect});
3516  LastEffect = &*IP;
3517  }
3518  for (auto &Reorder : Reorders)
3519  Reorder.first->moveBefore(Reorder.second);
3520  }
3521 
3523 
3524  for (Instruction *GuardedI : SPMDCompatibilityTracker) {
3525  BasicBlock *BB = GuardedI->getParent();
3526  auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
3527  IRPosition::function(*GuardedI->getFunction()), nullptr,
3529  assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
3530  auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
3531  // Continue if instruction is already guarded.
3532  if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
3533  continue;
3534 
3535  Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
3536  for (Instruction &I : *BB) {
3537  // If instruction I needs to be guarded update the guarded region
3538  // bounds.
3539  if (SPMDCompatibilityTracker.contains(&I)) {
3540  CalleeAAFunction.getGuardedInstructions().insert(&I);
3541  if (GuardedRegionStart)
3542  GuardedRegionEnd = &I;
3543  else
3544  GuardedRegionStart = GuardedRegionEnd = &I;
3545 
3546  continue;
3547  }
3548 
3549  // Instruction I does not need guarding, store
3550  // any region found and reset bounds.
3551  if (GuardedRegionStart) {
3552  GuardedRegions.push_back(
3553  std::make_pair(GuardedRegionStart, GuardedRegionEnd));
3554  GuardedRegionStart = nullptr;
3555  GuardedRegionEnd = nullptr;
3556  }
3557  }
3558  }
3559 
3560  for (auto &GR : GuardedRegions)
3561  CreateGuardedRegion(GR.first, GR.second);
3562  }
3563 
3564  void forceSingleThreadPerWorkgroupHelper(Attributor &A) {
3565  // Only allow 1 thread per workgroup to continue executing the user code.
3566  //
3567  // InitCB = __kmpc_target_init(...)
3568  // ThreadIdInBlock = __kmpc_get_hardware_thread_id_in_block();
3569  // if (ThreadIdInBlock != 0) return;
3570  // UserCode:
3571  // // user code
3572  //
3573  auto &Ctx = getAnchorValue().getContext();
3574  Function *Kernel = getAssociatedFunction();
3575  assert(Kernel && "Expected an associated function!");
3576 
3577  // Create block for user code to branch to from initial block.
3578  BasicBlock *InitBB = KernelInitCB->getParent();
3579  BasicBlock *UserCodeBB = InitBB->splitBasicBlock(
3580  KernelInitCB->getNextNode(), "main.thread.user_code");
3581  BasicBlock *ReturnBB =
3582  BasicBlock::Create(Ctx, "exit.threads", Kernel, UserCodeBB);
3583 
3584  // Register blocks with attributor:
3585  A.registerManifestAddedBasicBlock(*InitBB);
3586  A.registerManifestAddedBasicBlock(*UserCodeBB);
3587  A.registerManifestAddedBasicBlock(*ReturnBB);
3588 
3589  // Debug location:
3590  const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
3591  ReturnInst::Create(Ctx, ReturnBB)->setDebugLoc(DLoc);
3592  InitBB->getTerminator()->eraseFromParent();
3593 
3594  // Prepare call to OMPRTL___kmpc_get_hardware_thread_id_in_block.
3595  Module &M = *Kernel->getParent();
3596  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3597  FunctionCallee ThreadIdInBlockFn =
3598  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3599  M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
3600 
3601  // Get thread ID in block.
3602  CallInst *ThreadIdInBlock =
3603  CallInst::Create(ThreadIdInBlockFn, "thread_id.in.block", InitBB);
3604  OMPInfoCache.setCallingConvention(ThreadIdInBlockFn, ThreadIdInBlock);
3605  ThreadIdInBlock->setDebugLoc(DLoc);
3606 
3607  // Eliminate all threads in the block with ID not equal to 0:
3608  Instruction *IsMainThread =
3609  ICmpInst::Create(ICmpInst::ICmp, CmpInst::ICMP_NE, ThreadIdInBlock,
3610  ConstantInt::get(ThreadIdInBlock->getType(), 0),
3611  "thread.is_main", InitBB);
3612  IsMainThread->setDebugLoc(DLoc);
3613  BranchInst::Create(ReturnBB, UserCodeBB, IsMainThread, InitBB);
3614  }
3615 
3616  bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {
3617  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3618 
3619  if (!SPMDCompatibilityTracker.isAssumed()) {
3620  for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
3621  if (!NonCompatibleI)
3622  continue;
3623 
3624  // Skip diagnostics on calls to known OpenMP runtime functions for now.
3625  if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
3626  if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
3627  continue;
3628 
3629  auto Remark = [&](OptimizationRemarkAnalysis ORA) {
3630  ORA << "Value has potential side effects preventing SPMD-mode "
3631  "execution";
3632  if (isa<CallBase>(NonCompatibleI)) {
3633  ORA << ". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to "
3634  "the called function to override";
3635  }
3636  return ORA << ".";
3637  };
3638  A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
3639  Remark);
3640 
3641  LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
3642  << *NonCompatibleI << "\n");
3643  }
3644 
3645  return false;
3646  }
3647 
3648  // Get the actual kernel, could be the caller of the anchor scope if we have
3649  // a debug wrapper.
3650  Function *Kernel = getAnchorScope();
3651  if (Kernel->hasLocalLinkage()) {
3652  assert(Kernel->hasOneUse() && "Unexpected use of debug kernel wrapper.");
3653  auto *CB = cast<CallBase>(Kernel->user_back());
3654  Kernel = CB->getCaller();
3655  }
3656  assert(OMPInfoCache.Kernels.count(Kernel) && "Expected kernel function!");
3657 
3658  // Check if the kernel is already in SPMD mode, if so, return success.
3660  (Kernel->getName() + "_exec_mode").str());
3661  assert(ExecMode && "Kernel without exec mode?");
3662  assert(ExecMode->getInitializer() && "ExecMode doesn't have initializer!");
3663 
3664  // Set the global exec mode flag to indicate SPMD-Generic mode.
3665  assert(isa<ConstantInt>(ExecMode->getInitializer()) &&
3666  "ExecMode is not an integer!");
3667  const int8_t ExecModeVal =
3668  cast<ConstantInt>(ExecMode->getInitializer())->getSExtValue();
3669  if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC)
3670  return true;
3671 
3672  // We will now unconditionally modify the IR, indicate a change.
3673  Changed = ChangeStatus::CHANGED;
3674 
3675  // Do not use instruction guards when no parallel is present inside
3676  // the target region.
3677  if (mayContainParallelRegion())
3678  insertInstructionGuardsHelper(A);
3679  else
3680  forceSingleThreadPerWorkgroupHelper(A);
3681 
3682  // Adjust the global exec mode flag that tells the runtime what mode this
3683  // kernel is executed in.
3684  assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
3685  "Initially non-SPMD kernel has SPMD exec mode!");
3686  ExecMode->setInitializer(
3687  ConstantInt::get(ExecMode->getInitializer()->getType(),
3688  ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
3689 
3690  // Next rewrite the init and deinit calls to indicate we use SPMD-mode now.
3691  const int InitModeArgNo = 1;
3692  const int DeinitModeArgNo = 1;
3693  const int InitUseStateMachineArgNo = 2;
3694  const int InitRequiresFullRuntimeArgNo = 3;
3695  const int DeinitRequiresFullRuntimeArgNo = 2;
3696 
3697  auto &Ctx = getAnchorValue().getContext();
3698  A.changeUseAfterManifest(
3699  KernelInitCB->getArgOperandUse(InitModeArgNo),
3702  A.changeUseAfterManifest(
3703  KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo),
3704  *ConstantInt::getBool(Ctx, false));
3705  A.changeUseAfterManifest(
3706  KernelDeinitCB->getArgOperandUse(DeinitModeArgNo),
3709  A.changeUseAfterManifest(
3710  KernelInitCB->getArgOperandUse(InitRequiresFullRuntimeArgNo),
3711  *ConstantInt::getBool(Ctx, false));
3712  A.changeUseAfterManifest(
3713  KernelDeinitCB->getArgOperandUse(DeinitRequiresFullRuntimeArgNo),
3714  *ConstantInt::getBool(Ctx, false));
3715 
3716  ++NumOpenMPTargetRegionKernelsSPMD;
3717 
3718  auto Remark = [&](OptimizationRemark OR) {
3719  return OR << "Transformed generic-mode kernel to SPMD-mode.";
3720  };
3721  A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
3722  return true;
3723  };
3724 
3725  ChangeStatus buildCustomStateMachine(Attributor &A) {
3726  // If we have disabled state machine rewrites, don't make a custom one
3728  return ChangeStatus::UNCHANGED;
3729 
3730  // Don't rewrite the state machine if we are not in a valid state.
3731  if (!ReachedKnownParallelRegions.isValidState())
3732  return ChangeStatus::UNCHANGED;
3733 
3734  const int InitModeArgNo = 1;
3735  const int InitUseStateMachineArgNo = 2;
3736 
3737  // Check if the current configuration is non-SPMD and generic state machine.
3738  // If we already have SPMD mode or a custom state machine we do not need to
3739  // go any further. If it is anything but a constant something is weird and
3740  // we give up.
3741  ConstantInt *UseStateMachine = dyn_cast<ConstantInt>(
3742  KernelInitCB->getArgOperand(InitUseStateMachineArgNo));
3743  ConstantInt *Mode =
3744  dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo));
3745 
3746  // If we are stuck with generic mode, try to create a custom device (=GPU)
3747  // state machine which is specialized for the parallel regions that are
3748  // reachable by the kernel.
3749  if (!UseStateMachine || UseStateMachine->isZero() || !Mode ||
3750  (Mode->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD))
3751  return ChangeStatus::UNCHANGED;
3752 
3753  // If not SPMD mode, indicate we use a custom state machine now.
3754  auto &Ctx = getAnchorValue().getContext();
3755  auto *FalseVal = ConstantInt::getBool(Ctx, false);
3756  A.changeUseAfterManifest(
3757  KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *FalseVal);
3758 
3759  // If we don't actually need a state machine we are done here. This can
3760  // happen if there simply are no parallel regions. In the resulting kernel
3761  // all worker threads will simply exit right away, leaving the main thread
3762  // to do the work alone.
3763  if (!mayContainParallelRegion()) {
3764  ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
3765 
3766  auto Remark = [&](OptimizationRemark OR) {
3767  return OR << "Removing unused state machine from generic-mode kernel.";
3768  };
3769  A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
3770 
3771  return ChangeStatus::CHANGED;
3772  }
3773 
3774  // Keep track in the statistics of our new shiny custom state machine.
3775  if (ReachedUnknownParallelRegions.empty()) {
3776  ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
3777 
3778  auto Remark = [&](OptimizationRemark OR) {
3779  return OR << "Rewriting generic-mode kernel with a customized state "
3780  "machine.";
3781  };
3782  A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
3783  } else {
3784  ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
3785 
3786  auto Remark = [&](OptimizationRemarkAnalysis OR) {
3787  return OR << "Generic-mode kernel is executed with a customized state "
3788  "machine that requires a fallback.";
3789  };
3790  A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
3791 
3792  // Tell the user why we ended up with a fallback.
3793  for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
3794  if (!UnknownParallelRegionCB)
3795  continue;
3796  auto Remark = [&](OptimizationRemarkAnalysis ORA) {
3797  return ORA << "Call may contain unknown parallel regions. Use "
3798  << "`__attribute__((assume(\"omp_no_parallelism\")))` to "
3799  "override.";
3800  };
3801  A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
3802  "OMP133", Remark);
3803  }
3804  }
3805 
3806  // Create all the blocks:
3807  //
3808  // InitCB = __kmpc_target_init(...)
3809  // BlockHwSize =
3810  // __kmpc_get_hardware_num_threads_in_block();
3811  // WarpSize = __kmpc_get_warp_size();
3812  // BlockSize = BlockHwSize - WarpSize;
3813  // IsWorkerCheckBB: bool IsWorker = InitCB != -1;
3814  // if (IsWorker) {
3815  // if (InitCB >= BlockSize) return;
3816  // SMBeginBB: __kmpc_barrier_simple_generic(...);
3817  // void *WorkFn;
3818  // bool Active = __kmpc_kernel_parallel(&WorkFn);
3819  // if (!WorkFn) return;
3820  // SMIsActiveCheckBB: if (Active) {
3821  // SMIfCascadeCurrentBB: if (WorkFn == <ParFn0>)
3822  // ParFn0(...);
3823  // SMIfCascadeCurrentBB: else if (WorkFn == <ParFn1>)
3824  // ParFn1(...);
3825  // ...
3826  // SMIfCascadeCurrentBB: else
3827  // ((WorkFnTy*)WorkFn)(...);
3828  // SMEndParallelBB: __kmpc_kernel_end_parallel(...);
3829  // }
3830  // SMDoneBB: __kmpc_barrier_simple_generic(...);
3831  // goto SMBeginBB;
3832  // }
3833  // UserCodeEntryBB: // user code
3834  // __kmpc_target_deinit(...)
3835  //
3836  Function *Kernel = getAssociatedFunction();
3837  assert(Kernel && "Expected an associated function!");
3838 
3839  BasicBlock *InitBB = KernelInitCB->getParent();
3840  BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
3841  KernelInitCB->getNextNode(), "thread.user_code.check");
3842  BasicBlock *IsWorkerCheckBB =
3843  BasicBlock::Create(Ctx, "is_worker_check", Kernel, UserCodeEntryBB);
3844  BasicBlock *StateMachineBeginBB = BasicBlock::Create(
3845  Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
3846  BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
3847  Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
3848  BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
3849  Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
3850  BasicBlock *StateMachineIfCascadeCurrentBB =
3851  BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3852  Kernel, UserCodeEntryBB);
3853  BasicBlock *StateMachineEndParallelBB =
3854  BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
3855  Kernel, UserCodeEntryBB);
3856  BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
3857  Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
3858  A.registerManifestAddedBasicBlock(*InitBB);
3859  A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
3860  A.registerManifestAddedBasicBlock(*IsWorkerCheckBB);
3861  A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
3862  A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
3863  A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
3864  A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
3865  A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
3866  A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
3867 
3868  const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
3869  ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
3870  InitBB->getTerminator()->eraseFromParent();
3871 
3872  Instruction *IsWorker =
3873  ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
3874  ConstantInt::get(KernelInitCB->getType(), -1),
3875  "thread.is_worker", InitBB);
3876  IsWorker->setDebugLoc(DLoc);
3877  BranchInst::Create(IsWorkerCheckBB, UserCodeEntryBB, IsWorker, InitBB);
3878 
3879  Module &M = *Kernel->getParent();
3880  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3881  FunctionCallee BlockHwSizeFn =
3882  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3883  M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
3884  FunctionCallee WarpSizeFn =
3885  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3886  M, OMPRTL___kmpc_get_warp_size);
3887  CallInst *BlockHwSize =
3888  CallInst::Create(BlockHwSizeFn, "block.hw_size", IsWorkerCheckBB);
3889  OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);
3890  BlockHwSize->setDebugLoc(DLoc);
3891  CallInst *WarpSize =
3892  CallInst::Create(WarpSizeFn, "warp.size", IsWorkerCheckBB);
3893  OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);
3894  WarpSize->setDebugLoc(DLoc);
3895  Instruction *BlockSize = BinaryOperator::CreateSub(
3896  BlockHwSize, WarpSize, "block.size", IsWorkerCheckBB);
3897  BlockSize->setDebugLoc(DLoc);
3898  Instruction *IsMainOrWorker = ICmpInst::Create(
3899  ICmpInst::ICmp, llvm::CmpInst::ICMP_SLT, KernelInitCB, BlockSize,
3900  "thread.is_main_or_worker", IsWorkerCheckBB);
3901  IsMainOrWorker->setDebugLoc(DLoc);
3902  BranchInst::Create(StateMachineBeginBB, StateMachineFinishedBB,
3903  IsMainOrWorker, IsWorkerCheckBB);
3904 
3905  // Create local storage for the work function pointer.
3906  const DataLayout &DL = M.getDataLayout();
3907  Type *VoidPtrTy = Type::getInt8PtrTy(Ctx);
3908  Instruction *WorkFnAI =
3909  new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
3910  "worker.work_fn.addr", &Kernel->getEntryBlock().front());
3911  WorkFnAI->setDebugLoc(DLoc);
3912 
3913  OMPInfoCache.OMPBuilder.updateToLocation(
3915  IRBuilder<>::InsertPoint(StateMachineBeginBB,
3916  StateMachineBeginBB->end()),
3917  DLoc));
3918 
3919  Value *Ident = KernelInitCB->getArgOperand(0);
3920  Value *GTid = KernelInitCB;
3921 
3922  FunctionCallee BarrierFn =
3923  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3924  M, OMPRTL___kmpc_barrier_simple_generic);
3925  CallInst *Barrier =
3926  CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB);
3927  OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
3928  Barrier->setDebugLoc(DLoc);
3929 
3930  if (WorkFnAI->getType()->getPointerAddressSpace() !=
3931  (unsigned int)AddressSpace::Generic) {
3932  WorkFnAI = new AddrSpaceCastInst(
3933  WorkFnAI,
3935  cast<PointerType>(WorkFnAI->getType()),
3936  (unsigned int)AddressSpace::Generic),
3937  WorkFnAI->getName() + ".generic", StateMachineBeginBB);
3938  WorkFnAI->setDebugLoc(DLoc);
3939  }
3940 
3941  FunctionCallee KernelParallelFn =
3942  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3943  M, OMPRTL___kmpc_kernel_parallel);
3944  CallInst *IsActiveWorker = CallInst::Create(
3945  KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
3946  OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);
3947  IsActiveWorker->setDebugLoc(DLoc);
3948  Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
3949  StateMachineBeginBB);
3950  WorkFn->setDebugLoc(DLoc);
3951 
3952  FunctionType *ParallelRegionFnTy = FunctionType::get(
3954  false);
3956  WorkFn, ParallelRegionFnTy->getPointerTo(), "worker.work_fn.addr_cast",
3957  StateMachineBeginBB);
3958 
3959  Instruction *IsDone =
3960  ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
3961  Constant::getNullValue(VoidPtrTy), "worker.is_done",
3962  StateMachineBeginBB);
3963  IsDone->setDebugLoc(DLoc);
3964  BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
3965  IsDone, StateMachineBeginBB)
3966  ->setDebugLoc(DLoc);
3967 
3968  BranchInst::Create(StateMachineIfCascadeCurrentBB,
3969  StateMachineDoneBarrierBB, IsActiveWorker,
3970  StateMachineIsActiveCheckBB)
3971  ->setDebugLoc(DLoc);
3972 
3973  Value *ZeroArg =
3974  Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
3975 
3976  // Now that we have most of the CFG skeleton it is time for the if-cascade
3977  // that checks the function pointer we got from the runtime against the
3978  // parallel regions we expect, if there are any.
3979  for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) {
3980  auto *ParallelRegion = ReachedKnownParallelRegions[I];
3981  BasicBlock *PRExecuteBB = BasicBlock::Create(
3982  Ctx, "worker_state_machine.parallel_region.execute", Kernel,
3983  StateMachineEndParallelBB);
3984  CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
3985  ->setDebugLoc(DLoc);
3986  BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
3987  ->setDebugLoc(DLoc);
3988 
3989  BasicBlock *PRNextBB =
3990  BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3991  Kernel, StateMachineEndParallelBB);
3992 
3993  // Check if we need to compare the pointer at all or if we can just
3994  // call the parallel region function.
3995  Value *IsPR;
3996  if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) {
3997  Instruction *CmpI = ICmpInst::Create(
3998  ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFnCast, ParallelRegion,
3999  "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
4000  CmpI->setDebugLoc(DLoc);
4001  IsPR = CmpI;
4002  } else {
4003  IsPR = ConstantInt::getTrue(Ctx);
4004  }
4005 
4006  BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
4007  StateMachineIfCascadeCurrentBB)
4008  ->setDebugLoc(DLoc);
4009  StateMachineIfCascadeCurrentBB = PRNextBB;
4010  }
4011 
4012  // At the end of the if-cascade we place the indirect function pointer call
4013  // in case we might need it, that is if there can be parallel regions we
4014  // have not handled in the if-cascade above.
4015  if (!ReachedUnknownParallelRegions.empty()) {
4016  StateMachineIfCascadeCurrentBB->setName(
4017  "worker_state_machine.parallel_region.fallback.execute");
4018  CallInst::Create(ParallelRegionFnTy, WorkFnCast, {ZeroArg, GTid}, "",
4019  StateMachineIfCascadeCurrentBB)
4020  ->setDebugLoc(DLoc);
4021  }
4022  BranchInst::Create(StateMachineEndParallelBB,
4023  StateMachineIfCascadeCurrentBB)
4024  ->setDebugLoc(DLoc);
4025 
4026  FunctionCallee EndParallelFn =
4027  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4028  M, OMPRTL___kmpc_kernel_end_parallel);
4029  CallInst *EndParallel =
4030  CallInst::Create(EndParallelFn, {}, "", StateMachineEndParallelBB);
4031  OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);
4032  EndParallel->setDebugLoc(DLoc);
4033  BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
4034  ->setDebugLoc(DLoc);
4035 
4036  CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
4037  ->setDebugLoc(DLoc);
4038  BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
4039  ->setDebugLoc(DLoc);
4040 
4041  return ChangeStatus::CHANGED;
4042  }
4043 
4044  /// Fixpoint iteration update function. Will be called every time a dependence
4045  /// changed its state (and in the beginning).
4046  ChangeStatus updateImpl(Attributor &A) override {
4047  KernelInfoState StateBefore = getState();
4048 
4049  // Callback to check a read/write instruction.
4050  auto CheckRWInst = [&](Instruction &I) {
4051  // We handle calls later.
4052  if (isa<CallBase>(I))
4053  return true;
4054  // We only care about write effects.
4055  if (!I.mayWriteToMemory())
4056  return true;
4057  if (auto *SI = dyn_cast<StoreInst>(&I)) {
4059  getUnderlyingObjects(SI->getPointerOperand(), Objects);
4060  if (llvm::all_of(Objects,
4061  [](const Value *Obj) { return isa<AllocaInst>(Obj); }))
4062  return true;
4063  // Check for AAHeapToStack moved objects which must not be guarded.
4064  auto &HS = A.getAAFor<AAHeapToStack>(
4065  *this, IRPosition::function(*I.getFunction()),
4067  if (llvm::all_of(Objects, [&HS](const Value *Obj) {
4068  auto *CB = dyn_cast<CallBase>(Obj);
4069  if (!CB)
4070  return false;
4071  return HS.isAssumedHeapToStack(*CB);
4072  })) {
4073  return true;
4074  }
4075  }
4076 
4077  // Insert instruction that needs guarding.
4078  SPMDCompatibilityTracker.insert(&I);
4079  return true;
4080  };
4081 
4082  bool UsedAssumedInformationInCheckRWInst = false;
4083  if (!SPMDCompatibilityTracker.isAtFixpoint())
4084  if (!A.checkForAllReadWriteInstructions(
4085  CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
4086  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4087 
4088  bool UsedAssumedInformationFromReachingKernels = false;
4089  if (!IsKernelEntry) {
4090  updateParallelLevels(A);
4091 
4092  bool AllReachingKernelsKnown = true;
4093  updateReachingKernelEntries(A, AllReachingKernelsKnown);
4094  UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
4095 
4096  if (!ParallelLevels.isValidState())
4097  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4098  else if (!ReachingKernelEntries.isValidState())
4099  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4100  else if (!SPMDCompatibilityTracker.empty()) {
4101  // Check if all reaching kernels agree on the mode as we can otherwise
4102  // not guard instructions. We might not be sure about the mode so we
4103  // we cannot fix the internal spmd-zation state either.
4104  int SPMD = 0, Generic = 0;
4105  for (auto *Kernel : ReachingKernelEntries) {
4106  auto &CBAA = A.getAAFor<AAKernelInfo>(
4108  if (CBAA.SPMDCompatibilityTracker.isValidState() &&
4109  CBAA.SPMDCompatibilityTracker.isAssumed())
4110  ++SPMD;
4111  else
4112  ++Generic;
4113  if (!CBAA.SPMDCompatibilityTracker.isAtFixpoint())
4114  UsedAssumedInformationFromReachingKernels = true;
4115  }
4116  if (SPMD != 0 && Generic != 0)
4117  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4118  }
4119  }
4120 
4121  // Callback to check a call instruction.
4122  bool AllParallelRegionStatesWereFixed = true;
4123  bool AllSPMDStatesWereFixed = true;
4124  auto CheckCallInst = [&](Instruction &I) {
4125  auto &CB = cast<CallBase>(I);
4126  auto &CBAA = A.getAAFor<AAKernelInfo>(
4128  getState() ^= CBAA.getState();
4129  AllSPMDStatesWereFixed &= CBAA.SPMDCompatibilityTracker.isAtFixpoint();
4130  AllParallelRegionStatesWereFixed &=
4131  CBAA.ReachedKnownParallelRegions.isAtFixpoint();
4132  AllParallelRegionStatesWereFixed &=
4133  CBAA.ReachedUnknownParallelRegions.isAtFixpoint();
4134  return true;
4135  };
4136 
4137  bool UsedAssumedInformationInCheckCallInst = false;
4138  if (!A.checkForAllCallLikeInstructions(
4139  CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) {
4140  LLVM_DEBUG(dbgs() << TAG
4141  << "Failed to visit all call-like instructions!\n";);
4142  return indicatePessimisticFixpoint();
4143  }
4144 
4145  // If we haven't used any assumed information for the reached parallel
4146  // region states we can fix it.
4147  if (!UsedAssumedInformationInCheckCallInst &&
4148  AllParallelRegionStatesWereFixed) {
4149  ReachedKnownParallelRegions.indicateOptimisticFixpoint();
4150  ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
4151  }
4152 
4153  // If we haven't used any assumed information for the SPMD state we can fix
4154  // it.
4155  if (!UsedAssumedInformationInCheckRWInst &&
4156  !UsedAssumedInformationInCheckCallInst &&
4157  !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
4158  SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4159 
4160  return StateBefore == getState() ? ChangeStatus::UNCHANGED
4162  }
4163 
4164 private:
4165  /// Update info regarding reaching kernels.
4166  void updateReachingKernelEntries(Attributor &A,
4167  bool &AllReachingKernelsKnown) {
4168  auto PredCallSite = [&](AbstractCallSite ACS) {
4169  Function *Caller = ACS.getInstruction()->getFunction();
4170 
4171  assert(Caller && "Caller is nullptr");
4172 
4173  auto &CAA = A.getOrCreateAAFor<AAKernelInfo>(
4174  IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
4175  if (CAA.ReachingKernelEntries.isValidState()) {
4176  ReachingKernelEntries ^= CAA.ReachingKernelEntries;
4177  return true;
4178  }
4179 
4180  // We lost track of the caller of the associated function, any kernel
4181  // could reach now.
4182  ReachingKernelEntries.indicatePessimisticFixpoint();
4183 
4184  return true;
4185  };
4186 
4187  if (!A.checkForAllCallSites(PredCallSite, *this,
4188  true /* RequireAllCallSites */,
4189  AllReachingKernelsKnown))
4190  ReachingKernelEntries.indicatePessimisticFixpoint();
4191  }
4192 
4193  /// Update info regarding parallel levels.
4194  void updateParallelLevels(Attributor &A) {
4195  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4196  OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
4197  OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
4198 
4199  auto PredCallSite = [&](AbstractCallSite ACS) {
4200  Function *Caller = ACS.getInstruction()->getFunction();
4201 
4202  assert(Caller && "Caller is nullptr");
4203 
4204  auto &CAA =
4205  A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
4206  if (CAA.ParallelLevels.isValidState()) {
4207  // Any function that is called by `__kmpc_parallel_51` will not be
4208  // folded as the parallel level in the function is updated. In order to
4209  // get it right, all the analysis would depend on the implentation. That
4210  // said, if in the future any change to the implementation, the analysis
4211  // could be wrong. As a consequence, we are just conservative here.
4212  if (Caller == Parallel51RFI.Declaration) {
4213  ParallelLevels.indicatePessimisticFixpoint();
4214  return true;
4215  }
4216 
4217  ParallelLevels ^= CAA.ParallelLevels;
4218 
4219  return true;
4220  }
4221 
4222  // We lost track of the caller of the associated function, any kernel
4223  // could reach now.
4224  ParallelLevels.indicatePessimisticFixpoint();
4225 
4226  return true;
4227  };
4228 
4229  bool AllCallSitesKnown = true;
4230  if (!A.checkForAllCallSites(PredCallSite, *this,
4231  true /* RequireAllCallSites */,
4232  AllCallSitesKnown))
4233  ParallelLevels.indicatePessimisticFixpoint();
4234  }
4235 };
4236 
4237 /// The call site kernel info abstract attribute, basically, what can we say
4238 /// about a call site with regards to the KernelInfoState. For now this simply
4239 /// forwards the information from the callee.
4240 struct AAKernelInfoCallSite : AAKernelInfo {
4241  AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
4242  : AAKernelInfo(IRP, A) {}
4243 
4244  /// See AbstractAttribute::initialize(...).
4245  void initialize(Attributor &A) override {
4247 
4248  CallBase &CB = cast<CallBase>(getAssociatedValue());
4249  Function *Callee = getAssociatedFunction();
4250 
4251  auto &AssumptionAA = A.getAAFor<AAAssumptionInfo>(
4253 
4254  // Check for SPMD-mode assumptions.
4255  if (AssumptionAA.hasAssumption("ompx_spmd_amenable")) {
4256  SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4257  indicateOptimisticFixpoint();
4258  }
4259 
4260  // First weed out calls we do not care about, that is readonly/readnone
4261  // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
4262  // parallel region or anything else we are looking for.
4263  if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
4264  indicateOptimisticFixpoint();
4265  return;
4266  }
4267 
4268  // Next we check if we know the callee. If it is a known OpenMP function
4269  // we will handle them explicitly in the switch below. If it is not, we
4270  // will use an AAKernelInfo object on the callee to gather information and
4271  // merge that into the current state. The latter happens in the updateImpl.
4272  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4273  const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4274  if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4275  // Unknown caller or declarations are not analyzable, we give up.
4276  if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
4277 
4278  // Unknown callees might contain parallel regions, except if they have
4279  // an appropriate assumption attached.
4280  if (!(AssumptionAA.hasAssumption("omp_no_openmp") ||
4281  AssumptionAA.hasAssumption("omp_no_parallelism")))
4282  ReachedUnknownParallelRegions.insert(&CB);
4283 
4284  // If SPMDCompatibilityTracker is not fixed, we need to give up on the
4285  // idea we can run something unknown in SPMD-mode.
4286  if (!SPMDCompatibilityTracker.isAtFixpoint()) {
4287  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4288  SPMDCompatibilityTracker.insert(&CB);
4289  }
4290 
4291  // We have updated the state for this unknown call properly, there won't
4292  // be any change so we indicate a fixpoint.
4293  indicateOptimisticFixpoint();
4294  }
4295  // If the callee is known and can be used in IPO, we will update the state
4296  // based on the callee state in updateImpl.
4297  return;
4298  }
4299 
4300  const unsigned int WrapperFunctionArgNo = 6;
4301  RuntimeFunction RF = It->getSecond();
4302  switch (RF) {
4303  // All the functions we know are compatible with SPMD mode.
4304  case OMPRTL___kmpc_is_spmd_exec_mode:
4305  case OMPRTL___kmpc_distribute_static_fini:
4306  case OMPRTL___kmpc_for_static_fini:
4307  case OMPRTL___kmpc_global_thread_num:
4308  case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4309  case OMPRTL___kmpc_get_hardware_num_blocks:
4310  case OMPRTL___kmpc_single:
4311  case OMPRTL___kmpc_end_single:
4312  case OMPRTL___kmpc_master:
4313  case OMPRTL___kmpc_end_master:
4314  case OMPRTL___kmpc_barrier:
4315  case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
4316  case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
4317  case OMPRTL___kmpc_nvptx_end_reduce_nowait:
4318  break;
4319  case OMPRTL___kmpc_distribute_static_init_4:
4320  case OMPRTL___kmpc_distribute_static_init_4u:
4321  case OMPRTL___kmpc_distribute_static_init_8:
4322  case OMPRTL___kmpc_distribute_static_init_8u:
4323  case OMPRTL___kmpc_for_static_init_4:
4324  case OMPRTL___kmpc_for_static_init_4u:
4325  case OMPRTL___kmpc_for_static_init_8:
4326  case OMPRTL___kmpc_for_static_init_8u: {
4327  // Check the schedule and allow static schedule in SPMD mode.
4328  unsigned ScheduleArgOpNo = 2;
4329  auto *ScheduleTypeCI =
4330  dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
4331  unsigned ScheduleTypeVal =
4332  ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
4333  switch (OMPScheduleType(ScheduleTypeVal)) {
4338  break;
4339  default:
4340  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4341  SPMDCompatibilityTracker.insert(&CB);
4342  break;
4343  };
4344  } break;
4345  case OMPRTL___kmpc_target_init:
4346  KernelInitCB = &CB;
4347  break;
4348  case OMPRTL___kmpc_target_deinit:
4349  KernelDeinitCB = &CB;
4350  break;
4351  case OMPRTL___kmpc_parallel_51:
4352  if (auto *ParallelRegion = dyn_cast<Function>(
4353  CB.getArgOperand(WrapperFunctionArgNo)->stripPointerCasts())) {
4354  ReachedKnownParallelRegions.insert(ParallelRegion);
4355  break;
4356  }
4357  // The condition above should usually get the parallel region function
4358  // pointer and record it. In the off chance it doesn't we assume the
4359  // worst.
4360  ReachedUnknownParallelRegions.insert(&CB);
4361  break;
4362  case OMPRTL___kmpc_omp_task:
4363  // We do not look into tasks right now, just give up.
4364  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4365  SPMDCompatibilityTracker.insert(&CB);
4366  ReachedUnknownParallelRegions.insert(&CB);
4367  break;
4368  case OMPRTL___kmpc_alloc_shared:
4369  case OMPRTL___kmpc_free_shared:
4370  // Return without setting a fixpoint, to be resolved in updateImpl.
4371  return;
4372  default:
4373  // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
4374  // generally. However, they do not hide parallel regions.
4375  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4376  SPMDCompatibilityTracker.insert(&CB);
4377  break;
4378  }
4379  // All other OpenMP runtime calls will not reach parallel regions so they
4380  // can be safely ignored for now. Since it is a known OpenMP runtime call we
4381  // have now modeled all effects and there is no need for any update.
4382  indicateOptimisticFixpoint();
4383  }
4384 
4385  ChangeStatus updateImpl(Attributor &A) override {
4386  // TODO: Once we have call site specific value information we can provide
4387  // call site specific liveness information and then it makes
4388  // sense to specialize attributes for call sites arguments instead of
4389  // redirecting requests to the callee argument.
4390  Function *F = getAssociatedFunction();
4391 
4392  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4393  const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
4394 
4395  // If F is not a runtime function, propagate the AAKernelInfo of the callee.
4396  if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4397  const IRPosition &FnPos = IRPosition::function(*F);
4398  auto &FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
4399  if (getState() == FnAA.getState())
4400  return ChangeStatus::UNCHANGED;
4401  getState() = FnAA.getState();
4402  return ChangeStatus::CHANGED;
4403  }
4404 
4405  // F is a runtime function that allocates or frees memory, check
4406  // AAHeapToStack and AAHeapToShared.
4407  KernelInfoState StateBefore = getState();
4408  assert((It->getSecond() == OMPRTL___kmpc_alloc_shared ||
4409  It->getSecond() == OMPRTL___kmpc_free_shared) &&
4410  "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
4411 
4412  CallBase &CB = cast<CallBase>(getAssociatedValue());
4413 
4414  auto &HeapToStackAA = A.getAAFor<AAHeapToStack>(
4416  auto &HeapToSharedAA = A.getAAFor<AAHeapToShared>(
4418 
4419  RuntimeFunction RF = It->getSecond();
4420 
4421  switch (RF) {
4422  // If neither HeapToStack nor HeapToShared assume the call is removed,
4423  // assume SPMD incompatibility.
4424  case OMPRTL___kmpc_alloc_shared:
4425  if (!HeapToStackAA.isAssumedHeapToStack(CB) &&
4426  !HeapToSharedAA.isAssumedHeapToShared(CB))
4427  SPMDCompatibilityTracker.insert(&CB);
4428  break;
4429  case OMPRTL___kmpc_free_shared:
4430  if (!HeapToStackAA.isAssumedHeapToStackRemovedFree(CB) &&
4431  !HeapToSharedAA.isAssumedHeapToSharedRemovedFree(CB))
4432  SPMDCompatibilityTracker.insert(&CB);
4433  break;
4434  default:
4435  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4436  SPMDCompatibilityTracker.insert(&CB);
4437  }
4438 
4439  return StateBefore == getState() ? ChangeStatus::UNCHANGED
4441  }
4442 };
4443 
4444 struct AAFoldRuntimeCall
4445  : public StateWrapper<BooleanState, AbstractAttribute> {
4447 
4448  AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
4449 
4450  /// Statistics are tracked as part of manifest for now.
4451  void trackStatistics() const override {}
4452 
4453  /// Create an abstract attribute biew for the position \p IRP.
4454  static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
4455  Attributor &A);
4456 
4457  /// See AbstractAttribute::getName()
4458  const std::string getName() const override { return "AAFoldRuntimeCall"; }
4459 
4460  /// See AbstractAttribute::getIdAddr()
4461  const char *getIdAddr() const override { return &ID; }
4462 
4463  /// This function should return true if the type of the \p AA is
4464  /// AAFoldRuntimeCall
4465  static bool classof(const AbstractAttribute *AA) {
4466  return (AA->getIdAddr() == &ID);
4467  }
4468 
4469  static const char ID;
4470 };
4471 
4472 struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
4473  AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
4474  : AAFoldRuntimeCall(IRP, A) {}
4475 
4476  /// See AbstractAttribute::getAsStr()
4477  const std::string getAsStr() const override {
4478  if (!isValidState())
4479  return "<invalid>";
4480 
4481  std::string Str("simplified value: ");
4482 
4483  if (!SimplifiedValue)
4484  return Str + std::string("none");
4485 
4486  if (!SimplifiedValue.value())
4487  return Str + std::string("nullptr");
4488 
4489  if (ConstantInt *CI = dyn_cast<ConstantInt>(SimplifiedValue.value()))
4490  return Str + std::to_string(CI->getSExtValue());
4491 
4492  return Str + std::string("unknown");
4493  }
4494 
4495  void initialize(Attributor &A) override {
4497  indicatePessimisticFixpoint();
4498 
4499  Function *Callee = getAssociatedFunction();
4500 
4501  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4502  const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4503  assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
4504  "Expected a known OpenMP runtime function");
4505 
4506  RFKind = It->getSecond();
4507 
4508  CallBase &CB = cast<CallBase>(getAssociatedValue());
4509  A.registerSimplificationCallback(
4511  [&](const IRPosition &IRP, const AbstractAttribute *AA,
4512  bool &UsedAssumedInformation) -> Optional<Value *> {
4513  assert((isValidState() ||
4514  (SimplifiedValue && SimplifiedValue.value() == nullptr)) &&
4515  "Unexpected invalid state!");
4516 
4517  if (!isAtFixpoint()) {
4518  UsedAssumedInformation = true;
4519  if (AA)
4520  A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
4521  }
4522  return SimplifiedValue;
4523  });
4524  }
4525 
4526  ChangeStatus updateImpl(Attributor &A) override {
4528  switch (RFKind) {
4529  case OMPRTL___kmpc_is_spmd_exec_mode:
4530  Changed |= foldIsSPMDExecMode(A);
4531  break;
4532  case OMPRTL___kmpc_is_generic_main_thread_id:
4533  Changed |= foldIsGenericMainThread(A);
4534  break;
4535  case OMPRTL___kmpc_parallel_level:
4536  Changed |= foldParallelLevel(A);
4537  break;
4538  case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4539  Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
4540  break;
4541  case OMPRTL___kmpc_get_hardware_num_blocks:
4542  Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
4543  break;
4544  default:
4545  llvm_unreachable("Unhandled OpenMP runtime function!");
4546  }
4547 
4548  return Changed;
4549  }
4550 
4551  ChangeStatus manifest(Attributor &A) override {
4553 
4554  if (SimplifiedValue && *SimplifiedValue) {
4555  Instruction &I = *getCtxI();
4556  A.changeAfterManifest(IRPosition::inst(I), **SimplifiedValue);
4557  A.deleteAfterManifest(I);
4558 
4559  CallBase *CB = dyn_cast<CallBase>(&I);
4560  auto Remark = [&](OptimizationRemark OR) {
4561  if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue))
4562  return OR << "Replacing OpenMP runtime call "
4563  << CB->getCalledFunction()->getName() << " with "
4564  << ore::NV("FoldedValue", C->getZExtValue()) << ".";
4565  return OR << "Replacing OpenMP runtime call "
4566  << CB->getCalledFunction()->getName() << ".";
4567  };
4568 
4569  if (CB && EnableVerboseRemarks)
4570  A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark);
4571 
4572  LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "
4573  << **SimplifiedValue << "\n");
4574 
4575  Changed = ChangeStatus::CHANGED;
4576  }
4577 
4578  return Changed;
4579  }
4580 
4581  ChangeStatus indicatePessimisticFixpoint() override {
4582  SimplifiedValue = nullptr;
4583  return AAFoldRuntimeCall::indicatePessimisticFixpoint();
4584  }
4585 
4586 private:
4587  /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
4588  ChangeStatus foldIsSPMDExecMode(Attributor &A) {
4589  Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4590 
4591  unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
4592  unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
4593  auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4594  *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4595 
4596  if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4597  return indicatePessimisticFixpoint();
4598 
4599  for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4600  auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
4602 
4603  if (!AA.isValidState()) {
4604  SimplifiedValue = nullptr;
4605  return indicatePessimisticFixpoint();
4606  }
4607 
4608  if (AA.SPMDCompatibilityTracker.isAssumed()) {
4609  if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4610  ++KnownSPMDCount;
4611  else
4612  ++AssumedSPMDCount;
4613  } else {
4614  if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4615  ++KnownNonSPMDCount;
4616  else
4617  ++AssumedNonSPMDCount;
4618  }
4619  }
4620 
4621  if ((AssumedSPMDCount + KnownSPMDCount) &&
4622  (AssumedNonSPMDCount + KnownNonSPMDCount))
4623  return indicatePessimisticFixpoint();
4624 
4625  auto &Ctx = getAnchorValue().getContext();
4626  if (KnownSPMDCount || AssumedSPMDCount) {
4627  assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
4628  "Expected only SPMD kernels!");
4629  // All reaching kernels are in SPMD mode. Update all function calls to
4630  // __kmpc_is_spmd_exec_mode to 1.
4631  SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
4632  } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
4633  assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
4634  "Expected only non-SPMD kernels!");
4635  // All reaching kernels are in non-SPMD mode. Update all function
4636  // calls to __kmpc_is_spmd_exec_mode to 0.
4637  SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
4638  } else {
4639  // We have empty reaching kernels, therefore we cannot tell if the
4640  // associated call site can be folded. At this moment, SimplifiedValue
4641  // must be none.
4642  assert(!SimplifiedValue && "SimplifiedValue should be none");
4643  }
4644 
4645  return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4647  }
4648 
4649  /// Fold __kmpc_is_generic_main_thread_id into a constant if possible.
4650  ChangeStatus foldIsGenericMainThread(Attributor &A) {
4651  Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4652 
4653  CallBase &CB = cast<CallBase>(getAssociatedValue());
4654  Function *F = CB.getFunction();
4655  const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
4657 
4658  if (!ExecutionDomainAA.isValidState())
4659  return indicatePessimisticFixpoint();
4660 
4661  auto &Ctx = getAnchorValue().getContext();
4662  if (ExecutionDomainAA.isExecutedByInitialThreadOnly(CB))
4663  SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
4664  else
4665  return indicatePessimisticFixpoint();
4666 
4667  return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4669  }
4670 
4671  /// Fold __kmpc_parallel_level into a constant if possible.
4672  ChangeStatus foldParallelLevel(Attributor &A) {
4673  Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4674 
4675  auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4676  *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4677 
4678  if (!CallerKernelInfoAA.ParallelLevels.isValidState())
4679  return indicatePessimisticFixpoint();
4680 
4681  if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4682  return indicatePessimisticFixpoint();
4683 
4684  if (CallerKernelInfoAA.ReachingKernelEntries.empty()) {
4685  assert(!SimplifiedValue &&
4686  "SimplifiedValue should keep none at this point");
4687  return ChangeStatus::UNCHANGED;
4688  }
4689 
4690  unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
4691  unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
4692  for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4693  auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
4695  if (!AA.SPMDCompatibilityTracker.isValidState())
4696  return indicatePessimisticFixpoint();
4697 
4698  if (AA.SPMDCompatibilityTracker.isAssumed()) {
4699  if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4700  ++KnownSPMDCount;
4701  else
4702  ++AssumedSPMDCount;
4703  } else {
4704  if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4705  ++KnownNonSPMDCount;
4706  else
4707  ++AssumedNonSPMDCount;
4708  }
4709  }
4710 
4711  if ((AssumedSPMDCount + KnownSPMDCount) &&
4712  (AssumedNonSPMDCount + KnownNonSPMDCount))
4713  return indicatePessimisticFixpoint();
4714 
4715  auto &Ctx = getAnchorValue().getContext();
4716  // If the caller can only be reached by SPMD kernel entries, the parallel
4717  // level is 1. Similarly, if the caller can only be reached by non-SPMD
4718  // kernel entries, it is 0.
4719  if (AssumedSPMDCount || KnownSPMDCount) {
4720  assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
4721  "Expected only SPMD kernels!");
4722  SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
4723  } else {
4724  assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
4725  "Expected only non-SPMD kernels!");
4726  SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
4727  }
4728  return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4730  }
4731 
4732  ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
4733  // Specialize only if all the calls agree with the attribute constant value
4734  int32_t CurrentAttrValue = -1;
4735  Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4736 
4737  auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4738  *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4739 
4740  if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4741  return indicatePessimisticFixpoint();
4742 
4743  // Iterate over the kernels that reach this function
4744  for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4745  int32_t NextAttrVal = -1;
4746  if (K->hasFnAttribute(Attr))
4747  NextAttrVal =
4748  std::stoi(K->getFnAttribute(Attr).getValueAsString().str());
4749 
4750  if (NextAttrVal == -1 ||
4751  (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
4752  return indicatePessimisticFixpoint();
4753  CurrentAttrValue = NextAttrVal;
4754  }
4755 
4756  if (CurrentAttrValue != -1) {
4757  auto &Ctx = getAnchorValue().getContext();
4758  SimplifiedValue =
4759  ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
4760  }
4761  return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4763  }
4764 
4765  /// An optional value the associated value is assumed to fold to. That is, we
4766  /// assume the associated value (which is a call) can be replaced by this
4767  /// simplified value.
4768  Optional<Value *> SimplifiedValue;
4769 
4770  /// The runtime function kind of the callee of the associated call site.
4771  RuntimeFunction RFKind;
4772 };
4773 
4774 } // namespace
4775 
4776 /// Register folding callsite
4777 void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
4778  auto &RFI = OMPInfoCache.RFIs[RF];
4779  RFI.foreachUse(SCC, [&](Use &U, Function &F) {
4780  CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
4781  if (!CI)
4782  return false;
4783  A.getOrCreateAAFor<AAFoldRuntimeCall>(
4784  IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
4785  DepClassTy::NONE, /* ForceUpdate */ false,
4786  /* UpdateAfterInit */ false);
4787  return false;
4788  });
4789 }
4790 
4791 void OpenMPOpt::registerAAs(bool IsModulePass) {
4792  if (SCC.empty())
4793  return;
4794 
4795  if (IsModulePass) {
4796  // Ensure we create the AAKernelInfo AAs first and without triggering an
4797  // update. This will make sure we register all value simplification
4798  // callbacks before any other AA has the chance to create an AAValueSimplify
4799  // or similar.
4800  auto CreateKernelInfoCB = [&](Use &, Function &Kernel) {
4801  A.getOrCreateAAFor<AAKernelInfo>(
4802  IRPosition::function(Kernel), /* QueryingAA */ nullptr,
4803  DepClassTy::NONE, /* ForceUpdate */ false,
4804  /* UpdateAfterInit */ false);
4805  return false;
4806  };
4807  OMPInformationCache::RuntimeFunctionInfo &InitRFI =
4808  OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
4809  InitRFI.foreachUse(SCC, CreateKernelInfoCB);
4810 
4811  registerFoldRuntimeCall(OMPRTL___kmpc_is_generic_main_thread_id);
4812  registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
4813  registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
4814  registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
4815  registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
4816  }
4817 
4818  // Create CallSite AA for all Getters.
4819  for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
4820  auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
4821 
4822  auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
4823 
4824  auto CreateAA = [&](Use &U, Function &Caller) {
4825  CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
4826  if (!CI)
4827  return false;
4828 
4829  auto &CB = cast<CallBase>(*CI);
4830 
4832  A.getOrCreateAAFor<AAICVTracker>(CBPos);
4833  return false;
4834  };
4835 
4836  GetterRFI.foreachUse(SCC, CreateAA);
4837  }
4838  auto &GlobalizationRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4839  auto CreateAA = [&](Use &U, Function &F) {
4840  A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
4841  return false;
4842  };
4844  GlobalizationRFI.foreachUse(SCC, CreateAA);
4845 
4846  // Create an ExecutionDomain AA for every function and a HeapToStack AA for
4847  // every function if there is a device kernel.
4848  if (!isOpenMPDevice(M))
4849  return;
4850 
4851  for (auto *F : SCC) {
4852  if (F->isDeclaration())
4853  continue;
4854 
4855  A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(*F));
4857  A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(*F));
4858 
4859  for (auto &I : instructions(*F)) {
4860  if (auto *LI = dyn_cast<LoadInst>(&I)) {
4861  bool UsedAssumedInformation = false;
4862  A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
4863  UsedAssumedInformation, AA::Interprocedural);
4864  } else if (auto *SI = dyn_cast<StoreInst>(&I)) {
4865  A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI));
4866  }
4867  }
4868  }
4869 }
4870 
4871 const char AAICVTracker::ID = 0;
4872 const char AAKernelInfo::ID = 0;
4873 const char AAExecutionDomain::ID = 0;
4874 const char AAHeapToShared::ID = 0;
4875 const char AAFoldRuntimeCall::ID = 0;
4876 
4877 AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
4878  Attributor &A) {
4879  AAICVTracker *AA = nullptr;
4880  switch (IRP.getPositionKind()) {
4882  case IRPosition::IRP_FLOAT:
4885  llvm_unreachable("ICVTracker can only be created for function position!");
4887  AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
4888  break;
4890  AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
4891  break;
4893  AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
4894  break;
4896  AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
4897  break;
4898  }
4899 
4900  return *AA;
4901 }
4902 
4904  Attributor &A) {
4905  AAExecutionDomainFunction *AA = nullptr;
4906  switch (IRP.getPositionKind()) {
4908  case IRPosition::IRP_FLOAT:
4915  "AAExecutionDomain can only be created for function position!");
4917  AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
4918  break;
4919  }
4920 
4921  return *AA;
4922 }
4923 
4924 AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
4925  Attributor &A) {
4926  AAHeapToSharedFunction *AA = nullptr;
4927  switch (IRP.getPositionKind()) {
4929  case IRPosition::IRP_FLOAT:
4936  "AAHeapToShared can only be created for function position!");
4938  AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
4939  break;
4940  }
4941 
4942  return *AA;
4943 }
4944 
4945 AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
4946  Attributor &A) {
4947  AAKernelInfo *AA = nullptr;
4948  switch (IRP.getPositionKind()) {
4950  case IRPosition::IRP_FLOAT:
4955  llvm_unreachable("KernelInfo can only be created for function position!");
4957  AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
4958  break;
4960  AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
4961  break;
4962  }
4963 
4964  return *AA;
4965 }
4966 
4967 AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
4968  Attributor &A) {
4969  AAFoldRuntimeCall *AA = nullptr;
4970  switch (IRP.getPositionKind()) {
4972  case IRPosition::IRP_FLOAT:
4978  llvm_unreachable("KernelInfo can only be created for call site position!");
4980  AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
4981  break;
4982  }
4983 
4984  return *AA;
4985 }
4986 
4988  if (!containsOpenMP(M))
4989  return PreservedAnalyses::all();
4991  return PreservedAnalyses::all();
4992 
4996 
4998  LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt Module Pass:\n" << M);
4999 
5000  auto IsCalled = [&](Function &F) {
5001  if (Kernels.contains(&F))
5002  return true;
5003  for (const User *U : F.users())
5004  if (!isa<BlockAddress>(U))
5005  return true;
5006  return false;
5007  };
5008 
5009  auto EmitRemark = [&](Function &F) {
5011  ORE.emit([&]() {
5012  OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
5013  return ORA << "Could not internalize function. "
5014  << "Some optimizations may not be possible. [OMP140]";
5015  });
5016  };
5017 
5018  // Create internal copies of each function if this is a kernel Module. This
5019  // allows iterprocedural passes to see every call edge.
5020  DenseMap<Function *, Function *> InternalizedMap;
5021  if (isOpenMPDevice(M)) {
5022  SmallPtrSet<Function *, 16> InternalizeFns;
5023  for (Function &F : M)
5024  if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
5027  InternalizeFns.insert(&F);
5028  } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
5029  EmitRemark(F);
5030  }
5031  }
5032 
5033  Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
5034  }
5035 
5036  // Look at every function in the Module unless it was internalized.
5038  for (Function &F : M)
5039  if (!F.isDeclaration() && !InternalizedMap.lookup(&F))
5040  SCC.push_back(&F);
5041 
5042  if (SCC.empty())
5043  return PreservedAnalyses::all();
5044 
5045  AnalysisGetter AG(FAM);
5046 
5047  auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5049  };
5050 
5052  CallGraphUpdater CGUpdater;
5053 
5054  OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, Kernels);
5055 
5056  unsigned MaxFixpointIterations =
5058 
5059  AttributorConfig AC(CGUpdater);
5060  AC.DefaultInitializeLiveInternals = false;
5061  AC.RewriteSignatures = false;
5062  AC.MaxFixpointIterations = MaxFixpointIterations;
5063  AC.OREGetter = OREGetter;
5064  AC.PassName = DEBUG_TYPE;
5065 
5066  SetVector<Function *> Functions;
5067  Attributor A(Functions, InfoCache, AC);
5068 
5069  OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5070  bool Changed = OMPOpt.run(true);
5071 
5072  // Optionally inline device functions for potentially better performance.
5074  for (Function &F : M)
5075  if (!F.isDeclaration() && !Kernels.contains(&F) &&
5076  !F.hasFnAttribute(Attribute::NoInline))
5077  F.addFnAttr(Attribute::AlwaysInline);
5078 
5080  LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M);
5081 
5082  if (Changed)
5083  return PreservedAnalyses::none();
5084 
5085  return PreservedAnalyses::all();
5086 }
5087 
5090  LazyCallGraph &CG,
5091  CGSCCUpdateResult &UR) {
5092  if (!containsOpenMP(*C.begin()->getFunction().getParent()))
5093  return PreservedAnalyses::all();
5095  return PreservedAnalyses::all();
5096 
5098  // If there are kernels in the module, we have to run on all SCC's.
5099  for (LazyCallGraph::Node &N : C) {
5100  Function *Fn = &N.getFunction();
5101  SCC.push_back(Fn);
5102  }
5103 
5104  if (SCC.empty())
5105  return PreservedAnalyses::all();
5106 
5107  Module &M = *C.begin()->getFunction().getParent();
5108 
5110  LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt CGSCC Pass:\n" << M);
5111 
5113 
5115  AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
5116 
5117  AnalysisGetter AG(FAM);
5118 
5119  auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5121  };
5122 
5124  CallGraphUpdater CGUpdater;
5125  CGUpdater.initialize(CG, C, AM, UR);
5126 
5127  SetVector<Function *> Functions(SCC.begin(), SCC.end());
5128  OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
5129  /*CGSCC*/ &Functions, Kernels);
5130 
5131  unsigned MaxFixpointIterations =
5133 
5134  AttributorConfig AC(CGUpdater);
5135  AC.DefaultInitializeLiveInternals = false;
5136  AC.IsModulePass = false;
5137  AC.RewriteSignatures = false;
5138  AC.MaxFixpointIterations = MaxFixpointIterations;
5139  AC.OREGetter = OREGetter;
5140  AC.PassName = DEBUG_TYPE;
5141 
5142  Attributor A(Functions, InfoCache, AC);
5143 
5144  OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5145  bool Changed = OMPOpt.run(false);
5146 
5148  LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
5149 
5150  if (Changed)
5151  return PreservedAnalyses::none();
5152 
5153  return PreservedAnalyses::all();
5154 }
5155 
5156 namespace {
5157 
5158 struct OpenMPOptCGSCCLegacyPass : public CallGraphSCCPass {
5159  CallGraphUpdater CGUpdater;
5160  static char ID;
5161 
5162  OpenMPOptCGSCCLegacyPass() : CallGraphSCCPass(ID) {
5164  }
5165 
5166  void getAnalysisUsage(AnalysisUsage &AU) const override {
5168  }
5169 
5170  bool runOnSCC(CallGraphSCC &CGSCC) override {
5171  if (!containsOpenMP(CGSCC.getCallGraph().getModule()))
5172  return false;
5173  if (DisableOpenMPOptimizations || skipSCC(CGSCC))
5174  return false;
5175 
5177  // If there are kernels in the module, we have to run on all SCC's.
5178  for (CallGraphNode *CGN : CGSCC) {
5179  Function *Fn = CGN->getFunction();
5180  if (!Fn || Fn->isDeclaration())
5181  continue;
5182  SCC.push_back(Fn);
5183  }
5184 
5185  if (SCC.empty())
5186  return false;
5187 
5188  Module &M = CGSCC.getCallGraph().getModule();
5190 
5191  CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
5192  CGUpdater.initialize(CG, CGSCC);
5193 
5194  // Maintain a map of functions to avoid rebuilding the ORE
5196  auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & {
5197  std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F];
5198  if (!ORE)
5199  ORE = std::make_unique<OptimizationRemarkEmitter>(F);
5200  return *ORE;
5201  };
5202 
5203  AnalysisGetter AG;
5204  SetVector<Function *> Functions(SCC.begin(), SCC.end());
5206  OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG,
5207  Allocator,
5208  /*CGSCC*/ &Functions, Kernels);
5209 
5210  unsigned MaxFixpointIterations =
5212 
5213  AttributorConfig AC(CGUpdater);
5214  AC.DefaultInitializeLiveInternals = false;
5215  AC.IsModulePass = false;
5216  AC.RewriteSignatures = false;
5217  AC.MaxFixpointIterations = MaxFixpointIterations;
5218  AC.OREGetter = OREGetter;
5219  AC.PassName = DEBUG_TYPE;
5220 
5221  Attributor A(Functions, InfoCache, AC);
5222 
5223  OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5224  bool Result = OMPOpt.run(false);
5225 
5227  LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
5228 
5229  return Result;
5230  }
5231 
5232  bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); }
5233 };
5234 
5235 } // end anonymous namespace
5236 
5238  // TODO: Create a more cross-platform way of determining device kernels.
5239  NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
5241 
5242  if (!MD)
5243  return Kernels;
5244 
5245  for (auto *Op : MD->operands()) {
5246  if (Op->getNumOperands() < 2)
5247  continue;
5248  MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
5249  if (!KindID || KindID->getString() != "kernel")
5250  continue;
5251 
5252  Function *KernelFn =
5253  mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
5254  if (!KernelFn)
5255  continue;
5256 
5257  ++NumOpenMPTargetRegionKernels;
5258 
5259  Kernels.insert(KernelFn);
5260  }
5261 
5262  return Kernels;
5263 }
5264 
5266  Metadata *MD = M.getModuleFlag("openmp");
5267  if (!MD)
5268  return false;
5269 
5270  return true;
5271 }
5272 
5274  Metadata *MD = M.getModuleFlag("openmp-device");
5275  if (!MD)
5276  return false;
5277 
5278  return true;
5279 }
5280 
5282 
5283 INITIALIZE_PASS_BEGIN(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
5284  "OpenMP specific optimizations", false, false)
5286 INITIALIZE_PASS_END(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
5287  "OpenMP specific optimizations", false, false)
5288 
5290  return new OpenMPOptCGSCCLegacyPass();
5291 }
llvm::AA::isValidAtPosition
bool isValidAtPosition(const ValueAndContext &VAC, InformationCache &InfoCache)
Return true if the value of VAC is a valid at the position of VAC, that is a constant,...
Definition: Attributor.cpp:260
llvm::PreservedAnalyses
A set of analyses that are preserved following a run of a transformation pass.
Definition: PassManager.h:152
llvm::CallGraphUpdater
Wrapper to unify "old style" CallGraph and "new style" LazyCallGraph.
Definition: CallGraphUpdater.h:29
EnableVerboseRemarks
static cl::opt< bool > EnableVerboseRemarks("openmp-opt-verbose-remarks", cl::desc("Enables more verbose remarks."), cl::Hidden, cl::init(false))
llvm::Argument
This class represents an incoming formal argument to a Function.
Definition: Argument.h:28
llvm::IRPosition::function
static const IRPosition function(const Function &F, const CallBaseContext *CBContext=nullptr)
Create a position describing the function scope of F.
Definition: Attributor.h:546
llvm::BasicBlock::end
iterator end()
Definition: BasicBlock.h:308
llvm::OptimizationRemarkMissed
Diagnostic information for missed-optimization remarks.
Definition: DiagnosticInfo.h:734
llvm::OpenMPIRBuilder::LocationDescription
Description of a LLVM-IR insertion point (IP) and a debug/source location (filename,...
Definition: OMPIRBuilder.h:294
getName
static StringRef getName(Value *V)
Definition: ProvenanceAnalysisEvaluator.cpp:20
SetFixpointIterations
static cl::opt< unsigned > SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden, cl::desc("Maximal number of attributor iterations."), cl::init(256))
MI
IRTranslator LLVM IR MI
Definition: IRTranslator.cpp:108
Merge
R600 Clause Merge
Definition: R600ClauseMergePass.cpp:70
llvm
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
BlockSize
static const int BlockSize
Definition: TarWriter.cpp:33
llvm::CastInst::CreatePointerBitCastOrAddrSpaceCast
static CastInst * CreatePointerBitCastOrAddrSpaceCast(Value *S, Type *Ty, const Twine &Name, BasicBlock *InsertAtEnd)
Create a BitCast or an AddrSpaceCast cast instruction.
Definition: Instructions.cpp:3469
M
We currently emits eax Perhaps this is what we really should generate is Is imull three or four cycles eax eax The current instruction priority is based on pattern complexity The former is more complex because it folds a load so the latter will not be emitted Perhaps we should use AddedComplexity to give LEA32r a higher priority We should always try to match LEA first since the LEA matching code does some estimate to determine whether the match is profitable if we care more about code then imull is better It s two bytes shorter than movl leal On a Pentium M
Definition: README.txt:252
llvm::Instruction::getModule
const Module * getModule() const
Return the module owning the function this instruction belongs to or nullptr it the function does not...
Definition: Instruction.cpp:69
llvm::CmpInst::ICMP_EQ
@ ICMP_EQ
equal
Definition: InstrTypes.h:741
llvm::NamedMDNode
A tuple of MDNodes.
Definition: Metadata.h:1588
llvm::drop_begin
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition: STLExtras.h:387
OpenMPOpt.h
llvm::DataLayout
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:113
llvm::ISD::OR
@ OR
Definition: ISDOpcodes.h:667
llvm::AArch64PACKey::ID
ID
Definition: AArch64BaseInfo.h:818
llvm::Value::hasOneUse
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
llvm::Type::getInt8PtrTy
static PointerType * getInt8PtrTy(LLVMContext &C, unsigned AS=0)
Definition: Type.cpp:291
DisableOpenMPOptimizations
static cl::opt< bool > DisableOpenMPOptimizations("openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."), cl::Hidden, cl::init(false))
llvm::BasicBlock::getParent
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:104
IntrinsicInst.h
llvm::AnalysisManager::getResult
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:774
llvm::cl::Prefix
@ Prefix
Definition: CommandLine.h:161
llvm::MemTransferInst
This class wraps the llvm.memcpy/memmove intrinsics.
Definition: IntrinsicInst.h:1106
llvm::Function
Definition: Function.h:60
DisableOpenMPOptDeglobalization
static cl::opt< bool > DisableOpenMPOptDeglobalization("openmp-opt-disable-deglobalization", cl::desc("Disable OpenMP optimizations involving deglobalization."), cl::Hidden, cl::init(false))
llvm::Attribute
Definition: Attributes.h:66
llvm::DenseMapBase< DenseMap< KeyT, ValueT, DenseMapInfo< KeyT >, llvm::detail::DenseMapPair< KeyT, ValueT > >, KeyT, ValueT, DenseMapInfo< KeyT >, llvm::detail::DenseMapPair< KeyT, ValueT > >::lookup
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition: DenseMap.h:197
StringRef.h
P
This currently compiles esp xmm0 movsd esp eax eax esp ret We should use not the dag combiner This is because dagcombine2 needs to be able to see through the X86ISD::Wrapper which DAGCombine can t really do The code for turning x load into a single vector load is target independent and should be moved to the dag combiner The code for turning x load into a vector load can only handle a direct load from a global or a direct load from the stack It should be generalized to handle any load from P
Definition: README-SSE.txt:411
TAG
static constexpr auto TAG
Definition: OpenMPOpt.cpp:172
llvm::raw_string_ostream
A raw_ostream that writes to an std::string.
Definition: raw_ostream.h:629
llvm::SetVector< T, SmallVector< T, N >, SmallDenseSet< T, N > >::size
size_type size() const
Determine the number of elements in the SetVector.
Definition: SetVector.h:77
contains
return AArch64::GPR64RegClass contains(Reg)
PrintModuleAfterOptimizations
static cl::opt< bool > PrintModuleAfterOptimizations("openmp-opt-print-module-after", cl::desc("Print the current module after OpenMP optimizations."), cl::Hidden, cl::init(false))
llvm::GlobalValue::NotThreadLocal
@ NotThreadLocal
Definition: GlobalValue.h:192
llvm::ilist_node_with_parent::getNextNode
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition: ilist_node.h:289
llvm::SmallVector
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1199
Statistic.h
llvm::initializeOpenMPOptCGSCCLegacyPassPass
void initializeOpenMPOptCGSCCLegacyPassPass(PassRegistry &)
llvm::IRPosition::value
static const IRPosition value(const Value &V, const CallBaseContext *CBContext=nullptr)
Create a position describing the value of V.
Definition: Attributor.h:527
llvm::OpenMPOptCGSCCPass::run
PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR)
Definition: OpenMPOpt.cpp:5088
llvm::Type::getPointerAddressSpace
unsigned getPointerAddressSpace() const
Get the address space of this pointer or pointer vector type.
Definition: DerivedTypes.h:729
llvm::Function::getEntryBlock
const BasicBlock & getEntryBlock() const
Definition: Function.h:691
llvm::OpenMPIRBuilder::InsertPointTy
IRBuilder<>::InsertPoint InsertPointTy
Type used throughout for insertion points.
Definition: OMPIRBuilder.h:178
llvm::IRBuilder
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2525
llvm::StateWrapper
Helper to tie a abstract state implementation to an abstract attribute.
Definition: Attributor.h:2897
llvm::GlobalVariable
Definition: GlobalVariable.h:39
llvm::SmallDenseMap
Definition: DenseMap.h:880
llvm::CmpInst::ICMP_NE
@ ICMP_NE
not equal
Definition: InstrTypes.h:742
llvm::CallingConv::Cold
@ Cold
Attempts to make code in the caller as efficient as possible under the assumption that the call is no...
Definition: CallingConv.h:47
llvm::FunctionType::get
static FunctionType * get(Type *Result, ArrayRef< Type * > Params, bool isVarArg)
This static method is the primary way of constructing a FunctionType.
Definition: Type.cpp:361
ValueTracking.h
OptimizationRemarkEmitter.h
llvm::CallGraph
The basic data container for the call graph of a Module of IR.
Definition: CallGraph.h:72
llvm::AA::Interprocedural
@ Interprocedural
Definition: Attributor.h:160
llvm::DominatorTree
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:166
FAM
FunctionAnalysisManager FAM
Definition: PassBuilderBindings.cpp:59
llvm::cl::Hidden
@ Hidden
Definition: CommandLine.h:140
llvm::DenseMapBase< DenseMap< KeyT, ValueT, DenseMapInfo< KeyT >, llvm::detail::DenseMapPair< KeyT, ValueT > >, KeyT, ValueT, DenseMapInfo< KeyT >, llvm::detail::DenseMapPair< KeyT, ValueT > >::begin
iterator begin()
Definition: DenseMap.h:75
llvm::PreservedAnalyses::none
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Definition: PassManager.h:155
llvm::Type
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
llvm::sys::path::end
const_iterator end(StringRef path)
Get end iterator over path.
Definition: Path.cpp:235
llvm::sys::path::begin
const_iterator begin(StringRef path, Style style=Style::native)
Get begin iterator over path.
Definition: Path.cpp:226
llvm::CallBase::isCallee
bool isCallee(Value::const_user_iterator UI) const
Determine whether the passed iterator points to the callee operand's Use.
Definition: InstrTypes.h:1408
llvm::Use::getOperandNo
unsigned getOperandNo() const
Return the operand # of this use in its User.
Definition: Use.cpp:31
llvm::BasicBlock::splitBasicBlock
BasicBlock * splitBasicBlock(iterator I, const Twine &BBName="", bool Before=false)
Split the basic block into two basic blocks at the specified instruction.
Definition: BasicBlock.cpp:402
llvm::operator!=
bool operator!=(uint64_t V1, const APInt &V2)
Definition: APInt.h:2012
llvm::InformationCache
Data structure to hold cached (LLVM-IR) information.
Definition: Attributor.h:1072
llvm::MemIntrinsic
This is the common base class for memset/memcpy/memmove.
Definition: IntrinsicInst.h:1041
llvm::AttributorConfig
Configuration for the Attributor.
Definition: Attributor.h:1318
llvm::Optional
Definition: APInt.h:33
llvm::Instruction::mayWriteToMemory
bool mayWriteToMemory() const LLVM_READONLY
Return true if this instruction may modify memory.
Definition: Instruction.cpp:633
llvm::SmallPtrSet< Instruction *, 4 >
llvm::ore::NV
DiagnosticInfoOptimizationBase::Argument NV
Definition: OptimizationRemarkEmitter.h:136
llvm::tgtok::FalseVal
@ FalseVal
Definition: TGLexer.h:62
llvm::omp::AddressSpace::Constant
@ Constant
llvm::max
Expected< ExpressionValue > max(const ExpressionValue &Lhs, const ExpressionValue &Rhs)
Definition: FileCheck.cpp:337
llvm::IRPosition::IRP_ARGUMENT
@ IRP_ARGUMENT
An attribute for a function argument.
Definition: Attributor.h:517
llvm::IRPosition::IRP_RETURNED
@ IRP_RETURNED
An attribute for the function return value.
Definition: Attributor.h:513
llvm::MipsISD::Ret
@ Ret
Definition: MipsISelLowering.h:119
initialize
static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, ArrayRef< StringLiteral > StandardNames)
Initialize the set of available library functions based on the specified target triple.
Definition: TargetLibraryInfo.cpp:150
llvm::AbstractCallSite
AbstractCallSite.
Definition: AbstractCallSite.h:50
RHS
Value * RHS
Definition: X86PartialReduction.cpp:76
llvm::GlobalValue::LinkageTypes
LinkageTypes
An enumeration for the kinds of linkage for global values.
Definition: GlobalValue.h:47
llvm::AAExecutionDomain::ID
static const char ID
Unique ID (due to the unique address)
Definition: Attributor.h:4932
llvm::CallBase::addParamAttr
void addParamAttr(unsigned ArgNo, Attribute::AttrKind Kind)
Adds the attribute to the indicated argument.
Definition: InstrTypes.h:1527
DisableOpenMPOptBarrierElimination
static cl::opt< bool > DisableOpenMPOptBarrierElimination("openmp-opt-disable-barrier-elimination", cl::desc("Disable OpenMP optimizations that eliminate barriers."), cl::Hidden, cl::init(false))
llvm::Type::getInt8Ty
static IntegerType * getInt8Ty(LLVMContext &C)
Definition: Type.cpp:237
llvm::IRPosition::IRP_FLOAT
@ IRP_FLOAT
A position that is not associated with a spot suitable for attributes.
Definition: Attributor.h:511
llvm::Type::getInt32Ty
static IntegerType * getInt32Ty(LLVMContext &C)
Definition: Type.cpp:239
LLVM_DEBUG
#define LLVM_DEBUG(X)
Definition: Debug.h:101
PrintICVValues
static cl::opt< bool > PrintICVValues("openmp-print-icv-values", cl::init(false), cl::Hidden)
F
#define F(x, y, z)
Definition: MD5.cpp:55
llvm::IntegerStateBase< bool, true, false >::operator^=
void operator^=(const IntegerStateBase< base_t, BestState, WorstState > &R)
"Clamp" this state with R.
Definition: Attributor.h:2423
llvm::ConstantExpr::getPointerCast
static Constant * getPointerCast(Constant *C, Type *Ty)
Create a BitCast, AddrSpaceCast, or a PtrToInt cast constant expression.
Definition: Constants.cpp:2014
llvm::BasicBlock::getUniqueSuccessor
const BasicBlock * getUniqueSuccessor() const
Return the successor of this block if it has a unique successor.
Definition: BasicBlock.cpp:323
DisableOpenMPOptSPMDization
static cl::opt< bool > DisableOpenMPOptSPMDization("openmp-opt-disable-spmdization", cl::desc("Disable OpenMP optimizations involving SPMD-ization."), cl::Hidden, cl::init(false))
llvm::BasicBlock
LLVM Basic Block Representation.
Definition: BasicBlock.h:55
llvm::AMDGPU::isKernel
LLVM_READNONE bool isKernel(CallingConv::ID CC)
Definition: AMDGPUBaseInfo.h:1073
llvm::Instruction::mayReadFromMemory
bool mayReadFromMemory() const LLVM_READONLY
Return true if this instruction may read memory.
Definition: Instruction.cpp:613
llvm::dbgs
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
llvm::CallGraphUpdater::finalize
bool finalize()
}
Definition: CallGraphUpdater.cpp:24
Arg
amdgpu Simplify well known AMD library false FunctionCallee Value * Arg
Definition: AMDGPULibCalls.cpp:187
Instruction.h
llvm::AMDGPU::HSAMD::Key::Kernels
constexpr char Kernels[]
Key for HSA::Metadata::mKernels.
Definition: AMDGPUMetadata.h:432
CommandLine.h
llvm::ConstantInt
This is the shared class of boolean and integer constants.
Definition: Constants.h:79
llvm::all_of
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1734
llvm::operator&=
bool operator&=(SparseBitVector< ElementSize > *LHS, const SparseBitVector< ElementSize > &RHS)
Definition: SparseBitVector.h:835
llvm::Instruction::isLifetimeStartOrEnd
bool isLifetimeStartOrEnd() const LLVM_READONLY
Return true if the instruction is a llvm.lifetime.start or llvm.lifetime.end marker.
Definition: Instruction.cpp:755
GlobalValue.h
ll