LLVM  15.0.0git
OpenMPOpt.cpp
Go to the documentation of this file.
1 //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpenMP specific optimizations:
10 //
11 // - Deduplication of runtime calls, e.g., omp_get_thread_num.
12 // - Replacing globalized device memory with stack memory.
13 // - Replacing globalized device memory with shared memory.
14 // - Parallel region merging.
15 // - Transforming generic-mode device kernels to SPMD mode.
16 // - Specializing the state machine for generic-mode device kernels.
17 //
18 //===----------------------------------------------------------------------===//
19 
21 
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/Statistic.h"
26 #include "llvm/ADT/StringRef.h"
34 #include "llvm/IR/Assumptions.h"
35 #include "llvm/IR/Constants.h"
36 #include "llvm/IR/DiagnosticInfo.h"
37 #include "llvm/IR/GlobalValue.h"
38 #include "llvm/IR/GlobalVariable.h"
39 #include "llvm/IR/Instruction.h"
40 #include "llvm/IR/Instructions.h"
41 #include "llvm/IR/IntrinsicInst.h"
42 #include "llvm/IR/IntrinsicsAMDGPU.h"
43 #include "llvm/IR/IntrinsicsNVPTX.h"
44 #include "llvm/IR/LLVMContext.h"
45 #include "llvm/InitializePasses.h"
47 #include "llvm/Support/Debug.h"
48 #include "llvm/Transforms/IPO.h"
52 
53 #include <algorithm>
54 
55 using namespace llvm;
56 using namespace omp;
57 
58 #define DEBUG_TYPE "openmp-opt"
59 
61  "openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."),
62  cl::Hidden, cl::init(false));
63 
65  "openmp-opt-enable-merging",
66  cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
67  cl::init(false));
68 
69 static cl::opt<bool>
70  DisableInternalization("openmp-opt-disable-internalization",
71  cl::desc("Disable function internalization."),
72  cl::Hidden, cl::init(false));
73 
74 static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
75  cl::Hidden);
76 static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
77  cl::init(false), cl::Hidden);
78 
80  "openmp-hide-memory-transfer-latency",
81  cl::desc("[WIP] Tries to hide the latency of host to device memory"
82  " transfers"),
83  cl::Hidden, cl::init(false));
84 
86  "openmp-opt-disable-deglobalization",
87  cl::desc("Disable OpenMP optimizations involving deglobalization."),
88  cl::Hidden, cl::init(false));
89 
91  "openmp-opt-disable-spmdization",
92  cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
93  cl::Hidden, cl::init(false));
94 
96  "openmp-opt-disable-folding",
97  cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
98  cl::init(false));
99 
101  "openmp-opt-disable-state-machine-rewrite",
102  cl::desc("Disable OpenMP optimizations that replace the state machine."),
103  cl::Hidden, cl::init(false));
104 
106  "openmp-opt-disable-barrier-elimination",
107  cl::desc("Disable OpenMP optimizations that eliminate barriers."),
108  cl::Hidden, cl::init(false));
109 
111  "openmp-opt-print-module-after",
112  cl::desc("Print the current module after OpenMP optimizations."),
113  cl::Hidden, cl::init(false));
114 
116  "openmp-opt-print-module-before",
117  cl::desc("Print the current module before OpenMP optimizations."),
118  cl::Hidden, cl::init(false));
119 
121  "openmp-opt-inline-device",
122  cl::desc("Inline all applicible functions on the device."), cl::Hidden,
123  cl::init(false));
124 
125 static cl::opt<bool>
126  EnableVerboseRemarks("openmp-opt-verbose-remarks",
127  cl::desc("Enables more verbose remarks."), cl::Hidden,
128  cl::init(false));
129 
130 static cl::opt<unsigned>
131  SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden,
132  cl::desc("Maximal number of attributor iterations."),
133  cl::init(256));
134 
135 static cl::opt<unsigned>
136  SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden,
137  cl::desc("Maximum amount of shared memory to use."),
139 
140 STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
141  "Number of OpenMP runtime calls deduplicated");
142 STATISTIC(NumOpenMPParallelRegionsDeleted,
143  "Number of OpenMP parallel regions deleted");
144 STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
145  "Number of OpenMP runtime functions identified");
146 STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
147  "Number of OpenMP runtime function uses identified");
148 STATISTIC(NumOpenMPTargetRegionKernels,
149  "Number of OpenMP target region entry points (=kernels) identified");
150 STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
151  "Number of OpenMP target region entry points (=kernels) executed in "
152  "SPMD-mode instead of generic-mode");
153 STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
154  "Number of OpenMP target region entry points (=kernels) executed in "
155  "generic-mode without a state machines");
156 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
157  "Number of OpenMP target region entry points (=kernels) executed in "
158  "generic-mode with customized state machines with fallback");
159 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
160  "Number of OpenMP target region entry points (=kernels) executed in "
161  "generic-mode with customized state machines without fallback");
162 STATISTIC(
163  NumOpenMPParallelRegionsReplacedInGPUStateMachine,
164  "Number of OpenMP parallel regions replaced with ID in GPU state machines");
165 STATISTIC(NumOpenMPParallelRegionsMerged,
166  "Number of OpenMP parallel regions merged");
167 STATISTIC(NumBytesMovedToSharedMemory,
168  "Amount of memory pushed to shared memory");
169 STATISTIC(NumBarriersEliminated, "Number of redundant barriers eliminated");
170 
171 #if !defined(NDEBUG)
172 static constexpr auto TAG = "[" DEBUG_TYPE "]";
173 #endif
174 
175 namespace {
176 
177 struct AAHeapToShared;
178 
179 struct AAICVTracker;
180 
181 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for
182 /// Attributor runs.
183 struct OMPInformationCache : public InformationCache {
184  OMPInformationCache(Module &M, AnalysisGetter &AG,
187  : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M),
188  Kernels(Kernels) {
189 
190  OMPBuilder.initialize();
191  initializeRuntimeFunctions();
192  initializeInternalControlVars();
193  }
194 
195  /// Generic information that describes an internal control variable.
196  struct InternalControlVarInfo {
197  /// The kind, as described by InternalControlVar enum.
199 
200  /// The name of the ICV.
201  StringRef Name;
202 
203  /// Environment variable associated with this ICV.
204  StringRef EnvVarName;
205 
206  /// Initial value kind.
207  ICVInitValue InitKind;
208 
209  /// Initial value.
210  ConstantInt *InitValue;
211 
212  /// Setter RTL function associated with this ICV.
213  RuntimeFunction Setter;
214 
215  /// Getter RTL function associated with this ICV.
216  RuntimeFunction Getter;
217 
218  /// RTL Function corresponding to the override clause of this ICV
220  };
221 
222  /// Generic information that describes a runtime function
223  struct RuntimeFunctionInfo {
224 
225  /// The kind, as described by the RuntimeFunction enum.
227 
228  /// The name of the function.
229  StringRef Name;
230 
231  /// Flag to indicate a variadic function.
232  bool IsVarArg;
233 
234  /// The return type of the function.
235  Type *ReturnType;
236 
237  /// The argument types of the function.
238  SmallVector<Type *, 8> ArgumentTypes;
239 
240  /// The declaration if available.
241  Function *Declaration = nullptr;
242 
243  /// Uses of this runtime function per function containing the use.
244  using UseVector = SmallVector<Use *, 16>;
245 
246  /// Clear UsesMap for runtime function.
247  void clearUsesMap() { UsesMap.clear(); }
248 
249  /// Boolean conversion that is true if the runtime function was found.
250  operator bool() const { return Declaration; }
251 
252  /// Return the vector of uses in function \p F.
253  UseVector &getOrCreateUseVector(Function *F) {
254  std::shared_ptr<UseVector> &UV = UsesMap[F];
255  if (!UV)
256  UV = std::make_shared<UseVector>();
257  return *UV;
258  }
259 
260  /// Return the vector of uses in function \p F or `nullptr` if there are
261  /// none.
262  const UseVector *getUseVector(Function &F) const {
263  auto I = UsesMap.find(&F);
264  if (I != UsesMap.end())
265  return I->second.get();
266  return nullptr;
267  }
268 
269  /// Return how many functions contain uses of this runtime function.
270  size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
271 
272  /// Return the number of arguments (or the minimal number for variadic
273  /// functions).
274  size_t getNumArgs() const { return ArgumentTypes.size(); }
275 
276  /// Run the callback \p CB on each use and forget the use if the result is
277  /// true. The callback will be fed the function in which the use was
278  /// encountered as second argument.
279  void foreachUse(SmallVectorImpl<Function *> &SCC,
280  function_ref<bool(Use &, Function &)> CB) {
281  for (Function *F : SCC)
282  foreachUse(CB, F);
283  }
284 
285  /// Run the callback \p CB on each use within the function \p F and forget
286  /// the use if the result is true.
287  void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
288  SmallVector<unsigned, 8> ToBeDeleted;
289  ToBeDeleted.clear();
290 
291  unsigned Idx = 0;
292  UseVector &UV = getOrCreateUseVector(F);
293 
294  for (Use *U : UV) {
295  if (CB(*U, *F))
296  ToBeDeleted.push_back(Idx);
297  ++Idx;
298  }
299 
300  // Remove the to-be-deleted indices in reverse order as prior
301  // modifications will not modify the smaller indices.
302  while (!ToBeDeleted.empty()) {
303  unsigned Idx = ToBeDeleted.pop_back_val();
304  UV[Idx] = UV.back();
305  UV.pop_back();
306  }
307  }
308 
309  private:
310  /// Map from functions to all uses of this runtime function contained in
311  /// them.
313 
314  public:
315  /// Iterators for the uses of this runtime function.
316  decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
317  decltype(UsesMap)::iterator end() { return UsesMap.end(); }
318  };
319 
320  /// An OpenMP-IR-Builder instance
321  OpenMPIRBuilder OMPBuilder;
322 
323  /// Map from runtime function kind to the runtime function description.
324  EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
325  RuntimeFunction::OMPRTL___last>
326  RFIs;
327 
328  /// Map from function declarations/definitions to their runtime enum type.
329  DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
330 
331  /// Map from ICV kind to the ICV description.
332  EnumeratedArray<InternalControlVarInfo, InternalControlVar,
333  InternalControlVar::ICV___last>
334  ICVs;
335 
336  /// Helper to initialize all internal control variable information for those
337  /// defined in OMPKinds.def.
338  void initializeInternalControlVars() {
339 #define ICV_RT_SET(_Name, RTL) \
340  { \
341  auto &ICV = ICVs[_Name]; \
342  ICV.Setter = RTL; \
343  }
344 #define ICV_RT_GET(Name, RTL) \
345  { \
346  auto &ICV = ICVs[Name]; \
347  ICV.Getter = RTL; \
348  }
349 #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \
350  { \
351  auto &ICV = ICVs[Enum]; \
352  ICV.Name = _Name; \
353  ICV.Kind = Enum; \
354  ICV.InitKind = Init; \
355  ICV.EnvVarName = _EnvVarName; \
356  switch (ICV.InitKind) { \
357  case ICV_IMPLEMENTATION_DEFINED: \
358  ICV.InitValue = nullptr; \
359  break; \
360  case ICV_ZERO: \
361  ICV.InitValue = ConstantInt::get( \
362  Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \
363  break; \
364  case ICV_FALSE: \
365  ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \
366  break; \
367  case ICV_LAST: \
368  break; \
369  } \
370  }
371 #include "llvm/Frontend/OpenMP/OMPKinds.def"
372  }
373 
374  /// Returns true if the function declaration \p F matches the runtime
375  /// function types, that is, return type \p RTFRetType, and argument types
376  /// \p RTFArgTypes.
377  static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
378  SmallVector<Type *, 8> &RTFArgTypes) {
379  // TODO: We should output information to the user (under debug output
380  // and via remarks).
381 
382  if (!F)
383  return false;
384  if (F->getReturnType() != RTFRetType)
385  return false;
386  if (F->arg_size() != RTFArgTypes.size())
387  return false;
388 
389  auto *RTFTyIt = RTFArgTypes.begin();
390  for (Argument &Arg : F->args()) {
391  if (Arg.getType() != *RTFTyIt)
392  return false;
393 
394  ++RTFTyIt;
395  }
396 
397  return true;
398  }
399 
400  // Helper to collect all uses of the declaration in the UsesMap.
401  unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
402  unsigned NumUses = 0;
403  if (!RFI.Declaration)
404  return NumUses;
405  OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
406 
407  if (CollectStats) {
408  NumOpenMPRuntimeFunctionsIdentified += 1;
409  NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
410  }
411 
412  // TODO: We directly convert uses into proper calls and unknown uses.
413  for (Use &U : RFI.Declaration->uses()) {
414  if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
415  if (ModuleSlice.count(UserI->getFunction())) {
416  RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
417  ++NumUses;
418  }
419  } else {
420  RFI.getOrCreateUseVector(nullptr).push_back(&U);
421  ++NumUses;
422  }
423  }
424  return NumUses;
425  }
426 
427  // Helper function to recollect uses of a runtime function.
428  void recollectUsesForFunction(RuntimeFunction RTF) {
429  auto &RFI = RFIs[RTF];
430  RFI.clearUsesMap();
431  collectUses(RFI, /*CollectStats*/ false);
432  }
433 
434  // Helper function to recollect uses of all runtime functions.
435  void recollectUses() {
436  for (int Idx = 0; Idx < RFIs.size(); ++Idx)
437  recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
438  }
439 
440  // Helper function to inherit the calling convention of the function callee.
441  void setCallingConvention(FunctionCallee Callee, CallInst *CI) {
442  if (Function *Fn = dyn_cast<Function>(Callee.getCallee()))
443  CI->setCallingConv(Fn->getCallingConv());
444  }
445 
446  /// Helper to initialize all runtime function information for those defined
447  /// in OpenMPKinds.def.
448  void initializeRuntimeFunctions() {
449  Module &M = *((*ModuleSlice.begin())->getParent());
450 
451  // Helper macros for handling __VA_ARGS__ in OMP_RTL
452 #define OMP_TYPE(VarName, ...) \
453  Type *VarName = OMPBuilder.VarName; \
454  (void)VarName;
455 
456 #define OMP_ARRAY_TYPE(VarName, ...) \
457  ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \
458  (void)VarName##Ty; \
459  PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \
460  (void)VarName##PtrTy;
461 
462 #define OMP_FUNCTION_TYPE(VarName, ...) \
463  FunctionType *VarName = OMPBuilder.VarName; \
464  (void)VarName; \
465  PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
466  (void)VarName##Ptr;
467 
468 #define OMP_STRUCT_TYPE(VarName, ...) \
469  StructType *VarName = OMPBuilder.VarName; \
470  (void)VarName; \
471  PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
472  (void)VarName##Ptr;
473 
474 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \
475  { \
476  SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \
477  Function *F = M.getFunction(_Name); \
478  RTLFunctions.insert(F); \
479  if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \
480  RuntimeFunctionIDMap[F] = _Enum; \
481  auto &RFI = RFIs[_Enum]; \
482  RFI.Kind = _Enum; \
483  RFI.Name = _Name; \
484  RFI.IsVarArg = _IsVarArg; \
485  RFI.ReturnType = OMPBuilder._ReturnType; \
486  RFI.ArgumentTypes = std::move(ArgsTypes); \
487  RFI.Declaration = F; \
488  unsigned NumUses = collectUses(RFI); \
489  (void)NumUses; \
490  LLVM_DEBUG({ \
491  dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \
492  << " found\n"; \
493  if (RFI.Declaration) \
494  dbgs() << TAG << "-> got " << NumUses << " uses in " \
495  << RFI.getNumFunctionsWithUses() \
496  << " different functions.\n"; \
497  }); \
498  } \
499  }
500 #include "llvm/Frontend/OpenMP/OMPKinds.def"
501 
502  // Remove the `noinline` attribute from `__kmpc`, `_OMP::` and `omp_`
503  // functions, except if `optnone` is present.
504  if (isOpenMPDevice(M)) {
505  for (Function &F : M) {
506  for (StringRef Prefix : {"__kmpc", "_ZN4_OMP", "omp_"})
507  if (F.hasFnAttribute(Attribute::NoInline) &&
508  F.getName().startswith(Prefix) &&
509  !F.hasFnAttribute(Attribute::OptimizeNone))
510  F.removeFnAttr(Attribute::NoInline);
511  }
512  }
513 
514  // TODO: We should attach the attributes defined in OMPKinds.def.
515  }
516 
517  /// Collection of known kernels (\see Kernel) in the module.
519 
520  /// Collection of known OpenMP runtime functions..
521  DenseSet<const Function *> RTLFunctions;
522 };
523 
524 template <typename Ty, bool InsertInvalidates = true>
525 struct BooleanStateWithSetVector : public BooleanState {
526  bool contains(const Ty &Elem) const { return Set.contains(Elem); }
527  bool insert(const Ty &Elem) {
528  if (InsertInvalidates)
530  return Set.insert(Elem);
531  }
532 
533  const Ty &operator[](int Idx) const { return Set[Idx]; }
534  bool operator==(const BooleanStateWithSetVector &RHS) const {
535  return BooleanState::operator==(RHS) && Set == RHS.Set;
536  }
537  bool operator!=(const BooleanStateWithSetVector &RHS) const {
538  return !(*this == RHS);
539  }
540 
541  bool empty() const { return Set.empty(); }
542  size_t size() const { return Set.size(); }
543 
544  /// "Clamp" this state with \p RHS.
545  BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
547  Set.insert(RHS.Set.begin(), RHS.Set.end());
548  return *this;
549  }
550 
551 private:
552  /// A set to keep track of elements.
553  SetVector<Ty> Set;
554 
555 public:
556  typename decltype(Set)::iterator begin() { return Set.begin(); }
557  typename decltype(Set)::iterator end() { return Set.end(); }
558  typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
559  typename decltype(Set)::const_iterator end() const { return Set.end(); }
560 };
561 
562 template <typename Ty, bool InsertInvalidates = true>
563 using BooleanStateWithPtrSetVector =
564  BooleanStateWithSetVector<Ty *, InsertInvalidates>;
565 
566 struct KernelInfoState : AbstractState {
567  /// Flag to track if we reached a fixpoint.
568  bool IsAtFixpoint = false;
569 
570  /// The parallel regions (identified by the outlined parallel functions) that
571  /// can be reached from the associated function.
572  BooleanStateWithPtrSetVector<Function, /* InsertInvalidates */ false>
573  ReachedKnownParallelRegions;
574 
575  /// State to track what parallel region we might reach.
576  BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
577 
578  /// State to track if we are in SPMD-mode, assumed or know, and why we decided
579  /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
580  /// false.
581  BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
582 
583  /// The __kmpc_target_init call in this kernel, if any. If we find more than
584  /// one we abort as the kernel is malformed.
585  CallBase *KernelInitCB = nullptr;
586 
587  /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
588  /// one we abort as the kernel is malformed.
589  CallBase *KernelDeinitCB = nullptr;
590 
591  /// Flag to indicate if the associated function is a kernel entry.
592  bool IsKernelEntry = false;
593 
594  /// State to track what kernel entries can reach the associated function.
595  BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
596 
597  /// State to indicate if we can track parallel level of the associated
598  /// function. We will give up tracking if we encounter unknown caller or the
599  /// caller is __kmpc_parallel_51.
600  BooleanStateWithSetVector<uint8_t> ParallelLevels;
601 
602  /// Abstract State interface
603  ///{
604 
605  KernelInfoState() = default;
606  KernelInfoState(bool BestState) {
607  if (!BestState)
608  indicatePessimisticFixpoint();
609  }
610 
611  /// See AbstractState::isValidState(...)
612  bool isValidState() const override { return true; }
613 
614  /// See AbstractState::isAtFixpoint(...)
615  bool isAtFixpoint() const override { return IsAtFixpoint; }
616 
617  /// See AbstractState::indicatePessimisticFixpoint(...)
618  ChangeStatus indicatePessimisticFixpoint() override {
619  IsAtFixpoint = true;
620  ReachingKernelEntries.indicatePessimisticFixpoint();
621  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
622  ReachedKnownParallelRegions.indicatePessimisticFixpoint();
623  ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
624  return ChangeStatus::CHANGED;
625  }
626 
627  /// See AbstractState::indicateOptimisticFixpoint(...)
628  ChangeStatus indicateOptimisticFixpoint() override {
629  IsAtFixpoint = true;
630  ReachingKernelEntries.indicateOptimisticFixpoint();
631  SPMDCompatibilityTracker.indicateOptimisticFixpoint();
632  ReachedKnownParallelRegions.indicateOptimisticFixpoint();
633  ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
635  }
636 
637  /// Return the assumed state
638  KernelInfoState &getAssumed() { return *this; }
639  const KernelInfoState &getAssumed() const { return *this; }
640 
641  bool operator==(const KernelInfoState &RHS) const {
642  if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
643  return false;
644  if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
645  return false;
646  if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
647  return false;
648  if (ReachingKernelEntries != RHS.ReachingKernelEntries)
649  return false;
650  return true;
651  }
652 
653  /// Returns true if this kernel contains any OpenMP parallel regions.
654  bool mayContainParallelRegion() {
655  return !ReachedKnownParallelRegions.empty() ||
656  !ReachedUnknownParallelRegions.empty();
657  }
658 
659  /// Return empty set as the best state of potential values.
660  static KernelInfoState getBestState() { return KernelInfoState(true); }
661 
662  static KernelInfoState getBestState(KernelInfoState &KIS) {
663  return getBestState();
664  }
665 
666  /// Return full set as the worst state of potential values.
667  static KernelInfoState getWorstState() { return KernelInfoState(false); }
668 
669  /// "Clamp" this state with \p KIS.
670  KernelInfoState operator^=(const KernelInfoState &KIS) {
671  // Do not merge two different _init and _deinit call sites.
672  if (KIS.KernelInitCB) {
673  if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
674  llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
675  "assumptions.");
676  KernelInitCB = KIS.KernelInitCB;
677  }
678  if (KIS.KernelDeinitCB) {
679  if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
680  llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
681  "assumptions.");
682  KernelDeinitCB = KIS.KernelDeinitCB;
683  }
684  SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
685  ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
686  ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
687  return *this;
688  }
689 
690  KernelInfoState operator&=(const KernelInfoState &KIS) {
691  return (*this ^= KIS);
692  }
693 
694  ///}
695 };
696 
697 /// Used to map the values physically (in the IR) stored in an offload
698 /// array, to a vector in memory.
699 struct OffloadArray {
700  /// Physical array (in the IR).
701  AllocaInst *Array = nullptr;
702  /// Mapped values.
703  SmallVector<Value *, 8> StoredValues;
704  /// Last stores made in the offload array.
705  SmallVector<StoreInst *, 8> LastAccesses;
706 
707  OffloadArray() = default;
708 
709  /// Initializes the OffloadArray with the values stored in \p Array before
710  /// instruction \p Before is reached. Returns false if the initialization
711  /// fails.
712  /// This MUST be used immediately after the construction of the object.
713  bool initialize(AllocaInst &Array, Instruction &Before) {
714  if (!Array.getAllocatedType()->isArrayTy())
715  return false;
716 
717  if (!getValues(Array, Before))
718  return false;
719 
720  this->Array = &Array;
721  return true;
722  }
723 
724  static const unsigned DeviceIDArgNum = 1;
725  static const unsigned BasePtrsArgNum = 3;
726  static const unsigned PtrsArgNum = 4;
727  static const unsigned SizesArgNum = 5;
728 
729 private:
730  /// Traverses the BasicBlock where \p Array is, collecting the stores made to
731  /// \p Array, leaving StoredValues with the values stored before the
732  /// instruction \p Before is reached.
733  bool getValues(AllocaInst &Array, Instruction &Before) {
734  // Initialize container.
735  const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
736  StoredValues.assign(NumValues, nullptr);
737  LastAccesses.assign(NumValues, nullptr);
738 
739  // TODO: This assumes the instruction \p Before is in the same
740  // BasicBlock as Array. Make it general, for any control flow graph.
741  BasicBlock *BB = Array.getParent();
742  if (BB != Before.getParent())
743  return false;
744 
745  const DataLayout &DL = Array.getModule()->getDataLayout();
746  const unsigned int PointerSize = DL.getPointerSize();
747 
748  for (Instruction &I : *BB) {
749  if (&I == &Before)
750  break;
751 
752  if (!isa<StoreInst>(&I))
753  continue;
754 
755  auto *S = cast<StoreInst>(&I);
756  int64_t Offset = -1;
757  auto *Dst =
758  GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
759  if (Dst == &Array) {
760  int64_t Idx = Offset / PointerSize;
761  StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
762  LastAccesses[Idx] = S;
763  }
764  }
765 
766  return isFilled();
767  }
768 
769  /// Returns true if all values in StoredValues and
770  /// LastAccesses are not nullptrs.
771  bool isFilled() {
772  const unsigned NumValues = StoredValues.size();
773  for (unsigned I = 0; I < NumValues; ++I) {
774  if (!StoredValues[I] || !LastAccesses[I])
775  return false;
776  }
777 
778  return true;
779  }
780 };
781 
782 struct OpenMPOpt {
783 
784  using OptimizationRemarkGetter =
786 
787  OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
788  OptimizationRemarkGetter OREGetter,
789  OMPInformationCache &OMPInfoCache, Attributor &A)
790  : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
791  OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
792 
793  /// Check if any remarks are enabled for openmp-opt
794  bool remarksEnabled() {
795  auto &Ctx = M.getContext();
796  return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
797  }
798 
799  /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice.
800  bool run(bool IsModulePass) {
801  if (SCC.empty())
802  return false;
803 
804  bool Changed = false;
805 
806  LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
807  << " functions in a slice with "
808  << OMPInfoCache.ModuleSlice.size() << " functions\n");
809 
810  if (IsModulePass) {
811  Changed |= runAttributor(IsModulePass);
812 
813  // Recollect uses, in case Attributor deleted any.
814  OMPInfoCache.recollectUses();
815 
816  // TODO: This should be folded into buildCustomStateMachine.
817  Changed |= rewriteDeviceCodeStateMachine();
818 
819  if (remarksEnabled())
820  analysisGlobalization();
821 
822  Changed |= eliminateBarriers();
823  } else {
824  if (PrintICVValues)
825  printICVs();
826  if (PrintOpenMPKernels)
827  printKernels();
828 
829  Changed |= runAttributor(IsModulePass);
830 
831  // Recollect uses, in case Attributor deleted any.
832  OMPInfoCache.recollectUses();
833 
834  Changed |= deleteParallelRegions();
835 
837  Changed |= hideMemTransfersLatency();
838  Changed |= deduplicateRuntimeCalls();
840  if (mergeParallelRegions()) {
841  deduplicateRuntimeCalls();
842  Changed = true;
843  }
844  }
845 
846  Changed |= eliminateBarriers();
847  }
848 
849  return Changed;
850  }
851 
852  /// Print initial ICV values for testing.
853  /// FIXME: This should be done from the Attributor once it is added.
854  void printICVs() const {
855  InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
856  ICV_proc_bind};
857 
858  for (Function *F : OMPInfoCache.ModuleSlice) {
859  for (auto ICV : ICVs) {
860  auto ICVInfo = OMPInfoCache.ICVs[ICV];
861  auto Remark = [&](OptimizationRemarkAnalysis ORA) {
862  return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
863  << " Value: "
864  << (ICVInfo.InitValue
865  ? toString(ICVInfo.InitValue->getValue(), 10, true)
866  : "IMPLEMENTATION_DEFINED");
867  };
868 
869  emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
870  }
871  }
872  }
873 
874  /// Print OpenMP GPU kernels for testing.
875  void printKernels() const {
876  for (Function *F : SCC) {
877  if (!OMPInfoCache.Kernels.count(F))
878  continue;
879 
880  auto Remark = [&](OptimizationRemarkAnalysis ORA) {
881  return ORA << "OpenMP GPU kernel "
882  << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
883  };
884 
885  emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
886  }
887  }
888 
889  /// Return the call if \p U is a callee use in a regular call. If \p RFI is
890  /// given it has to be the callee or a nullptr is returned.
891  static CallInst *getCallIfRegularCall(
892  Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
893  CallInst *CI = dyn_cast<CallInst>(U.getUser());
894  if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
895  (!RFI ||
896  (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
897  return CI;
898  return nullptr;
899  }
900 
901  /// Return the call if \p V is a regular call. If \p RFI is given it has to be
902  /// the callee or a nullptr is returned.
903  static CallInst *getCallIfRegularCall(
904  Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
905  CallInst *CI = dyn_cast<CallInst>(&V);
906  if (CI && !CI->hasOperandBundles() &&
907  (!RFI ||
908  (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
909  return CI;
910  return nullptr;
911  }
912 
913 private:
914  /// Merge parallel regions when it is safe.
915  bool mergeParallelRegions() {
916  const unsigned CallbackCalleeOperand = 2;
917  const unsigned CallbackFirstArgOperand = 3;
918  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
919 
920  // Check if there are any __kmpc_fork_call calls to merge.
921  OMPInformationCache::RuntimeFunctionInfo &RFI =
922  OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
923 
924  if (!RFI.Declaration)
925  return false;
926 
927  // Unmergable calls that prevent merging a parallel region.
928  OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
929  OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
930  OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
931  };
932 
933  bool Changed = false;
934  LoopInfo *LI = nullptr;
935  DominatorTree *DT = nullptr;
936 
938 
939  BasicBlock *StartBB = nullptr, *EndBB = nullptr;
940  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
941  BasicBlock *CGStartBB = CodeGenIP.getBlock();
942  BasicBlock *CGEndBB =
943  SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
944  assert(StartBB != nullptr && "StartBB should not be null");
945  CGStartBB->getTerminator()->setSuccessor(0, StartBB);
946  assert(EndBB != nullptr && "EndBB should not be null");
947  EndBB->getTerminator()->setSuccessor(0, CGEndBB);
948  };
949 
950  auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
951  Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
952  ReplacementValue = &Inner;
953  return CodeGenIP;
954  };
955 
956  auto FiniCB = [&](InsertPointTy CodeGenIP) {};
957 
958  /// Create a sequential execution region within a merged parallel region,
959  /// encapsulated in a master construct with a barrier for synchronization.
960  auto CreateSequentialRegion = [&](Function *OuterFn,
961  BasicBlock *OuterPredBB,
962  Instruction *SeqStartI,
963  Instruction *SeqEndI) {
964  // Isolate the instructions of the sequential region to a separate
965  // block.
966  BasicBlock *ParentBB = SeqStartI->getParent();
967  BasicBlock *SeqEndBB =
968  SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
969  BasicBlock *SeqAfterBB =
970  SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
971  BasicBlock *SeqStartBB =
972  SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
973 
974  assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
975  "Expected a different CFG");
976  const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
977  ParentBB->getTerminator()->eraseFromParent();
978 
979  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
980  BasicBlock *CGStartBB = CodeGenIP.getBlock();
981  BasicBlock *CGEndBB =
982  SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
983  assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
984  CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
985  assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
986  SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
987  };
988  auto FiniCB = [&](InsertPointTy CodeGenIP) {};
989 
990  // Find outputs from the sequential region to outside users and
991  // broadcast their values to them.
992  for (Instruction &I : *SeqStartBB) {
993  SmallPtrSet<Instruction *, 4> OutsideUsers;
994  for (User *Usr : I.users()) {
995  Instruction &UsrI = *cast<Instruction>(Usr);
996  // Ignore outputs to LT intrinsics, code extraction for the merged
997  // parallel region will fix them.
998  if (UsrI.isLifetimeStartOrEnd())
999  continue;
1000 
1001  if (UsrI.getParent() != SeqStartBB)
1002  OutsideUsers.insert(&UsrI);
1003  }
1004 
1005  if (OutsideUsers.empty())
1006  continue;
1007 
1008  // Emit an alloca in the outer region to store the broadcasted
1009  // value.
1010  const DataLayout &DL = M.getDataLayout();
1011  AllocaInst *AllocaI = new AllocaInst(
1012  I.getType(), DL.getAllocaAddrSpace(), nullptr,
1013  I.getName() + ".seq.output.alloc", &OuterFn->front().front());
1014 
1015  // Emit a store instruction in the sequential BB to update the
1016  // value.
1017  new StoreInst(&I, AllocaI, SeqStartBB->getTerminator());
1018 
1019  // Emit a load instruction and replace the use of the output value
1020  // with it.
1021  for (Instruction *UsrI : OutsideUsers) {
1022  LoadInst *LoadI = new LoadInst(
1023  I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI);
1024  UsrI->replaceUsesOfWith(&I, LoadI);
1025  }
1026  }
1027 
1029  InsertPointTy(ParentBB, ParentBB->end()), DL);
1030  InsertPointTy SeqAfterIP =
1031  OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
1032 
1033  OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
1034 
1035  BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
1036 
1037  LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
1038  << "\n");
1039  };
1040 
1041  // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
1042  // contained in BB and only separated by instructions that can be
1043  // redundantly executed in parallel. The block BB is split before the first
1044  // call (in MergableCIs) and after the last so the entire region we merge
1045  // into a single parallel region is contained in a single basic block
1046  // without any other instructions. We use the OpenMPIRBuilder to outline
1047  // that block and call the resulting function via __kmpc_fork_call.
1048  auto Merge = [&](const SmallVectorImpl<CallInst *> &MergableCIs,
1049  BasicBlock *BB) {
1050  // TODO: Change the interface to allow single CIs expanded, e.g, to
1051  // include an outer loop.
1052  assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
1053 
1054  auto Remark = [&](OptimizationRemark OR) {
1055  OR << "Parallel region merged with parallel region"
1056  << (MergableCIs.size() > 2 ? "s" : "") << " at ";
1057  for (auto *CI : llvm::drop_begin(MergableCIs)) {
1058  OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
1059  if (CI != MergableCIs.back())
1060  OR << ", ";
1061  }
1062  return OR << ".";
1063  };
1064 
1065  emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
1066 
1067  Function *OriginalFn = BB->getParent();
1068  LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
1069  << " parallel regions in " << OriginalFn->getName()
1070  << "\n");
1071 
1072  // Isolate the calls to merge in a separate block.
1073  EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
1074  BasicBlock *AfterBB =
1075  SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1076  StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
1077  "omp.par.merged");
1078 
1079  assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
1080  const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1081  BB->getTerminator()->eraseFromParent();
1082 
1083  // Create sequential regions for sequential instructions that are
1084  // in-between mergable parallel regions.
1085  for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1086  It != End; ++It) {
1087  Instruction *ForkCI = *It;
1088  Instruction *NextForkCI = *(It + 1);
1089 
1090  // Continue if there are not in-between instructions.
1091  if (ForkCI->getNextNode() == NextForkCI)
1092  continue;
1093 
1094  CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1095  NextForkCI->getPrevNode());
1096  }
1097 
1098  OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1099  DL);
1100  IRBuilder<>::InsertPoint AllocaIP(
1101  &OriginalFn->getEntryBlock(),
1102  OriginalFn->getEntryBlock().getFirstInsertionPt());
1103  // Create the merged parallel region with default proc binding, to
1104  // avoid overriding binding settings, and without explicit cancellation.
1105  InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
1106  Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
1107  OMP_PROC_BIND_default, /* IsCancellable */ false);
1108  BranchInst::Create(AfterBB, AfterIP.getBlock());
1109 
1110  // Perform the actual outlining.
1111  OMPInfoCache.OMPBuilder.finalize(OriginalFn);
1112 
1113  Function *OutlinedFn = MergableCIs.front()->getCaller();
1114 
1115  // Replace the __kmpc_fork_call calls with direct calls to the outlined
1116  // callbacks.
1118  for (auto *CI : MergableCIs) {
1119  Value *Callee = CI->getArgOperand(CallbackCalleeOperand);
1120  FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask;
1121  Args.clear();
1122  Args.push_back(OutlinedFn->getArg(0));
1123  Args.push_back(OutlinedFn->getArg(1));
1124  for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1125  ++U)
1126  Args.push_back(CI->getArgOperand(U));
1127 
1128  CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI);
1129  if (CI->getDebugLoc())
1130  NewCI->setDebugLoc(CI->getDebugLoc());
1131 
1132  // Forward parameter attributes from the callback to the callee.
1133  for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1134  ++U)
1135  for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
1136  NewCI->addParamAttr(
1137  U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1138 
1139  // Emit an explicit barrier to replace the implicit fork-join barrier.
1140  if (CI != MergableCIs.back()) {
1141  // TODO: Remove barrier if the merged parallel region includes the
1142  // 'nowait' clause.
1143  OMPInfoCache.OMPBuilder.createBarrier(
1144  InsertPointTy(NewCI->getParent(),
1145  NewCI->getNextNode()->getIterator()),
1146  OMPD_parallel);
1147  }
1148 
1149  CI->eraseFromParent();
1150  }
1151 
1152  assert(OutlinedFn != OriginalFn && "Outlining failed");
1153  CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1154  CGUpdater.reanalyzeFunction(*OriginalFn);
1155 
1156  NumOpenMPParallelRegionsMerged += MergableCIs.size();
1157 
1158  return true;
1159  };
1160 
1161  // Helper function that identifes sequences of
1162  // __kmpc_fork_call uses in a basic block.
1163  auto DetectPRsCB = [&](Use &U, Function &F) {
1164  CallInst *CI = getCallIfRegularCall(U, &RFI);
1165  BB2PRMap[CI->getParent()].insert(CI);
1166 
1167  return false;
1168  };
1169 
1170  BB2PRMap.clear();
1171  RFI.foreachUse(SCC, DetectPRsCB);
1172  SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1173  // Find mergable parallel regions within a basic block that are
1174  // safe to merge, that is any in-between instructions can safely
1175  // execute in parallel after merging.
1176  // TODO: support merging across basic-blocks.
1177  for (auto &It : BB2PRMap) {
1178  auto &CIs = It.getSecond();
1179  if (CIs.size() < 2)
1180  continue;
1181 
1182  BasicBlock *BB = It.getFirst();
1183  SmallVector<CallInst *, 4> MergableCIs;
1184 
1185  /// Returns true if the instruction is mergable, false otherwise.
1186  /// A terminator instruction is unmergable by definition since merging
1187  /// works within a BB. Instructions before the mergable region are
1188  /// mergable if they are not calls to OpenMP runtime functions that may
1189  /// set different execution parameters for subsequent parallel regions.
1190  /// Instructions in-between parallel regions are mergable if they are not
1191  /// calls to any non-intrinsic function since that may call a non-mergable
1192  /// OpenMP runtime function.
1193  auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1194  // We do not merge across BBs, hence return false (unmergable) if the
1195  // instruction is a terminator.
1196  if (I.isTerminator())
1197  return false;
1198 
1199  if (!isa<CallInst>(&I))
1200  return true;
1201 
1202  CallInst *CI = cast<CallInst>(&I);
1203  if (IsBeforeMergableRegion) {
1204  Function *CalledFunction = CI->getCalledFunction();
1205  if (!CalledFunction)
1206  return false;
1207  // Return false (unmergable) if the call before the parallel
1208  // region calls an explicit affinity (proc_bind) or number of
1209  // threads (num_threads) compiler-generated function. Those settings
1210  // may be incompatible with following parallel regions.
1211  // TODO: ICV tracking to detect compatibility.
1212  for (const auto &RFI : UnmergableCallsInfo) {
1213  if (CalledFunction == RFI.Declaration)
1214  return false;
1215  }
1216  } else {
1217  // Return false (unmergable) if there is a call instruction
1218  // in-between parallel regions when it is not an intrinsic. It
1219  // may call an unmergable OpenMP runtime function in its callpath.
1220  // TODO: Keep track of possible OpenMP calls in the callpath.
1221  if (!isa<IntrinsicInst>(CI))
1222  return false;
1223  }
1224 
1225  return true;
1226  };
1227  // Find maximal number of parallel region CIs that are safe to merge.
1228  for (auto It = BB->begin(), End = BB->end(); It != End;) {
1229  Instruction &I = *It;
1230  ++It;
1231 
1232  if (CIs.count(&I)) {
1233  MergableCIs.push_back(cast<CallInst>(&I));
1234  continue;
1235  }
1236 
1237  // Continue expanding if the instruction is mergable.
1238  if (IsMergable(I, MergableCIs.empty()))
1239  continue;
1240 
1241  // Forward the instruction iterator to skip the next parallel region
1242  // since there is an unmergable instruction which can affect it.
1243  for (; It != End; ++It) {
1244  Instruction &SkipI = *It;
1245  if (CIs.count(&SkipI)) {
1246  LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1247  << " due to " << I << "\n");
1248  ++It;
1249  break;
1250  }
1251  }
1252 
1253  // Store mergable regions found.
1254  if (MergableCIs.size() > 1) {
1255  MergableCIsVector.push_back(MergableCIs);
1256  LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1257  << " parallel regions in block " << BB->getName()
1258  << " of function " << BB->getParent()->getName()
1259  << "\n";);
1260  }
1261 
1262  MergableCIs.clear();
1263  }
1264 
1265  if (!MergableCIsVector.empty()) {
1266  Changed = true;
1267 
1268  for (auto &MergableCIs : MergableCIsVector)
1269  Merge(MergableCIs, BB);
1270  MergableCIsVector.clear();
1271  }
1272  }
1273 
1274  if (Changed) {
1275  /// Re-collect use for fork calls, emitted barrier calls, and
1276  /// any emitted master/end_master calls.
1277  OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1278  OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1279  OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1280  OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1281  }
1282 
1283  return Changed;
1284  }
1285 
1286  /// Try to delete parallel regions if possible.
1287  bool deleteParallelRegions() {
1288  const unsigned CallbackCalleeOperand = 2;
1289 
1290  OMPInformationCache::RuntimeFunctionInfo &RFI =
1291  OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1292 
1293  if (!RFI.Declaration)
1294  return false;
1295 
1296  bool Changed = false;
1297  auto DeleteCallCB = [&](Use &U, Function &) {
1298  CallInst *CI = getCallIfRegularCall(U);
1299  if (!CI)
1300  return false;
1301  auto *Fn = dyn_cast<Function>(
1302  CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1303  if (!Fn)
1304  return false;
1305  if (!Fn->onlyReadsMemory())
1306  return false;
1307  if (!Fn->hasFnAttribute(Attribute::WillReturn))
1308  return false;
1309 
1310  LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1311  << CI->getCaller()->getName() << "\n");
1312 
1313  auto Remark = [&](OptimizationRemark OR) {
1314  return OR << "Removing parallel region with no side-effects.";
1315  };
1316  emitRemark<OptimizationRemark>(CI, "OMP160", Remark);
1317 
1318  CGUpdater.removeCallSite(*CI);
1319  CI->eraseFromParent();
1320  Changed = true;
1321  ++NumOpenMPParallelRegionsDeleted;
1322  return true;
1323  };
1324 
1325  RFI.foreachUse(SCC, DeleteCallCB);
1326 
1327  return Changed;
1328  }
1329 
1330  /// Try to eliminate runtime calls by reusing existing ones.
1331  bool deduplicateRuntimeCalls() {
1332  bool Changed = false;
1333 
1334  RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1335  OMPRTL_omp_get_num_threads,
1336  OMPRTL_omp_in_parallel,
1337  OMPRTL_omp_get_cancellation,
1338  OMPRTL_omp_get_thread_limit,
1339  OMPRTL_omp_get_supported_active_levels,
1340  OMPRTL_omp_get_level,
1341  OMPRTL_omp_get_ancestor_thread_num,
1342  OMPRTL_omp_get_team_size,
1343  OMPRTL_omp_get_active_level,
1344  OMPRTL_omp_in_final,
1345  OMPRTL_omp_get_proc_bind,
1346  OMPRTL_omp_get_num_places,
1347  OMPRTL_omp_get_num_procs,
1348  OMPRTL_omp_get_place_num,
1349  OMPRTL_omp_get_partition_num_places,
1350  OMPRTL_omp_get_partition_place_nums};
1351 
1352  // Global-tid is handled separately.
1353  SmallSetVector<Value *, 16> GTIdArgs;
1354  collectGlobalThreadIdArguments(GTIdArgs);
1355  LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1356  << " global thread ID arguments\n");
1357 
1358  for (Function *F : SCC) {
1359  for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1360  Changed |= deduplicateRuntimeCalls(
1361  *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1362 
1363  // __kmpc_global_thread_num is special as we can replace it with an
1364  // argument in enough cases to make it worth trying.
1365  Value *GTIdArg = nullptr;
1366  for (Argument &Arg : F->args())
1367  if (GTIdArgs.count(&Arg)) {
1368  GTIdArg = &Arg;
1369  break;
1370  }
1371  Changed |= deduplicateRuntimeCalls(
1372  *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1373  }
1374 
1375  return Changed;
1376  }
1377 
1378  /// Tries to hide the latency of runtime calls that involve host to
1379  /// device memory transfers by splitting them into their "issue" and "wait"
1380  /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1381  /// moved downards as much as possible. The "issue" issues the memory transfer
1382  /// asynchronously, returning a handle. The "wait" waits in the returned
1383  /// handle for the memory transfer to finish.
1384  bool hideMemTransfersLatency() {
1385  auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1386  bool Changed = false;
1387  auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1388  auto *RTCall = getCallIfRegularCall(U, &RFI);
1389  if (!RTCall)
1390  return false;
1391 
1392  OffloadArray OffloadArrays[3];
1393  if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1394  return false;
1395 
1396  LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1397 
1398  // TODO: Check if can be moved upwards.
1399  bool WasSplit = false;
1400  Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1401  if (WaitMovementPoint)
1402  WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1403 
1404  Changed |= WasSplit;
1405  return WasSplit;
1406  };
1407  RFI.foreachUse(SCC, SplitMemTransfers);
1408 
1409  return Changed;
1410  }
1411 
1412  /// Eliminates redundant, aligned barriers in OpenMP offloaded kernels.
1413  /// TODO: Make this an AA and expand it to work across blocks and functions.
1414  bool eliminateBarriers() {
1415  bool Changed = false;
1416 
1418  return /*Changed=*/false;
1419 
1420  if (OMPInfoCache.Kernels.empty())
1421  return /*Changed=*/false;
1422 
1423  enum ImplicitBarrierType { IBT_ENTRY, IBT_EXIT };
1424 
1425  class BarrierInfo {
1426  Instruction *I;
1427  enum ImplicitBarrierType Type;
1428 
1429  public:
1430  BarrierInfo(enum ImplicitBarrierType Type) : I(nullptr), Type(Type) {}
1431  BarrierInfo(Instruction &I) : I(&I) {}
1432 
1433  bool isImplicit() { return !I; }
1434 
1435  bool isImplicitEntry() { return isImplicit() && Type == IBT_ENTRY; }
1436 
1437  bool isImplicitExit() { return isImplicit() && Type == IBT_EXIT; }
1438 
1439  Instruction *getInstruction() { return I; }
1440  };
1441 
1442  for (Function *Kernel : OMPInfoCache.Kernels) {
1443  for (BasicBlock &BB : *Kernel) {
1444  SmallVector<BarrierInfo, 8> BarriersInBlock;
1445  SmallPtrSet<Instruction *, 8> BarriersToBeDeleted;
1446 
1447  // Add the kernel entry implicit barrier.
1448  if (&Kernel->getEntryBlock() == &BB)
1449  BarriersInBlock.push_back(IBT_ENTRY);
1450 
1451  // Find implicit and explicit aligned barriers in the same basic block.
1452  for (Instruction &I : BB) {
1453  if (isa<ReturnInst>(I)) {
1454  // Add the implicit barrier when exiting the kernel.
1455  BarriersInBlock.push_back(IBT_EXIT);
1456  continue;
1457  }
1458  CallBase *CB = dyn_cast<CallBase>(&I);
1459  if (!CB)
1460  continue;
1461 
1462  auto IsAlignBarrierCB = [&](CallBase &CB) {
1463  switch (CB.getIntrinsicID()) {
1464  case Intrinsic::nvvm_barrier0:
1465  case Intrinsic::nvvm_barrier0_and:
1466  case Intrinsic::nvvm_barrier0_or:
1467  case Intrinsic::nvvm_barrier0_popc:
1468  return true;
1469  default:
1470  break;
1471  }
1472  return hasAssumption(CB,
1473  KnownAssumptionString("ompx_aligned_barrier"));
1474  };
1475 
1476  if (IsAlignBarrierCB(*CB)) {
1477  // Add an explicit aligned barrier.
1478  BarriersInBlock.push_back(I);
1479  }
1480  }
1481 
1482  if (BarriersInBlock.size() <= 1)
1483  continue;
1484 
1485  // A barrier in a barrier pair is removeable if all instructions
1486  // between the barriers in the pair are side-effect free modulo the
1487  // barrier operation.
1488  auto IsBarrierRemoveable = [&Kernel](BarrierInfo *StartBI,
1489  BarrierInfo *EndBI) {
1490  assert(
1491  !StartBI->isImplicitExit() &&
1492  "Expected start barrier to be other than a kernel exit barrier");
1493  assert(
1494  !EndBI->isImplicitEntry() &&
1495  "Expected end barrier to be other than a kernel entry barrier");
1496  // If StarBI instructions is null then this the implicit
1497  // kernel entry barrier, so iterate from the first instruction in the
1498  // entry block.
1499  Instruction *I = (StartBI->isImplicitEntry())
1500  ? &Kernel->getEntryBlock().front()
1501  : StartBI->getInstruction()->getNextNode();
1502  assert(I && "Expected non-null start instruction");
1503  Instruction *E = (EndBI->isImplicitExit())
1504  ? I->getParent()->getTerminator()
1505  : EndBI->getInstruction();
1506  assert(E && "Expected non-null end instruction");
1507 
1508  for (; I != E; I = I->getNextNode()) {
1509  if (!I->mayHaveSideEffects() && !I->mayReadFromMemory())
1510  continue;
1511 
1512  auto IsPotentiallyAffectedByBarrier =
1513  [](Optional<MemoryLocation> Loc) {
1514  const Value *Obj = (Loc && Loc->Ptr)
1515  ? getUnderlyingObject(Loc->Ptr)
1516  : nullptr;
1517  if (!Obj) {
1518  LLVM_DEBUG(
1519  dbgs()
1520  << "Access to unknown location requires barriers\n");
1521  return true;
1522  }
1523  if (isa<UndefValue>(Obj))
1524  return false;
1525  if (isa<AllocaInst>(Obj))
1526  return false;
1527  if (auto *GV = dyn_cast<GlobalVariable>(Obj)) {
1528  if (GV->isConstant())
1529  return false;
1530  if (GV->isThreadLocal())
1531  return false;
1532  if (GV->getAddressSpace() == (int)AddressSpace::Local)
1533  return false;
1534  if (GV->getAddressSpace() == (int)AddressSpace::Constant)
1535  return false;
1536  }
1537  LLVM_DEBUG(dbgs() << "Access to '" << *Obj
1538  << "' requires barriers\n");
1539  return true;
1540  };
1541 
1542  if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(I)) {
1544  if (IsPotentiallyAffectedByBarrier(Loc))
1545  return false;
1546  if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(I)) {
1549  if (IsPotentiallyAffectedByBarrier(Loc))
1550  return false;
1551  }
1552  continue;
1553  }
1554 
1555  if (auto *LI = dyn_cast<LoadInst>(I))
1556  if (LI->hasMetadata(LLVMContext::MD_invariant_load))
1557  continue;
1558 
1560  if (IsPotentiallyAffectedByBarrier(Loc))
1561  return false;
1562  }
1563 
1564  return true;
1565  };
1566 
1567  // Iterate barrier pairs and remove an explicit barrier if analysis
1568  // deems it removeable.
1569  for (auto *It = BarriersInBlock.begin(),
1570  *End = BarriersInBlock.end() - 1;
1571  It != End; ++It) {
1572 
1573  BarrierInfo *StartBI = It;
1574  BarrierInfo *EndBI = (It + 1);
1575 
1576  // Cannot remove when both are implicit barriers, continue.
1577  if (StartBI->isImplicit() && EndBI->isImplicit())
1578  continue;
1579 
1580  if (!IsBarrierRemoveable(StartBI, EndBI))
1581  continue;
1582 
1583  assert(!(StartBI->isImplicit() && EndBI->isImplicit()) &&
1584  "Expected at least one explicit barrier to remove.");
1585 
1586  // Remove an explicit barrier, check first, then second.
1587  if (!StartBI->isImplicit()) {
1588  LLVM_DEBUG(dbgs() << "Remove start barrier "
1589  << *StartBI->getInstruction() << "\n");
1590  BarriersToBeDeleted.insert(StartBI->getInstruction());
1591  } else {
1592  LLVM_DEBUG(dbgs() << "Remove end barrier "
1593  << *EndBI->getInstruction() << "\n");
1594  BarriersToBeDeleted.insert(EndBI->getInstruction());
1595  }
1596  }
1597 
1598  if (BarriersToBeDeleted.empty())
1599  continue;
1600 
1601  Changed = true;
1602  for (Instruction *I : BarriersToBeDeleted) {
1603  ++NumBarriersEliminated;
1604  auto Remark = [&](OptimizationRemark OR) {
1605  return OR << "Redundant barrier eliminated.";
1606  };
1607 
1609  emitRemark<OptimizationRemark>(I, "OMP190", Remark);
1610  I->eraseFromParent();
1611  }
1612  }
1613  }
1614 
1615  return Changed;
1616  }
1617 
1618  void analysisGlobalization() {
1619  auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1620 
1621  auto CheckGlobalization = [&](Use &U, Function &Decl) {
1622  if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1623  auto Remark = [&](OptimizationRemarkMissed ORM) {
1624  return ORM
1625  << "Found thread data sharing on the GPU. "
1626  << "Expect degraded performance due to data globalization.";
1627  };
1628  emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark);
1629  }
1630 
1631  return false;
1632  };
1633 
1634  RFI.foreachUse(SCC, CheckGlobalization);
1635  }
1636 
1637  /// Maps the values stored in the offload arrays passed as arguments to
1638  /// \p RuntimeCall into the offload arrays in \p OAs.
1639  bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1641  assert(OAs.size() == 3 && "Need space for three offload arrays!");
1642 
1643  // A runtime call that involves memory offloading looks something like:
1644  // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1645  // i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1646  // ...)
1647  // So, the idea is to access the allocas that allocate space for these
1648  // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1649  // Therefore:
1650  // i8** %offload_baseptrs.
1651  Value *BasePtrsArg =
1652  RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1653  // i8** %offload_ptrs.
1654  Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1655  // i8** %offload_sizes.
1656  Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1657 
1658  // Get values stored in **offload_baseptrs.
1659  auto *V = getUnderlyingObject(BasePtrsArg);
1660  if (!isa<AllocaInst>(V))
1661  return false;
1662  auto *BasePtrsArray = cast<AllocaInst>(V);
1663  if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1664  return false;
1665 
1666  // Get values stored in **offload_baseptrs.
1667  V = getUnderlyingObject(PtrsArg);
1668  if (!isa<AllocaInst>(V))
1669  return false;
1670  auto *PtrsArray = cast<AllocaInst>(V);
1671  if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1672  return false;
1673 
1674  // Get values stored in **offload_sizes.
1675  V = getUnderlyingObject(SizesArg);
1676  // If it's a [constant] global array don't analyze it.
1677  if (isa<GlobalValue>(V))
1678  return isa<Constant>(V);
1679  if (!isa<AllocaInst>(V))
1680  return false;
1681 
1682  auto *SizesArray = cast<AllocaInst>(V);
1683  if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1684  return false;
1685 
1686  return true;
1687  }
1688 
1689  /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1690  /// For now this is a way to test that the function getValuesInOffloadArrays
1691  /// is working properly.
1692  /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1693  void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1694  assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1695 
1696  LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1697  std::string ValuesStr;
1698  raw_string_ostream Printer(ValuesStr);
1699  std::string Separator = " --- ";
1700 
1701  for (auto *BP : OAs[0].StoredValues) {
1702  BP->print(Printer);
1703  Printer << Separator;
1704  }
1705  LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n");
1706  ValuesStr.clear();
1707 
1708  for (auto *P : OAs[1].StoredValues) {
1709  P->print(Printer);
1710  Printer << Separator;
1711  }
1712  LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n");
1713  ValuesStr.clear();
1714 
1715  for (auto *S : OAs[2].StoredValues) {
1716  S->print(Printer);
1717  Printer << Separator;
1718  }
1719  LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n");
1720  }
1721 
1722  /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1723  /// moved. Returns nullptr if the movement is not possible, or not worth it.
1724  Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1725  // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1726  // Make it traverse the CFG.
1727 
1728  Instruction *CurrentI = &RuntimeCall;
1729  bool IsWorthIt = false;
1730  while ((CurrentI = CurrentI->getNextNode())) {
1731 
1732  // TODO: Once we detect the regions to be offloaded we should use the
1733  // alias analysis manager to check if CurrentI may modify one of
1734  // the offloaded regions.
1735  if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1736  if (IsWorthIt)
1737  return CurrentI;
1738 
1739  return nullptr;
1740  }
1741 
1742  // FIXME: For now if we move it over anything without side effect
1743  // is worth it.
1744  IsWorthIt = true;
1745  }
1746 
1747  // Return end of BasicBlock.
1748  return RuntimeCall.getParent()->getTerminator();
1749  }
1750 
1751  /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1752  bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1753  Instruction &WaitMovementPoint) {
1754  // Create stack allocated handle (__tgt_async_info) at the beginning of the
1755  // function. Used for storing information of the async transfer, allowing to
1756  // wait on it later.
1757  auto &IRBuilder = OMPInfoCache.OMPBuilder;
1758  auto *F = RuntimeCall.getCaller();
1759  Instruction *FirstInst = &(F->getEntryBlock().front());
1760  AllocaInst *Handle = new AllocaInst(
1761  IRBuilder.AsyncInfo, F->getAddressSpace(), "handle", FirstInst);
1762 
1763  // Add "issue" runtime call declaration:
1764  // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1765  // i8**, i8**, i64*, i64*)
1766  FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1767  M, OMPRTL___tgt_target_data_begin_mapper_issue);
1768 
1769  // Change RuntimeCall call site for its asynchronous version.
1771  for (auto &Arg : RuntimeCall.args())
1772  Args.push_back(Arg.get());
1773  Args.push_back(Handle);
1774 
1775  CallInst *IssueCallsite =
1776  CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall);
1777  OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);
1778  RuntimeCall.eraseFromParent();
1779 
1780  // Add "wait" runtime call declaration:
1781  // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1782  FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1783  M, OMPRTL___tgt_target_data_begin_mapper_wait);
1784 
1785  Value *WaitParams[2] = {
1786  IssueCallsite->getArgOperand(
1787  OffloadArray::DeviceIDArgNum), // device_id.
1788  Handle // handle to wait on.
1789  };
1790  CallInst *WaitCallsite = CallInst::Create(
1791  WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint);
1792  OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);
1793 
1794  return true;
1795  }
1796 
1797  static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1798  bool GlobalOnly, bool &SingleChoice) {
1799  if (CurrentIdent == NextIdent)
1800  return CurrentIdent;
1801 
1802  // TODO: Figure out how to actually combine multiple debug locations. For
1803  // now we just keep an existing one if there is a single choice.
1804  if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1805  SingleChoice = !CurrentIdent;
1806  return NextIdent;
1807  }
1808  return nullptr;
1809  }
1810 
1811  /// Return an `struct ident_t*` value that represents the ones used in the
1812  /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1813  /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1814  /// return value we create one from scratch. We also do not yet combine
1815  /// information, e.g., the source locations, see combinedIdentStruct.
1816  Value *
1817  getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1818  Function &F, bool GlobalOnly) {
1819  bool SingleChoice = true;
1820  Value *Ident = nullptr;
1821  auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1822  CallInst *CI = getCallIfRegularCall(U, &RFI);
1823  if (!CI || &F != &Caller)
1824  return false;
1825  Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1826  /* GlobalOnly */ true, SingleChoice);
1827  return false;
1828  };
1829  RFI.foreachUse(SCC, CombineIdentStruct);
1830 
1831  if (!Ident || !SingleChoice) {
1832  // The IRBuilder uses the insertion block to get to the module, this is
1833  // unfortunate but we work around it for now.
1834  if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1835  OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1836  &F.getEntryBlock(), F.getEntryBlock().begin()));
1837  // Create a fallback location if non was found.
1838  // TODO: Use the debug locations of the calls instead.
1839  uint32_t SrcLocStrSize;
1840  Constant *Loc =
1841  OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1842  Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize);
1843  }
1844  return Ident;
1845  }
1846 
1847  /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1848  /// \p ReplVal if given.
1849  bool deduplicateRuntimeCalls(Function &F,
1850  OMPInformationCache::RuntimeFunctionInfo &RFI,
1851  Value *ReplVal = nullptr) {
1852  auto *UV = RFI.getUseVector(F);
1853  if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1854  return false;
1855 
1856  LLVM_DEBUG(
1857  dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1858  << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1859 
1860  assert((!ReplVal || (isa<Argument>(ReplVal) &&
1861  cast<Argument>(ReplVal)->getParent() == &F)) &&
1862  "Unexpected replacement value!");
1863 
1864  // TODO: Use dominance to find a good position instead.
1865  auto CanBeMoved = [this](CallBase &CB) {
1866  unsigned NumArgs = CB.arg_size();
1867  if (NumArgs == 0)
1868  return true;
1869  if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1870  return false;
1871  for (unsigned U = 1; U < NumArgs; ++U)
1872  if (isa<Instruction>(CB.getArgOperand(U)))
1873  return false;
1874  return true;
1875  };
1876 
1877  if (!ReplVal) {
1878  for (Use *U : *UV)
1879  if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1880  if (!CanBeMoved(*CI))
1881  continue;
1882 
1883  // If the function is a kernel, dedup will move
1884  // the runtime call right after the kernel init callsite. Otherwise,
1885  // it will move it to the beginning of the caller function.
1886  if (isKernel(F)) {
1887  auto &KernelInitRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
1888  auto *KernelInitUV = KernelInitRFI.getUseVector(F);
1889 
1890  if (KernelInitUV->empty())
1891  continue;
1892 
1893  assert(KernelInitUV->size() == 1 &&
1894  "Expected a single __kmpc_target_init in kernel\n");
1895 
1896  CallInst *KernelInitCI =
1897  getCallIfRegularCall(*KernelInitUV->front(), &KernelInitRFI);
1898  assert(KernelInitCI &&
1899  "Expected a call to __kmpc_target_init in kernel\n");
1900 
1901  CI->moveAfter(KernelInitCI);
1902  } else
1903  CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt());
1904  ReplVal = CI;
1905  break;
1906  }
1907  if (!ReplVal)
1908  return false;
1909  }
1910 
1911  // If we use a call as a replacement value we need to make sure the ident is
1912  // valid at the new location. For now we just pick a global one, either
1913  // existing and used by one of the calls, or created from scratch.
1914  if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1915  if (!CI->arg_empty() &&
1916  CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1917  Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1918  /* GlobalOnly */ true);
1919  CI->setArgOperand(0, Ident);
1920  }
1921  }
1922 
1923  bool Changed = false;
1924  auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1925  CallInst *CI = getCallIfRegularCall(U, &RFI);
1926  if (!CI || CI == ReplVal || &F != &Caller)
1927  return false;
1928  assert(CI->getCaller() == &F && "Unexpected call!");
1929 
1930  auto Remark = [&](OptimizationRemark OR) {
1931  return OR << "OpenMP runtime call "
1932  << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1933  };
1934  if (CI->getDebugLoc())
1935  emitRemark<OptimizationRemark>(CI, "OMP170", Remark);
1936  else
1937  emitRemark<OptimizationRemark>(&F, "OMP170", Remark);
1938 
1939  CGUpdater.removeCallSite(*CI);
1940  CI->replaceAllUsesWith(ReplVal);
1941  CI->eraseFromParent();
1942  ++NumOpenMPRuntimeCallsDeduplicated;
1943  Changed = true;
1944  return true;
1945  };
1946  RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1947 
1948  return Changed;
1949  }
1950 
1951  /// Collect arguments that represent the global thread id in \p GTIdArgs.
1952  void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1953  // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1954  // initialization. We could define an AbstractAttribute instead and
1955  // run the Attributor here once it can be run as an SCC pass.
1956 
1957  // Helper to check the argument \p ArgNo at all call sites of \p F for
1958  // a GTId.
1959  auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1960  if (!F.hasLocalLinkage())
1961  return false;
1962  for (Use &U : F.uses()) {
1963  if (CallInst *CI = getCallIfRegularCall(U)) {
1964  Value *ArgOp = CI->getArgOperand(ArgNo);
1965  if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1966  getCallIfRegularCall(
1967  *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1968  continue;
1969  }
1970  return false;
1971  }
1972  return true;
1973  };
1974 
1975  // Helper to identify uses of a GTId as GTId arguments.
1976  auto AddUserArgs = [&](Value &GTId) {
1977  for (Use &U : GTId.uses())
1978  if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1979  if (CI->isArgOperand(&U))
1980  if (Function *Callee = CI->getCalledFunction())
1981  if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1982  GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1983  };
1984 
1985  // The argument users of __kmpc_global_thread_num calls are GTIds.
1986  OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1987  OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1988 
1989  GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1990  if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1991  AddUserArgs(*CI);
1992  return false;
1993  });
1994 
1995  // Transitively search for more arguments by looking at the users of the
1996  // ones we know already. During the search the GTIdArgs vector is extended
1997  // so we cannot cache the size nor can we use a range based for.
1998  for (unsigned U = 0; U < GTIdArgs.size(); ++U)
1999  AddUserArgs(*GTIdArgs[U]);
2000  }
2001 
2002  /// Kernel (=GPU) optimizations and utility functions
2003  ///
2004  ///{{
2005 
2006  /// Check if \p F is a kernel, hence entry point for target offloading.
2007  bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); }
2008 
2009  /// Cache to remember the unique kernel for a function.
2010  DenseMap<Function *, Optional<Kernel>> UniqueKernelMap;
2011 
2012  /// Find the unique kernel that will execute \p F, if any.
2013  Kernel getUniqueKernelFor(Function &F);
2014 
2015  /// Find the unique kernel that will execute \p I, if any.
2016  Kernel getUniqueKernelFor(Instruction &I) {
2017  return getUniqueKernelFor(*I.getFunction());
2018  }
2019 
2020  /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
2021  /// the cases we can avoid taking the address of a function.
2022  bool rewriteDeviceCodeStateMachine();
2023 
2024  ///
2025  ///}}
2026 
2027  /// Emit a remark generically
2028  ///
2029  /// This template function can be used to generically emit a remark. The
2030  /// RemarkKind should be one of the following:
2031  /// - OptimizationRemark to indicate a successful optimization attempt
2032  /// - OptimizationRemarkMissed to report a failed optimization attempt
2033  /// - OptimizationRemarkAnalysis to provide additional information about an
2034  /// optimization attempt
2035  ///
2036  /// The remark is built using a callback function provided by the caller that
2037  /// takes a RemarkKind as input and returns a RemarkKind.
2038  template <typename RemarkKind, typename RemarkCallBack>
2039  void emitRemark(Instruction *I, StringRef RemarkName,
2040  RemarkCallBack &&RemarkCB) const {
2041  Function *F = I->getParent()->getParent();
2042  auto &ORE = OREGetter(F);
2043 
2044  if (RemarkName.startswith("OMP"))
2045  ORE.emit([&]() {
2046  return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
2047  << " [" << RemarkName << "]";
2048  });
2049  else
2050  ORE.emit(
2051  [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
2052  }
2053 
2054  /// Emit a remark on a function.
2055  template <typename RemarkKind, typename RemarkCallBack>
2056  void emitRemark(Function *F, StringRef RemarkName,
2057  RemarkCallBack &&RemarkCB) const {
2058  auto &ORE = OREGetter(F);
2059 
2060  if (RemarkName.startswith("OMP"))
2061  ORE.emit([&]() {
2062  return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
2063  << " [" << RemarkName << "]";
2064  });
2065  else
2066  ORE.emit(
2067  [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
2068  }
2069 
2070  /// RAII struct to temporarily change an RTL function's linkage to external.
2071  /// This prevents it from being mistakenly removed by other optimizations.
2072  struct ExternalizationRAII {
2073  ExternalizationRAII(OMPInformationCache &OMPInfoCache,
2074  RuntimeFunction RFKind)
2075  : Declaration(OMPInfoCache.RFIs[RFKind].Declaration) {
2076  if (!Declaration)
2077  return;
2078 
2079  LinkageType = Declaration->getLinkage();
2080  Declaration->setLinkage(GlobalValue::ExternalLinkage);
2081  }
2082 
2083  ~ExternalizationRAII() {
2084  if (!Declaration)
2085  return;
2086 
2087  Declaration->setLinkage(LinkageType);
2088  }
2089 
2090  Function *Declaration;
2092  };
2093 
2094  /// The underlying module.
2095  Module &M;
2096 
2097  /// The SCC we are operating on.
2099 
2100  /// Callback to update the call graph, the first argument is a removed call,
2101  /// the second an optional replacement call.
2102  CallGraphUpdater &CGUpdater;
2103 
2104  /// Callback to get an OptimizationRemarkEmitter from a Function *
2105  OptimizationRemarkGetter OREGetter;
2106 
2107  /// OpenMP-specific information cache. Also Used for Attributor runs.
2108  OMPInformationCache &OMPInfoCache;
2109 
2110  /// Attributor instance.
2111  Attributor &A;
2112 
2113  /// Helper function to run Attributor on SCC.
2114  bool runAttributor(bool IsModulePass) {
2115  if (SCC.empty())
2116  return false;
2117 
2118  // Temporarily make these function have external linkage so the Attributor
2119  // doesn't remove them when we try to look them up later.
2120  ExternalizationRAII Parallel(OMPInfoCache, OMPRTL___kmpc_kernel_parallel);
2121  ExternalizationRAII EndParallel(OMPInfoCache,
2122  OMPRTL___kmpc_kernel_end_parallel);
2123  ExternalizationRAII BarrierSPMD(OMPInfoCache,
2124  OMPRTL___kmpc_barrier_simple_spmd);
2125  ExternalizationRAII BarrierGeneric(OMPInfoCache,
2126  OMPRTL___kmpc_barrier_simple_generic);
2127  ExternalizationRAII ThreadId(OMPInfoCache,
2128  OMPRTL___kmpc_get_hardware_thread_id_in_block);
2129  ExternalizationRAII NumThreads(
2130  OMPInfoCache, OMPRTL___kmpc_get_hardware_num_threads_in_block);
2131  ExternalizationRAII WarpSize(OMPInfoCache, OMPRTL___kmpc_get_warp_size);
2132 
2133  registerAAs(IsModulePass);
2134 
2135  ChangeStatus Changed = A.run();
2136 
2137  LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
2138  << " functions, result: " << Changed << ".\n");
2139 
2140  return Changed == ChangeStatus::CHANGED;
2141  }
2142 
2143  void registerFoldRuntimeCall(RuntimeFunction RF);
2144 
2145  /// Populate the Attributor with abstract attribute opportunities in the
2146  /// function.
2147  void registerAAs(bool IsModulePass);
2148 };
2149 
2150 Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
2151  if (!OMPInfoCache.ModuleSlice.count(&F))
2152  return nullptr;
2153 
2154  // Use a scope to keep the lifetime of the CachedKernel short.
2155  {
2156  Optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
2157  if (CachedKernel)
2158  return *CachedKernel;
2159 
2160  // TODO: We should use an AA to create an (optimistic and callback
2161  // call-aware) call graph. For now we stick to simple patterns that
2162  // are less powerful, basically the worst fixpoint.
2163  if (isKernel(F)) {
2164  CachedKernel = Kernel(&F);
2165  return *CachedKernel;
2166  }
2167 
2168  CachedKernel = nullptr;
2169  if (!F.hasLocalLinkage()) {
2170 
2171  // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
2172  auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2173  return ORA << "Potentially unknown OpenMP target region caller.";
2174  };
2175  emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
2176 
2177  return nullptr;
2178  }
2179  }
2180 
2181  auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
2182  if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
2183  // Allow use in equality comparisons.
2184  if (Cmp->isEquality())
2185  return getUniqueKernelFor(*Cmp);
2186  return nullptr;
2187  }
2188  if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
2189  // Allow direct calls.
2190  if (CB->isCallee(&U))
2191  return getUniqueKernelFor(*CB);
2192 
2193  OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2194  OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2195  // Allow the use in __kmpc_parallel_51 calls.
2196  if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
2197  return getUniqueKernelFor(*CB);
2198  return nullptr;
2199  }
2200  // Disallow every other use.
2201  return nullptr;
2202  };
2203 
2204  // TODO: In the future we want to track more than just a unique kernel.
2205  SmallPtrSet<Kernel, 2> PotentialKernels;
2206  OMPInformationCache::foreachUse(F, [&](const Use &U) {
2207  PotentialKernels.insert(GetUniqueKernelForUse(U));
2208  });
2209 
2210  Kernel K = nullptr;
2211  if (PotentialKernels.size() == 1)
2212  K = *PotentialKernels.begin();
2213 
2214  // Cache the result.
2215  UniqueKernelMap[&F] = K;
2216 
2217  return K;
2218 }
2219 
2220 bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
2221  OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2222  OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2223 
2224  bool Changed = false;
2225  if (!KernelParallelRFI)
2226  return Changed;
2227 
2228  // If we have disabled state machine changes, exit
2230  return Changed;
2231 
2232  for (Function *F : SCC) {
2233 
2234  // Check if the function is a use in a __kmpc_parallel_51 call at
2235  // all.
2236  bool UnknownUse = false;
2237  bool KernelParallelUse = false;
2238  unsigned NumDirectCalls = 0;
2239 
2240  SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
2241  OMPInformationCache::foreachUse(*F, [&](Use &U) {
2242  if (auto *CB = dyn_cast<CallBase>(U.getUser()))
2243  if (CB->isCallee(&U)) {
2244  ++NumDirectCalls;
2245  return;
2246  }
2247 
2248  if (isa<ICmpInst>(U.getUser())) {
2249  ToBeReplacedStateMachineUses.push_back(&U);
2250  return;
2251  }
2252 
2253  // Find wrapper functions that represent parallel kernels.
2254  CallInst *CI =
2255  OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
2256  const unsigned int WrapperFunctionArgNo = 6;
2257  if (!KernelParallelUse && CI &&
2258  CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
2259  KernelParallelUse = true;
2260  ToBeReplacedStateMachineUses.push_back(&U);
2261  return;
2262  }
2263  UnknownUse = true;
2264  });
2265 
2266  // Do not emit a remark if we haven't seen a __kmpc_parallel_51
2267  // use.
2268  if (!KernelParallelUse)
2269  continue;
2270 
2271  // If this ever hits, we should investigate.
2272  // TODO: Checking the number of uses is not a necessary restriction and
2273  // should be lifted.
2274  if (UnknownUse || NumDirectCalls != 1 ||
2275  ToBeReplacedStateMachineUses.size() > 2) {
2276  auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2277  return ORA << "Parallel region is used in "
2278  << (UnknownUse ? "unknown" : "unexpected")
2279  << " ways. Will not attempt to rewrite the state machine.";
2280  };
2281  emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark);
2282  continue;
2283  }
2284 
2285  // Even if we have __kmpc_parallel_51 calls, we (for now) give
2286  // up if the function is not called from a unique kernel.
2287  Kernel K = getUniqueKernelFor(*F);
2288  if (!K) {
2289  auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2290  return ORA << "Parallel region is not called from a unique kernel. "
2291  "Will not attempt to rewrite the state machine.";
2292  };
2293  emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark);
2294  continue;
2295  }
2296 
2297  // We now know F is a parallel body function called only from the kernel K.
2298  // We also identified the state machine uses in which we replace the
2299  // function pointer by a new global symbol for identification purposes. This
2300  // ensures only direct calls to the function are left.
2301 
2302  Module &M = *F->getParent();
2303  Type *Int8Ty = Type::getInt8Ty(M.getContext());
2304 
2305  auto *ID = new GlobalVariable(
2306  M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2307  UndefValue::get(Int8Ty), F->getName() + ".ID");
2308 
2309  for (Use *U : ToBeReplacedStateMachineUses)
2311  ID, U->get()->getType()));
2312 
2313  ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2314 
2315  Changed = true;
2316  }
2317 
2318  return Changed;
2319 }
2320 
2321 /// Abstract Attribute for tracking ICV values.
2322 struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2324  AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2325 
2326  void initialize(Attributor &A) override {
2327  Function *F = getAnchorScope();
2328  if (!F || !A.isFunctionIPOAmendable(*F))
2329  indicatePessimisticFixpoint();
2330  }
2331 
2332  /// Returns true if value is assumed to be tracked.
2333  bool isAssumedTracked() const { return getAssumed(); }
2334 
2335  /// Returns true if value is known to be tracked.
2336  bool isKnownTracked() const { return getAssumed(); }
2337 
2338  /// Create an abstract attribute biew for the position \p IRP.
2339  static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2340 
2341  /// Return the value with which \p I can be replaced for specific \p ICV.
2342  virtual Optional<Value *> getReplacementValue(InternalControlVar ICV,
2343  const Instruction *I,
2344  Attributor &A) const {
2345  return None;
2346  }
2347 
2348  /// Return an assumed unique ICV value if a single candidate is found. If
2349  /// there cannot be one, return a nullptr. If it is not clear yet, return the
2350  /// Optional::NoneType.
2351  virtual Optional<Value *>
2352  getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2353 
2354  // Currently only nthreads is being tracked.
2355  // this array will only grow with time.
2356  InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2357 
2358  /// See AbstractAttribute::getName()
2359  const std::string getName() const override { return "AAICVTracker"; }
2360 
2361  /// See AbstractAttribute::getIdAddr()
2362  const char *getIdAddr() const override { return &ID; }
2363 
2364  /// This function should return true if the type of the \p AA is AAICVTracker
2365  static bool classof(const AbstractAttribute *AA) {
2366  return (AA->getIdAddr() == &ID);
2367  }
2368 
2369  static const char ID;
2370 };
2371 
2372 struct AAICVTrackerFunction : public AAICVTracker {
2373  AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2374  : AAICVTracker(IRP, A) {}
2375 
2376  // FIXME: come up with better string.
2377  const std::string getAsStr() const override { return "ICVTrackerFunction"; }
2378 
2379  // FIXME: come up with some stats.
2380  void trackStatistics() const override {}
2381 
2382  /// We don't manifest anything for this AA.
2383  ChangeStatus manifest(Attributor &A) override {
2384  return ChangeStatus::UNCHANGED;
2385  }
2386 
2387  // Map of ICV to their values at specific program point.
2389  InternalControlVar::ICV___last>
2390  ICVReplacementValuesMap;
2391 
2392  ChangeStatus updateImpl(Attributor &A) override {
2394 
2395  Function *F = getAnchorScope();
2396 
2397  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2398 
2399  for (InternalControlVar ICV : TrackableICVs) {
2400  auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2401 
2402  auto &ValuesMap = ICVReplacementValuesMap[ICV];
2403  auto TrackValues = [&](Use &U, Function &) {
2404  CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2405  if (!CI)
2406  return false;
2407 
2408  // FIXME: handle setters with more that 1 arguments.
2409  /// Track new value.
2410  if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2411  HasChanged = ChangeStatus::CHANGED;
2412 
2413  return false;
2414  };
2415 
2416  auto CallCheck = [&](Instruction &I) {
2417  Optional<Value *> ReplVal = getValueForCall(A, I, ICV);
2418  if (ReplVal && ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2419  HasChanged = ChangeStatus::CHANGED;
2420 
2421  return true;
2422  };
2423 
2424  // Track all changes of an ICV.
2425  SetterRFI.foreachUse(TrackValues, F);
2426 
2427  bool UsedAssumedInformation = false;
2428  A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2429  UsedAssumedInformation,
2430  /* CheckBBLivenessOnly */ true);
2431 
2432  /// TODO: Figure out a way to avoid adding entry in
2433  /// ICVReplacementValuesMap
2434  Instruction *Entry = &F->getEntryBlock().front();
2435  if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
2436  ValuesMap.insert(std::make_pair(Entry, nullptr));
2437  }
2438 
2439  return HasChanged;
2440  }
2441 
2442  /// Helper to check if \p I is a call and get the value for it if it is
2443  /// unique.
2444  Optional<Value *> getValueForCall(Attributor &A, const Instruction &I,
2445  InternalControlVar &ICV) const {
2446 
2447  const auto *CB = dyn_cast<CallBase>(&I);
2448  if (!CB || CB->hasFnAttr("no_openmp") ||
2449  CB->hasFnAttr("no_openmp_routines"))
2450  return None;
2451 
2452  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2453  auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2454  auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2455  Function *CalledFunction = CB->getCalledFunction();
2456 
2457  // Indirect call, assume ICV changes.
2458  if (CalledFunction == nullptr)
2459  return nullptr;
2460  if (CalledFunction == GetterRFI.Declaration)
2461  return None;
2462  if (CalledFunction == SetterRFI.Declaration) {
2463  if (ICVReplacementValuesMap[ICV].count(&I))
2464  return ICVReplacementValuesMap[ICV].lookup(&I);
2465 
2466  return nullptr;
2467  }
2468 
2469  // Since we don't know, assume it changes the ICV.
2470  if (CalledFunction->isDeclaration())
2471  return nullptr;
2472 
2473  const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2475 
2476  if (ICVTrackingAA.isAssumedTracked()) {
2477  Optional<Value *> URV = ICVTrackingAA.getUniqueReplacementValue(ICV);
2478  if (!URV || (*URV && AA::isValidAtPosition(AA::ValueAndContext(**URV, I),
2479  OMPInfoCache)))
2480  return URV;
2481  }
2482 
2483  // If we don't know, assume it changes.
2484  return nullptr;
2485  }
2486 
2487  // We don't check unique value for a function, so return None.
2489  getUniqueReplacementValue(InternalControlVar ICV) const override {
2490  return None;
2491  }
2492 
2493  /// Return the value with which \p I can be replaced for specific \p ICV.
2494  Optional<Value *> getReplacementValue(InternalControlVar ICV,
2495  const Instruction *I,
2496  Attributor &A) const override {
2497  const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2498  if (ValuesMap.count(I))
2499  return ValuesMap.lookup(I);
2500 
2503  Worklist.push_back(I);
2504 
2505  Optional<Value *> ReplVal;
2506 
2507  while (!Worklist.empty()) {
2508  const Instruction *CurrInst = Worklist.pop_back_val();
2509  if (!Visited.insert(CurrInst).second)
2510  continue;
2511 
2512  const BasicBlock *CurrBB = CurrInst->getParent();
2513 
2514  // Go up and look for all potential setters/calls that might change the
2515  // ICV.
2516  while ((CurrInst = CurrInst->getPrevNode())) {
2517  if (ValuesMap.count(CurrInst)) {
2518  Optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2519  // Unknown value, track new.
2520  if (!ReplVal) {
2521  ReplVal = NewReplVal;
2522  break;
2523  }
2524 
2525  // If we found a new value, we can't know the icv value anymore.
2526  if (NewReplVal)
2527  if (ReplVal != NewReplVal)
2528  return nullptr;
2529 
2530  break;
2531  }
2532 
2533  Optional<Value *> NewReplVal = getValueForCall(A, *CurrInst, ICV);
2534  if (!NewReplVal)
2535  continue;
2536 
2537  // Unknown value, track new.
2538  if (!ReplVal) {
2539  ReplVal = NewReplVal;
2540  break;
2541  }
2542 
2543  // if (NewReplVal.hasValue())
2544  // We found a new value, we can't know the icv value anymore.
2545  if (ReplVal != NewReplVal)
2546  return nullptr;
2547  }
2548 
2549  // If we are in the same BB and we have a value, we are done.
2550  if (CurrBB == I->getParent() && ReplVal)
2551  return ReplVal;
2552 
2553  // Go through all predecessors and add terminators for analysis.
2554  for (const BasicBlock *Pred : predecessors(CurrBB))
2555  if (const Instruction *Terminator = Pred->getTerminator())
2556  Worklist.push_back(Terminator);
2557  }
2558 
2559  return ReplVal;
2560  }
2561 };
2562 
2563 struct AAICVTrackerFunctionReturned : AAICVTracker {
2564  AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2565  : AAICVTracker(IRP, A) {}
2566 
2567  // FIXME: come up with better string.
2568  const std::string getAsStr() const override {
2569  return "ICVTrackerFunctionReturned";
2570  }
2571 
2572  // FIXME: come up with some stats.
2573  void trackStatistics() const override {}
2574 
2575  /// We don't manifest anything for this AA.
2576  ChangeStatus manifest(Attributor &A) override {
2577  return ChangeStatus::UNCHANGED;
2578  }
2579 
2580  // Map of ICV to their values at specific program point.
2582  InternalControlVar::ICV___last>
2583  ICVReplacementValuesMap;
2584 
2585  /// Return the value with which \p I can be replaced for specific \p ICV.
2587  getUniqueReplacementValue(InternalControlVar ICV) const override {
2588  return ICVReplacementValuesMap[ICV];
2589  }
2590 
2591  ChangeStatus updateImpl(Attributor &A) override {
2593  const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2594  *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2595 
2596  if (!ICVTrackingAA.isAssumedTracked())
2597  return indicatePessimisticFixpoint();
2598 
2599  for (InternalControlVar ICV : TrackableICVs) {
2600  Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2601  Optional<Value *> UniqueICVValue;
2602 
2603  auto CheckReturnInst = [&](Instruction &I) {
2604  Optional<Value *> NewReplVal =
2605  ICVTrackingAA.getReplacementValue(ICV, &I, A);
2606 
2607  // If we found a second ICV value there is no unique returned value.
2608  if (UniqueICVValue && UniqueICVValue != NewReplVal)
2609  return false;
2610 
2611  UniqueICVValue = NewReplVal;
2612 
2613  return true;
2614  };
2615 
2616  bool UsedAssumedInformation = false;
2617  if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2618  UsedAssumedInformation,
2619  /* CheckBBLivenessOnly */ true))
2620  UniqueICVValue = nullptr;
2621 
2622  if (UniqueICVValue == ReplVal)
2623  continue;
2624 
2625  ReplVal = UniqueICVValue;
2626  Changed = ChangeStatus::CHANGED;
2627  }
2628 
2629  return Changed;
2630  }
2631 };
2632 
2633 struct AAICVTrackerCallSite : AAICVTracker {
2634  AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2635  : AAICVTracker(IRP, A) {}
2636 
2637  void initialize(Attributor &A) override {
2638  Function *F = getAnchorScope();
2639  if (!F || !A.isFunctionIPOAmendable(*F))
2640  indicatePessimisticFixpoint();
2641 
2642  // We only initialize this AA for getters, so we need to know which ICV it
2643  // gets.
2644  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2645  for (InternalControlVar ICV : TrackableICVs) {
2646  auto ICVInfo = OMPInfoCache.ICVs[ICV];
2647  auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2648  if (Getter.Declaration == getAssociatedFunction()) {
2649  AssociatedICV = ICVInfo.Kind;
2650  return;
2651  }
2652  }
2653 
2654  /// Unknown ICV.
2655  indicatePessimisticFixpoint();
2656  }
2657 
2658  ChangeStatus manifest(Attributor &A) override {
2659  if (!ReplVal || !*ReplVal)
2660  return ChangeStatus::UNCHANGED;
2661 
2662  A.changeAfterManifest(IRPosition::inst(*getCtxI()), **ReplVal);
2663  A.deleteAfterManifest(*getCtxI());
2664 
2665  return ChangeStatus::CHANGED;
2666  }
2667 
2668  // FIXME: come up with better string.
2669  const std::string getAsStr() const override { return "ICVTrackerCallSite"; }
2670 
2671  // FIXME: come up with some stats.
2672  void trackStatistics() const override {}
2673 
2674  InternalControlVar AssociatedICV;
2675  Optional<Value *> ReplVal;
2676 
2677  ChangeStatus updateImpl(Attributor &A) override {
2678  const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2679  *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2680 
2681  // We don't have any information, so we assume it changes the ICV.
2682  if (!ICVTrackingAA.isAssumedTracked())
2683  return indicatePessimisticFixpoint();
2684 
2685  Optional<Value *> NewReplVal =
2686  ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A);
2687 
2688  if (ReplVal == NewReplVal)
2689  return ChangeStatus::UNCHANGED;
2690 
2691  ReplVal = NewReplVal;
2692  return ChangeStatus::CHANGED;
2693  }
2694 
2695  // Return the value with which associated value can be replaced for specific
2696  // \p ICV.
2698  getUniqueReplacementValue(InternalControlVar ICV) const override {
2699  return ReplVal;
2700  }
2701 };
2702 
2703 struct AAICVTrackerCallSiteReturned : AAICVTracker {
2704  AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2705  : AAICVTracker(IRP, A) {}
2706 
2707  // FIXME: come up with better string.
2708  const std::string getAsStr() const override {
2709  return "ICVTrackerCallSiteReturned";
2710  }
2711 
2712  // FIXME: come up with some stats.
2713  void trackStatistics() const override {}
2714 
2715  /// We don't manifest anything for this AA.
2716  ChangeStatus manifest(Attributor &A) override {
2717  return ChangeStatus::UNCHANGED;
2718  }
2719 
2720  // Map of ICV to their values at specific program point.
2722  InternalControlVar::ICV___last>
2723  ICVReplacementValuesMap;
2724 
2725  /// Return the value with which associated value can be replaced for specific
2726  /// \p ICV.
2728  getUniqueReplacementValue(InternalControlVar ICV) const override {
2729  return ICVReplacementValuesMap[ICV];
2730  }
2731 
2732  ChangeStatus updateImpl(Attributor &A) override {
2734  const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2735  *this, IRPosition::returned(*getAssociatedFunction()),
2737 
2738  // We don't have any information, so we assume it changes the ICV.
2739  if (!ICVTrackingAA.isAssumedTracked())
2740  return indicatePessimisticFixpoint();
2741 
2742  for (InternalControlVar ICV : TrackableICVs) {
2743  Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2744  Optional<Value *> NewReplVal =
2745  ICVTrackingAA.getUniqueReplacementValue(ICV);
2746 
2747  if (ReplVal == NewReplVal)
2748  continue;
2749 
2750  ReplVal = NewReplVal;
2751  Changed = ChangeStatus::CHANGED;
2752  }
2753  return Changed;
2754  }
2755 };
2756 
2757 struct AAExecutionDomainFunction : public AAExecutionDomain {
2758  AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2759  : AAExecutionDomain(IRP, A) {}
2760 
2761  const std::string getAsStr() const override {
2762  return "[AAExecutionDomain] " + std::to_string(SingleThreadedBBs.size()) +
2763  "/" + std::to_string(NumBBs) + " BBs thread 0 only.";
2764  }
2765 
2766  /// See AbstractAttribute::trackStatistics().
2767  void trackStatistics() const override {}
2768 
2769  void initialize(Attributor &A) override {
2770  Function *F = getAnchorScope();
2771  for (const auto &BB : *F)
2772  SingleThreadedBBs.insert(&BB);
2773  NumBBs = SingleThreadedBBs.size();
2774  }
2775 
2776  ChangeStatus manifest(Attributor &A) override {
2777  LLVM_DEBUG({
2778  for (const BasicBlock *BB : SingleThreadedBBs)
2779  dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2780  << BB->getName() << " is executed by a single thread.\n";
2781  });
2782  return ChangeStatus::UNCHANGED;
2783  }
2784 
2785  ChangeStatus updateImpl(Attributor &A) override;
2786 
2787  /// Check if an instruction is executed by a single thread.
2788  bool isExecutedByInitialThreadOnly(const Instruction &I) const override {
2789  return isExecutedByInitialThreadOnly(*I.getParent());
2790  }
2791 
2792  bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2793  return isValidState() && SingleThreadedBBs.contains(&BB);
2794  }
2795 
2796  /// Set of basic blocks that are executed by a single thread.
2797  SmallSetVector<const BasicBlock *, 16> SingleThreadedBBs;
2798 
2799  /// Total number of basic blocks in this function.
2800  long unsigned NumBBs = 0;
2801 };
2802 
2803 ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
2804  Function *F = getAnchorScope();
2806  auto NumSingleThreadedBBs = SingleThreadedBBs.size();
2807 
2808  bool AllCallSitesKnown;
2809  auto PredForCallSite = [&](AbstractCallSite ACS) {
2810  const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
2811  *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
2813  return ACS.isDirectCall() &&
2814  ExecutionDomainAA.isExecutedByInitialThreadOnly(
2815  *ACS.getInstruction());
2816  };
2817 
2818  if (!A.checkForAllCallSites(PredForCallSite, *this,
2819  /* RequiresAllCallSites */ true,
2820  AllCallSitesKnown))
2821  SingleThreadedBBs.remove(&F->getEntryBlock());
2822 
2823  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2824  auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2825 
2826  // Check if the edge into the successor block contains a condition that only
2827  // lets the main thread execute it.
2828  auto IsInitialThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) {
2829  if (!Edge || !Edge->isConditional())
2830  return false;
2831  if (Edge->getSuccessor(0) != SuccessorBB)
2832  return false;
2833 
2834  auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2835  if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2836  return false;
2837 
2838  ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2839  if (!C)
2840  return false;
2841 
2842  // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
2843  if (C->isAllOnesValue()) {
2844  auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2845  CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2846  if (!CB)
2847  return false;
2848  const int InitModeArgNo = 1;
2849  auto *ModeCI = dyn_cast<ConstantInt>(CB->getOperand(InitModeArgNo));
2850  return ModeCI && (ModeCI->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC);
2851  }
2852 
2853  if (C->isZero()) {
2854  // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x()
2855  if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2856  if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2857  return true;
2858 
2859  // Match: 0 == llvm.amdgcn.workitem.id.x()
2860  if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2861  if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2862  return true;
2863  }
2864 
2865  return false;
2866  };
2867 
2868  // Merge all the predecessor states into the current basic block. A basic
2869  // block is executed by a single thread if all of its predecessors are.
2870  auto MergePredecessorStates = [&](BasicBlock *BB) {
2871  if (pred_empty(BB))
2872  return SingleThreadedBBs.contains(BB);
2873 
2874  bool IsInitialThread = true;
2875  for (BasicBlock *PredBB : predecessors(BB)) {
2876  if (!IsInitialThreadOnly(dyn_cast<BranchInst>(PredBB->getTerminator()),
2877  BB))
2878  IsInitialThread &= SingleThreadedBBs.contains(PredBB);
2879  }
2880 
2881  return IsInitialThread;
2882  };
2883 
2884  for (auto *BB : RPOT) {
2885  if (!MergePredecessorStates(BB))
2886  SingleThreadedBBs.remove(BB);
2887  }
2888 
2889  return (NumSingleThreadedBBs == SingleThreadedBBs.size())
2892 }
2893 
2894 /// Try to replace memory allocation calls called by a single thread with a
2895 /// static buffer of shared memory.
2896 struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
2898  AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2899 
2900  /// Create an abstract attribute view for the position \p IRP.
2901  static AAHeapToShared &createForPosition(const IRPosition &IRP,
2902  Attributor &A);
2903 
2904  /// Returns true if HeapToShared conversion is assumed to be possible.
2905  virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
2906 
2907  /// Returns true if HeapToShared conversion is assumed and the CB is a
2908  /// callsite to a free operation to be removed.
2909  virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
2910 
2911  /// See AbstractAttribute::getName().
2912  const std::string getName() const override { return "AAHeapToShared"; }
2913 
2914  /// See AbstractAttribute::getIdAddr().
2915  const char *getIdAddr() const override { return &ID; }
2916 
2917  /// This function should return true if the type of the \p AA is
2918  /// AAHeapToShared.
2919  static bool classof(const AbstractAttribute *AA) {
2920  return (AA->getIdAddr() == &ID);
2921  }
2922 
2923  /// Unique ID (due to the unique address)
2924  static const char ID;
2925 };
2926 
2927 struct AAHeapToSharedFunction : public AAHeapToShared {
2928  AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
2929  : AAHeapToShared(IRP, A) {}
2930 
2931  const std::string getAsStr() const override {
2932  return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
2933  " malloc calls eligible.";
2934  }
2935 
2936  /// See AbstractAttribute::trackStatistics().
2937  void trackStatistics() const override {}
2938 
2939  /// This functions finds free calls that will be removed by the
2940  /// HeapToShared transformation.
2941  void findPotentialRemovedFreeCalls(Attributor &A) {
2942  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2943  auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2944 
2945  PotentialRemovedFreeCalls.clear();
2946  // Update free call users of found malloc calls.
2947  for (CallBase *CB : MallocCalls) {
2948  SmallVector<CallBase *, 4> FreeCalls;
2949  for (auto *U : CB->users()) {
2950  CallBase *C = dyn_cast<CallBase>(U);
2951  if (C && C->getCalledFunction() == FreeRFI.Declaration)
2952  FreeCalls.push_back(C);
2953  }
2954 
2955  if (FreeCalls.size() != 1)
2956  continue;
2957 
2958  PotentialRemovedFreeCalls.insert(FreeCalls.front());
2959  }
2960  }
2961 
2962  void initialize(Attributor &A) override {
2964  indicatePessimisticFixpoint();
2965  return;
2966  }
2967 
2968  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2969  auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
2970 
2972  [](const IRPosition &, const AbstractAttribute *,
2973  bool &) -> Optional<Value *> { return nullptr; };
2974  for (User *U : RFI.Declaration->users())
2975  if (CallBase *CB = dyn_cast<CallBase>(U)) {
2976  MallocCalls.insert(CB);
2977  A.registerSimplificationCallback(IRPosition::callsite_returned(*CB),
2978  SCB);
2979  }
2980 
2981  findPotentialRemovedFreeCalls(A);
2982  }
2983 
2984  bool isAssumedHeapToShared(CallBase &CB) const override {
2985  return isValidState() && MallocCalls.count(&CB);
2986  }
2987 
2988  bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
2989  return isValidState() && PotentialRemovedFreeCalls.count(&CB);
2990  }
2991 
2992  ChangeStatus manifest(Attributor &A) override {
2993  if (MallocCalls.empty())
2994  return ChangeStatus::UNCHANGED;
2995 
2996  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2997  auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2998 
2999  Function *F = getAnchorScope();
3000  auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
3002 
3004  for (CallBase *CB : MallocCalls) {
3005  // Skip replacing this if HeapToStack has already claimed it.
3006  if (HS && HS->isAssumedHeapToStack(*CB))
3007  continue;
3008 
3009  // Find the unique free call to remove it.
3010  SmallVector<CallBase *, 4> FreeCalls;
3011  for (auto *U : CB->users()) {
3012  CallBase *C = dyn_cast<CallBase>(U);
3013  if (C && C->getCalledFunction() == FreeCall.Declaration)
3014  FreeCalls.push_back(C);
3015  }
3016  if (FreeCalls.size() != 1)
3017  continue;
3018 
3019  auto *AllocSize = cast<ConstantInt>(CB->getArgOperand(0));
3020 
3021  if (AllocSize->getZExtValue() + SharedMemoryUsed > SharedMemoryLimit) {
3022  LLVM_DEBUG(dbgs() << TAG << "Cannot replace call " << *CB
3023  << " with shared memory."
3024  << " Shared memory usage is limited to "
3025  << SharedMemoryLimit << " bytes\n");
3026  continue;
3027  }
3028 
3029  LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
3030  << " with " << AllocSize->getZExtValue()
3031  << " bytes of shared memory\n");
3032 
3033  // Create a new shared memory buffer of the same size as the allocation
3034  // and replace all the uses of the original allocation with it.
3035  Module *M = CB->getModule();
3036  Type *Int8Ty = Type::getInt8Ty(M->getContext());
3037  Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
3038  auto *SharedMem = new GlobalVariable(
3039  *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
3040  UndefValue::get(Int8ArrTy), CB->getName() + "_shared", nullptr,
3042  static_cast<unsigned>(AddressSpace::Shared));
3043  auto *NewBuffer =
3044  ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo());
3045 
3046  auto Remark = [&](OptimizationRemark OR) {
3047  return OR << "Replaced globalized variable with "
3048  << ore::NV("SharedMemory", AllocSize->getZExtValue())
3049  << ((AllocSize->getZExtValue() != 1) ? " bytes " : " byte ")
3050  << "of shared memory.";
3051  };
3052  A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
3053 
3054  MaybeAlign Alignment = CB->getRetAlign();
3055  assert(Alignment &&
3056  "HeapToShared on allocation without alignment attribute");
3057  SharedMem->setAlignment(MaybeAlign(Alignment));
3058 
3059  A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewBuffer);
3060  A.deleteAfterManifest(*CB);
3061  A.deleteAfterManifest(*FreeCalls.front());
3062 
3063  SharedMemoryUsed += AllocSize->getZExtValue();
3064  NumBytesMovedToSharedMemory = SharedMemoryUsed;
3065  Changed = ChangeStatus::CHANGED;
3066  }
3067 
3068  return Changed;
3069  }
3070 
3071  ChangeStatus updateImpl(Attributor &A) override {
3072  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3073  auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3074  Function *F = getAnchorScope();
3075 
3076  auto NumMallocCalls = MallocCalls.size();
3077 
3078  // Only consider malloc calls executed by a single thread with a constant.
3079  for (User *U : RFI.Declaration->users()) {
3080  const auto &ED = A.getAAFor<AAExecutionDomain>(
3082  if (CallBase *CB = dyn_cast<CallBase>(U))
3083  if (!isa<ConstantInt>(CB->getArgOperand(0)) ||
3084  !ED.isExecutedByInitialThreadOnly(*CB))
3085  MallocCalls.remove(CB);
3086  }
3087 
3088  findPotentialRemovedFreeCalls(A);
3089 
3090  if (NumMallocCalls != MallocCalls.size())
3091  return ChangeStatus::CHANGED;
3092 
3093  return ChangeStatus::UNCHANGED;
3094  }
3095 
3096  /// Collection of all malloc calls in a function.
3097  SmallSetVector<CallBase *, 4> MallocCalls;
3098  /// Collection of potentially removed free calls in a function.
3099  SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
3100  /// The total amount of shared memory that has been used for HeapToShared.
3101  unsigned SharedMemoryUsed = 0;
3102 };
3103 
3104 struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
3106  AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3107 
3108  /// Statistics are tracked as part of manifest for now.
3109  void trackStatistics() const override {}
3110 
3111  /// See AbstractAttribute::getAsStr()
3112  const std::string getAsStr() const override {
3113  if (!isValidState())
3114  return "<invalid>";
3115  return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
3116  : "generic") +
3117  std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
3118  : "") +
3119  std::string(" #PRs: ") +
3120  (ReachedKnownParallelRegions.isValidState()
3121  ? std::to_string(ReachedKnownParallelRegions.size())
3122  : "<invalid>") +
3123  ", #Unknown PRs: " +
3124  (ReachedUnknownParallelRegions.isValidState()
3125  ? std::to_string(ReachedUnknownParallelRegions.size())
3126  : "<invalid>") +
3127  ", #Reaching Kernels: " +
3128  (ReachingKernelEntries.isValidState()
3129  ? std::to_string(ReachingKernelEntries.size())
3130  : "<invalid>");
3131  }
3132 
3133  /// Create an abstract attribute biew for the position \p IRP.
3134  static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
3135 
3136  /// See AbstractAttribute::getName()
3137  const std::string getName() const override { return "AAKernelInfo"; }
3138 
3139  /// See AbstractAttribute::getIdAddr()
3140  const char *getIdAddr() const override { return &ID; }
3141 
3142  /// This function should return true if the type of the \p AA is AAKernelInfo
3143  static bool classof(const AbstractAttribute *AA) {
3144  return (AA->getIdAddr() == &ID);
3145  }
3146 
3147  static const char ID;
3148 };
3149 
3150 /// The function kernel info abstract attribute, basically, what can we say
3151 /// about a function with regards to the KernelInfoState.
3152 struct AAKernelInfoFunction : AAKernelInfo {
3153  AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
3154  : AAKernelInfo(IRP, A) {}
3155 
3156  SmallPtrSet<Instruction *, 4> GuardedInstructions;
3157 
3158  SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
3159  return GuardedInstructions;
3160  }
3161 
3162  /// See AbstractAttribute::initialize(...).
3163  void initialize(Attributor &A) override {
3164  // This is a high-level transform that might change the constant arguments
3165  // of the init and dinit calls. We need to tell the Attributor about this
3166  // to avoid other parts using the current constant value for simpliication.
3167  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3168 
3169  Function *Fn = getAnchorScope();
3170 
3171  OMPInformationCache::RuntimeFunctionInfo &InitRFI =
3172  OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
3173  OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
3174  OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
3175 
3176  // For kernels we perform more initialization work, first we find the init
3177  // and deinit calls.
3178  auto StoreCallBase = [](Use &U,
3179  OMPInformationCache::RuntimeFunctionInfo &RFI,
3180  CallBase *&Storage) {
3181  CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
3182  assert(CB &&
3183  "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
3184  assert(!Storage &&
3185  "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
3186  Storage = CB;
3187  return false;
3188  };
3189  InitRFI.foreachUse(
3190  [&](Use &U, Function &) {
3191  StoreCallBase(U, InitRFI, KernelInitCB);
3192  return false;
3193  },
3194  Fn);
3195  DeinitRFI.foreachUse(
3196  [&](Use &U, Function &) {
3197  StoreCallBase(U, DeinitRFI, KernelDeinitCB);
3198  return false;
3199  },
3200  Fn);
3201 
3202  // Ignore kernels without initializers such as global constructors.
3203  if (!KernelInitCB || !KernelDeinitCB)
3204  return;
3205 
3206  // Add itself to the reaching kernel and set IsKernelEntry.
3207  ReachingKernelEntries.insert(Fn);
3208  IsKernelEntry = true;
3209 
3210  // For kernels we might need to initialize/finalize the IsSPMD state and
3211  // we need to register a simplification callback so that the Attributor
3212  // knows the constant arguments to __kmpc_target_init and
3213  // __kmpc_target_deinit might actually change.
3214 
3215  Attributor::SimplifictionCallbackTy StateMachineSimplifyCB =
3216  [&](const IRPosition &IRP, const AbstractAttribute *AA,
3217  bool &UsedAssumedInformation) -> Optional<Value *> {
3218  // IRP represents the "use generic state machine" argument of an
3219  // __kmpc_target_init call. We will answer this one with the internal
3220  // state. As long as we are not in an invalid state, we will create a
3221  // custom state machine so the value should be a `i1 false`. If we are
3222  // in an invalid state, we won't change the value that is in the IR.
3223  if (!ReachedKnownParallelRegions.isValidState())
3224  return nullptr;
3225  // If we have disabled state machine rewrites, don't make a custom one.
3227  return nullptr;
3228  if (AA)
3229  A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3230  UsedAssumedInformation = !isAtFixpoint();
3231  auto *FalseVal =
3233  return FalseVal;
3234  };
3235 
3236  Attributor::SimplifictionCallbackTy ModeSimplifyCB =
3237  [&](const IRPosition &IRP, const AbstractAttribute *AA,
3238  bool &UsedAssumedInformation) -> Optional<Value *> {
3239  // IRP represents the "SPMDCompatibilityTracker" argument of an
3240  // __kmpc_target_init or
3241  // __kmpc_target_deinit call. We will answer this one with the internal
3242  // state.
3243  if (!SPMDCompatibilityTracker.isValidState())
3244  return nullptr;
3245  if (!SPMDCompatibilityTracker.isAtFixpoint()) {
3246  if (AA)
3247  A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3248  UsedAssumedInformation = true;
3249  } else {
3250  UsedAssumedInformation = false;
3251  }
3252  auto *Val = ConstantInt::getSigned(
3254  SPMDCompatibilityTracker.isAssumed() ? OMP_TGT_EXEC_MODE_SPMD
3256  return Val;
3257  };
3258 
3259  Attributor::SimplifictionCallbackTy IsGenericModeSimplifyCB =
3260  [&](const IRPosition &IRP, const AbstractAttribute *AA,
3261  bool &UsedAssumedInformation) -> Optional<Value *> {
3262  // IRP represents the "RequiresFullRuntime" argument of an
3263  // __kmpc_target_init or __kmpc_target_deinit call. We will answer this
3264  // one with the internal state of the SPMDCompatibilityTracker, so if
3265  // generic then true, if SPMD then false.
3266  if (!SPMDCompatibilityTracker.isValidState())
3267  return nullptr;
3268  if (!SPMDCompatibilityTracker.isAtFixpoint()) {
3269  if (AA)
3270  A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3271  UsedAssumedInformation = true;
3272  } else {
3273  UsedAssumedInformation = false;
3274  }
3275  auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(),
3276  !SPMDCompatibilityTracker.isAssumed());
3277  return Val;
3278  };
3279 
3280  constexpr const int InitModeArgNo = 1;
3281  constexpr const int DeinitModeArgNo = 1;
3282  constexpr const int InitUseStateMachineArgNo = 2;
3283  constexpr const int InitRequiresFullRuntimeArgNo = 3;
3284  constexpr const int DeinitRequiresFullRuntimeArgNo = 2;
3285  A.registerSimplificationCallback(
3286  IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo),
3287  StateMachineSimplifyCB);
3288  A.registerSimplificationCallback(
3289  IRPosition::callsite_argument(*KernelInitCB, InitModeArgNo),
3290  ModeSimplifyCB);
3291  A.registerSimplificationCallback(
3292  IRPosition::callsite_argument(*KernelDeinitCB, DeinitModeArgNo),
3293  ModeSimplifyCB);
3294  A.registerSimplificationCallback(
3295  IRPosition::callsite_argument(*KernelInitCB,
3296  InitRequiresFullRuntimeArgNo),
3297  IsGenericModeSimplifyCB);
3298  A.registerSimplificationCallback(
3299  IRPosition::callsite_argument(*KernelDeinitCB,
3300  DeinitRequiresFullRuntimeArgNo),
3301  IsGenericModeSimplifyCB);
3302 
3303  // Check if we know we are in SPMD-mode already.
3304  ConstantInt *ModeArg =
3305  dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo));
3306  if (ModeArg && (ModeArg->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD))
3307  SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3308  // This is a generic region but SPMDization is disabled so stop tracking.
3309  else if (DisableOpenMPOptSPMDization)
3310  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3311  }
3312 
3313  /// Sanitize the string \p S such that it is a suitable global symbol name.
3314  static std::string sanitizeForGlobalName(std::string S) {
3315  std::replace_if(
3316  S.begin(), S.end(),
3317  [](const char C) {
3318  return !((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') ||
3319  (C >= '0' && C <= '9') || C == '_');
3320  },
3321  '.');
3322  return S;
3323  }
3324 
3325  /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
3326  /// finished now.
3327  ChangeStatus manifest(Attributor &A) override {
3328  // If we are not looking at a kernel with __kmpc_target_init and
3329  // __kmpc_target_deinit call we cannot actually manifest the information.
3330  if (!KernelInitCB || !KernelDeinitCB)
3331  return ChangeStatus::UNCHANGED;
3332 
3333  // If we can we change the execution mode to SPMD-mode otherwise we build a
3334  // custom state machine.
3336  if (!changeToSPMDMode(A, Changed))
3337  return buildCustomStateMachine(A);
3338 
3339  return Changed;
3340  }
3341 
3342  bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {
3343  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3344 
3345  if (!SPMDCompatibilityTracker.isAssumed()) {
3346  for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
3347  if (!NonCompatibleI)
3348  continue;
3349 
3350  // Skip diagnostics on calls to known OpenMP runtime functions for now.
3351  if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
3352  if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
3353  continue;
3354 
3355  auto Remark = [&](OptimizationRemarkAnalysis ORA) {
3356  ORA << "Value has potential side effects preventing SPMD-mode "
3357  "execution";
3358  if (isa<CallBase>(NonCompatibleI)) {
3359  ORA << ". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to "
3360  "the called function to override";
3361  }
3362  return ORA << ".";
3363  };
3364  A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
3365  Remark);
3366 
3367  LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
3368  << *NonCompatibleI << "\n");
3369  }
3370 
3371  return false;
3372  }
3373 
3374  // Get the actual kernel, could be the caller of the anchor scope if we have
3375  // a debug wrapper.
3376  Function *Kernel = getAnchorScope();
3377  if (Kernel->hasLocalLinkage()) {
3378  assert(Kernel->hasOneUse() && "Unexpected use of debug kernel wrapper.");
3379  auto *CB = cast<CallBase>(Kernel->user_back());
3380  Kernel = CB->getCaller();
3381  }
3382  assert(OMPInfoCache.Kernels.count(Kernel) && "Expected kernel function!");
3383 
3384  // Check if the kernel is already in SPMD mode, if so, return success.
3386  (Kernel->getName() + "_exec_mode").str());
3387  assert(ExecMode && "Kernel without exec mode?");
3388  assert(ExecMode->getInitializer() && "ExecMode doesn't have initializer!");
3389 
3390  // Set the global exec mode flag to indicate SPMD-Generic mode.
3391  assert(isa<ConstantInt>(ExecMode->getInitializer()) &&
3392  "ExecMode is not an integer!");
3393  const int8_t ExecModeVal =
3394  cast<ConstantInt>(ExecMode->getInitializer())->getSExtValue();
3395  if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC)
3396  return true;
3397 
3398  // We will now unconditionally modify the IR, indicate a change.
3399  Changed = ChangeStatus::CHANGED;
3400 
3401  auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3402  Instruction *RegionEndI) {
3403  LoopInfo *LI = nullptr;
3404  DominatorTree *DT = nullptr;
3405  MemorySSAUpdater *MSU = nullptr;
3406  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3407 
3408  BasicBlock *ParentBB = RegionStartI->getParent();
3409  Function *Fn = ParentBB->getParent();
3410  Module &M = *Fn->getParent();
3411 
3412  // Create all the blocks and logic.
3413  // ParentBB:
3414  // goto RegionCheckTidBB
3415  // RegionCheckTidBB:
3416  // Tid = __kmpc_hardware_thread_id()
3417  // if (Tid != 0)
3418  // goto RegionBarrierBB
3419  // RegionStartBB:
3420  // <execute instructions guarded>
3421  // goto RegionEndBB
3422  // RegionEndBB:
3423  // <store escaping values to shared mem>
3424  // goto RegionBarrierBB
3425  // RegionBarrierBB:
3426  // __kmpc_simple_barrier_spmd()
3427  // // second barrier is omitted if lacking escaping values.
3428  // <load escaping values from shared mem>
3429  // __kmpc_simple_barrier_spmd()
3430  // goto RegionExitBB
3431  // RegionExitBB:
3432  // <execute rest of instructions>
3433 
3434  BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3435  DT, LI, MSU, "region.guarded.end");
3436  BasicBlock *RegionBarrierBB =
3437  SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3438  MSU, "region.barrier");
3439  BasicBlock *RegionExitBB =
3440  SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3441  DT, LI, MSU, "region.exit");
3442  BasicBlock *RegionStartBB =
3443  SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
3444 
3445  assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
3446  "Expected a different CFG");
3447 
3448  BasicBlock *RegionCheckTidBB = SplitBlock(
3449  ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
3450 
3451  // Register basic blocks with the Attributor.
3452  A.registerManifestAddedBasicBlock(*RegionEndBB);
3453  A.registerManifestAddedBasicBlock(*RegionBarrierBB);
3454  A.registerManifestAddedBasicBlock(*RegionExitBB);
3455  A.registerManifestAddedBasicBlock(*RegionStartBB);
3456  A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
3457 
3458  bool HasBroadcastValues = false;
3459  // Find escaping outputs from the guarded region to outside users and
3460  // broadcast their values to them.
3461  for (Instruction &I : *RegionStartBB) {
3462  SmallPtrSet<Instruction *, 4> OutsideUsers;
3463  for (User *Usr : I.users()) {
3464  Instruction &UsrI = *cast<Instruction>(Usr);
3465  if (UsrI.getParent() != RegionStartBB)
3466  OutsideUsers.insert(&UsrI);
3467  }
3468 
3469  if (OutsideUsers.empty())
3470  continue;
3471 
3472  HasBroadcastValues = true;
3473 
3474  // Emit a global variable in shared memory to store the broadcasted
3475  // value.
3476  auto *SharedMem = new GlobalVariable(
3477  M, I.getType(), /* IsConstant */ false,
3479  sanitizeForGlobalName(
3480  (I.getName() + ".guarded.output.alloc").str()),
3481  nullptr, GlobalValue::NotThreadLocal,
3482  static_cast<unsigned>(AddressSpace::Shared));
3483 
3484  // Emit a store instruction to update the value.
3485  new StoreInst(&I, SharedMem, RegionEndBB->getTerminator());
3486 
3487  LoadInst *LoadI = new LoadInst(I.getType(), SharedMem,
3488  I.getName() + ".guarded.output.load",
3489  RegionBarrierBB->getTerminator());
3490 
3491  // Emit a load instruction and replace uses of the output value.
3492  for (Instruction *UsrI : OutsideUsers)
3493  UsrI->replaceUsesOfWith(&I, LoadI);
3494  }
3495 
3496  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3497 
3498  // Go to tid check BB in ParentBB.
3499  const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
3500  ParentBB->getTerminator()->eraseFromParent();
3502  InsertPointTy(ParentBB, ParentBB->end()), DL);
3503  OMPInfoCache.OMPBuilder.updateToLocation(Loc);
3504  uint32_t SrcLocStrSize;
3505  auto *SrcLocStr =
3506  OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3507  Value *Ident =
3508  OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3509  BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
3510 
3511  // Add check for Tid in RegionCheckTidBB
3512  RegionCheckTidBB->getTerminator()->eraseFromParent();
3513  OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
3514  InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
3515  OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
3516  FunctionCallee HardwareTidFn =
3517  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3518  M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
3519  CallInst *Tid =
3520  OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
3521  Tid->setDebugLoc(DL);
3522  OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);
3523  Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
3524  OMPInfoCache.OMPBuilder.Builder
3525  .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
3526  ->setDebugLoc(DL);
3527 
3528  // First barrier for synchronization, ensures main thread has updated
3529  // values.
3530  FunctionCallee BarrierFn =
3531  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3532  M, OMPRTL___kmpc_barrier_simple_spmd);
3533  OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
3534  RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
3535  CallInst *Barrier =
3536  OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});
3537  Barrier->setDebugLoc(DL);
3538  OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
3539 
3540  // Second barrier ensures workers have read broadcast values.
3541  if (HasBroadcastValues) {
3542  CallInst *Barrier = CallInst::Create(BarrierFn, {Ident, Tid}, "",
3543  RegionBarrierBB->getTerminator());
3544  Barrier->setDebugLoc(DL);
3545  OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
3546  }
3547  };
3548 
3549  auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3551  for (Instruction *GuardedI : SPMDCompatibilityTracker) {
3552  BasicBlock *BB = GuardedI->getParent();
3553  if (!Visited.insert(BB).second)
3554  continue;
3555 
3557  Instruction *LastEffect = nullptr;
3558  BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();
3559  while (++IP != IPEnd) {
3560  if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
3561  continue;
3562  Instruction *I = &*IP;
3563  if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI))
3564  continue;
3565  if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) {
3566  LastEffect = nullptr;
3567  continue;
3568  }
3569  if (LastEffect)
3570  Reorders.push_back({I, LastEffect});
3571  LastEffect = &*IP;
3572  }
3573  for (auto &Reorder : Reorders)
3574  Reorder.first->moveBefore(Reorder.second);
3575  }
3576 
3578 
3579  for (Instruction *GuardedI : SPMDCompatibilityTracker) {
3580  BasicBlock *BB = GuardedI->getParent();
3581  auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
3582  IRPosition::function(*GuardedI->getFunction()), nullptr,
3584  assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
3585  auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
3586  // Continue if instruction is already guarded.
3587  if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
3588  continue;
3589 
3590  Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
3591  for (Instruction &I : *BB) {
3592  // If instruction I needs to be guarded update the guarded region
3593  // bounds.
3594  if (SPMDCompatibilityTracker.contains(&I)) {
3595  CalleeAAFunction.getGuardedInstructions().insert(&I);
3596  if (GuardedRegionStart)
3597  GuardedRegionEnd = &I;
3598  else
3599  GuardedRegionStart = GuardedRegionEnd = &I;
3600 
3601  continue;
3602  }
3603 
3604  // Instruction I does not need guarding, store
3605  // any region found and reset bounds.
3606  if (GuardedRegionStart) {
3607  GuardedRegions.push_back(
3608  std::make_pair(GuardedRegionStart, GuardedRegionEnd));
3609  GuardedRegionStart = nullptr;
3610  GuardedRegionEnd = nullptr;
3611  }
3612  }
3613  }
3614 
3615  for (auto &GR : GuardedRegions)
3616  CreateGuardedRegion(GR.first, GR.second);
3617 
3618  // Adjust the global exec mode flag that tells the runtime what mode this
3619  // kernel is executed in.
3620  assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
3621  "Initially non-SPMD kernel has SPMD exec mode!");
3622  ExecMode->setInitializer(
3623  ConstantInt::get(ExecMode->getInitializer()->getType(),
3624  ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
3625 
3626  // Next rewrite the init and deinit calls to indicate we use SPMD-mode now.
3627  const int InitModeArgNo = 1;
3628  const int DeinitModeArgNo = 1;
3629  const int InitUseStateMachineArgNo = 2;
3630  const int InitRequiresFullRuntimeArgNo = 3;
3631  const int DeinitRequiresFullRuntimeArgNo = 2;
3632 
3633  auto &Ctx = getAnchorValue().getContext();
3634  A.changeUseAfterManifest(
3635  KernelInitCB->getArgOperandUse(InitModeArgNo),
3638  A.changeUseAfterManifest(
3639  KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo),
3640  *ConstantInt::getBool(Ctx, false));
3641  A.changeUseAfterManifest(
3642  KernelDeinitCB->getArgOperandUse(DeinitModeArgNo),
3645  A.changeUseAfterManifest(
3646  KernelInitCB->getArgOperandUse(InitRequiresFullRuntimeArgNo),
3647  *ConstantInt::getBool(Ctx, false));
3648  A.changeUseAfterManifest(
3649  KernelDeinitCB->getArgOperandUse(DeinitRequiresFullRuntimeArgNo),
3650  *ConstantInt::getBool(Ctx, false));
3651 
3652  ++NumOpenMPTargetRegionKernelsSPMD;
3653 
3654  auto Remark = [&](OptimizationRemark OR) {
3655  return OR << "Transformed generic-mode kernel to SPMD-mode.";
3656  };
3657  A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
3658  return true;
3659  };
3660 
3661  ChangeStatus buildCustomStateMachine(Attributor &A) {
3662  // If we have disabled state machine rewrites, don't make a custom one
3664  return ChangeStatus::UNCHANGED;
3665 
3666  // Don't rewrite the state machine if we are not in a valid state.
3667  if (!ReachedKnownParallelRegions.isValidState())
3668  return ChangeStatus::UNCHANGED;
3669 
3670  const int InitModeArgNo = 1;
3671  const int InitUseStateMachineArgNo = 2;
3672 
3673  // Check if the current configuration is non-SPMD and generic state machine.
3674  // If we already have SPMD mode or a custom state machine we do not need to
3675  // go any further. If it is anything but a constant something is weird and
3676  // we give up.
3677  ConstantInt *UseStateMachine = dyn_cast<ConstantInt>(
3678  KernelInitCB->getArgOperand(InitUseStateMachineArgNo));
3679  ConstantInt *Mode =
3680  dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo));
3681 
3682  // If we are stuck with generic mode, try to create a custom device (=GPU)
3683  // state machine which is specialized for the parallel regions that are
3684  // reachable by the kernel.
3685  if (!UseStateMachine || UseStateMachine->isZero() || !Mode ||
3686  (Mode->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD))
3687  return ChangeStatus::UNCHANGED;
3688 
3689  // If not SPMD mode, indicate we use a custom state machine now.
3690  auto &Ctx = getAnchorValue().getContext();
3691  auto *FalseVal = ConstantInt::getBool(Ctx, false);
3692  A.changeUseAfterManifest(
3693  KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *FalseVal);
3694 
3695  // If we don't actually need a state machine we are done here. This can
3696  // happen if there simply are no parallel regions. In the resulting kernel
3697  // all worker threads will simply exit right away, leaving the main thread
3698  // to do the work alone.
3699  if (!mayContainParallelRegion()) {
3700  ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
3701 
3702  auto Remark = [&](OptimizationRemark OR) {
3703  return OR << "Removing unused state machine from generic-mode kernel.";
3704  };
3705  A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
3706 
3707  return ChangeStatus::CHANGED;
3708  }
3709 
3710  // Keep track in the statistics of our new shiny custom state machine.
3711  if (ReachedUnknownParallelRegions.empty()) {
3712  ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
3713 
3714  auto Remark = [&](OptimizationRemark OR) {
3715  return OR << "Rewriting generic-mode kernel with a customized state "
3716  "machine.";
3717  };
3718  A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
3719  } else {
3720  ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
3721 
3722  auto Remark = [&](OptimizationRemarkAnalysis OR) {
3723  return OR << "Generic-mode kernel is executed with a customized state "
3724  "machine that requires a fallback.";
3725  };
3726  A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
3727 
3728  // Tell the user why we ended up with a fallback.
3729  for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
3730  if (!UnknownParallelRegionCB)
3731  continue;
3732  auto Remark = [&](OptimizationRemarkAnalysis ORA) {
3733  return ORA << "Call may contain unknown parallel regions. Use "
3734  << "`__attribute__((assume(\"omp_no_parallelism\")))` to "
3735  "override.";
3736  };
3737  A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
3738  "OMP133", Remark);
3739  }
3740  }
3741 
3742  // Create all the blocks:
3743  //
3744  // InitCB = __kmpc_target_init(...)
3745  // BlockHwSize =
3746  // __kmpc_get_hardware_num_threads_in_block();
3747  // WarpSize = __kmpc_get_warp_size();
3748  // BlockSize = BlockHwSize - WarpSize;
3749  // IsWorkerCheckBB: bool IsWorker = InitCB != -1;
3750  // if (IsWorker) {
3751  // if (InitCB >= BlockSize) return;
3752  // SMBeginBB: __kmpc_barrier_simple_generic(...);
3753  // void *WorkFn;
3754  // bool Active = __kmpc_kernel_parallel(&WorkFn);
3755  // if (!WorkFn) return;
3756  // SMIsActiveCheckBB: if (Active) {
3757  // SMIfCascadeCurrentBB: if (WorkFn == <ParFn0>)
3758  // ParFn0(...);
3759  // SMIfCascadeCurrentBB: else if (WorkFn == <ParFn1>)
3760  // ParFn1(...);
3761  // ...
3762  // SMIfCascadeCurrentBB: else
3763  // ((WorkFnTy*)WorkFn)(...);
3764  // SMEndParallelBB: __kmpc_kernel_end_parallel(...);
3765  // }
3766  // SMDoneBB: __kmpc_barrier_simple_generic(...);
3767  // goto SMBeginBB;
3768  // }
3769  // UserCodeEntryBB: // user code
3770  // __kmpc_target_deinit(...)
3771  //
3772  Function *Kernel = getAssociatedFunction();
3773  assert(Kernel && "Expected an associated function!");
3774 
3775  BasicBlock *InitBB = KernelInitCB->getParent();
3776  BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
3777  KernelInitCB->getNextNode(), "thread.user_code.check");
3778  BasicBlock *IsWorkerCheckBB =
3779  BasicBlock::Create(Ctx, "is_worker_check", Kernel, UserCodeEntryBB);
3780  BasicBlock *StateMachineBeginBB = BasicBlock::Create(
3781  Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
3782  BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
3783  Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
3784  BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
3785  Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
3786  BasicBlock *StateMachineIfCascadeCurrentBB =
3787  BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3788  Kernel, UserCodeEntryBB);
3789  BasicBlock *StateMachineEndParallelBB =
3790  BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
3791  Kernel, UserCodeEntryBB);
3792  BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
3793  Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
3794  A.registerManifestAddedBasicBlock(*InitBB);
3795  A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
3796  A.registerManifestAddedBasicBlock(*IsWorkerCheckBB);
3797  A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
3798  A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
3799  A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
3800  A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
3801  A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
3802  A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
3803 
3804  const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
3805  ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
3806  InitBB->getTerminator()->eraseFromParent();
3807 
3808  Instruction *IsWorker =
3809  ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
3810  ConstantInt::get(KernelInitCB->getType(), -1),
3811  "thread.is_worker", InitBB);
3812  IsWorker->setDebugLoc(DLoc);
3813  BranchInst::Create(IsWorkerCheckBB, UserCodeEntryBB, IsWorker, InitBB);
3814 
3815  Module &M = *Kernel->getParent();
3816  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3817  FunctionCallee BlockHwSizeFn =
3818  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3819  M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
3820  FunctionCallee WarpSizeFn =
3821  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3822  M, OMPRTL___kmpc_get_warp_size);
3823  CallInst *BlockHwSize =
3824  CallInst::Create(BlockHwSizeFn, "block.hw_size", IsWorkerCheckBB);
3825  OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);
3826  BlockHwSize->setDebugLoc(DLoc);
3827  CallInst *WarpSize =
3828  CallInst::Create(WarpSizeFn, "warp.size", IsWorkerCheckBB);
3829  OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);
3830  WarpSize->setDebugLoc(DLoc);
3831  Instruction *BlockSize = BinaryOperator::CreateSub(
3832  BlockHwSize, WarpSize, "block.size", IsWorkerCheckBB);
3833  BlockSize->setDebugLoc(DLoc);
3834  Instruction *IsMainOrWorker = ICmpInst::Create(
3835  ICmpInst::ICmp, llvm::CmpInst::ICMP_SLT, KernelInitCB, BlockSize,
3836  "thread.is_main_or_worker", IsWorkerCheckBB);
3837  IsMainOrWorker->setDebugLoc(DLoc);
3838  BranchInst::Create(StateMachineBeginBB, StateMachineFinishedBB,
3839  IsMainOrWorker, IsWorkerCheckBB);
3840 
3841  // Create local storage for the work function pointer.
3842  const DataLayout &DL = M.getDataLayout();
3843  Type *VoidPtrTy = Type::getInt8PtrTy(Ctx);
3844  Instruction *WorkFnAI =
3845  new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
3846  "worker.work_fn.addr", &Kernel->getEntryBlock().front());
3847  WorkFnAI->setDebugLoc(DLoc);
3848 
3849  OMPInfoCache.OMPBuilder.updateToLocation(
3851  IRBuilder<>::InsertPoint(StateMachineBeginBB,
3852  StateMachineBeginBB->end()),
3853  DLoc));
3854 
3855  Value *Ident = KernelInitCB->getArgOperand(0);
3856  Value *GTid = KernelInitCB;
3857 
3858  FunctionCallee BarrierFn =
3859  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3860  M, OMPRTL___kmpc_barrier_simple_generic);
3861  CallInst *Barrier =
3862  CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB);
3863  OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
3864  Barrier->setDebugLoc(DLoc);
3865 
3866  if (WorkFnAI->getType()->getPointerAddressSpace() !=
3867  (unsigned int)AddressSpace::Generic) {
3868  WorkFnAI = new AddrSpaceCastInst(
3869  WorkFnAI,
3871  cast<PointerType>(WorkFnAI->getType()),
3872  (unsigned int)AddressSpace::Generic),
3873  WorkFnAI->getName() + ".generic", StateMachineBeginBB);
3874  WorkFnAI->setDebugLoc(DLoc);
3875  }
3876 
3877  FunctionCallee KernelParallelFn =
3878  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3879  M, OMPRTL___kmpc_kernel_parallel);
3880  CallInst *IsActiveWorker = CallInst::Create(
3881  KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
3882  OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);
3883  IsActiveWorker->setDebugLoc(DLoc);
3884  Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
3885  StateMachineBeginBB);
3886  WorkFn->setDebugLoc(DLoc);
3887 
3888  FunctionType *ParallelRegionFnTy = FunctionType::get(
3890  false);
3892  WorkFn, ParallelRegionFnTy->getPointerTo(), "worker.work_fn.addr_cast",
3893  StateMachineBeginBB);
3894 
3895  Instruction *IsDone =
3896  ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
3897  Constant::getNullValue(VoidPtrTy), "worker.is_done",
3898  StateMachineBeginBB);
3899  IsDone->setDebugLoc(DLoc);
3900  BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
3901  IsDone, StateMachineBeginBB)
3902  ->setDebugLoc(DLoc);
3903 
3904  BranchInst::Create(StateMachineIfCascadeCurrentBB,
3905  StateMachineDoneBarrierBB, IsActiveWorker,
3906  StateMachineIsActiveCheckBB)
3907  ->setDebugLoc(DLoc);
3908 
3909  Value *ZeroArg =
3910  Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
3911 
3912  // Now that we have most of the CFG skeleton it is time for the if-cascade
3913  // that checks the function pointer we got from the runtime against the
3914  // parallel regions we expect, if there are any.
3915  for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) {
3916  auto *ParallelRegion = ReachedKnownParallelRegions[I];
3917  BasicBlock *PRExecuteBB = BasicBlock::Create(
3918  Ctx, "worker_state_machine.parallel_region.execute", Kernel,
3919  StateMachineEndParallelBB);
3920  CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
3921  ->setDebugLoc(DLoc);
3922  BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
3923  ->setDebugLoc(DLoc);
3924 
3925  BasicBlock *PRNextBB =
3926  BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3927  Kernel, StateMachineEndParallelBB);
3928 
3929  // Check if we need to compare the pointer at all or if we can just
3930  // call the parallel region function.
3931  Value *IsPR;
3932  if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) {
3933  Instruction *CmpI = ICmpInst::Create(
3934  ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFnCast, ParallelRegion,
3935  "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
3936  CmpI->setDebugLoc(DLoc);
3937  IsPR = CmpI;
3938  } else {
3939  IsPR = ConstantInt::getTrue(Ctx);
3940  }
3941 
3942  BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
3943  StateMachineIfCascadeCurrentBB)
3944  ->setDebugLoc(DLoc);
3945  StateMachineIfCascadeCurrentBB = PRNextBB;
3946  }
3947 
3948  // At the end of the if-cascade we place the indirect function pointer call
3949  // in case we might need it, that is if there can be parallel regions we
3950  // have not handled in the if-cascade above.
3951  if (!ReachedUnknownParallelRegions.empty()) {
3952  StateMachineIfCascadeCurrentBB->setName(
3953  "worker_state_machine.parallel_region.fallback.execute");
3954  CallInst::Create(ParallelRegionFnTy, WorkFnCast, {ZeroArg, GTid}, "",
3955  StateMachineIfCascadeCurrentBB)
3956  ->setDebugLoc(DLoc);
3957  }
3958  BranchInst::Create(StateMachineEndParallelBB,
3959  StateMachineIfCascadeCurrentBB)
3960  ->setDebugLoc(DLoc);
3961 
3962  FunctionCallee EndParallelFn =
3963  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3964  M, OMPRTL___kmpc_kernel_end_parallel);
3965  CallInst *EndParallel =
3966  CallInst::Create(EndParallelFn, {}, "", StateMachineEndParallelBB);
3967  OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);
3968  EndParallel->setDebugLoc(DLoc);
3969  BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
3970  ->setDebugLoc(DLoc);
3971 
3972  CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
3973  ->setDebugLoc(DLoc);
3974  BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
3975  ->setDebugLoc(DLoc);
3976 
3977  return ChangeStatus::CHANGED;
3978  }
3979 
3980  /// Fixpoint iteration update function. Will be called every time a dependence
3981  /// changed its state (and in the beginning).
3982  ChangeStatus updateImpl(Attributor &A) override {
3983  KernelInfoState StateBefore = getState();
3984 
3985  // Callback to check a read/write instruction.
3986  auto CheckRWInst = [&](Instruction &I) {
3987  // We handle calls later.
3988  if (isa<CallBase>(I))
3989  return true;
3990  // We only care about write effects.
3991  if (!I.mayWriteToMemory())
3992  return true;
3993  if (auto *SI = dyn_cast<StoreInst>(&I)) {
3995  getUnderlyingObjects(SI->getPointerOperand(), Objects);
3996  if (llvm::all_of(Objects,
3997  [](const Value *Obj) { return isa<AllocaInst>(Obj); }))
3998  return true;
3999  // Check for AAHeapToStack moved objects which must not be guarded.
4000  auto &HS = A.getAAFor<AAHeapToStack>(
4001  *this, IRPosition::function(*I.getFunction()),
4003  if (llvm::all_of(Objects, [&HS](const Value *Obj) {
4004  auto *CB = dyn_cast<CallBase>(Obj);
4005  if (!CB)
4006  return false;
4007  return HS.isAssumedHeapToStack(*CB);
4008  })) {
4009  return true;
4010  }
4011  }
4012 
4013  // Insert instruction that needs guarding.
4014  SPMDCompatibilityTracker.insert(&I);
4015  return true;
4016  };
4017 
4018  bool UsedAssumedInformationInCheckRWInst = false;
4019  if (!SPMDCompatibilityTracker.isAtFixpoint())
4020  if (!A.checkForAllReadWriteInstructions(
4021  CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
4022  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4023 
4024  bool UsedAssumedInformationFromReachingKernels = false;
4025  if (!IsKernelEntry) {
4026  updateParallelLevels(A);
4027 
4028  bool AllReachingKernelsKnown = true;
4029  updateReachingKernelEntries(A, AllReachingKernelsKnown);
4030  UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
4031 
4032  if (!ParallelLevels.isValidState())
4033  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4034  else if (!ReachingKernelEntries.isValidState())
4035  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4036  else if (!SPMDCompatibilityTracker.empty()) {
4037  // Check if all reaching kernels agree on the mode as we can otherwise
4038  // not guard instructions. We might not be sure about the mode so we
4039  // we cannot fix the internal spmd-zation state either.
4040  int SPMD = 0, Generic = 0;
4041  for (auto *Kernel : ReachingKernelEntries) {
4042  auto &CBAA = A.getAAFor<AAKernelInfo>(
4044  if (CBAA.SPMDCompatibilityTracker.isValidState() &&
4045  CBAA.SPMDCompatibilityTracker.isAssumed())
4046  ++SPMD;
4047  else
4048  ++Generic;
4049  if (!CBAA.SPMDCompatibilityTracker.isAtFixpoint())
4050  UsedAssumedInformationFromReachingKernels = true;
4051  }
4052  if (SPMD != 0 && Generic != 0)
4053  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4054  }
4055  }
4056 
4057  // Callback to check a call instruction.
4058  bool AllParallelRegionStatesWereFixed = true;
4059  bool AllSPMDStatesWereFixed = true;
4060  auto CheckCallInst = [&](Instruction &I) {
4061  auto &CB = cast<CallBase>(I);
4062  auto &CBAA = A.getAAFor<AAKernelInfo>(
4064  getState() ^= CBAA.getState();
4065  AllSPMDStatesWereFixed &= CBAA.SPMDCompatibilityTracker.isAtFixpoint();
4066  AllParallelRegionStatesWereFixed &=
4067  CBAA.ReachedKnownParallelRegions.isAtFixpoint();
4068  AllParallelRegionStatesWereFixed &=
4069  CBAA.ReachedUnknownParallelRegions.isAtFixpoint();
4070  return true;
4071  };
4072 
4073  bool UsedAssumedInformationInCheckCallInst = false;
4074  if (!A.checkForAllCallLikeInstructions(
4075  CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) {
4076  LLVM_DEBUG(dbgs() << TAG
4077  << "Failed to visit all call-like instructions!\n";);
4078  return indicatePessimisticFixpoint();
4079  }
4080 
4081  // If we haven't used any assumed information for the reached parallel
4082  // region states we can fix it.
4083  if (!UsedAssumedInformationInCheckCallInst &&
4084  AllParallelRegionStatesWereFixed) {
4085  ReachedKnownParallelRegions.indicateOptimisticFixpoint();
4086  ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
4087  }
4088 
4089  // If we are sure there are no parallel regions in the kernel we do not
4090  // want SPMD mode.
4091  if (IsKernelEntry && ReachedUnknownParallelRegions.isAtFixpoint() &&
4092  ReachedKnownParallelRegions.isAtFixpoint() &&
4093  ReachedUnknownParallelRegions.isValidState() &&
4094  ReachedKnownParallelRegions.isValidState() &&
4095  !mayContainParallelRegion())
4096  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4097 
4098  // If we haven't used any assumed information for the SPMD state we can fix
4099  // it.
4100  if (!UsedAssumedInformationInCheckRWInst &&
4101  !UsedAssumedInformationInCheckCallInst &&
4102  !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
4103  SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4104 
4105  return StateBefore == getState() ? ChangeStatus::UNCHANGED
4107  }
4108 
4109 private:
4110  /// Update info regarding reaching kernels.
4111  void updateReachingKernelEntries(Attributor &A,
4112  bool &AllReachingKernelsKnown) {
4113  auto PredCallSite = [&](AbstractCallSite ACS) {
4114  Function *Caller = ACS.getInstruction()->getFunction();
4115 
4116  assert(Caller && "Caller is nullptr");
4117 
4118  auto &CAA = A.getOrCreateAAFor<AAKernelInfo>(
4119  IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
4120  if (CAA.ReachingKernelEntries.isValidState()) {
4121  ReachingKernelEntries ^= CAA.ReachingKernelEntries;
4122  return true;
4123  }
4124 
4125  // We lost track of the caller of the associated function, any kernel
4126  // could reach now.
4127  ReachingKernelEntries.indicatePessimisticFixpoint();
4128 
4129  return true;
4130  };
4131 
4132  if (!A.checkForAllCallSites(PredCallSite, *this,
4133  true /* RequireAllCallSites */,
4134  AllReachingKernelsKnown))
4135  ReachingKernelEntries.indicatePessimisticFixpoint();
4136  }
4137 
4138  /// Update info regarding parallel levels.
4139  void updateParallelLevels(Attributor &A) {
4140  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4141  OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
4142  OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
4143 
4144  auto PredCallSite = [&](AbstractCallSite ACS) {
4145  Function *Caller = ACS.getInstruction()->getFunction();
4146 
4147  assert(Caller && "Caller is nullptr");
4148 
4149  auto &CAA =
4150  A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
4151  if (CAA.ParallelLevels.isValidState()) {
4152  // Any function that is called by `__kmpc_parallel_51` will not be
4153  // folded as the parallel level in the function is updated. In order to
4154  // get it right, all the analysis would depend on the implentation. That
4155  // said, if in the future any change to the implementation, the analysis
4156  // could be wrong. As a consequence, we are just conservative here.
4157  if (Caller == Parallel51RFI.Declaration) {
4158  ParallelLevels.indicatePessimisticFixpoint();
4159  return true;
4160  }
4161 
4162  ParallelLevels ^= CAA.ParallelLevels;
4163 
4164  return true;
4165  }
4166 
4167  // We lost track of the caller of the associated function, any kernel
4168  // could reach now.
4169  ParallelLevels.indicatePessimisticFixpoint();
4170 
4171  return true;
4172  };
4173 
4174  bool AllCallSitesKnown = true;
4175  if (!A.checkForAllCallSites(PredCallSite, *this,
4176  true /* RequireAllCallSites */,
4177  AllCallSitesKnown))
4178  ParallelLevels.indicatePessimisticFixpoint();
4179  }
4180 };
4181 
4182 /// The call site kernel info abstract attribute, basically, what can we say
4183 /// about a call site with regards to the KernelInfoState. For now this simply
4184 /// forwards the information from the callee.
4185 struct AAKernelInfoCallSite : AAKernelInfo {
4186  AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
4187  : AAKernelInfo(IRP, A) {}
4188 
4189  /// See AbstractAttribute::initialize(...).
4190  void initialize(Attributor &A) override {
4192 
4193  CallBase &CB = cast<CallBase>(getAssociatedValue());
4194  Function *Callee = getAssociatedFunction();
4195 
4196  auto &AssumptionAA = A.getAAFor<AAAssumptionInfo>(
4198 
4199  // Check for SPMD-mode assumptions.
4200  if (AssumptionAA.hasAssumption("ompx_spmd_amenable")) {
4201  SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4202  indicateOptimisticFixpoint();
4203  }
4204 
4205  // First weed out calls we do not care about, that is readonly/readnone
4206  // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
4207  // parallel region or anything else we are looking for.
4208  if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
4209  indicateOptimisticFixpoint();
4210  return;
4211  }
4212 
4213  // Next we check if we know the callee. If it is a known OpenMP function
4214  // we will handle them explicitly in the switch below. If it is not, we
4215  // will use an AAKernelInfo object on the callee to gather information and
4216  // merge that into the current state. The latter happens in the updateImpl.
4217  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4218  const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4219  if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4220  // Unknown caller or declarations are not analyzable, we give up.
4221  if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
4222 
4223  // Unknown callees might contain parallel regions, except if they have
4224  // an appropriate assumption attached.
4225  if (!(AssumptionAA.hasAssumption("omp_no_openmp") ||
4226  AssumptionAA.hasAssumption("omp_no_parallelism")))
4227  ReachedUnknownParallelRegions.insert(&CB);
4228 
4229  // If SPMDCompatibilityTracker is not fixed, we need to give up on the
4230  // idea we can run something unknown in SPMD-mode.
4231  if (!SPMDCompatibilityTracker.isAtFixpoint()) {
4232  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4233  SPMDCompatibilityTracker.insert(&CB);
4234  }
4235 
4236  // We have updated the state for this unknown call properly, there won't
4237  // be any change so we indicate a fixpoint.
4238  indicateOptimisticFixpoint();
4239  }
4240  // If the callee is known and can be used in IPO, we will update the state
4241  // based on the callee state in updateImpl.
4242  return;
4243  }
4244 
4245  const unsigned int WrapperFunctionArgNo = 6;
4246  RuntimeFunction RF = It->getSecond();
4247  switch (RF) {
4248  // All the functions we know are compatible with SPMD mode.
4249  case OMPRTL___kmpc_is_spmd_exec_mode:
4250  case OMPRTL___kmpc_distribute_static_fini:
4251  case OMPRTL___kmpc_for_static_fini:
4252  case OMPRTL___kmpc_global_thread_num:
4253  case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4254  case OMPRTL___kmpc_get_hardware_num_blocks:
4255  case OMPRTL___kmpc_single:
4256  case OMPRTL___kmpc_end_single:
4257  case OMPRTL___kmpc_master:
4258  case OMPRTL___kmpc_end_master:
4259  case OMPRTL___kmpc_barrier:
4260  case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
4261  case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
4262  case OMPRTL___kmpc_nvptx_end_reduce_nowait:
4263  break;
4264  case OMPRTL___kmpc_distribute_static_init_4:
4265  case OMPRTL___kmpc_distribute_static_init_4u:
4266  case OMPRTL___kmpc_distribute_static_init_8:
4267  case OMPRTL___kmpc_distribute_static_init_8u:
4268  case OMPRTL___kmpc_for_static_init_4:
4269  case OMPRTL___kmpc_for_static_init_4u:
4270  case OMPRTL___kmpc_for_static_init_8:
4271  case OMPRTL___kmpc_for_static_init_8u: {
4272  // Check the schedule and allow static schedule in SPMD mode.
4273  unsigned ScheduleArgOpNo = 2;
4274  auto *ScheduleTypeCI =
4275  dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
4276  unsigned ScheduleTypeVal =
4277  ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
4278  switch (OMPScheduleType(ScheduleTypeVal)) {
4283  break;
4284  default:
4285  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4286  SPMDCompatibilityTracker.insert(&CB);
4287  break;
4288  };
4289  } break;
4290  case OMPRTL___kmpc_target_init:
4291  KernelInitCB = &CB;
4292  break;
4293  case OMPRTL___kmpc_target_deinit:
4294  KernelDeinitCB = &CB;
4295  break;
4296  case OMPRTL___kmpc_parallel_51:
4297  if (auto *ParallelRegion = dyn_cast<Function>(
4298  CB.getArgOperand(WrapperFunctionArgNo)->stripPointerCasts())) {
4299  ReachedKnownParallelRegions.insert(ParallelRegion);
4300  break;
4301  }
4302  // The condition above should usually get the parallel region function
4303  // pointer and record it. In the off chance it doesn't we assume the
4304  // worst.
4305  ReachedUnknownParallelRegions.insert(&CB);
4306  break;
4307  case OMPRTL___kmpc_omp_task:
4308  // We do not look into tasks right now, just give up.
4309  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4310  SPMDCompatibilityTracker.insert(&CB);
4311  ReachedUnknownParallelRegions.insert(&CB);
4312  break;
4313  case OMPRTL___kmpc_alloc_shared:
4314  case OMPRTL___kmpc_free_shared:
4315  // Return without setting a fixpoint, to be resolved in updateImpl.
4316  return;
4317  default:
4318  // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
4319  // generally. However, they do not hide parallel regions.
4320  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4321  SPMDCompatibilityTracker.insert(&CB);
4322  break;
4323  }
4324  // All other OpenMP runtime calls will not reach parallel regions so they
4325  // can be safely ignored for now. Since it is a known OpenMP runtime call we
4326  // have now modeled all effects and there is no need for any update.
4327  indicateOptimisticFixpoint();
4328  }
4329 
4330  ChangeStatus updateImpl(Attributor &A) override {
4331  // TODO: Once we have call site specific value information we can provide
4332  // call site specific liveness information and then it makes
4333  // sense to specialize attributes for call sites arguments instead of
4334  // redirecting requests to the callee argument.
4335  Function *F = getAssociatedFunction();
4336 
4337  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4338  const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
4339 
4340  // If F is not a runtime function, propagate the AAKernelInfo of the callee.
4341  if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4342  const IRPosition &FnPos = IRPosition::function(*F);
4343  auto &FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
4344  if (getState() == FnAA.getState())
4345  return ChangeStatus::UNCHANGED;
4346  getState() = FnAA.getState();
4347  return ChangeStatus::CHANGED;
4348  }
4349 
4350  // F is a runtime function that allocates or frees memory, check
4351  // AAHeapToStack and AAHeapToShared.
4352  KernelInfoState StateBefore = getState();
4353  assert((It->getSecond() == OMPRTL___kmpc_alloc_shared ||
4354  It->getSecond() == OMPRTL___kmpc_free_shared) &&
4355  "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
4356 
4357  CallBase &CB = cast<CallBase>(getAssociatedValue());
4358 
4359  auto &HeapToStackAA = A.getAAFor<AAHeapToStack>(
4361  auto &HeapToSharedAA = A.getAAFor<AAHeapToShared>(
4363 
4364  RuntimeFunction RF = It->getSecond();
4365 
4366  switch (RF) {
4367  // If neither HeapToStack nor HeapToShared assume the call is removed,
4368  // assume SPMD incompatibility.
4369  case OMPRTL___kmpc_alloc_shared:
4370  if (!HeapToStackAA.isAssumedHeapToStack(CB) &&
4371  !HeapToSharedAA.isAssumedHeapToShared(CB))
4372  SPMDCompatibilityTracker.insert(&CB);
4373  break;
4374  case OMPRTL___kmpc_free_shared:
4375  if (!HeapToStackAA.isAssumedHeapToStackRemovedFree(CB) &&
4376  !HeapToSharedAA.isAssumedHeapToSharedRemovedFree(CB))
4377  SPMDCompatibilityTracker.insert(&CB);
4378  break;
4379  default:
4380  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4381  SPMDCompatibilityTracker.insert(&CB);
4382  }
4383 
4384  return StateBefore == getState() ? ChangeStatus::UNCHANGED
4386  }
4387 };
4388 
4389 struct AAFoldRuntimeCall
4390  : public StateWrapper<BooleanState, AbstractAttribute> {
4392 
4393  AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
4394 
4395  /// Statistics are tracked as part of manifest for now.
4396  void trackStatistics() const override {}
4397 
4398  /// Create an abstract attribute biew for the position \p IRP.
4399  static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
4400  Attributor &A);
4401 
4402  /// See AbstractAttribute::getName()
4403  const std::string getName() const override { return "AAFoldRuntimeCall"; }
4404 
4405  /// See AbstractAttribute::getIdAddr()
4406  const char *getIdAddr() const override { return &ID; }
4407 
4408  /// This function should return true if the type of the \p AA is
4409  /// AAFoldRuntimeCall
4410  static bool classof(const AbstractAttribute *AA) {
4411  return (AA->getIdAddr() == &ID);
4412  }
4413 
4414  static const char ID;
4415 };
4416 
4417 struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
4418  AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
4419  : AAFoldRuntimeCall(IRP, A) {}
4420 
4421  /// See AbstractAttribute::getAsStr()
4422  const std::string getAsStr() const override {
4423  if (!isValidState())
4424  return "<invalid>";
4425 
4426  std::string Str("simplified value: ");
4427 
4428  if (!SimplifiedValue)
4429  return Str + std::string("none");
4430 
4431  if (!SimplifiedValue.getValue())
4432  return Str + std::string("nullptr");
4433 
4434  if (ConstantInt *CI = dyn_cast<ConstantInt>(SimplifiedValue.getValue()))
4435  return Str + std::to_string(CI->getSExtValue());
4436 
4437  return Str + std::string("unknown");
4438  }
4439 
4440  void initialize(Attributor &A) override {
4442  indicatePessimisticFixpoint();
4443 
4444  Function *Callee = getAssociatedFunction();
4445 
4446  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4447  const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4448  assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
4449  "Expected a known OpenMP runtime function");
4450 
4451  RFKind = It->getSecond();
4452 
4453  CallBase &CB = cast<CallBase>(getAssociatedValue());
4454  A.registerSimplificationCallback(
4456  [&](const IRPosition &IRP, const AbstractAttribute *AA,
4457  bool &UsedAssumedInformation) -> Optional<Value *> {
4458  assert((isValidState() ||
4459  (SimplifiedValue && SimplifiedValue.getValue() == nullptr)) &&
4460  "Unexpected invalid state!");
4461 
4462  if (!isAtFixpoint()) {
4463  UsedAssumedInformation = true;
4464  if (AA)
4465  A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
4466  }
4467  return SimplifiedValue;
4468  });
4469  }
4470 
4471  ChangeStatus updateImpl(Attributor &A) override {
4473  switch (RFKind) {
4474  case OMPRTL___kmpc_is_spmd_exec_mode:
4475  Changed |= foldIsSPMDExecMode(A);
4476  break;
4477  case OMPRTL___kmpc_is_generic_main_thread_id:
4478  Changed |= foldIsGenericMainThread(A);
4479  break;
4480  case OMPRTL___kmpc_parallel_level:
4481  Changed |= foldParallelLevel(A);
4482  break;
4483  case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4484  Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
4485  break;
4486  case OMPRTL___kmpc_get_hardware_num_blocks:
4487  Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
4488  break;
4489  default:
4490  llvm_unreachable("Unhandled OpenMP runtime function!");
4491  }
4492 
4493  return Changed;
4494  }
4495 
4496  ChangeStatus manifest(Attributor &A) override {
4498 
4499  if (SimplifiedValue && *SimplifiedValue) {
4500  Instruction &I = *getCtxI();
4501  A.changeAfterManifest(IRPosition::inst(I), **SimplifiedValue);
4502  A.deleteAfterManifest(I);
4503 
4504  CallBase *CB = dyn_cast<CallBase>(&I);
4505  auto Remark = [&](OptimizationRemark OR) {
4506  if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue))
4507  return OR << "Replacing OpenMP runtime call "
4508  << CB->getCalledFunction()->getName() << " with "
4509  << ore::NV("FoldedValue", C->getZExtValue()) << ".";
4510  return OR << "Replacing OpenMP runtime call "
4511  << CB->getCalledFunction()->getName() << ".";
4512  };
4513 
4514  if (CB && EnableVerboseRemarks)
4515  A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark);
4516 
4517  LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "
4518  << **SimplifiedValue << "\n");
4519 
4520  Changed = ChangeStatus::CHANGED;
4521  }
4522 
4523  return Changed;
4524  }
4525 
4526  ChangeStatus indicatePessimisticFixpoint() override {
4527  SimplifiedValue = nullptr;
4528  return AAFoldRuntimeCall::indicatePessimisticFixpoint();
4529  }
4530 
4531 private:
4532  /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
4533  ChangeStatus foldIsSPMDExecMode(Attributor &A) {
4534  Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4535 
4536  unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
4537  unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
4538  auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4539  *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4540 
4541  if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4542  return indicatePessimisticFixpoint();
4543 
4544  for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4545  auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
4547 
4548  if (!AA.isValidState()) {
4549  SimplifiedValue = nullptr;
4550  return indicatePessimisticFixpoint();
4551  }
4552 
4553  if (AA.SPMDCompatibilityTracker.isAssumed()) {
4554  if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4555  ++KnownSPMDCount;
4556  else
4557  ++AssumedSPMDCount;
4558  } else {
4559  if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4560  ++KnownNonSPMDCount;
4561  else
4562  ++AssumedNonSPMDCount;
4563  }
4564  }
4565 
4566  if ((AssumedSPMDCount + KnownSPMDCount) &&
4567  (AssumedNonSPMDCount + KnownNonSPMDCount))
4568  return indicatePessimisticFixpoint();
4569 
4570  auto &Ctx = getAnchorValue().getContext();
4571  if (KnownSPMDCount || AssumedSPMDCount) {
4572  assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
4573  "Expected only SPMD kernels!");
4574  // All reaching kernels are in SPMD mode. Update all function calls to
4575  // __kmpc_is_spmd_exec_mode to 1.
4576  SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
4577  } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
4578  assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
4579  "Expected only non-SPMD kernels!");
4580  // All reaching kernels are in non-SPMD mode. Update all function
4581  // calls to __kmpc_is_spmd_exec_mode to 0.
4582  SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
4583  } else {
4584  // We have empty reaching kernels, therefore we cannot tell if the
4585  // associated call site can be folded. At this moment, SimplifiedValue
4586  // must be none.
4587  assert(!SimplifiedValue && "SimplifiedValue should be none");
4588  }
4589 
4590  return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4592  }
4593 
4594  /// Fold __kmpc_is_generic_main_thread_id into a constant if possible.
4595  ChangeStatus foldIsGenericMainThread(Attributor &A) {
4596  Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4597 
4598  CallBase &CB = cast<CallBase>(getAssociatedValue());
4599  Function *F = CB.getFunction();
4600  const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
4602 
4603  if (!ExecutionDomainAA.isValidState())
4604  return indicatePessimisticFixpoint();
4605 
4606  auto &Ctx = getAnchorValue().getContext();
4607  if (ExecutionDomainAA.isExecutedByInitialThreadOnly(CB))
4608  SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
4609  else
4610  return indicatePessimisticFixpoint();
4611 
4612  return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4614  }
4615 
4616  /// Fold __kmpc_parallel_level into a constant if possible.
4617  ChangeStatus foldParallelLevel(Attributor &A) {
4618  Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4619 
4620  auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4621  *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4622 
4623  if (!CallerKernelInfoAA.ParallelLevels.isValidState())
4624  return indicatePessimisticFixpoint();
4625 
4626  if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4627  return indicatePessimisticFixpoint();
4628 
4629  if (CallerKernelInfoAA.ReachingKernelEntries.empty()) {
4630  assert(!SimplifiedValue &&
4631  "SimplifiedValue should keep none at this point");
4632  return ChangeStatus::UNCHANGED;
4633  }
4634 
4635  unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
4636  unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
4637  for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4638  auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
4640  if (!AA.SPMDCompatibilityTracker.isValidState())
4641  return indicatePessimisticFixpoint();
4642 
4643  if (AA.SPMDCompatibilityTracker.isAssumed()) {
4644  if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4645  ++KnownSPMDCount;
4646  else
4647  ++AssumedSPMDCount;
4648  } else {
4649  if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4650  ++KnownNonSPMDCount;
4651  else
4652  ++AssumedNonSPMDCount;
4653  }
4654  }
4655 
4656  if ((AssumedSPMDCount + KnownSPMDCount) &&
4657  (AssumedNonSPMDCount + KnownNonSPMDCount))
4658  return indicatePessimisticFixpoint();
4659 
4660  auto &Ctx = getAnchorValue().getContext();
4661  // If the caller can only be reached by SPMD kernel entries, the parallel
4662  // level is 1. Similarly, if the caller can only be reached by non-SPMD
4663  // kernel entries, it is 0.
4664  if (AssumedSPMDCount || KnownSPMDCount) {
4665  assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
4666  "Expected only SPMD kernels!");
4667  SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
4668  } else {
4669  assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
4670  "Expected only non-SPMD kernels!");
4671  SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
4672  }
4673  return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4675  }
4676 
4677  ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
4678  // Specialize only if all the calls agree with the attribute constant value
4679  int32_t CurrentAttrValue = -1;
4680  Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4681 
4682  auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4683  *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4684 
4685  if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4686  return indicatePessimisticFixpoint();
4687 
4688  // Iterate over the kernels that reach this function
4689  for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4690  int32_t NextAttrVal = -1;
4691  if (K->hasFnAttribute(Attr))
4692  NextAttrVal =
4693  std::stoi(K->getFnAttribute(Attr).getValueAsString().str());
4694 
4695  if (NextAttrVal == -1 ||
4696  (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
4697  return indicatePessimisticFixpoint();
4698  CurrentAttrValue = NextAttrVal;
4699  }
4700 
4701  if (CurrentAttrValue != -1) {
4702  auto &Ctx = getAnchorValue().getContext();
4703  SimplifiedValue =
4704  ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
4705  }
4706  return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4708  }
4709 
4710  /// An optional value the associated value is assumed to fold to. That is, we
4711  /// assume the associated value (which is a call) can be replaced by this
4712  /// simplified value.
4713  Optional<Value *> SimplifiedValue;
4714 
4715  /// The runtime function kind of the callee of the associated call site.
4716  RuntimeFunction RFKind;
4717 };
4718 
4719 } // namespace
4720 
4721 /// Register folding callsite
4722 void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
4723  auto &RFI = OMPInfoCache.RFIs[RF];
4724  RFI.foreachUse(SCC, [&](Use &U, Function &F) {
4725  CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
4726  if (!CI)
4727  return false;
4728  A.getOrCreateAAFor<AAFoldRuntimeCall>(
4729  IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
4730  DepClassTy::NONE, /* ForceUpdate */ false,
4731  /* UpdateAfterInit */ false);
4732  return false;
4733  });
4734 }
4735 
4736 void OpenMPOpt::registerAAs(bool IsModulePass) {
4737  if (SCC.empty())
4738  return;
4739 
4740  if (IsModulePass) {
4741  // Ensure we create the AAKernelInfo AAs first and without triggering an
4742  // update. This will make sure we register all value simplification
4743  // callbacks before any other AA has the chance to create an AAValueSimplify
4744  // or similar.
4745  auto CreateKernelInfoCB = [&](Use &, Function &Kernel) {
4746  A.getOrCreateAAFor<AAKernelInfo>(
4747  IRPosition::function(Kernel), /* QueryingAA */ nullptr,
4748  DepClassTy::NONE, /* ForceUpdate */ false,
4749  /* UpdateAfterInit */ false);
4750  return false;
4751  };
4752  OMPInformationCache::RuntimeFunctionInfo &InitRFI =
4753  OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
4754  InitRFI.foreachUse(SCC, CreateKernelInfoCB);
4755 
4756  registerFoldRuntimeCall(OMPRTL___kmpc_is_generic_main_thread_id);
4757  registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
4758  registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
4759  registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
4760  registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
4761  }
4762 
4763  // Create CallSite AA for all Getters.
4764  for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
4765  auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
4766 
4767  auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
4768 
4769  auto CreateAA = [&](Use &U, Function &Caller) {
4770  CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
4771  if (!CI)
4772  return false;
4773 
4774  auto &CB = cast<CallBase>(*CI);
4775 
4777  A.getOrCreateAAFor<AAICVTracker>(CBPos);
4778  return false;
4779  };
4780 
4781  GetterRFI.foreachUse(SCC, CreateAA);
4782  }
4783  auto &GlobalizationRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4784  auto CreateAA = [&](Use &U, Function &F) {
4785  A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
4786  return false;
4787  };
4789  GlobalizationRFI.foreachUse(SCC, CreateAA);
4790 
4791  // Create an ExecutionDomain AA for every function and a HeapToStack AA for
4792  // every function if there is a device kernel.
4793  if (!isOpenMPDevice(M))
4794  return;
4795 
4796  for (auto *F : SCC) {
4797  if (F->isDeclaration())
4798  continue;
4799 
4800  A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(*F));
4802  A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(*F));
4803 
4804  for (auto &I : instructions(*F)) {
4805  if (auto *LI = dyn_cast<LoadInst>(&I)) {
4806  bool UsedAssumedInformation = false;
4807  A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
4808  UsedAssumedInformation);
4809  } else if (auto *SI = dyn_cast<StoreInst>(&I)) {
4810  A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI));
4811  }
4812  }
4813  }
4814 }
4815 
4816 const char AAICVTracker::ID = 0;
4817 const char AAKernelInfo::ID = 0;
4818 const char AAExecutionDomain::ID = 0;
4819 const char AAHeapToShared::ID = 0;
4820 const char AAFoldRuntimeCall::ID = 0;
4821 
4822 AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
4823  Attributor &A) {
4824  AAICVTracker *AA = nullptr;
4825  switch (IRP.getPositionKind()) {
4827  case IRPosition::IRP_FLOAT:
4830  llvm_unreachable("ICVTracker can only be created for function position!");
4832  AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
4833  break;
4835  AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
4836  break;
4838  AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
4839  break;
4841  AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
4842  break;
4843  }
4844 
4845  return *AA;
4846 }
4847 
4849  Attributor &A) {
4850  AAExecutionDomainFunction *AA = nullptr;
4851  switch (IRP.getPositionKind()) {
4853  case IRPosition::IRP_FLOAT:
4860  "AAExecutionDomain can only be created for function position!");
4862  AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
4863  break;
4864  }
4865 
4866  return *AA;
4867 }
4868 
4869 AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
4870  Attributor &A) {
4871  AAHeapToSharedFunction *AA = nullptr;
4872  switch (IRP.getPositionKind()) {
4874  case IRPosition::IRP_FLOAT:
4881  "AAHeapToShared can only be created for function position!");
4883  AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
4884  break;
4885  }
4886 
4887  return *AA;
4888 }
4889 
4890 AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
4891  Attributor &A) {
4892  AAKernelInfo *AA = nullptr;
4893  switch (IRP.getPositionKind()) {
4895  case IRPosition::IRP_FLOAT:
4900  llvm_unreachable("KernelInfo can only be created for function position!");
4902  AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
4903  break;
4905  AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
4906  break;
4907  }
4908 
4909  return *AA;
4910 }
4911 
4912 AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
4913  Attributor &A) {
4914  AAFoldRuntimeCall *AA = nullptr;
4915  switch (IRP.getPositionKind()) {
4917  case IRPosition::IRP_FLOAT:
4923  llvm_unreachable("KernelInfo can only be created for call site position!");
4925  AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
4926  break;
4927  }
4928 
4929  return *AA;
4930 }
4931 
4933  if (!containsOpenMP(M))
4934  return PreservedAnalyses::all();
4936  return PreservedAnalyses::all();
4937 
4941 
4943  LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt Module Pass:\n" << M);
4944 
4945  auto IsCalled = [&](Function &F) {
4946  if (Kernels.contains(&F))
4947  return true;
4948  for (const User *U : F.users())
4949  if (!isa<BlockAddress>(U))
4950  return true;
4951  return false;
4952  };
4953 
4954  auto EmitRemark = [&](Function &F) {
4956  ORE.emit([&]() {
4957  OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
4958  return ORA << "Could not internalize function. "
4959  << "Some optimizations may not be possible. [OMP140]";
4960  });
4961  };
4962 
4963  // Create internal copies of each function if this is a kernel Module. This
4964  // allows iterprocedural passes to see every call edge.
4965  DenseMap<Function *, Function *> InternalizedMap;
4966  if (isOpenMPDevice(M)) {
4967  SmallPtrSet<Function *, 16> InternalizeFns;
4968  for (Function &F : M)
4969  if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
4972  InternalizeFns.insert(&F);
4973  } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
4974  EmitRemark(F);
4975  }
4976  }
4977 
4978  Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
4979  }
4980 
4981  // Look at every function in the Module unless it was internalized.
4983  for (Function &F : M)
4984  if (!F.isDeclaration() && !InternalizedMap.lookup(&F))
4985  SCC.push_back(&F);
4986 
4987  if (SCC.empty())
4988  return PreservedAnalyses::all();
4989 
4990  AnalysisGetter AG(FAM);
4991 
4992  auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
4994  };
4995 
4997  CallGraphUpdater CGUpdater;
4998 
4999  SetVector<Function *> Functions(SCC.begin(), SCC.end());
5000  OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions, Kernels);
5001 
5002  unsigned MaxFixpointIterations =
5004 
5005  AttributorConfig AC(CGUpdater);
5006  AC.DefaultInitializeLiveInternals = false;
5007  AC.RewriteSignatures = false;
5008  AC.MaxFixpointIterations = MaxFixpointIterations;
5009  AC.OREGetter = OREGetter;
5010  AC.PassName = DEBUG_TYPE;
5011 
5012  Attributor A(Functions, InfoCache, AC);
5013 
5014  OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5015  bool Changed = OMPOpt.run(true);
5016 
5017  // Optionally inline device functions for potentially better performance.
5019  for (Function &F : M)
5020  if (!F.isDeclaration() && !Kernels.contains(&F) &&
5021  !F.hasFnAttribute(Attribute::NoInline))
5022  F.addFnAttr(Attribute::AlwaysInline);
5023 
5025  LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M);
5026 
5027  if (Changed)
5028  return PreservedAnalyses::none();
5029 
5030  return PreservedAnalyses::all();
5031 }
5032 
5035  LazyCallGraph &CG,
5036  CGSCCUpdateResult &UR) {
5037  if (!containsOpenMP(*C.begin()->getFunction().getParent()))
5038  return PreservedAnalyses::all();
5040  return PreservedAnalyses::all();
5041 
5043  // If there are kernels in the module, we have to run on all SCC's.
5044  for (LazyCallGraph::Node &N : C) {
5045  Function *Fn = &N.getFunction();
5046  SCC.push_back(Fn);
5047  }
5048 
5049  if (SCC.empty())
5050  return PreservedAnalyses::all();
5051 
5052  Module &M = *C.begin()->getFunction().getParent();
5053 
5055  LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt CGSCC Pass:\n" << M);
5056 
5058 
5060  AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
5061 
5062  AnalysisGetter AG(FAM);
5063 
5064  auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5066  };
5067 
5069  CallGraphUpdater CGUpdater;
5070  CGUpdater.initialize(CG, C, AM, UR);
5071 
5072  SetVector<Function *> Functions(SCC.begin(), SCC.end());
5073  OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
5074  /*CGSCC*/ Functions, Kernels);
5075 
5076  unsigned MaxFixpointIterations =
5078 
5079  AttributorConfig AC(CGUpdater);
5080  AC.DefaultInitializeLiveInternals = false;
5081  AC.IsModulePass = false;
5082  AC.RewriteSignatures = false;
5083  AC.MaxFixpointIterations = MaxFixpointIterations;
5084  AC.OREGetter = OREGetter;
5085  AC.PassName = DEBUG_TYPE;
5086 
5087  Attributor A(Functions, InfoCache, AC);
5088 
5089  OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5090  bool Changed = OMPOpt.run(false);
5091 
5093  LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
5094 
5095  if (Changed)
5096  return PreservedAnalyses::none();
5097 
5098  return PreservedAnalyses::all();
5099 }
5100 
5101 namespace {
5102 
5103 struct OpenMPOptCGSCCLegacyPass : public CallGraphSCCPass {
5104  CallGraphUpdater CGUpdater;
5105  static char ID;
5106 
5107  OpenMPOptCGSCCLegacyPass() : CallGraphSCCPass(ID) {
5109  }
5110 
5111  void getAnalysisUsage(AnalysisUsage &AU) const override {
5113  }
5114 
5115  bool runOnSCC(CallGraphSCC &CGSCC) override {
5116  if (!containsOpenMP(CGSCC.getCallGraph().getModule()))
5117  return false;
5118  if (DisableOpenMPOptimizations || skipSCC(CGSCC))
5119  return false;
5120 
5122  // If there are kernels in the module, we have to run on all SCC's.
5123  for (CallGraphNode *CGN : CGSCC) {
5124  Function *Fn = CGN->getFunction();
5125  if (!Fn || Fn->isDeclaration())
5126  continue;
5127  SCC.push_back(Fn);
5128  }
5129 
5130  if (SCC.empty())
5131  return false;
5132 
5133  Module &M = CGSCC.getCallGraph().getModule();
5135 
5136  CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
5137  CGUpdater.initialize(CG, CGSCC);
5138 
5139  // Maintain a map of functions to avoid rebuilding the ORE
5141  auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & {
5142  std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F];
5143  if (!ORE)
5144  ORE = std::make_unique<OptimizationRemarkEmitter>(F);
5145  return *ORE;
5146  };
5147 
5148  AnalysisGetter AG;
5149  SetVector<Function *> Functions(SCC.begin(), SCC.end());
5151  OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG,
5152  Allocator,
5153  /*CGSCC*/ Functions, Kernels);
5154 
5155  unsigned MaxFixpointIterations =
5157 
5158  AttributorConfig AC(CGUpdater);
5159  AC.DefaultInitializeLiveInternals = false;
5160  AC.IsModulePass = false;
5161  AC.RewriteSignatures = false;
5162  AC.MaxFixpointIterations = MaxFixpointIterations;
5163  AC.OREGetter = OREGetter;
5164  AC.PassName = DEBUG_TYPE;
5165 
5166  Attributor A(Functions, InfoCache, AC);
5167 
5168  OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5169  bool Result = OMPOpt.run(false);
5170 
5172  LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
5173 
5174  return Result;
5175  }
5176 
5177  bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); }
5178 };
5179 
5180 } // end anonymous namespace
5181 
5183  // TODO: Create a more cross-platform way of determining device kernels.
5184  NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
5186 
5187  if (!MD)
5188  return Kernels;
5189 
5190  for (auto *Op : MD->operands()) {
5191  if (Op->getNumOperands() < 2)
5192  continue;
5193  MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
5194  if (!KindID || KindID->getString() != "kernel")
5195  continue;
5196 
5197  Function *KernelFn =
5198  mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
5199  if (!KernelFn)
5200  continue;
5201 
5202  ++NumOpenMPTargetRegionKernels;
5203 
5204  Kernels.insert(KernelFn);
5205  }
5206 
5207  return Kernels;
5208 }
5209 
5211  Metadata *MD = M.getModuleFlag("openmp");
5212  if (!MD)
5213  return false;
5214 
5215  return true;
5216 }
5217 
5219  Metadata *MD = M.getModuleFlag("openmp-device");
5220  if (!MD)
5221  return false;
5222 
5223  return true;
5224 }
5225 
5227 
5228 INITIALIZE_PASS_BEGIN(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
5229  "OpenMP specific optimizations", false, false)
5231 INITIALIZE_PASS_END(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
5232  "OpenMP specific optimizations", false, false)
5233 
5235  return new OpenMPOptCGSCCLegacyPass();
5236 }
llvm::AA::isValidAtPosition
bool isValidAtPosition(const ValueAndContext &VAC, InformationCache &InfoCache)
Return true if the value of VAC is a valid at the position of VAC, that is a constant,...
Definition: Attributor.cpp:247
llvm::PreservedAnalyses
A set of analyses that are preserved following a run of a transformation pass.
Definition: PassManager.h:152
llvm::CallGraphUpdater
Wrapper to unify "old style" CallGraph and "new style" LazyCallGraph.
Definition: CallGraphUpdater.h:29
EnableVerboseRemarks
static cl::opt< bool > EnableVerboseRemarks("openmp-opt-verbose-remarks", cl::desc("Enables more verbose remarks."), cl::Hidden, cl::init(false))
llvm::Argument
This class represents an incoming formal argument to a Function.
Definition: Argument.h:28
llvm::IRPosition::function
static const IRPosition function(const Function &F, const CallBaseContext *CBContext=nullptr)
Create a position describing the function scope of F.
Definition: Attributor.h:450
llvm::BasicBlock::end
iterator end()
Definition: BasicBlock.h:299
llvm::StringRef::startswith
LLVM_NODISCARD bool startswith(StringRef Prefix) const
Check if this string starts with the given Prefix.
Definition: StringRef.h:290
llvm::OptimizationRemarkMissed
Diagnostic information for missed-optimization remarks.
Definition: DiagnosticInfo.h:735
llvm::OpenMPIRBuilder::LocationDescription
Description of a LLVM-IR insertion point (IP) and a debug/source location (filename,...
Definition: OMPIRBuilder.h:202
getName
static StringRef getName(Value *V)
Definition: ProvenanceAnalysisEvaluator.cpp:42
SetFixpointIterations
static cl::opt< unsigned > SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden, cl::desc("Maximal number of attributor iterations."), cl::init(256))
MI
IRTranslator LLVM IR MI
Definition: IRTranslator.cpp:104
Merge
R600 Clause Merge
Definition: R600ClauseMergePass.cpp:70
llvm
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:17
BlockSize
static const int BlockSize
Definition: TarWriter.cpp:33
llvm::CastInst::CreatePointerBitCastOrAddrSpaceCast
static CastInst * CreatePointerBitCastOrAddrSpaceCast(Value *S, Type *Ty, const Twine &Name, BasicBlock *InsertAtEnd)
Create a BitCast or an AddrSpaceCast cast instruction.
Definition: Instructions.cpp:3311
M
We currently emits eax Perhaps this is what we really should generate is Is imull three or four cycles eax eax The current instruction priority is based on pattern complexity The former is more complex because it folds a load so the latter will not be emitted Perhaps we should use AddedComplexity to give LEA32r a higher priority We should always try to match LEA first since the LEA matching code does some estimate to determine whether the match is profitable if we care more about code then imull is better It s two bytes shorter than movl leal On a Pentium M
Definition: README.txt:252
llvm::Instruction::getModule
const Module * getModule() const
Return the module owning the function this instruction belongs to or nullptr it the function does not...
Definition: Instruction.cpp:65
llvm::CmpInst::ICMP_EQ
@ ICMP_EQ
equal
Definition: InstrTypes.h:740
llvm::NamedMDNode
A tuple of MDNodes.
Definition: Metadata.h:1572
llvm::drop_begin
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition: STLExtras.h:280
OpenMPOpt.h
llvm::SPIRV::LinkageType
LinkageType
Definition: SPIRVBaseInfo.h:431
llvm::DataLayout
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:113
llvm::ISD::OR
@ OR
Definition: ISDOpcodes.h:667
llvm::Value::hasOneUse
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
llvm::Type::getInt8PtrTy
static PointerType * getInt8PtrTy(LLVMContext &C, unsigned AS=0)
Definition: Type.cpp:291
DisableOpenMPOptimizations
static cl::opt< bool > DisableOpenMPOptimizations("openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."), cl::Hidden, cl::init(false))
llvm::BasicBlock::getParent
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:104
IntrinsicInst.h
llvm::AnalysisManager::getResult
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:780
llvm::cl::Prefix
@ Prefix
Definition: CommandLine.h:160
llvm::MemTransferInst
This class wraps the llvm.memcpy/memmove intrinsics.
Definition: IntrinsicInst.h:1022
llvm::Function
Definition: Function.h:60
DisableOpenMPOptDeglobalization
static cl::opt< bool > DisableOpenMPOptDeglobalization("openmp-opt-disable-deglobalization", cl::desc("Disable OpenMP optimizations involving deglobalization."), cl::Hidden, cl::init(false))
llvm::Attribute
Definition: Attributes.h:65
llvm::DenseMapBase< DenseMap< KeyT, ValueT, DenseMapInfo< KeyT >, llvm::detail::DenseMapPair< KeyT, ValueT > >, KeyT, ValueT, DenseMapInfo< KeyT >, llvm::detail::DenseMapPair< KeyT, ValueT > >::lookup
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition: DenseMap.h:199
StringRef.h
P
This currently compiles esp xmm0 movsd esp eax eax esp ret We should use not the dag combiner This is because dagcombine2 needs to be able to see through the X86ISD::Wrapper which DAGCombine can t really do The code for turning x load into a single vector load is target independent and should be moved to the dag combiner The code for turning x load into a vector load can only handle a direct load from a global or a direct load from the stack It should be generalized to handle any load from P
Definition: README-SSE.txt:411
TAG
static constexpr auto TAG
Definition: OpenMPOpt.cpp:172
llvm::raw_string_ostream
A raw_ostream that writes to an std::string.
Definition: raw_ostream.h:632
llvm::SetVector< T, SmallVector< T, N >, SmallDenseSet< T, N > >::size
size_type size() const
Determine the number of elements in the SetVector.
Definition: SetVector.h:77
contains
return AArch64::GPR64RegClass contains(Reg)
PrintModuleAfterOptimizations
static cl::opt< bool > PrintModuleAfterOptimizations("openmp-opt-print-module-after", cl::desc("Print the current module after OpenMP optimizations."), cl::Hidden, cl::init(false))
llvm::GlobalValue::NotThreadLocal
@ NotThreadLocal
Definition: GlobalValue.h:184
llvm::ilist_node_with_parent::getNextNode
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition: ilist_node.h:289
llvm::SmallVector
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1185
Statistic.h
llvm::initializeOpenMPOptCGSCCLegacyPassPass
void initializeOpenMPOptCGSCCLegacyPassPass(PassRegistry &)
llvm::IRPosition::value
static const IRPosition value(const Value &V, const CallBaseContext *CBContext=nullptr)
Create a position describing the value of V.
Definition: Attributor.h:431
llvm::OpenMPOptCGSCCPass::run
PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR)
Definition: OpenMPOpt.cpp:5033
llvm::Type::getPointerAddressSpace
unsigned getPointerAddressSpace() const
Get the address space of this pointer or pointer vector type.
Definition: DerivedTypes.h:729
llvm::Function::getEntryBlock
const BasicBlock & getEntryBlock() const
Definition: Function.h:710
llvm::OpenMPIRBuilder::InsertPointTy
IRBuilder<>::InsertPoint InsertPointTy
Type used throughout for insertion points.
Definition: OMPIRBuilder.h:96
llvm::IRBuilder
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2491
llvm::StateWrapper
Helper to tie a abstract state implementation to an abstract attribute.
Definition: Attributor.h:2790
llvm::GlobalVariable
Definition: GlobalVariable.h:39
llvm::SmallDenseMap
Definition: DenseMap.h:882
llvm::CmpInst::ICMP_NE
@ ICMP_NE
not equal
Definition: InstrTypes.h:741
llvm::FunctionType::get
static FunctionType * get(Type *Result, ArrayRef< Type * > Params, bool isVarArg)
This static method is the primary way of constructing a FunctionType.
Definition: Type.cpp:361
ValueTracking.h
OptimizationRemarkEmitter.h
llvm::CallGraph
The basic data container for the call graph of a Module of IR.
Definition: CallGraph.h:72
llvm::DominatorTree
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:166
FAM
FunctionAnalysisManager FAM
Definition: PassBuilderBindings.cpp:59
llvm::cl::Hidden
@ Hidden
Definition: CommandLine.h:139
llvm::DenseMapBase< DenseMap< KeyT, ValueT, DenseMapInfo< KeyT >, llvm::detail::DenseMapPair< KeyT, ValueT > >, KeyT, ValueT, DenseMapInfo< KeyT >, llvm::detail::DenseMapPair< KeyT, ValueT > >::begin
iterator begin()
Definition: DenseMap.h:75
llvm::PreservedAnalyses::none
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Definition: PassManager.h:155
llvm::Type
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
llvm::sys::path::end
const_iterator end(StringRef path)
Get end iterator over path.
Definition: Path.cpp:235
llvm::sys::path::begin
const_iterator begin(StringRef path, Style style=Style::native)
Get begin iterator over path.
Definition: Path.cpp:226
llvm::CallBase::isCallee
bool isCallee(Value::const_user_iterator UI) const
Determine whether the passed iterator points to the callee operand's Use.
Definition: InstrTypes.h:1407
llvm::Use::getOperandNo
unsigned getOperandNo() const
Return the operand # of this use in its User.
Definition: Use.cpp:31
llvm::BasicBlock::splitBasicBlock
BasicBlock * splitBasicBlock(iterator I, const Twine &BBName="", bool Before=false)
Split the basic block into two basic blocks at the specified instruction.
Definition: BasicBlock.cpp:378
llvm::operator!=
bool operator!=(uint64_t V1, const APInt &V2)
Definition: APInt.h:1992
llvm::InformationCache
Data structure to hold cached (LLVM-IR) information.
Definition: Attributor.h:976
llvm::MemIntrinsic
This is the common base class for memset/memcpy/memmove.
Definition: IntrinsicInst.h:957
llvm::AttributorConfig
Configuration for the Attributor.
Definition: Attributor.h:1222
llvm::Optional
Definition: APInt.h:33
llvm::SmallPtrSet< Instruction *, 4 >
llvm::ore::NV
DiagnosticInfoOptimizationBase::Argument NV
Definition: OptimizationRemarkEmitter.h:136
llvm::tgtok::FalseVal
@ FalseVal
Definition: TGLexer.h:62
llvm::omp::AddressSpace::Constant
@ Constant
llvm::max
Expected< ExpressionValue > max(const ExpressionValue &Lhs, const ExpressionValue &Rhs)
Definition: FileCheck.cpp:337
llvm::IRPosition::IRP_ARGUMENT
@ IRP_ARGUMENT
An attribute for a function argument.
Definition: Attributor.h:421
llvm::IRPosition::IRP_RETURNED
@ IRP_RETURNED
An attribute for the function return value.
Definition: Attributor.h:417
llvm::MipsISD::Ret
@ Ret
Definition: MipsISelLowering.h:119
initialize
static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, ArrayRef< StringLiteral > StandardNames)
Initialize the set of available library functions based on the specified target triple.
Definition: TargetLibraryInfo.cpp:116
llvm::AbstractCallSite
AbstractCallSite.
Definition: AbstractCallSite.h:50
llvm::SmallVectorImpl::pop_back_val
LLVM_NODISCARD T pop_back_val()
Definition: SmallVector.h:654
RHS
Value * RHS
Definition: X86PartialReduction.cpp:76
llvm::GlobalValue::LinkageTypes
LinkageTypes
An enumeration for the kinds of linkage for global values.
Definition: GlobalValue.h:47
llvm::AAExecutionDomain::ID
static const char ID
Unique ID (due to the unique address)
Definition: Attributor.h:4781
llvm::CallBase::addParamAttr
void addParamAttr(unsigned ArgNo, Attribute::AttrKind Kind)
Adds the attribute to the indicated argument.
Definition: InstrTypes.h:1526
DisableOpenMPOptBarrierElimination
static cl::opt< bool > DisableOpenMPOptBarrierElimination("openmp-opt-disable-barrier-elimination", cl::desc("Disable OpenMP optimizations that eliminate barriers."), cl::Hidden, cl::init(false))
llvm::Type::getInt8Ty
static IntegerType * getInt8Ty(LLVMContext &C)
Definition: Type.cpp:237
llvm::IRPosition::IRP_FLOAT
@ IRP_FLOAT
A position that is not associated with a spot suitable for attributes.
Definition: Attributor.h:415
llvm::Type::getInt32Ty
static IntegerType * getInt32Ty(LLVMContext &C)
Definition: Type.cpp:239
LLVM_DEBUG
#define LLVM_DEBUG(X)
Definition: Debug.h:101
llvm::Instruction::mayHaveSideEffects
bool mayHaveSideEffects() const
Return true if the instruction may have side effects.
Definition: Instruction.cpp:695
PrintICVValues
static cl::opt< bool > PrintICVValues("openmp-print-icv-values", cl::init(false), cl::Hidden)
F
#define F(x, y, z)
Definition: MD5.cpp:55
llvm::IntegerStateBase< bool, true, false >::operator^=
void operator^=(const IntegerStateBase< base_t, BestState, WorstState > &R)
"Clamp" this state with R.
Definition: Attributor.h:2316
llvm::ConstantExpr::getPointerCast
static Constant * getPointerCast(Constant *C, Type *Ty)
Create a BitCast, AddrSpaceCast, or a PtrToInt cast constant expression.
Definition: Constants.cpp:2069
llvm::BasicBlock::getUniqueSuccessor
const BasicBlock * getUniqueSuccessor() const
Return the successor of this block if it has a unique successor.
Definition: BasicBlock.cpp:299
DisableOpenMPOptSPMDization
static cl::opt< bool > DisableOpenMPOptSPMDization("openmp-opt-disable-spmdization", cl::desc("Disable OpenMP optimizations involving SPMD-ization."), cl::Hidden, cl::init(false))
llvm::BasicBlock
LLVM Basic Block Representation.
Definition: BasicBlock.h:55
llvm::AMDGPU::isKernel
LLVM_READNONE bool isKernel(CallingConv::ID CC)
Definition: AMDGPUBaseInfo.h:786
llvm::dbgs
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
llvm::CallGraphUpdater::finalize
bool finalize()
}
Definition: CallGraphUpdater.cpp:24
Arg
amdgpu Simplify well known AMD library false FunctionCallee Value * Arg
Definition: AMDGPULibCalls.cpp:186
Instruction.h
llvm::AMDGPU::HSAMD::Key::Kernels
constexpr char Kernels[]
Key for HSA::Metadata::mKernels.
Definition: AMDGPUMetadata.h:432
CommandLine.h
Printer
print alias Alias Set Printer
Definition: AliasSetTracker.cpp:733
llvm::ConstantInt
This is the shared class of boolean and integer constants.
Definition: Constants.h:79
llvm::all_of
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1617
llvm::operator&=
bool operator&=(SparseBitVector< ElementSize > *LHS, const SparseBitVector< ElementSize > &RHS)
Definition: SparseBitVector.h:835
GlobalValue.h
llvm::CallGraphSCC
CallGraphSCC - This is a single SCC that a CallGraphSCCPass is run on.
Definition: CallGraphSCCPass.h:87
llvm::omp::OMPScheduleType::OrderedDistribute
@ OrderedDistribute
llvm::AddrSpaceCastInst
This class represents a conversion between pointers from one address space to another.
Definition: Instructions.h:5252
llvm::SetVector::begin
iterator begin()
Get an iterator to the beginning of the SetVector.
Definition: SetVector.h:82
llvm::PassRegistry::getPassRegistry
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Definition: PassRegistry.cpp:31
llvm::LazyCallGraph::SCC
An SCC of the call graph.
Definition: LazyCallGraph.h:419
llvm::GlobalValue::isDeclaration
bool isDeclaration() const
Return true if the primary definition of this global value is outside of the current translation unit...
Definition: Globals.cpp:264
llvm::MutableArrayRef
MutableArrayRef - Represent a mutable reference to an array (0 or more elements consecutively in memo...
Definition: ArrayRef.h:306
OMPIRBuilder.h
Constants.h
llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
@ OMP_TGT_EXEC_MODE_GENERIC
Definition: OMPConstants.h:189
llvm::ChangeStatus
ChangeStatus
{
Definition: Attributor.h:313
E
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
llvm::CallGraphUpdater::removeCallSite
void removeCallSite(CallBase &CS)
Remove the call site CS from the call graph.
Definition: CallGraphUpdater.cpp:162
llvm::OpenMPIRBuilder
An interface to create LLVM-IR for OpenMP directives.
Definition: OMPIRBuilder.h:75
llvm::User
Definition: User.h:44
INITIALIZE_PASS_BEGIN
INITIALIZE_PASS_BEGIN(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc", "OpenMP specific optimizations", false, false) INITIALIZE_PASS_END(OpenMPOptCGSCCLegacyPass
llvm::AAExecutionDomain::isExecutedByInitialThreadOnly
virtual bool isExecutedByInitialThreadOnly(const Instruction &) const =0
Check if an instruction is executed only by the initial thread.
llvm::IRPosition::inst
static const IRPosition inst(const Instruction &I, const CallBaseContext *CBContext=nullptr)
Create a position describing the instruction I.
Definition: Attributor.h:443
C
(vector float) vec_cmpeq(*A, *B) C
Definition: README_ALTIVEC.txt:86
llvm::ARM_PROC::A
@ A
Definition: ARMBaseInfo.h:34
llvm::IntegerStateBase< bool, true, false >::operator==
bool operator==(const IntegerStateBase< base_t, BestState, WorstState > &R) const
Equality for IntegerStateBase.
Definition: Attributor.h:2302
llvm::CallBase::getCalledFunction
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1396
HideMemoryTransferLatency
static cl::opt< bool > HideMemoryTransferLatency("openmp-hide-memory-transfer-latency", cl::desc("[WIP] Tries to hide the latency of host to device memory" &quo