Bug Summary

File:llvm/lib/Transforms/IPO/OpenMPOpt.cpp
Warning:line 3718, column 9
Called C++ object pointer is null

Annotated Source Code

Press '?' to see keyboard shortcuts

clang -cc1 -cc1 -triple x86_64-pc-linux-gnu -analyze -disable-free -disable-llvm-verifier -discard-value-names -main-file-name OpenMPOpt.cpp -analyzer-store=region -analyzer-opt-analyze-nested-blocks -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -setup-static-analyzer -analyzer-config-compatibility-mode=true -mrelocation-model pic -pic-level 2 -mframe-pointer=none -fmath-errno -fno-rounding-math -mconstructor-aliases -munwind-tables -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -ffunction-sections -fdata-sections -fcoverage-compilation-dir=/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/build-llvm -resource-dir /usr/lib/llvm-14/lib/clang/14.0.0 -D _DEBUG -D _GNU_SOURCE -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -D __STDC_LIMIT_MACROS -I lib/Transforms/IPO -I /build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO -I include -I /build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/include -D NDEBUG -U NDEBUG -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/x86_64-linux-gnu/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/backward -internal-isystem /usr/lib/llvm-14/lib/clang/14.0.0/include -internal-isystem /usr/local/include -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../x86_64-linux-gnu/include -internal-externc-isystem /usr/include/x86_64-linux-gnu -internal-externc-isystem /include -internal-externc-isystem /usr/include -O2 -Wno-unused-command-line-argument -Wno-unknown-warning-option -Wno-unused-parameter -Wwrite-strings -Wno-missing-field-initializers -Wno-long-long -Wno-maybe-uninitialized -Wno-class-memaccess -Wno-redundant-move -Wno-pessimizing-move -Wno-noexcept-type -Wno-comment -std=c++14 -fdeprecated-macro -fdebug-compilation-dir=/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/build-llvm -ferror-limit 19 -fvisibility-inlines-hidden -fgnuc-version=4.2.1 -fcolor-diagnostics -vectorize-loops -vectorize-slp -analyzer-output=html -analyzer-config stable-report-filename=true -faddrsig -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /tmp/scan-build-2021-09-26-234817-15343-1 -x c++ /build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp

/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp

1//===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// OpenMP specific optimizations:
10//
11// - Deduplication of runtime calls, e.g., omp_get_thread_num.
12// - Replacing globalized device memory with stack memory.
13// - Replacing globalized device memory with shared memory.
14// - Parallel region merging.
15// - Transforming generic-mode device kernels to SPMD mode.
16// - Specializing the state machine for generic-mode device kernels.
17//
18//===----------------------------------------------------------------------===//
19
20#include "llvm/Transforms/IPO/OpenMPOpt.h"
21
22#include "llvm/ADT/EnumeratedArray.h"
23#include "llvm/ADT/PostOrderIterator.h"
24#include "llvm/ADT/Statistic.h"
25#include "llvm/Analysis/CallGraph.h"
26#include "llvm/Analysis/CallGraphSCCPass.h"
27#include "llvm/Analysis/OptimizationRemarkEmitter.h"
28#include "llvm/Analysis/ValueTracking.h"
29#include "llvm/Frontend/OpenMP/OMPConstants.h"
30#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
31#include "llvm/IR/Assumptions.h"
32#include "llvm/IR/DiagnosticInfo.h"
33#include "llvm/IR/GlobalValue.h"
34#include "llvm/IR/Instruction.h"
35#include "llvm/IR/IntrinsicInst.h"
36#include "llvm/IR/IntrinsicsAMDGPU.h"
37#include "llvm/IR/IntrinsicsNVPTX.h"
38#include "llvm/InitializePasses.h"
39#include "llvm/Support/CommandLine.h"
40#include "llvm/Transforms/IPO.h"
41#include "llvm/Transforms/IPO/Attributor.h"
42#include "llvm/Transforms/Utils/BasicBlockUtils.h"
43#include "llvm/Transforms/Utils/CallGraphUpdater.h"
44#include "llvm/Transforms/Utils/CodeExtractor.h"
45
46using namespace llvm;
47using namespace omp;
48
49#define DEBUG_TYPE"openmp-opt" "openmp-opt"
50
51static cl::opt<bool> DisableOpenMPOptimizations(
52 "openmp-opt-disable", cl::ZeroOrMore,
53 cl::desc("Disable OpenMP specific optimizations."), cl::Hidden,
54 cl::init(false));
55
56static cl::opt<bool> EnableParallelRegionMerging(
57 "openmp-opt-enable-merging", cl::ZeroOrMore,
58 cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
59 cl::init(false));
60
61static cl::opt<bool>
62 DisableInternalization("openmp-opt-disable-internalization", cl::ZeroOrMore,
63 cl::desc("Disable function internalization."),
64 cl::Hidden, cl::init(false));
65
66static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
67 cl::Hidden);
68static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
69 cl::init(false), cl::Hidden);
70
71static cl::opt<bool> HideMemoryTransferLatency(
72 "openmp-hide-memory-transfer-latency",
73 cl::desc("[WIP] Tries to hide the latency of host to device memory"
74 " transfers"),
75 cl::Hidden, cl::init(false));
76
77static cl::opt<bool> DisableOpenMPOptDeglobalization(
78 "openmp-opt-disable-deglobalization", cl::ZeroOrMore,
79 cl::desc("Disable OpenMP optimizations involving deglobalization."),
80 cl::Hidden, cl::init(false));
81
82static cl::opt<bool> DisableOpenMPOptSPMDization(
83 "openmp-opt-disable-spmdization", cl::ZeroOrMore,
84 cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
85 cl::Hidden, cl::init(false));
86
87static cl::opt<bool> DisableOpenMPOptFolding(
88 "openmp-opt-disable-folding", cl::ZeroOrMore,
89 cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
90 cl::init(false));
91
92static cl::opt<bool> DisableOpenMPOptStateMachineRewrite(
93 "openmp-opt-disable-state-machine-rewrite", cl::ZeroOrMore,
94 cl::desc("Disable OpenMP optimizations that replace the state machine."),
95 cl::Hidden, cl::init(false));
96
97static cl::opt<bool> PrintModuleAfterOptimizations(
98 "openmp-opt-print-module", cl::ZeroOrMore,
99 cl::desc("Print the current module after OpenMP optimizations."),
100 cl::Hidden, cl::init(false));
101
102static cl::opt<bool> AlwaysInlineDeviceFunctions(
103 "openmp-opt-inline-device", cl::ZeroOrMore,
104 cl::desc("Inline all applicible functions on the device."), cl::Hidden,
105 cl::init(false));
106
107static cl::opt<bool>
108 EnableVerboseRemarks("openmp-opt-verbose-remarks", cl::ZeroOrMore,
109 cl::desc("Enables more verbose remarks."), cl::Hidden,
110 cl::init(false));
111
112STATISTIC(NumOpenMPRuntimeCallsDeduplicated,static llvm::Statistic NumOpenMPRuntimeCallsDeduplicated = {"openmp-opt"
, "NumOpenMPRuntimeCallsDeduplicated", "Number of OpenMP runtime calls deduplicated"
}
113 "Number of OpenMP runtime calls deduplicated")static llvm::Statistic NumOpenMPRuntimeCallsDeduplicated = {"openmp-opt"
, "NumOpenMPRuntimeCallsDeduplicated", "Number of OpenMP runtime calls deduplicated"
}
;
114STATISTIC(NumOpenMPParallelRegionsDeleted,static llvm::Statistic NumOpenMPParallelRegionsDeleted = {"openmp-opt"
, "NumOpenMPParallelRegionsDeleted", "Number of OpenMP parallel regions deleted"
}
115 "Number of OpenMP parallel regions deleted")static llvm::Statistic NumOpenMPParallelRegionsDeleted = {"openmp-opt"
, "NumOpenMPParallelRegionsDeleted", "Number of OpenMP parallel regions deleted"
}
;
116STATISTIC(NumOpenMPRuntimeFunctionsIdentified,static llvm::Statistic NumOpenMPRuntimeFunctionsIdentified = {
"openmp-opt", "NumOpenMPRuntimeFunctionsIdentified", "Number of OpenMP runtime functions identified"
}
117 "Number of OpenMP runtime functions identified")static llvm::Statistic NumOpenMPRuntimeFunctionsIdentified = {
"openmp-opt", "NumOpenMPRuntimeFunctionsIdentified", "Number of OpenMP runtime functions identified"
}
;
118STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,static llvm::Statistic NumOpenMPRuntimeFunctionUsesIdentified
= {"openmp-opt", "NumOpenMPRuntimeFunctionUsesIdentified", "Number of OpenMP runtime function uses identified"
}
119 "Number of OpenMP runtime function uses identified")static llvm::Statistic NumOpenMPRuntimeFunctionUsesIdentified
= {"openmp-opt", "NumOpenMPRuntimeFunctionUsesIdentified", "Number of OpenMP runtime function uses identified"
}
;
120STATISTIC(NumOpenMPTargetRegionKernels,static llvm::Statistic NumOpenMPTargetRegionKernels = {"openmp-opt"
, "NumOpenMPTargetRegionKernels", "Number of OpenMP target region entry points (=kernels) identified"
}
121 "Number of OpenMP target region entry points (=kernels) identified")static llvm::Statistic NumOpenMPTargetRegionKernels = {"openmp-opt"
, "NumOpenMPTargetRegionKernels", "Number of OpenMP target region entry points (=kernels) identified"
}
;
122STATISTIC(NumOpenMPTargetRegionKernelsSPMD,static llvm::Statistic NumOpenMPTargetRegionKernelsSPMD = {"openmp-opt"
, "NumOpenMPTargetRegionKernelsSPMD", "Number of OpenMP target region entry points (=kernels) executed in "
"SPMD-mode instead of generic-mode"}
123 "Number of OpenMP target region entry points (=kernels) executed in "static llvm::Statistic NumOpenMPTargetRegionKernelsSPMD = {"openmp-opt"
, "NumOpenMPTargetRegionKernelsSPMD", "Number of OpenMP target region entry points (=kernels) executed in "
"SPMD-mode instead of generic-mode"}
124 "SPMD-mode instead of generic-mode")static llvm::Statistic NumOpenMPTargetRegionKernelsSPMD = {"openmp-opt"
, "NumOpenMPTargetRegionKernelsSPMD", "Number of OpenMP target region entry points (=kernels) executed in "
"SPMD-mode instead of generic-mode"}
;
125STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,static llvm::Statistic NumOpenMPTargetRegionKernelsWithoutStateMachine
= {"openmp-opt", "NumOpenMPTargetRegionKernelsWithoutStateMachine"
, "Number of OpenMP target region entry points (=kernels) executed in "
"generic-mode without a state machines"}
126 "Number of OpenMP target region entry points (=kernels) executed in "static llvm::Statistic NumOpenMPTargetRegionKernelsWithoutStateMachine
= {"openmp-opt", "NumOpenMPTargetRegionKernelsWithoutStateMachine"
, "Number of OpenMP target region entry points (=kernels) executed in "
"generic-mode without a state machines"}
127 "generic-mode without a state machines")static llvm::Statistic NumOpenMPTargetRegionKernelsWithoutStateMachine
= {"openmp-opt", "NumOpenMPTargetRegionKernelsWithoutStateMachine"
, "Number of OpenMP target region entry points (=kernels) executed in "
"generic-mode without a state machines"}
;
128STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,static llvm::Statistic NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback
= {"openmp-opt", "NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback"
, "Number of OpenMP target region entry points (=kernels) executed in "
"generic-mode with customized state machines with fallback"}
129 "Number of OpenMP target region entry points (=kernels) executed in "static llvm::Statistic NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback
= {"openmp-opt", "NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback"
, "Number of OpenMP target region entry points (=kernels) executed in "
"generic-mode with customized state machines with fallback"}
130 "generic-mode with customized state machines with fallback")static llvm::Statistic NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback
= {"openmp-opt", "NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback"
, "Number of OpenMP target region entry points (=kernels) executed in "
"generic-mode with customized state machines with fallback"}
;
131STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,static llvm::Statistic NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback
= {"openmp-opt", "NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback"
, "Number of OpenMP target region entry points (=kernels) executed in "
"generic-mode with customized state machines without fallback"
}
132 "Number of OpenMP target region entry points (=kernels) executed in "static llvm::Statistic NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback
= {"openmp-opt", "NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback"
, "Number of OpenMP target region entry points (=kernels) executed in "
"generic-mode with customized state machines without fallback"
}
133 "generic-mode with customized state machines without fallback")static llvm::Statistic NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback
= {"openmp-opt", "NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback"
, "Number of OpenMP target region entry points (=kernels) executed in "
"generic-mode with customized state machines without fallback"
}
;
134STATISTIC(static llvm::Statistic NumOpenMPParallelRegionsReplacedInGPUStateMachine
= {"openmp-opt", "NumOpenMPParallelRegionsReplacedInGPUStateMachine"
, "Number of OpenMP parallel regions replaced with ID in GPU state machines"
}
135 NumOpenMPParallelRegionsReplacedInGPUStateMachine,static llvm::Statistic NumOpenMPParallelRegionsReplacedInGPUStateMachine
= {"openmp-opt", "NumOpenMPParallelRegionsReplacedInGPUStateMachine"
, "Number of OpenMP parallel regions replaced with ID in GPU state machines"
}
136 "Number of OpenMP parallel regions replaced with ID in GPU state machines")static llvm::Statistic NumOpenMPParallelRegionsReplacedInGPUStateMachine
= {"openmp-opt", "NumOpenMPParallelRegionsReplacedInGPUStateMachine"
, "Number of OpenMP parallel regions replaced with ID in GPU state machines"
}
;
137STATISTIC(NumOpenMPParallelRegionsMerged,static llvm::Statistic NumOpenMPParallelRegionsMerged = {"openmp-opt"
, "NumOpenMPParallelRegionsMerged", "Number of OpenMP parallel regions merged"
}
138 "Number of OpenMP parallel regions merged")static llvm::Statistic NumOpenMPParallelRegionsMerged = {"openmp-opt"
, "NumOpenMPParallelRegionsMerged", "Number of OpenMP parallel regions merged"
}
;
139STATISTIC(NumBytesMovedToSharedMemory,static llvm::Statistic NumBytesMovedToSharedMemory = {"openmp-opt"
, "NumBytesMovedToSharedMemory", "Amount of memory pushed to shared memory"
}
140 "Amount of memory pushed to shared memory")static llvm::Statistic NumBytesMovedToSharedMemory = {"openmp-opt"
, "NumBytesMovedToSharedMemory", "Amount of memory pushed to shared memory"
}
;
141
142#if !defined(NDEBUG)
143static constexpr auto TAG = "[" DEBUG_TYPE"openmp-opt" "]";
144#endif
145
146namespace {
147
148enum class AddressSpace : unsigned {
149 Generic = 0,
150 Global = 1,
151 Shared = 3,
152 Constant = 4,
153 Local = 5,
154};
155
156struct AAHeapToShared;
157
158struct AAICVTracker;
159
160/// OpenMP specific information. For now, stores RFIs and ICVs also needed for
161/// Attributor runs.
162struct OMPInformationCache : public InformationCache {
163 OMPInformationCache(Module &M, AnalysisGetter &AG,
164 BumpPtrAllocator &Allocator, SetVector<Function *> &CGSCC,
165 SmallPtrSetImpl<Kernel> &Kernels)
166 : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M),
167 Kernels(Kernels) {
168
169 OMPBuilder.initialize();
170 initializeRuntimeFunctions();
171 initializeInternalControlVars();
172 }
173
174 /// Generic information that describes an internal control variable.
175 struct InternalControlVarInfo {
176 /// The kind, as described by InternalControlVar enum.
177 InternalControlVar Kind;
178
179 /// The name of the ICV.
180 StringRef Name;
181
182 /// Environment variable associated with this ICV.
183 StringRef EnvVarName;
184
185 /// Initial value kind.
186 ICVInitValue InitKind;
187
188 /// Initial value.
189 ConstantInt *InitValue;
190
191 /// Setter RTL function associated with this ICV.
192 RuntimeFunction Setter;
193
194 /// Getter RTL function associated with this ICV.
195 RuntimeFunction Getter;
196
197 /// RTL Function corresponding to the override clause of this ICV
198 RuntimeFunction Clause;
199 };
200
201 /// Generic information that describes a runtime function
202 struct RuntimeFunctionInfo {
203
204 /// The kind, as described by the RuntimeFunction enum.
205 RuntimeFunction Kind;
206
207 /// The name of the function.
208 StringRef Name;
209
210 /// Flag to indicate a variadic function.
211 bool IsVarArg;
212
213 /// The return type of the function.
214 Type *ReturnType;
215
216 /// The argument types of the function.
217 SmallVector<Type *, 8> ArgumentTypes;
218
219 /// The declaration if available.
220 Function *Declaration = nullptr;
221
222 /// Uses of this runtime function per function containing the use.
223 using UseVector = SmallVector<Use *, 16>;
224
225 /// Clear UsesMap for runtime function.
226 void clearUsesMap() { UsesMap.clear(); }
227
228 /// Boolean conversion that is true if the runtime function was found.
229 operator bool() const { return Declaration; }
230
231 /// Return the vector of uses in function \p F.
232 UseVector &getOrCreateUseVector(Function *F) {
233 std::shared_ptr<UseVector> &UV = UsesMap[F];
234 if (!UV)
235 UV = std::make_shared<UseVector>();
236 return *UV;
237 }
238
239 /// Return the vector of uses in function \p F or `nullptr` if there are
240 /// none.
241 const UseVector *getUseVector(Function &F) const {
242 auto I = UsesMap.find(&F);
243 if (I != UsesMap.end())
244 return I->second.get();
245 return nullptr;
246 }
247
248 /// Return how many functions contain uses of this runtime function.
249 size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
250
251 /// Return the number of arguments (or the minimal number for variadic
252 /// functions).
253 size_t getNumArgs() const { return ArgumentTypes.size(); }
254
255 /// Run the callback \p CB on each use and forget the use if the result is
256 /// true. The callback will be fed the function in which the use was
257 /// encountered as second argument.
258 void foreachUse(SmallVectorImpl<Function *> &SCC,
259 function_ref<bool(Use &, Function &)> CB) {
260 for (Function *F : SCC)
261 foreachUse(CB, F);
262 }
263
264 /// Run the callback \p CB on each use within the function \p F and forget
265 /// the use if the result is true.
266 void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
267 SmallVector<unsigned, 8> ToBeDeleted;
268 ToBeDeleted.clear();
269
270 unsigned Idx = 0;
271 UseVector &UV = getOrCreateUseVector(F);
272
273 for (Use *U : UV) {
274 if (CB(*U, *F))
275 ToBeDeleted.push_back(Idx);
276 ++Idx;
277 }
278
279 // Remove the to-be-deleted indices in reverse order as prior
280 // modifications will not modify the smaller indices.
281 while (!ToBeDeleted.empty()) {
282 unsigned Idx = ToBeDeleted.pop_back_val();
283 UV[Idx] = UV.back();
284 UV.pop_back();
285 }
286 }
287
288 private:
289 /// Map from functions to all uses of this runtime function contained in
290 /// them.
291 DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap;
292
293 public:
294 /// Iterators for the uses of this runtime function.
295 decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
296 decltype(UsesMap)::iterator end() { return UsesMap.end(); }
297 };
298
299 /// An OpenMP-IR-Builder instance
300 OpenMPIRBuilder OMPBuilder;
301
302 /// Map from runtime function kind to the runtime function description.
303 EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
304 RuntimeFunction::OMPRTL___last>
305 RFIs;
306
307 /// Map from function declarations/definitions to their runtime enum type.
308 DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
309
310 /// Map from ICV kind to the ICV description.
311 EnumeratedArray<InternalControlVarInfo, InternalControlVar,
312 InternalControlVar::ICV___last>
313 ICVs;
314
315 /// Helper to initialize all internal control variable information for those
316 /// defined in OMPKinds.def.
317 void initializeInternalControlVars() {
318#define ICV_RT_SET(_Name, RTL) \
319 { \
320 auto &ICV = ICVs[_Name]; \
321 ICV.Setter = RTL; \
322 }
323#define ICV_RT_GET(Name, RTL) \
324 { \
325 auto &ICV = ICVs[Name]; \
326 ICV.Getter = RTL; \
327 }
328#define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \
329 { \
330 auto &ICV = ICVs[Enum]; \
331 ICV.Name = _Name; \
332 ICV.Kind = Enum; \
333 ICV.InitKind = Init; \
334 ICV.EnvVarName = _EnvVarName; \
335 switch (ICV.InitKind) { \
336 case ICV_IMPLEMENTATION_DEFINED: \
337 ICV.InitValue = nullptr; \
338 break; \
339 case ICV_ZERO: \
340 ICV.InitValue = ConstantInt::get( \
341 Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \
342 break; \
343 case ICV_FALSE: \
344 ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \
345 break; \
346 case ICV_LAST: \
347 break; \
348 } \
349 }
350#include "llvm/Frontend/OpenMP/OMPKinds.def"
351 }
352
353 /// Returns true if the function declaration \p F matches the runtime
354 /// function types, that is, return type \p RTFRetType, and argument types
355 /// \p RTFArgTypes.
356 static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
357 SmallVector<Type *, 8> &RTFArgTypes) {
358 // TODO: We should output information to the user (under debug output
359 // and via remarks).
360
361 if (!F)
362 return false;
363 if (F->getReturnType() != RTFRetType)
364 return false;
365 if (F->arg_size() != RTFArgTypes.size())
366 return false;
367
368 auto RTFTyIt = RTFArgTypes.begin();
369 for (Argument &Arg : F->args()) {
370 if (Arg.getType() != *RTFTyIt)
371 return false;
372
373 ++RTFTyIt;
374 }
375
376 return true;
377 }
378
379 // Helper to collect all uses of the declaration in the UsesMap.
380 unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
381 unsigned NumUses = 0;
382 if (!RFI.Declaration)
383 return NumUses;
384 OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
385
386 if (CollectStats) {
387 NumOpenMPRuntimeFunctionsIdentified += 1;
388 NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
389 }
390
391 // TODO: We directly convert uses into proper calls and unknown uses.
392 for (Use &U : RFI.Declaration->uses()) {
393 if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
394 if (ModuleSlice.count(UserI->getFunction())) {
395 RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
396 ++NumUses;
397 }
398 } else {
399 RFI.getOrCreateUseVector(nullptr).push_back(&U);
400 ++NumUses;
401 }
402 }
403 return NumUses;
404 }
405
406 // Helper function to recollect uses of a runtime function.
407 void recollectUsesForFunction(RuntimeFunction RTF) {
408 auto &RFI = RFIs[RTF];
409 RFI.clearUsesMap();
410 collectUses(RFI, /*CollectStats*/ false);
411 }
412
413 // Helper function to recollect uses of all runtime functions.
414 void recollectUses() {
415 for (int Idx = 0; Idx < RFIs.size(); ++Idx)
416 recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
417 }
418
419 /// Helper to initialize all runtime function information for those defined
420 /// in OpenMPKinds.def.
421 void initializeRuntimeFunctions() {
422 Module &M = *((*ModuleSlice.begin())->getParent());
423
424 // Helper macros for handling __VA_ARGS__ in OMP_RTL
425#define OMP_TYPE(VarName, ...) \
426 Type *VarName = OMPBuilder.VarName; \
427 (void)VarName;
428
429#define OMP_ARRAY_TYPE(VarName, ...) \
430 ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \
431 (void)VarName##Ty; \
432 PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \
433 (void)VarName##PtrTy;
434
435#define OMP_FUNCTION_TYPE(VarName, ...) \
436 FunctionType *VarName = OMPBuilder.VarName; \
437 (void)VarName; \
438 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
439 (void)VarName##Ptr;
440
441#define OMP_STRUCT_TYPE(VarName, ...) \
442 StructType *VarName = OMPBuilder.VarName; \
443 (void)VarName; \
444 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
445 (void)VarName##Ptr;
446
447#define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \
448 { \
449 SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \
450 Function *F = M.getFunction(_Name); \
451 RTLFunctions.insert(F); \
452 if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \
453 RuntimeFunctionIDMap[F] = _Enum; \
454 F->removeFnAttr(Attribute::NoInline); \
455 auto &RFI = RFIs[_Enum]; \
456 RFI.Kind = _Enum; \
457 RFI.Name = _Name; \
458 RFI.IsVarArg = _IsVarArg; \
459 RFI.ReturnType = OMPBuilder._ReturnType; \
460 RFI.ArgumentTypes = std::move(ArgsTypes); \
461 RFI.Declaration = F; \
462 unsigned NumUses = collectUses(RFI); \
463 (void)NumUses; \
464 LLVM_DEBUG({ \do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { { dbgs() << TAG << RFI.Name <<
(RFI.Declaration ? "" : " not") << " found\n"; if (RFI
.Declaration) dbgs() << TAG << "-> got " <<
NumUses << " uses in " << RFI.getNumFunctionsWithUses
() << " different functions.\n"; }; } } while (false)
465 dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { { dbgs() << TAG << RFI.Name <<
(RFI.Declaration ? "" : " not") << " found\n"; if (RFI
.Declaration) dbgs() << TAG << "-> got " <<
NumUses << " uses in " << RFI.getNumFunctionsWithUses
() << " different functions.\n"; }; } } while (false)
466 << " found\n"; \do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { { dbgs() << TAG << RFI.Name <<
(RFI.Declaration ? "" : " not") << " found\n"; if (RFI
.Declaration) dbgs() << TAG << "-> got " <<
NumUses << " uses in " << RFI.getNumFunctionsWithUses
() << " different functions.\n"; }; } } while (false)
467 if (RFI.Declaration) \do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { { dbgs() << TAG << RFI.Name <<
(RFI.Declaration ? "" : " not") << " found\n"; if (RFI
.Declaration) dbgs() << TAG << "-> got " <<
NumUses << " uses in " << RFI.getNumFunctionsWithUses
() << " different functions.\n"; }; } } while (false)
468 dbgs() << TAG << "-> got " << NumUses << " uses in " \do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { { dbgs() << TAG << RFI.Name <<
(RFI.Declaration ? "" : " not") << " found\n"; if (RFI
.Declaration) dbgs() << TAG << "-> got " <<
NumUses << " uses in " << RFI.getNumFunctionsWithUses
() << " different functions.\n"; }; } } while (false)
469 << RFI.getNumFunctionsWithUses() \do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { { dbgs() << TAG << RFI.Name <<
(RFI.Declaration ? "" : " not") << " found\n"; if (RFI
.Declaration) dbgs() << TAG << "-> got " <<
NumUses << " uses in " << RFI.getNumFunctionsWithUses
() << " different functions.\n"; }; } } while (false)
470 << " different functions.\n"; \do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { { dbgs() << TAG << RFI.Name <<
(RFI.Declaration ? "" : " not") << " found\n"; if (RFI
.Declaration) dbgs() << TAG << "-> got " <<
NumUses << " uses in " << RFI.getNumFunctionsWithUses
() << " different functions.\n"; }; } } while (false)
471 })do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { { dbgs() << TAG << RFI.Name <<
(RFI.Declaration ? "" : " not") << " found\n"; if (RFI
.Declaration) dbgs() << TAG << "-> got " <<
NumUses << " uses in " << RFI.getNumFunctionsWithUses
() << " different functions.\n"; }; } } while (false)
; \
472 } \
473 }
474#include "llvm/Frontend/OpenMP/OMPKinds.def"
475
476 // TODO: We should attach the attributes defined in OMPKinds.def.
477 }
478
479 /// Collection of known kernels (\see Kernel) in the module.
480 SmallPtrSetImpl<Kernel> &Kernels;
481
482 /// Collection of known OpenMP runtime functions..
483 DenseSet<const Function *> RTLFunctions;
484};
485
486template <typename Ty, bool InsertInvalidates = true>
487struct BooleanStateWithSetVector : public BooleanState {
488 bool contains(const Ty &Elem) const { return Set.contains(Elem); }
489 bool insert(const Ty &Elem) {
490 if (InsertInvalidates)
491 BooleanState::indicatePessimisticFixpoint();
492 return Set.insert(Elem);
493 }
494
495 const Ty &operator[](int Idx) const { return Set[Idx]; }
496 bool operator==(const BooleanStateWithSetVector &RHS) const {
497 return BooleanState::operator==(RHS) && Set == RHS.Set;
498 }
499 bool operator!=(const BooleanStateWithSetVector &RHS) const {
500 return !(*this == RHS);
501 }
502
503 bool empty() const { return Set.empty(); }
504 size_t size() const { return Set.size(); }
505
506 /// "Clamp" this state with \p RHS.
507 BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
508 BooleanState::operator^=(RHS);
509 Set.insert(RHS.Set.begin(), RHS.Set.end());
510 return *this;
511 }
512
513private:
514 /// A set to keep track of elements.
515 SetVector<Ty> Set;
516
517public:
518 typename decltype(Set)::iterator begin() { return Set.begin(); }
519 typename decltype(Set)::iterator end() { return Set.end(); }
520 typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
521 typename decltype(Set)::const_iterator end() const { return Set.end(); }
522};
523
524template <typename Ty, bool InsertInvalidates = true>
525using BooleanStateWithPtrSetVector =
526 BooleanStateWithSetVector<Ty *, InsertInvalidates>;
527
528struct KernelInfoState : AbstractState {
529 /// Flag to track if we reached a fixpoint.
530 bool IsAtFixpoint = false;
531
532 /// The parallel regions (identified by the outlined parallel functions) that
533 /// can be reached from the associated function.
534 BooleanStateWithPtrSetVector<Function, /* InsertInvalidates */ false>
535 ReachedKnownParallelRegions;
536
537 /// State to track what parallel region we might reach.
538 BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
539
540 /// State to track if we are in SPMD-mode, assumed or know, and why we decided
541 /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
542 /// false.
543 BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
544
545 /// The __kmpc_target_init call in this kernel, if any. If we find more than
546 /// one we abort as the kernel is malformed.
547 CallBase *KernelInitCB = nullptr;
548
549 /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
550 /// one we abort as the kernel is malformed.
551 CallBase *KernelDeinitCB = nullptr;
552
553 /// Flag to indicate if the associated function is a kernel entry.
554 bool IsKernelEntry = false;
555
556 /// State to track what kernel entries can reach the associated function.
557 BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
558
559 /// State to indicate if we can track parallel level of the associated
560 /// function. We will give up tracking if we encounter unknown caller or the
561 /// caller is __kmpc_parallel_51.
562 BooleanStateWithSetVector<uint8_t> ParallelLevels;
563
564 /// Abstract State interface
565 ///{
566
567 KernelInfoState() {}
568 KernelInfoState(bool BestState) {
569 if (!BestState)
570 indicatePessimisticFixpoint();
571 }
572
573 /// See AbstractState::isValidState(...)
574 bool isValidState() const override { return true; }
575
576 /// See AbstractState::isAtFixpoint(...)
577 bool isAtFixpoint() const override { return IsAtFixpoint; }
578
579 /// See AbstractState::indicatePessimisticFixpoint(...)
580 ChangeStatus indicatePessimisticFixpoint() override {
581 IsAtFixpoint = true;
582 ReachingKernelEntries.indicatePessimisticFixpoint();
583 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
584 ReachedKnownParallelRegions.indicatePessimisticFixpoint();
585 ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
586 return ChangeStatus::CHANGED;
587 }
588
589 /// See AbstractState::indicateOptimisticFixpoint(...)
590 ChangeStatus indicateOptimisticFixpoint() override {
591 IsAtFixpoint = true;
592 return ChangeStatus::UNCHANGED;
593 }
594
595 /// Return the assumed state
596 KernelInfoState &getAssumed() { return *this; }
597 const KernelInfoState &getAssumed() const { return *this; }
598
599 bool operator==(const KernelInfoState &RHS) const {
600 if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
601 return false;
602 if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
603 return false;
604 if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
605 return false;
606 if (ReachingKernelEntries != RHS.ReachingKernelEntries)
607 return false;
608 return true;
609 }
610
611 /// Returns true if this kernel contains any OpenMP parallel regions.
612 bool mayContainParallelRegion() {
613 return !ReachedKnownParallelRegions.empty() ||
614 !ReachedUnknownParallelRegions.empty();
615 }
616
617 /// Return empty set as the best state of potential values.
618 static KernelInfoState getBestState() { return KernelInfoState(true); }
619
620 static KernelInfoState getBestState(KernelInfoState &KIS) {
621 return getBestState();
622 }
623
624 /// Return full set as the worst state of potential values.
625 static KernelInfoState getWorstState() { return KernelInfoState(false); }
626
627 /// "Clamp" this state with \p KIS.
628 KernelInfoState operator^=(const KernelInfoState &KIS) {
629 // Do not merge two different _init and _deinit call sites.
630 if (KIS.KernelInitCB) {
631 if(KernelInitCB && KernelInitCB != KIS.KernelInitCB)
632 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt assumptions.")::llvm::llvm_unreachable_internal("Kernel that calls another kernel violates OpenMP-Opt assumptions."
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 632)
;
633 KernelInitCB = KIS.KernelInitCB;
634 }
635 if (KIS.KernelDeinitCB) {
636 if(KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
637 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt assumptions.")::llvm::llvm_unreachable_internal("Kernel that calls another kernel violates OpenMP-Opt assumptions."
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 637)
;
638 KernelDeinitCB = KIS.KernelDeinitCB;
639 }
640 SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
641 ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
642 ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
643 return *this;
644 }
645
646 KernelInfoState operator&=(const KernelInfoState &KIS) {
647 return (*this ^= KIS);
648 }
649
650 ///}
651};
652
653/// Used to map the values physically (in the IR) stored in an offload
654/// array, to a vector in memory.
655struct OffloadArray {
656 /// Physical array (in the IR).
657 AllocaInst *Array = nullptr;
658 /// Mapped values.
659 SmallVector<Value *, 8> StoredValues;
660 /// Last stores made in the offload array.
661 SmallVector<StoreInst *, 8> LastAccesses;
662
663 OffloadArray() = default;
664
665 /// Initializes the OffloadArray with the values stored in \p Array before
666 /// instruction \p Before is reached. Returns false if the initialization
667 /// fails.
668 /// This MUST be used immediately after the construction of the object.
669 bool initialize(AllocaInst &Array, Instruction &Before) {
670 if (!Array.getAllocatedType()->isArrayTy())
671 return false;
672
673 if (!getValues(Array, Before))
674 return false;
675
676 this->Array = &Array;
677 return true;
678 }
679
680 static const unsigned DeviceIDArgNum = 1;
681 static const unsigned BasePtrsArgNum = 3;
682 static const unsigned PtrsArgNum = 4;
683 static const unsigned SizesArgNum = 5;
684
685private:
686 /// Traverses the BasicBlock where \p Array is, collecting the stores made to
687 /// \p Array, leaving StoredValues with the values stored before the
688 /// instruction \p Before is reached.
689 bool getValues(AllocaInst &Array, Instruction &Before) {
690 // Initialize container.
691 const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
692 StoredValues.assign(NumValues, nullptr);
693 LastAccesses.assign(NumValues, nullptr);
694
695 // TODO: This assumes the instruction \p Before is in the same
696 // BasicBlock as Array. Make it general, for any control flow graph.
697 BasicBlock *BB = Array.getParent();
698 if (BB != Before.getParent())
699 return false;
700
701 const DataLayout &DL = Array.getModule()->getDataLayout();
702 const unsigned int PointerSize = DL.getPointerSize();
703
704 for (Instruction &I : *BB) {
705 if (&I == &Before)
706 break;
707
708 if (!isa<StoreInst>(&I))
709 continue;
710
711 auto *S = cast<StoreInst>(&I);
712 int64_t Offset = -1;
713 auto *Dst =
714 GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
715 if (Dst == &Array) {
716 int64_t Idx = Offset / PointerSize;
717 StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
718 LastAccesses[Idx] = S;
719 }
720 }
721
722 return isFilled();
723 }
724
725 /// Returns true if all values in StoredValues and
726 /// LastAccesses are not nullptrs.
727 bool isFilled() {
728 const unsigned NumValues = StoredValues.size();
729 for (unsigned I = 0; I < NumValues; ++I) {
730 if (!StoredValues[I] || !LastAccesses[I])
731 return false;
732 }
733
734 return true;
735 }
736};
737
738struct OpenMPOpt {
739
740 using OptimizationRemarkGetter =
741 function_ref<OptimizationRemarkEmitter &(Function *)>;
742
743 OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
744 OptimizationRemarkGetter OREGetter,
745 OMPInformationCache &OMPInfoCache, Attributor &A)
746 : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
747 OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
748
749 /// Check if any remarks are enabled for openmp-opt
750 bool remarksEnabled() {
751 auto &Ctx = M.getContext();
752 return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE"openmp-opt");
753 }
754
755 /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice.
756 bool run(bool IsModulePass) {
757 if (SCC.empty())
758 return false;
759
760 bool Changed = false;
761
762 LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "Run on SCC with "
<< SCC.size() << " functions in a slice with " <<
OMPInfoCache.ModuleSlice.size() << " functions\n"; } }
while (false)
763 << " functions in a slice with "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "Run on SCC with "
<< SCC.size() << " functions in a slice with " <<
OMPInfoCache.ModuleSlice.size() << " functions\n"; } }
while (false)
764 << OMPInfoCache.ModuleSlice.size() << " functions\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "Run on SCC with "
<< SCC.size() << " functions in a slice with " <<
OMPInfoCache.ModuleSlice.size() << " functions\n"; } }
while (false)
;
765
766 if (IsModulePass) {
767 Changed |= runAttributor(IsModulePass);
768
769 // Recollect uses, in case Attributor deleted any.
770 OMPInfoCache.recollectUses();
771
772 // TODO: This should be folded into buildCustomStateMachine.
773 Changed |= rewriteDeviceCodeStateMachine();
774
775 if (remarksEnabled())
776 analysisGlobalization();
777 } else {
778 if (PrintICVValues)
779 printICVs();
780 if (PrintOpenMPKernels)
781 printKernels();
782
783 Changed |= runAttributor(IsModulePass);
784
785 // Recollect uses, in case Attributor deleted any.
786 OMPInfoCache.recollectUses();
787
788 Changed |= deleteParallelRegions();
789
790 if (HideMemoryTransferLatency)
791 Changed |= hideMemTransfersLatency();
792 Changed |= deduplicateRuntimeCalls();
793 if (EnableParallelRegionMerging) {
794 if (mergeParallelRegions()) {
795 deduplicateRuntimeCalls();
796 Changed = true;
797 }
798 }
799 }
800
801 return Changed;
802 }
803
804 /// Print initial ICV values for testing.
805 /// FIXME: This should be done from the Attributor once it is added.
806 void printICVs() const {
807 InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
808 ICV_proc_bind};
809
810 for (Function *F : OMPInfoCache.ModuleSlice) {
811 for (auto ICV : ICVs) {
812 auto ICVInfo = OMPInfoCache.ICVs[ICV];
813 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
814 return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
815 << " Value: "
816 << (ICVInfo.InitValue
817 ? toString(ICVInfo.InitValue->getValue(), 10, true)
818 : "IMPLEMENTATION_DEFINED");
819 };
820
821 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
822 }
823 }
824 }
825
826 /// Print OpenMP GPU kernels for testing.
827 void printKernels() const {
828 for (Function *F : SCC) {
829 if (!OMPInfoCache.Kernels.count(F))
830 continue;
831
832 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
833 return ORA << "OpenMP GPU kernel "
834 << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
835 };
836
837 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
838 }
839 }
840
841 /// Return the call if \p U is a callee use in a regular call. If \p RFI is
842 /// given it has to be the callee or a nullptr is returned.
843 static CallInst *getCallIfRegularCall(
844 Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
845 CallInst *CI = dyn_cast<CallInst>(U.getUser());
846 if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
847 (!RFI ||
848 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
849 return CI;
850 return nullptr;
851 }
852
853 /// Return the call if \p V is a regular call. If \p RFI is given it has to be
854 /// the callee or a nullptr is returned.
855 static CallInst *getCallIfRegularCall(
856 Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
857 CallInst *CI = dyn_cast<CallInst>(&V);
858 if (CI && !CI->hasOperandBundles() &&
859 (!RFI ||
860 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
861 return CI;
862 return nullptr;
863 }
864
865private:
866 /// Merge parallel regions when it is safe.
867 bool mergeParallelRegions() {
868 const unsigned CallbackCalleeOperand = 2;
869 const unsigned CallbackFirstArgOperand = 3;
870 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
871
872 // Check if there are any __kmpc_fork_call calls to merge.
873 OMPInformationCache::RuntimeFunctionInfo &RFI =
874 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
875
876 if (!RFI.Declaration)
877 return false;
878
879 // Unmergable calls that prevent merging a parallel region.
880 OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
881 OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
882 OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
883 };
884
885 bool Changed = false;
886 LoopInfo *LI = nullptr;
887 DominatorTree *DT = nullptr;
888
889 SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap;
890
891 BasicBlock *StartBB = nullptr, *EndBB = nullptr;
892 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
893 BasicBlock &ContinuationIP) {
894 BasicBlock *CGStartBB = CodeGenIP.getBlock();
895 BasicBlock *CGEndBB =
896 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
897 assert(StartBB != nullptr && "StartBB should not be null")(static_cast <bool> (StartBB != nullptr && "StartBB should not be null"
) ? void (0) : __assert_fail ("StartBB != nullptr && \"StartBB should not be null\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 897, __extension__ __PRETTY_FUNCTION__))
;
898 CGStartBB->getTerminator()->setSuccessor(0, StartBB);
899 assert(EndBB != nullptr && "EndBB should not be null")(static_cast <bool> (EndBB != nullptr && "EndBB should not be null"
) ? void (0) : __assert_fail ("EndBB != nullptr && \"EndBB should not be null\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 899, __extension__ __PRETTY_FUNCTION__))
;
900 EndBB->getTerminator()->setSuccessor(0, CGEndBB);
901 };
902
903 auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
904 Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
905 ReplacementValue = &Inner;
906 return CodeGenIP;
907 };
908
909 auto FiniCB = [&](InsertPointTy CodeGenIP) {};
910
911 /// Create a sequential execution region within a merged parallel region,
912 /// encapsulated in a master construct with a barrier for synchronization.
913 auto CreateSequentialRegion = [&](Function *OuterFn,
914 BasicBlock *OuterPredBB,
915 Instruction *SeqStartI,
916 Instruction *SeqEndI) {
917 // Isolate the instructions of the sequential region to a separate
918 // block.
919 BasicBlock *ParentBB = SeqStartI->getParent();
920 BasicBlock *SeqEndBB =
921 SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
922 BasicBlock *SeqAfterBB =
923 SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
924 BasicBlock *SeqStartBB =
925 SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
926
927 assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&(static_cast <bool> (ParentBB->getUniqueSuccessor() ==
SeqStartBB && "Expected a different CFG") ? void (0)
: __assert_fail ("ParentBB->getUniqueSuccessor() == SeqStartBB && \"Expected a different CFG\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 928, __extension__ __PRETTY_FUNCTION__))
928 "Expected a different CFG")(static_cast <bool> (ParentBB->getUniqueSuccessor() ==
SeqStartBB && "Expected a different CFG") ? void (0)
: __assert_fail ("ParentBB->getUniqueSuccessor() == SeqStartBB && \"Expected a different CFG\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 928, __extension__ __PRETTY_FUNCTION__))
;
929 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
930 ParentBB->getTerminator()->eraseFromParent();
931
932 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
933 BasicBlock &ContinuationIP) {
934 BasicBlock *CGStartBB = CodeGenIP.getBlock();
935 BasicBlock *CGEndBB =
936 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
937 assert(SeqStartBB != nullptr && "SeqStartBB should not be null")(static_cast <bool> (SeqStartBB != nullptr && "SeqStartBB should not be null"
) ? void (0) : __assert_fail ("SeqStartBB != nullptr && \"SeqStartBB should not be null\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 937, __extension__ __PRETTY_FUNCTION__))
;
938 CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
939 assert(SeqEndBB != nullptr && "SeqEndBB should not be null")(static_cast <bool> (SeqEndBB != nullptr && "SeqEndBB should not be null"
) ? void (0) : __assert_fail ("SeqEndBB != nullptr && \"SeqEndBB should not be null\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 939, __extension__ __PRETTY_FUNCTION__))
;
940 SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
941 };
942 auto FiniCB = [&](InsertPointTy CodeGenIP) {};
943
944 // Find outputs from the sequential region to outside users and
945 // broadcast their values to them.
946 for (Instruction &I : *SeqStartBB) {
947 SmallPtrSet<Instruction *, 4> OutsideUsers;
948 for (User *Usr : I.users()) {
949 Instruction &UsrI = *cast<Instruction>(Usr);
950 // Ignore outputs to LT intrinsics, code extraction for the merged
951 // parallel region will fix them.
952 if (UsrI.isLifetimeStartOrEnd())
953 continue;
954
955 if (UsrI.getParent() != SeqStartBB)
956 OutsideUsers.insert(&UsrI);
957 }
958
959 if (OutsideUsers.empty())
960 continue;
961
962 // Emit an alloca in the outer region to store the broadcasted
963 // value.
964 const DataLayout &DL = M.getDataLayout();
965 AllocaInst *AllocaI = new AllocaInst(
966 I.getType(), DL.getAllocaAddrSpace(), nullptr,
967 I.getName() + ".seq.output.alloc", &OuterFn->front().front());
968
969 // Emit a store instruction in the sequential BB to update the
970 // value.
971 new StoreInst(&I, AllocaI, SeqStartBB->getTerminator());
972
973 // Emit a load instruction and replace the use of the output value
974 // with it.
975 for (Instruction *UsrI : OutsideUsers) {
976 LoadInst *LoadI = new LoadInst(
977 I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI);
978 UsrI->replaceUsesOfWith(&I, LoadI);
979 }
980 }
981
982 OpenMPIRBuilder::LocationDescription Loc(
983 InsertPointTy(ParentBB, ParentBB->end()), DL);
984 InsertPointTy SeqAfterIP =
985 OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
986
987 OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
988
989 BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
990
991 LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFndo { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "After sequential inlining "
<< *OuterFn << "\n"; } } while (false)
992 << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "After sequential inlining "
<< *OuterFn << "\n"; } } while (false)
;
993 };
994
995 // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
996 // contained in BB and only separated by instructions that can be
997 // redundantly executed in parallel. The block BB is split before the first
998 // call (in MergableCIs) and after the last so the entire region we merge
999 // into a single parallel region is contained in a single basic block
1000 // without any other instructions. We use the OpenMPIRBuilder to outline
1001 // that block and call the resulting function via __kmpc_fork_call.
1002 auto Merge = [&](SmallVectorImpl<CallInst *> &MergableCIs, BasicBlock *BB) {
1003 // TODO: Change the interface to allow single CIs expanded, e.g, to
1004 // include an outer loop.
1005 assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs")(static_cast <bool> (MergableCIs.size() > 1 &&
"Assumed multiple mergable CIs") ? void (0) : __assert_fail (
"MergableCIs.size() > 1 && \"Assumed multiple mergable CIs\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 1005, __extension__ __PRETTY_FUNCTION__))
;
1006
1007 auto Remark = [&](OptimizationRemark OR) {
1008 OR << "Parallel region merged with parallel region"
1009 << (MergableCIs.size() > 2 ? "s" : "") << " at ";
1010 for (auto *CI : llvm::drop_begin(MergableCIs)) {
1011 OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
1012 if (CI != MergableCIs.back())
1013 OR << ", ";
1014 }
1015 return OR << ".";
1016 };
1017
1018 emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
1019
1020 Function *OriginalFn = BB->getParent();
1021 LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "Merge " <<
MergableCIs.size() << " parallel regions in " <<
OriginalFn->getName() << "\n"; } } while (false)
1022 << " parallel regions in " << OriginalFn->getName()do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "Merge " <<
MergableCIs.size() << " parallel regions in " <<
OriginalFn->getName() << "\n"; } } while (false)
1023 << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "Merge " <<
MergableCIs.size() << " parallel regions in " <<
OriginalFn->getName() << "\n"; } } while (false)
;
1024
1025 // Isolate the calls to merge in a separate block.
1026 EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
1027 BasicBlock *AfterBB =
1028 SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1029 StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
1030 "omp.par.merged");
1031
1032 assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG")(static_cast <bool> (BB->getUniqueSuccessor() == StartBB
&& "Expected a different CFG") ? void (0) : __assert_fail
("BB->getUniqueSuccessor() == StartBB && \"Expected a different CFG\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 1032, __extension__ __PRETTY_FUNCTION__))
;
1033 const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1034 BB->getTerminator()->eraseFromParent();
1035
1036 // Create sequential regions for sequential instructions that are
1037 // in-between mergable parallel regions.
1038 for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1039 It != End; ++It) {
1040 Instruction *ForkCI = *It;
1041 Instruction *NextForkCI = *(It + 1);
1042
1043 // Continue if there are not in-between instructions.
1044 if (ForkCI->getNextNode() == NextForkCI)
1045 continue;
1046
1047 CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1048 NextForkCI->getPrevNode());
1049 }
1050
1051 OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1052 DL);
1053 IRBuilder<>::InsertPoint AllocaIP(
1054 &OriginalFn->getEntryBlock(),
1055 OriginalFn->getEntryBlock().getFirstInsertionPt());
1056 // Create the merged parallel region with default proc binding, to
1057 // avoid overriding binding settings, and without explicit cancellation.
1058 InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
1059 Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
1060 OMP_PROC_BIND_default, /* IsCancellable */ false);
1061 BranchInst::Create(AfterBB, AfterIP.getBlock());
1062
1063 // Perform the actual outlining.
1064 OMPInfoCache.OMPBuilder.finalize(OriginalFn,
1065 /* AllowExtractorSinking */ true);
1066
1067 Function *OutlinedFn = MergableCIs.front()->getCaller();
1068
1069 // Replace the __kmpc_fork_call calls with direct calls to the outlined
1070 // callbacks.
1071 SmallVector<Value *, 8> Args;
1072 for (auto *CI : MergableCIs) {
1073 Value *Callee =
1074 CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts();
1075 FunctionType *FT =
1076 cast<FunctionType>(Callee->getType()->getPointerElementType());
1077 Args.clear();
1078 Args.push_back(OutlinedFn->getArg(0));
1079 Args.push_back(OutlinedFn->getArg(1));
1080 for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands();
1081 U < E; ++U)
1082 Args.push_back(CI->getArgOperand(U));
1083
1084 CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI);
1085 if (CI->getDebugLoc())
1086 NewCI->setDebugLoc(CI->getDebugLoc());
1087
1088 // Forward parameter attributes from the callback to the callee.
1089 for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands();
1090 U < E; ++U)
1091 for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
1092 NewCI->addParamAttr(
1093 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1094
1095 // Emit an explicit barrier to replace the implicit fork-join barrier.
1096 if (CI != MergableCIs.back()) {
1097 // TODO: Remove barrier if the merged parallel region includes the
1098 // 'nowait' clause.
1099 OMPInfoCache.OMPBuilder.createBarrier(
1100 InsertPointTy(NewCI->getParent(),
1101 NewCI->getNextNode()->getIterator()),
1102 OMPD_parallel);
1103 }
1104
1105 CI->eraseFromParent();
1106 }
1107
1108 assert(OutlinedFn != OriginalFn && "Outlining failed")(static_cast <bool> (OutlinedFn != OriginalFn &&
"Outlining failed") ? void (0) : __assert_fail ("OutlinedFn != OriginalFn && \"Outlining failed\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 1108, __extension__ __PRETTY_FUNCTION__))
;
1109 CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1110 CGUpdater.reanalyzeFunction(*OriginalFn);
1111
1112 NumOpenMPParallelRegionsMerged += MergableCIs.size();
1113
1114 return true;
1115 };
1116
1117 // Helper function that identifes sequences of
1118 // __kmpc_fork_call uses in a basic block.
1119 auto DetectPRsCB = [&](Use &U, Function &F) {
1120 CallInst *CI = getCallIfRegularCall(U, &RFI);
1121 BB2PRMap[CI->getParent()].insert(CI);
1122
1123 return false;
1124 };
1125
1126 BB2PRMap.clear();
1127 RFI.foreachUse(SCC, DetectPRsCB);
1128 SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1129 // Find mergable parallel regions within a basic block that are
1130 // safe to merge, that is any in-between instructions can safely
1131 // execute in parallel after merging.
1132 // TODO: support merging across basic-blocks.
1133 for (auto &It : BB2PRMap) {
1134 auto &CIs = It.getSecond();
1135 if (CIs.size() < 2)
1136 continue;
1137
1138 BasicBlock *BB = It.getFirst();
1139 SmallVector<CallInst *, 4> MergableCIs;
1140
1141 /// Returns true if the instruction is mergable, false otherwise.
1142 /// A terminator instruction is unmergable by definition since merging
1143 /// works within a BB. Instructions before the mergable region are
1144 /// mergable if they are not calls to OpenMP runtime functions that may
1145 /// set different execution parameters for subsequent parallel regions.
1146 /// Instructions in-between parallel regions are mergable if they are not
1147 /// calls to any non-intrinsic function since that may call a non-mergable
1148 /// OpenMP runtime function.
1149 auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1150 // We do not merge across BBs, hence return false (unmergable) if the
1151 // instruction is a terminator.
1152 if (I.isTerminator())
1153 return false;
1154
1155 if (!isa<CallInst>(&I))
1156 return true;
1157
1158 CallInst *CI = cast<CallInst>(&I);
1159 if (IsBeforeMergableRegion) {
1160 Function *CalledFunction = CI->getCalledFunction();
1161 if (!CalledFunction)
1162 return false;
1163 // Return false (unmergable) if the call before the parallel
1164 // region calls an explicit affinity (proc_bind) or number of
1165 // threads (num_threads) compiler-generated function. Those settings
1166 // may be incompatible with following parallel regions.
1167 // TODO: ICV tracking to detect compatibility.
1168 for (const auto &RFI : UnmergableCallsInfo) {
1169 if (CalledFunction == RFI.Declaration)
1170 return false;
1171 }
1172 } else {
1173 // Return false (unmergable) if there is a call instruction
1174 // in-between parallel regions when it is not an intrinsic. It
1175 // may call an unmergable OpenMP runtime function in its callpath.
1176 // TODO: Keep track of possible OpenMP calls in the callpath.
1177 if (!isa<IntrinsicInst>(CI))
1178 return false;
1179 }
1180
1181 return true;
1182 };
1183 // Find maximal number of parallel region CIs that are safe to merge.
1184 for (auto It = BB->begin(), End = BB->end(); It != End;) {
1185 Instruction &I = *It;
1186 ++It;
1187
1188 if (CIs.count(&I)) {
1189 MergableCIs.push_back(cast<CallInst>(&I));
1190 continue;
1191 }
1192
1193 // Continue expanding if the instruction is mergable.
1194 if (IsMergable(I, MergableCIs.empty()))
1195 continue;
1196
1197 // Forward the instruction iterator to skip the next parallel region
1198 // since there is an unmergable instruction which can affect it.
1199 for (; It != End; ++It) {
1200 Instruction &SkipI = *It;
1201 if (CIs.count(&SkipI)) {
1202 LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipIdo { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "Skip parallel region "
<< SkipI << " due to " << I << "\n";
} } while (false)
1203 << " due to " << I << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "Skip parallel region "
<< SkipI << " due to " << I << "\n";
} } while (false)
;
1204 ++It;
1205 break;
1206 }
1207 }
1208
1209 // Store mergable regions found.
1210 if (MergableCIs.size() > 1) {
1211 MergableCIsVector.push_back(MergableCIs);
1212 LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "Found " <<
MergableCIs.size() << " parallel regions in block " <<
BB->getName() << " of function " << BB->getParent
()->getName() << "\n";; } } while (false)
1213 << " parallel regions in block " << BB->getName()do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "Found " <<
MergableCIs.size() << " parallel regions in block " <<
BB->getName() << " of function " << BB->getParent
()->getName() << "\n";; } } while (false)
1214 << " of function " << BB->getParent()->getName()do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "Found " <<
MergableCIs.size() << " parallel regions in block " <<
BB->getName() << " of function " << BB->getParent
()->getName() << "\n";; } } while (false)
1215 << "\n";)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "Found " <<
MergableCIs.size() << " parallel regions in block " <<
BB->getName() << " of function " << BB->getParent
()->getName() << "\n";; } } while (false)
;
1216 }
1217
1218 MergableCIs.clear();
1219 }
1220
1221 if (!MergableCIsVector.empty()) {
1222 Changed = true;
1223
1224 for (auto &MergableCIs : MergableCIsVector)
1225 Merge(MergableCIs, BB);
1226 MergableCIsVector.clear();
1227 }
1228 }
1229
1230 if (Changed) {
1231 /// Re-collect use for fork calls, emitted barrier calls, and
1232 /// any emitted master/end_master calls.
1233 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1234 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1235 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1236 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1237 }
1238
1239 return Changed;
1240 }
1241
1242 /// Try to delete parallel regions if possible.
1243 bool deleteParallelRegions() {
1244 const unsigned CallbackCalleeOperand = 2;
1245
1246 OMPInformationCache::RuntimeFunctionInfo &RFI =
1247 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1248
1249 if (!RFI.Declaration)
1250 return false;
1251
1252 bool Changed = false;
1253 auto DeleteCallCB = [&](Use &U, Function &) {
1254 CallInst *CI = getCallIfRegularCall(U);
1255 if (!CI)
1256 return false;
1257 auto *Fn = dyn_cast<Function>(
1258 CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1259 if (!Fn)
1260 return false;
1261 if (!Fn->onlyReadsMemory())
1262 return false;
1263 if (!Fn->hasFnAttribute(Attribute::WillReturn))
1264 return false;
1265
1266 LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "Delete read-only parallel region in "
<< CI->getCaller()->getName() << "\n"; } }
while (false)
1267 << CI->getCaller()->getName() << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "Delete read-only parallel region in "
<< CI->getCaller()->getName() << "\n"; } }
while (false)
;
1268
1269 auto Remark = [&](OptimizationRemark OR) {
1270 return OR << "Removing parallel region with no side-effects.";
1271 };
1272 emitRemark<OptimizationRemark>(CI, "OMP160", Remark);
1273
1274 CGUpdater.removeCallSite(*CI);
1275 CI->eraseFromParent();
1276 Changed = true;
1277 ++NumOpenMPParallelRegionsDeleted;
1278 return true;
1279 };
1280
1281 RFI.foreachUse(SCC, DeleteCallCB);
1282
1283 return Changed;
1284 }
1285
1286 /// Try to eliminate runtime calls by reusing existing ones.
1287 bool deduplicateRuntimeCalls() {
1288 bool Changed = false;
1289
1290 RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1291 OMPRTL_omp_get_num_threads,
1292 OMPRTL_omp_in_parallel,
1293 OMPRTL_omp_get_cancellation,
1294 OMPRTL_omp_get_thread_limit,
1295 OMPRTL_omp_get_supported_active_levels,
1296 OMPRTL_omp_get_level,
1297 OMPRTL_omp_get_ancestor_thread_num,
1298 OMPRTL_omp_get_team_size,
1299 OMPRTL_omp_get_active_level,
1300 OMPRTL_omp_in_final,
1301 OMPRTL_omp_get_proc_bind,
1302 OMPRTL_omp_get_num_places,
1303 OMPRTL_omp_get_num_procs,
1304 OMPRTL_omp_get_place_num,
1305 OMPRTL_omp_get_partition_num_places,
1306 OMPRTL_omp_get_partition_place_nums};
1307
1308 // Global-tid is handled separately.
1309 SmallSetVector<Value *, 16> GTIdArgs;
1310 collectGlobalThreadIdArguments(GTIdArgs);
1311 LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "Found " <<
GTIdArgs.size() << " global thread ID arguments\n"; } }
while (false)
1312 << " global thread ID arguments\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "Found " <<
GTIdArgs.size() << " global thread ID arguments\n"; } }
while (false)
;
1313
1314 for (Function *F : SCC) {
1315 for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1316 Changed |= deduplicateRuntimeCalls(
1317 *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1318
1319 // __kmpc_global_thread_num is special as we can replace it with an
1320 // argument in enough cases to make it worth trying.
1321 Value *GTIdArg = nullptr;
1322 for (Argument &Arg : F->args())
1323 if (GTIdArgs.count(&Arg)) {
1324 GTIdArg = &Arg;
1325 break;
1326 }
1327 Changed |= deduplicateRuntimeCalls(
1328 *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1329 }
1330
1331 return Changed;
1332 }
1333
1334 /// Tries to hide the latency of runtime calls that involve host to
1335 /// device memory transfers by splitting them into their "issue" and "wait"
1336 /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1337 /// moved downards as much as possible. The "issue" issues the memory transfer
1338 /// asynchronously, returning a handle. The "wait" waits in the returned
1339 /// handle for the memory transfer to finish.
1340 bool hideMemTransfersLatency() {
1341 auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1342 bool Changed = false;
1343 auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1344 auto *RTCall = getCallIfRegularCall(U, &RFI);
1345 if (!RTCall)
1346 return false;
1347
1348 OffloadArray OffloadArrays[3];
1349 if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1350 return false;
1351
1352 LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays))do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dumpValuesInOffloadArrays(OffloadArrays); }
} while (false)
;
1353
1354 // TODO: Check if can be moved upwards.
1355 bool WasSplit = false;
1356 Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1357 if (WaitMovementPoint)
1358 WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1359
1360 Changed |= WasSplit;
1361 return WasSplit;
1362 };
1363 RFI.foreachUse(SCC, SplitMemTransfers);
1364
1365 return Changed;
1366 }
1367
1368 void analysisGlobalization() {
1369 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1370
1371 auto CheckGlobalization = [&](Use &U, Function &Decl) {
1372 if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1373 auto Remark = [&](OptimizationRemarkMissed ORM) {
1374 return ORM
1375 << "Found thread data sharing on the GPU. "
1376 << "Expect degraded performance due to data globalization.";
1377 };
1378 emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark);
1379 }
1380
1381 return false;
1382 };
1383
1384 RFI.foreachUse(SCC, CheckGlobalization);
1385 }
1386
1387 /// Maps the values stored in the offload arrays passed as arguments to
1388 /// \p RuntimeCall into the offload arrays in \p OAs.
1389 bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1390 MutableArrayRef<OffloadArray> OAs) {
1391 assert(OAs.size() == 3 && "Need space for three offload arrays!")(static_cast <bool> (OAs.size() == 3 && "Need space for three offload arrays!"
) ? void (0) : __assert_fail ("OAs.size() == 3 && \"Need space for three offload arrays!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 1391, __extension__ __PRETTY_FUNCTION__))
;
1392
1393 // A runtime call that involves memory offloading looks something like:
1394 // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1395 // i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1396 // ...)
1397 // So, the idea is to access the allocas that allocate space for these
1398 // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1399 // Therefore:
1400 // i8** %offload_baseptrs.
1401 Value *BasePtrsArg =
1402 RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1403 // i8** %offload_ptrs.
1404 Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1405 // i8** %offload_sizes.
1406 Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1407
1408 // Get values stored in **offload_baseptrs.
1409 auto *V = getUnderlyingObject(BasePtrsArg);
1410 if (!isa<AllocaInst>(V))
1411 return false;
1412 auto *BasePtrsArray = cast<AllocaInst>(V);
1413 if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1414 return false;
1415
1416 // Get values stored in **offload_baseptrs.
1417 V = getUnderlyingObject(PtrsArg);
1418 if (!isa<AllocaInst>(V))
1419 return false;
1420 auto *PtrsArray = cast<AllocaInst>(V);
1421 if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1422 return false;
1423
1424 // Get values stored in **offload_sizes.
1425 V = getUnderlyingObject(SizesArg);
1426 // If it's a [constant] global array don't analyze it.
1427 if (isa<GlobalValue>(V))
1428 return isa<Constant>(V);
1429 if (!isa<AllocaInst>(V))
1430 return false;
1431
1432 auto *SizesArray = cast<AllocaInst>(V);
1433 if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1434 return false;
1435
1436 return true;
1437 }
1438
1439 /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1440 /// For now this is a way to test that the function getValuesInOffloadArrays
1441 /// is working properly.
1442 /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1443 void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1444 assert(OAs.size() == 3 && "There are three offload arrays to debug!")(static_cast <bool> (OAs.size() == 3 && "There are three offload arrays to debug!"
) ? void (0) : __assert_fail ("OAs.size() == 3 && \"There are three offload arrays to debug!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 1444, __extension__ __PRETTY_FUNCTION__))
;
1445
1446 LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << " Successfully got offload values:\n"
; } } while (false)
;
1447 std::string ValuesStr;
1448 raw_string_ostream Printer(ValuesStr);
1449 std::string Separator = " --- ";
1450
1451 for (auto *BP : OAs[0].StoredValues) {
1452 BP->print(Printer);
1453 Printer << Separator;
1454 }
1455 LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << "\t\toffload_baseptrs: " <<
Printer.str() << "\n"; } } while (false)
;
1456 ValuesStr.clear();
1457
1458 for (auto *P : OAs[1].StoredValues) {
1459 P->print(Printer);
1460 Printer << Separator;
1461 }
1462 LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << "\t\toffload_ptrs: " <<
Printer.str() << "\n"; } } while (false)
;
1463 ValuesStr.clear();
1464
1465 for (auto *S : OAs[2].StoredValues) {
1466 S->print(Printer);
1467 Printer << Separator;
1468 }
1469 LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << "\t\toffload_sizes: " <<
Printer.str() << "\n"; } } while (false)
;
1470 }
1471
1472 /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1473 /// moved. Returns nullptr if the movement is not possible, or not worth it.
1474 Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1475 // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1476 // Make it traverse the CFG.
1477
1478 Instruction *CurrentI = &RuntimeCall;
1479 bool IsWorthIt = false;
1480 while ((CurrentI = CurrentI->getNextNode())) {
1481
1482 // TODO: Once we detect the regions to be offloaded we should use the
1483 // alias analysis manager to check if CurrentI may modify one of
1484 // the offloaded regions.
1485 if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1486 if (IsWorthIt)
1487 return CurrentI;
1488
1489 return nullptr;
1490 }
1491
1492 // FIXME: For now if we move it over anything without side effect
1493 // is worth it.
1494 IsWorthIt = true;
1495 }
1496
1497 // Return end of BasicBlock.
1498 return RuntimeCall.getParent()->getTerminator();
1499 }
1500
1501 /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1502 bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1503 Instruction &WaitMovementPoint) {
1504 // Create stack allocated handle (__tgt_async_info) at the beginning of the
1505 // function. Used for storing information of the async transfer, allowing to
1506 // wait on it later.
1507 auto &IRBuilder = OMPInfoCache.OMPBuilder;
1508 auto *F = RuntimeCall.getCaller();
1509 Instruction *FirstInst = &(F->getEntryBlock().front());
1510 AllocaInst *Handle = new AllocaInst(
1511 IRBuilder.AsyncInfo, F->getAddressSpace(), "handle", FirstInst);
1512
1513 // Add "issue" runtime call declaration:
1514 // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1515 // i8**, i8**, i64*, i64*)
1516 FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1517 M, OMPRTL___tgt_target_data_begin_mapper_issue);
1518
1519 // Change RuntimeCall call site for its asynchronous version.
1520 SmallVector<Value *, 16> Args;
1521 for (auto &Arg : RuntimeCall.args())
1522 Args.push_back(Arg.get());
1523 Args.push_back(Handle);
1524
1525 CallInst *IssueCallsite =
1526 CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall);
1527 RuntimeCall.eraseFromParent();
1528
1529 // Add "wait" runtime call declaration:
1530 // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1531 FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1532 M, OMPRTL___tgt_target_data_begin_mapper_wait);
1533
1534 Value *WaitParams[2] = {
1535 IssueCallsite->getArgOperand(
1536 OffloadArray::DeviceIDArgNum), // device_id.
1537 Handle // handle to wait on.
1538 };
1539 CallInst::Create(WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint);
1540
1541 return true;
1542 }
1543
1544 static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1545 bool GlobalOnly, bool &SingleChoice) {
1546 if (CurrentIdent == NextIdent)
1547 return CurrentIdent;
1548
1549 // TODO: Figure out how to actually combine multiple debug locations. For
1550 // now we just keep an existing one if there is a single choice.
1551 if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1552 SingleChoice = !CurrentIdent;
1553 return NextIdent;
1554 }
1555 return nullptr;
1556 }
1557
1558 /// Return an `struct ident_t*` value that represents the ones used in the
1559 /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1560 /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1561 /// return value we create one from scratch. We also do not yet combine
1562 /// information, e.g., the source locations, see combinedIdentStruct.
1563 Value *
1564 getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1565 Function &F, bool GlobalOnly) {
1566 bool SingleChoice = true;
1567 Value *Ident = nullptr;
1568 auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1569 CallInst *CI = getCallIfRegularCall(U, &RFI);
1570 if (!CI || &F != &Caller)
1571 return false;
1572 Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1573 /* GlobalOnly */ true, SingleChoice);
1574 return false;
1575 };
1576 RFI.foreachUse(SCC, CombineIdentStruct);
1577
1578 if (!Ident || !SingleChoice) {
1579 // The IRBuilder uses the insertion block to get to the module, this is
1580 // unfortunate but we work around it for now.
1581 if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1582 OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1583 &F.getEntryBlock(), F.getEntryBlock().begin()));
1584 // Create a fallback location if non was found.
1585 // TODO: Use the debug locations of the calls instead.
1586 Constant *Loc = OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr();
1587 Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc);
1588 }
1589 return Ident;
1590 }
1591
1592 /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1593 /// \p ReplVal if given.
1594 bool deduplicateRuntimeCalls(Function &F,
1595 OMPInformationCache::RuntimeFunctionInfo &RFI,
1596 Value *ReplVal = nullptr) {
1597 auto *UV = RFI.getUseVector(F);
1598 if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1599 return false;
1600
1601 LLVM_DEBUG(do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "Deduplicate "
<< UV->size() << " uses of " << RFI.Name
<< (ReplVal ? " with an existing value\n" : "\n") <<
"\n"; } } while (false)
1602 dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Namedo { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "Deduplicate "
<< UV->size() << " uses of " << RFI.Name
<< (ReplVal ? " with an existing value\n" : "\n") <<
"\n"; } } while (false)
1603 << (ReplVal ? " with an existing value\n" : "\n") << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "Deduplicate "
<< UV->size() << " uses of " << RFI.Name
<< (ReplVal ? " with an existing value\n" : "\n") <<
"\n"; } } while (false)
;
1604
1605 assert((!ReplVal || (isa<Argument>(ReplVal) &&(static_cast <bool> ((!ReplVal || (isa<Argument>(
ReplVal) && cast<Argument>(ReplVal)->getParent
() == &F)) && "Unexpected replacement value!") ? void
(0) : __assert_fail ("(!ReplVal || (isa<Argument>(ReplVal) && cast<Argument>(ReplVal)->getParent() == &F)) && \"Unexpected replacement value!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 1607, __extension__ __PRETTY_FUNCTION__))
1606 cast<Argument>(ReplVal)->getParent() == &F)) &&(static_cast <bool> ((!ReplVal || (isa<Argument>(
ReplVal) && cast<Argument>(ReplVal)->getParent
() == &F)) && "Unexpected replacement value!") ? void
(0) : __assert_fail ("(!ReplVal || (isa<Argument>(ReplVal) && cast<Argument>(ReplVal)->getParent() == &F)) && \"Unexpected replacement value!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 1607, __extension__ __PRETTY_FUNCTION__))
1607 "Unexpected replacement value!")(static_cast <bool> ((!ReplVal || (isa<Argument>(
ReplVal) && cast<Argument>(ReplVal)->getParent
() == &F)) && "Unexpected replacement value!") ? void
(0) : __assert_fail ("(!ReplVal || (isa<Argument>(ReplVal) && cast<Argument>(ReplVal)->getParent() == &F)) && \"Unexpected replacement value!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 1607, __extension__ __PRETTY_FUNCTION__))
;
1608
1609 // TODO: Use dominance to find a good position instead.
1610 auto CanBeMoved = [this](CallBase &CB) {
1611 unsigned NumArgs = CB.getNumArgOperands();
1612 if (NumArgs == 0)
1613 return true;
1614 if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1615 return false;
1616 for (unsigned u = 1; u < NumArgs; ++u)
1617 if (isa<Instruction>(CB.getArgOperand(u)))
1618 return false;
1619 return true;
1620 };
1621
1622 if (!ReplVal) {
1623 for (Use *U : *UV)
1624 if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1625 if (!CanBeMoved(*CI))
1626 continue;
1627
1628 // If the function is a kernel, dedup will move
1629 // the runtime call right after the kernel init callsite. Otherwise,
1630 // it will move it to the beginning of the caller function.
1631 if (isKernel(F)) {
1632 auto &KernelInitRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
1633 auto *KernelInitUV = KernelInitRFI.getUseVector(F);
1634
1635 if (KernelInitUV->empty())
1636 continue;
1637
1638 assert(KernelInitUV->size() == 1 &&(static_cast <bool> (KernelInitUV->size() == 1 &&
"Expected a single __kmpc_target_init in kernel\n") ? void (
0) : __assert_fail ("KernelInitUV->size() == 1 && \"Expected a single __kmpc_target_init in kernel\\n\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 1639, __extension__ __PRETTY_FUNCTION__))
1639 "Expected a single __kmpc_target_init in kernel\n")(static_cast <bool> (KernelInitUV->size() == 1 &&
"Expected a single __kmpc_target_init in kernel\n") ? void (
0) : __assert_fail ("KernelInitUV->size() == 1 && \"Expected a single __kmpc_target_init in kernel\\n\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 1639, __extension__ __PRETTY_FUNCTION__))
;
1640
1641 CallInst *KernelInitCI =
1642 getCallIfRegularCall(*KernelInitUV->front(), &KernelInitRFI);
1643 assert(KernelInitCI &&(static_cast <bool> (KernelInitCI && "Expected a call to __kmpc_target_init in kernel\n"
) ? void (0) : __assert_fail ("KernelInitCI && \"Expected a call to __kmpc_target_init in kernel\\n\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 1644, __extension__ __PRETTY_FUNCTION__))
1644 "Expected a call to __kmpc_target_init in kernel\n")(static_cast <bool> (KernelInitCI && "Expected a call to __kmpc_target_init in kernel\n"
) ? void (0) : __assert_fail ("KernelInitCI && \"Expected a call to __kmpc_target_init in kernel\\n\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 1644, __extension__ __PRETTY_FUNCTION__))
;
1645
1646 CI->moveAfter(KernelInitCI);
1647 } else
1648 CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt());
1649 ReplVal = CI;
1650 break;
1651 }
1652 if (!ReplVal)
1653 return false;
1654 }
1655
1656 // If we use a call as a replacement value we need to make sure the ident is
1657 // valid at the new location. For now we just pick a global one, either
1658 // existing and used by one of the calls, or created from scratch.
1659 if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1660 if (!CI->arg_empty() &&
1661 CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1662 Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1663 /* GlobalOnly */ true);
1664 CI->setArgOperand(0, Ident);
1665 }
1666 }
1667
1668 bool Changed = false;
1669 auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1670 CallInst *CI = getCallIfRegularCall(U, &RFI);
1671 if (!CI || CI == ReplVal || &F != &Caller)
1672 return false;
1673 assert(CI->getCaller() == &F && "Unexpected call!")(static_cast <bool> (CI->getCaller() == &F &&
"Unexpected call!") ? void (0) : __assert_fail ("CI->getCaller() == &F && \"Unexpected call!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 1673, __extension__ __PRETTY_FUNCTION__))
;
1674
1675 auto Remark = [&](OptimizationRemark OR) {
1676 return OR << "OpenMP runtime call "
1677 << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1678 };
1679 if (CI->getDebugLoc())
1680 emitRemark<OptimizationRemark>(CI, "OMP170", Remark);
1681 else
1682 emitRemark<OptimizationRemark>(&F, "OMP170", Remark);
1683
1684 CGUpdater.removeCallSite(*CI);
1685 CI->replaceAllUsesWith(ReplVal);
1686 CI->eraseFromParent();
1687 ++NumOpenMPRuntimeCallsDeduplicated;
1688 Changed = true;
1689 return true;
1690 };
1691 RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1692
1693 return Changed;
1694 }
1695
1696 /// Collect arguments that represent the global thread id in \p GTIdArgs.
1697 void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1698 // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1699 // initialization. We could define an AbstractAttribute instead and
1700 // run the Attributor here once it can be run as an SCC pass.
1701
1702 // Helper to check the argument \p ArgNo at all call sites of \p F for
1703 // a GTId.
1704 auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1705 if (!F.hasLocalLinkage())
1706 return false;
1707 for (Use &U : F.uses()) {
1708 if (CallInst *CI = getCallIfRegularCall(U)) {
1709 Value *ArgOp = CI->getArgOperand(ArgNo);
1710 if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1711 getCallIfRegularCall(
1712 *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1713 continue;
1714 }
1715 return false;
1716 }
1717 return true;
1718 };
1719
1720 // Helper to identify uses of a GTId as GTId arguments.
1721 auto AddUserArgs = [&](Value &GTId) {
1722 for (Use &U : GTId.uses())
1723 if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1724 if (CI->isArgOperand(&U))
1725 if (Function *Callee = CI->getCalledFunction())
1726 if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1727 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1728 };
1729
1730 // The argument users of __kmpc_global_thread_num calls are GTIds.
1731 OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1732 OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1733
1734 GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1735 if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1736 AddUserArgs(*CI);
1737 return false;
1738 });
1739
1740 // Transitively search for more arguments by looking at the users of the
1741 // ones we know already. During the search the GTIdArgs vector is extended
1742 // so we cannot cache the size nor can we use a range based for.
1743 for (unsigned u = 0; u < GTIdArgs.size(); ++u)
1744 AddUserArgs(*GTIdArgs[u]);
1745 }
1746
1747 /// Kernel (=GPU) optimizations and utility functions
1748 ///
1749 ///{{
1750
1751 /// Check if \p F is a kernel, hence entry point for target offloading.
1752 bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); }
1753
1754 /// Cache to remember the unique kernel for a function.
1755 DenseMap<Function *, Optional<Kernel>> UniqueKernelMap;
1756
1757 /// Find the unique kernel that will execute \p F, if any.
1758 Kernel getUniqueKernelFor(Function &F);
1759
1760 /// Find the unique kernel that will execute \p I, if any.
1761 Kernel getUniqueKernelFor(Instruction &I) {
1762 return getUniqueKernelFor(*I.getFunction());
1763 }
1764
1765 /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1766 /// the cases we can avoid taking the address of a function.
1767 bool rewriteDeviceCodeStateMachine();
1768
1769 ///
1770 ///}}
1771
1772 /// Emit a remark generically
1773 ///
1774 /// This template function can be used to generically emit a remark. The
1775 /// RemarkKind should be one of the following:
1776 /// - OptimizationRemark to indicate a successful optimization attempt
1777 /// - OptimizationRemarkMissed to report a failed optimization attempt
1778 /// - OptimizationRemarkAnalysis to provide additional information about an
1779 /// optimization attempt
1780 ///
1781 /// The remark is built using a callback function provided by the caller that
1782 /// takes a RemarkKind as input and returns a RemarkKind.
1783 template <typename RemarkKind, typename RemarkCallBack>
1784 void emitRemark(Instruction *I, StringRef RemarkName,
1785 RemarkCallBack &&RemarkCB) const {
1786 Function *F = I->getParent()->getParent();
1787 auto &ORE = OREGetter(F);
1788
1789 if (RemarkName.startswith("OMP"))
1790 ORE.emit([&]() {
1791 return RemarkCB(RemarkKind(DEBUG_TYPE"openmp-opt", RemarkName, I))
1792 << " [" << RemarkName << "]";
1793 });
1794 else
1795 ORE.emit(
1796 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE"openmp-opt", RemarkName, I)); });
1797 }
1798
1799 /// Emit a remark on a function.
1800 template <typename RemarkKind, typename RemarkCallBack>
1801 void emitRemark(Function *F, StringRef RemarkName,
1802 RemarkCallBack &&RemarkCB) const {
1803 auto &ORE = OREGetter(F);
1804
1805 if (RemarkName.startswith("OMP"))
1806 ORE.emit([&]() {
1807 return RemarkCB(RemarkKind(DEBUG_TYPE"openmp-opt", RemarkName, F))
1808 << " [" << RemarkName << "]";
1809 });
1810 else
1811 ORE.emit(
1812 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE"openmp-opt", RemarkName, F)); });
1813 }
1814
1815 /// RAII struct to temporarily change an RTL function's linkage to external.
1816 /// This prevents it from being mistakenly removed by other optimizations.
1817 struct ExternalizationRAII {
1818 ExternalizationRAII(OMPInformationCache &OMPInfoCache,
1819 RuntimeFunction RFKind)
1820 : Declaration(OMPInfoCache.RFIs[RFKind].Declaration) {
1821 if (!Declaration)
1822 return;
1823
1824 LinkageType = Declaration->getLinkage();
1825 Declaration->setLinkage(GlobalValue::ExternalLinkage);
1826 }
1827
1828 ~ExternalizationRAII() {
1829 if (!Declaration)
1830 return;
1831
1832 Declaration->setLinkage(LinkageType);
1833 }
1834
1835 Function *Declaration;
1836 GlobalValue::LinkageTypes LinkageType;
1837 };
1838
1839 /// The underlying module.
1840 Module &M;
1841
1842 /// The SCC we are operating on.
1843 SmallVectorImpl<Function *> &SCC;
1844
1845 /// Callback to update the call graph, the first argument is a removed call,
1846 /// the second an optional replacement call.
1847 CallGraphUpdater &CGUpdater;
1848
1849 /// Callback to get an OptimizationRemarkEmitter from a Function *
1850 OptimizationRemarkGetter OREGetter;
1851
1852 /// OpenMP-specific information cache. Also Used for Attributor runs.
1853 OMPInformationCache &OMPInfoCache;
1854
1855 /// Attributor instance.
1856 Attributor &A;
1857
1858 /// Helper function to run Attributor on SCC.
1859 bool runAttributor(bool IsModulePass) {
1860 if (SCC.empty())
1861 return false;
1862
1863 // Temporarily make these function have external linkage so the Attributor
1864 // doesn't remove them when we try to look them up later.
1865 ExternalizationRAII Parallel(OMPInfoCache, OMPRTL___kmpc_kernel_parallel);
1866 ExternalizationRAII EndParallel(OMPInfoCache,
1867 OMPRTL___kmpc_kernel_end_parallel);
1868 ExternalizationRAII BarrierSPMD(OMPInfoCache,
1869 OMPRTL___kmpc_barrier_simple_spmd);
1870 ExternalizationRAII ThreadId(OMPInfoCache,
1871 OMPRTL___kmpc_get_hardware_thread_id_in_block);
1872
1873 registerAAs(IsModulePass);
1874
1875 ChangeStatus Changed = A.run();
1876
1877 LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << "[Attributor] Done with " <<
SCC.size() << " functions, result: " << Changed <<
".\n"; } } while (false)
1878 << " functions, result: " << Changed << ".\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << "[Attributor] Done with " <<
SCC.size() << " functions, result: " << Changed <<
".\n"; } } while (false)
;
1879
1880 return Changed == ChangeStatus::CHANGED;
1881 }
1882
1883 void registerFoldRuntimeCall(RuntimeFunction RF);
1884
1885 /// Populate the Attributor with abstract attribute opportunities in the
1886 /// function.
1887 void registerAAs(bool IsModulePass);
1888};
1889
1890Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
1891 if (!OMPInfoCache.ModuleSlice.count(&F))
1892 return nullptr;
1893
1894 // Use a scope to keep the lifetime of the CachedKernel short.
1895 {
1896 Optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
1897 if (CachedKernel)
1898 return *CachedKernel;
1899
1900 // TODO: We should use an AA to create an (optimistic and callback
1901 // call-aware) call graph. For now we stick to simple patterns that
1902 // are less powerful, basically the worst fixpoint.
1903 if (isKernel(F)) {
1904 CachedKernel = Kernel(&F);
1905 return *CachedKernel;
1906 }
1907
1908 CachedKernel = nullptr;
1909 if (!F.hasLocalLinkage()) {
1910
1911 // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
1912 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1913 return ORA << "Potentially unknown OpenMP target region caller.";
1914 };
1915 emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
1916
1917 return nullptr;
1918 }
1919 }
1920
1921 auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
1922 if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
1923 // Allow use in equality comparisons.
1924 if (Cmp->isEquality())
1925 return getUniqueKernelFor(*Cmp);
1926 return nullptr;
1927 }
1928 if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
1929 // Allow direct calls.
1930 if (CB->isCallee(&U))
1931 return getUniqueKernelFor(*CB);
1932
1933 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1934 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
1935 // Allow the use in __kmpc_parallel_51 calls.
1936 if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
1937 return getUniqueKernelFor(*CB);
1938 return nullptr;
1939 }
1940 // Disallow every other use.
1941 return nullptr;
1942 };
1943
1944 // TODO: In the future we want to track more than just a unique kernel.
1945 SmallPtrSet<Kernel, 2> PotentialKernels;
1946 OMPInformationCache::foreachUse(F, [&](const Use &U) {
1947 PotentialKernels.insert(GetUniqueKernelForUse(U));
1948 });
1949
1950 Kernel K = nullptr;
1951 if (PotentialKernels.size() == 1)
1952 K = *PotentialKernels.begin();
1953
1954 // Cache the result.
1955 UniqueKernelMap[&F] = K;
1956
1957 return K;
1958}
1959
1960bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
1961 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1962 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
1963
1964 bool Changed = false;
1965 if (!KernelParallelRFI)
1966 return Changed;
1967
1968 // If we have disabled state machine changes, exit
1969 if (DisableOpenMPOptStateMachineRewrite)
1970 return Changed;
1971
1972 for (Function *F : SCC) {
1973
1974 // Check if the function is a use in a __kmpc_parallel_51 call at
1975 // all.
1976 bool UnknownUse = false;
1977 bool KernelParallelUse = false;
1978 unsigned NumDirectCalls = 0;
1979
1980 SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
1981 OMPInformationCache::foreachUse(*F, [&](Use &U) {
1982 if (auto *CB = dyn_cast<CallBase>(U.getUser()))
1983 if (CB->isCallee(&U)) {
1984 ++NumDirectCalls;
1985 return;
1986 }
1987
1988 if (isa<ICmpInst>(U.getUser())) {
1989 ToBeReplacedStateMachineUses.push_back(&U);
1990 return;
1991 }
1992
1993 // Find wrapper functions that represent parallel kernels.
1994 CallInst *CI =
1995 OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
1996 const unsigned int WrapperFunctionArgNo = 6;
1997 if (!KernelParallelUse && CI &&
1998 CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
1999 KernelParallelUse = true;
2000 ToBeReplacedStateMachineUses.push_back(&U);
2001 return;
2002 }
2003 UnknownUse = true;
2004 });
2005
2006 // Do not emit a remark if we haven't seen a __kmpc_parallel_51
2007 // use.
2008 if (!KernelParallelUse)
2009 continue;
2010
2011 // If this ever hits, we should investigate.
2012 // TODO: Checking the number of uses is not a necessary restriction and
2013 // should be lifted.
2014 if (UnknownUse || NumDirectCalls != 1 ||
2015 ToBeReplacedStateMachineUses.size() > 2) {
2016 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2017 return ORA << "Parallel region is used in "
2018 << (UnknownUse ? "unknown" : "unexpected")
2019 << " ways. Will not attempt to rewrite the state machine.";
2020 };
2021 emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark);
2022 continue;
2023 }
2024
2025 // Even if we have __kmpc_parallel_51 calls, we (for now) give
2026 // up if the function is not called from a unique kernel.
2027 Kernel K = getUniqueKernelFor(*F);
2028 if (!K) {
2029 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2030 return ORA << "Parallel region is not called from a unique kernel. "
2031 "Will not attempt to rewrite the state machine.";
2032 };
2033 emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark);
2034 continue;
2035 }
2036
2037 // We now know F is a parallel body function called only from the kernel K.
2038 // We also identified the state machine uses in which we replace the
2039 // function pointer by a new global symbol for identification purposes. This
2040 // ensures only direct calls to the function are left.
2041
2042 Module &M = *F->getParent();
2043 Type *Int8Ty = Type::getInt8Ty(M.getContext());
2044
2045 auto *ID = new GlobalVariable(
2046 M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2047 UndefValue::get(Int8Ty), F->getName() + ".ID");
2048
2049 for (Use *U : ToBeReplacedStateMachineUses)
2050 U->set(ConstantExpr::getPointerBitCastOrAddrSpaceCast(
2051 ID, U->get()->getType()));
2052
2053 ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2054
2055 Changed = true;
2056 }
2057
2058 return Changed;
2059}
2060
2061/// Abstract Attribute for tracking ICV values.
2062struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2063 using Base = StateWrapper<BooleanState, AbstractAttribute>;
2064 AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2065
2066 void initialize(Attributor &A) override {
2067 Function *F = getAnchorScope();
2068 if (!F || !A.isFunctionIPOAmendable(*F))
2069 indicatePessimisticFixpoint();
2070 }
2071
2072 /// Returns true if value is assumed to be tracked.
2073 bool isAssumedTracked() const { return getAssumed(); }
2074
2075 /// Returns true if value is known to be tracked.
2076 bool isKnownTracked() const { return getAssumed(); }
2077
2078 /// Create an abstract attribute biew for the position \p IRP.
2079 static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2080
2081 /// Return the value with which \p I can be replaced for specific \p ICV.
2082 virtual Optional<Value *> getReplacementValue(InternalControlVar ICV,
2083 const Instruction *I,
2084 Attributor &A) const {
2085 return None;
2086 }
2087
2088 /// Return an assumed unique ICV value if a single candidate is found. If
2089 /// there cannot be one, return a nullptr. If it is not clear yet, return the
2090 /// Optional::NoneType.
2091 virtual Optional<Value *>
2092 getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2093
2094 // Currently only nthreads is being tracked.
2095 // this array will only grow with time.
2096 InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2097
2098 /// See AbstractAttribute::getName()
2099 const std::string getName() const override { return "AAICVTracker"; }
2100
2101 /// See AbstractAttribute::getIdAddr()
2102 const char *getIdAddr() const override { return &ID; }
2103
2104 /// This function should return true if the type of the \p AA is AAICVTracker
2105 static bool classof(const AbstractAttribute *AA) {
2106 return (AA->getIdAddr() == &ID);
2107 }
2108
2109 static const char ID;
2110};
2111
2112struct AAICVTrackerFunction : public AAICVTracker {
2113 AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2114 : AAICVTracker(IRP, A) {}
2115
2116 // FIXME: come up with better string.
2117 const std::string getAsStr() const override { return "ICVTrackerFunction"; }
2118
2119 // FIXME: come up with some stats.
2120 void trackStatistics() const override {}
2121
2122 /// We don't manifest anything for this AA.
2123 ChangeStatus manifest(Attributor &A) override {
2124 return ChangeStatus::UNCHANGED;
2125 }
2126
2127 // Map of ICV to their values at specific program point.
2128 EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar,
2129 InternalControlVar::ICV___last>
2130 ICVReplacementValuesMap;
2131
2132 ChangeStatus updateImpl(Attributor &A) override {
2133 ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
2134
2135 Function *F = getAnchorScope();
2136
2137 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2138
2139 for (InternalControlVar ICV : TrackableICVs) {
2140 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2141
2142 auto &ValuesMap = ICVReplacementValuesMap[ICV];
2143 auto TrackValues = [&](Use &U, Function &) {
2144 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2145 if (!CI)
2146 return false;
2147
2148 // FIXME: handle setters with more that 1 arguments.
2149 /// Track new value.
2150 if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2151 HasChanged = ChangeStatus::CHANGED;
2152
2153 return false;
2154 };
2155
2156 auto CallCheck = [&](Instruction &I) {
2157 Optional<Value *> ReplVal = getValueForCall(A, &I, ICV);
2158 if (ReplVal.hasValue() &&
2159 ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2160 HasChanged = ChangeStatus::CHANGED;
2161
2162 return true;
2163 };
2164
2165 // Track all changes of an ICV.
2166 SetterRFI.foreachUse(TrackValues, F);
2167
2168 bool UsedAssumedInformation = false;
2169 A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2170 UsedAssumedInformation,
2171 /* CheckBBLivenessOnly */ true);
2172
2173 /// TODO: Figure out a way to avoid adding entry in
2174 /// ICVReplacementValuesMap
2175 Instruction *Entry = &F->getEntryBlock().front();
2176 if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
2177 ValuesMap.insert(std::make_pair(Entry, nullptr));
2178 }
2179
2180 return HasChanged;
2181 }
2182
2183 /// Hepler to check if \p I is a call and get the value for it if it is
2184 /// unique.
2185 Optional<Value *> getValueForCall(Attributor &A, const Instruction *I,
2186 InternalControlVar &ICV) const {
2187
2188 const auto *CB = dyn_cast<CallBase>(I);
2189 if (!CB || CB->hasFnAttr("no_openmp") ||
2190 CB->hasFnAttr("no_openmp_routines"))
2191 return None;
2192
2193 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2194 auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2195 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2196 Function *CalledFunction = CB->getCalledFunction();
2197
2198 // Indirect call, assume ICV changes.
2199 if (CalledFunction == nullptr)
2200 return nullptr;
2201 if (CalledFunction == GetterRFI.Declaration)
2202 return None;
2203 if (CalledFunction == SetterRFI.Declaration) {
2204 if (ICVReplacementValuesMap[ICV].count(I))
2205 return ICVReplacementValuesMap[ICV].lookup(I);
2206
2207 return nullptr;
2208 }
2209
2210 // Since we don't know, assume it changes the ICV.
2211 if (CalledFunction->isDeclaration())
2212 return nullptr;
2213
2214 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2215 *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
2216
2217 if (ICVTrackingAA.isAssumedTracked())
2218 return ICVTrackingAA.getUniqueReplacementValue(ICV);
2219
2220 // If we don't know, assume it changes.
2221 return nullptr;
2222 }
2223
2224 // We don't check unique value for a function, so return None.
2225 Optional<Value *>
2226 getUniqueReplacementValue(InternalControlVar ICV) const override {
2227 return None;
2228 }
2229
2230 /// Return the value with which \p I can be replaced for specific \p ICV.
2231 Optional<Value *> getReplacementValue(InternalControlVar ICV,
2232 const Instruction *I,
2233 Attributor &A) const override {
2234 const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2235 if (ValuesMap.count(I))
2236 return ValuesMap.lookup(I);
2237
2238 SmallVector<const Instruction *, 16> Worklist;
2239 SmallPtrSet<const Instruction *, 16> Visited;
2240 Worklist.push_back(I);
2241
2242 Optional<Value *> ReplVal;
2243
2244 while (!Worklist.empty()) {
2245 const Instruction *CurrInst = Worklist.pop_back_val();
2246 if (!Visited.insert(CurrInst).second)
2247 continue;
2248
2249 const BasicBlock *CurrBB = CurrInst->getParent();
2250
2251 // Go up and look for all potential setters/calls that might change the
2252 // ICV.
2253 while ((CurrInst = CurrInst->getPrevNode())) {
2254 if (ValuesMap.count(CurrInst)) {
2255 Optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2256 // Unknown value, track new.
2257 if (!ReplVal.hasValue()) {
2258 ReplVal = NewReplVal;
2259 break;
2260 }
2261
2262 // If we found a new value, we can't know the icv value anymore.
2263 if (NewReplVal.hasValue())
2264 if (ReplVal != NewReplVal)
2265 return nullptr;
2266
2267 break;
2268 }
2269
2270 Optional<Value *> NewReplVal = getValueForCall(A, CurrInst, ICV);
2271 if (!NewReplVal.hasValue())
2272 continue;
2273
2274 // Unknown value, track new.
2275 if (!ReplVal.hasValue()) {
2276 ReplVal = NewReplVal;
2277 break;
2278 }
2279
2280 // if (NewReplVal.hasValue())
2281 // We found a new value, we can't know the icv value anymore.
2282 if (ReplVal != NewReplVal)
2283 return nullptr;
2284 }
2285
2286 // If we are in the same BB and we have a value, we are done.
2287 if (CurrBB == I->getParent() && ReplVal.hasValue())
2288 return ReplVal;
2289
2290 // Go through all predecessors and add terminators for analysis.
2291 for (const BasicBlock *Pred : predecessors(CurrBB))
2292 if (const Instruction *Terminator = Pred->getTerminator())
2293 Worklist.push_back(Terminator);
2294 }
2295
2296 return ReplVal;
2297 }
2298};
2299
2300struct AAICVTrackerFunctionReturned : AAICVTracker {
2301 AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2302 : AAICVTracker(IRP, A) {}
2303
2304 // FIXME: come up with better string.
2305 const std::string getAsStr() const override {
2306 return "ICVTrackerFunctionReturned";
2307 }
2308
2309 // FIXME: come up with some stats.
2310 void trackStatistics() const override {}
2311
2312 /// We don't manifest anything for this AA.
2313 ChangeStatus manifest(Attributor &A) override {
2314 return ChangeStatus::UNCHANGED;
2315 }
2316
2317 // Map of ICV to their values at specific program point.
2318 EnumeratedArray<Optional<Value *>, InternalControlVar,
2319 InternalControlVar::ICV___last>
2320 ICVReplacementValuesMap;
2321
2322 /// Return the value with which \p I can be replaced for specific \p ICV.
2323 Optional<Value *>
2324 getUniqueReplacementValue(InternalControlVar ICV) const override {
2325 return ICVReplacementValuesMap[ICV];
2326 }
2327
2328 ChangeStatus updateImpl(Attributor &A) override {
2329 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2330 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2331 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2332
2333 if (!ICVTrackingAA.isAssumedTracked())
2334 return indicatePessimisticFixpoint();
2335
2336 for (InternalControlVar ICV : TrackableICVs) {
2337 Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2338 Optional<Value *> UniqueICVValue;
2339
2340 auto CheckReturnInst = [&](Instruction &I) {
2341 Optional<Value *> NewReplVal =
2342 ICVTrackingAA.getReplacementValue(ICV, &I, A);
2343
2344 // If we found a second ICV value there is no unique returned value.
2345 if (UniqueICVValue.hasValue() && UniqueICVValue != NewReplVal)
2346 return false;
2347
2348 UniqueICVValue = NewReplVal;
2349
2350 return true;
2351 };
2352
2353 bool UsedAssumedInformation = false;
2354 if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2355 UsedAssumedInformation,
2356 /* CheckBBLivenessOnly */ true))
2357 UniqueICVValue = nullptr;
2358
2359 if (UniqueICVValue == ReplVal)
2360 continue;
2361
2362 ReplVal = UniqueICVValue;
2363 Changed = ChangeStatus::CHANGED;
2364 }
2365
2366 return Changed;
2367 }
2368};
2369
2370struct AAICVTrackerCallSite : AAICVTracker {
2371 AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2372 : AAICVTracker(IRP, A) {}
2373
2374 void initialize(Attributor &A) override {
2375 Function *F = getAnchorScope();
2376 if (!F || !A.isFunctionIPOAmendable(*F))
2377 indicatePessimisticFixpoint();
2378
2379 // We only initialize this AA for getters, so we need to know which ICV it
2380 // gets.
2381 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2382 for (InternalControlVar ICV : TrackableICVs) {
2383 auto ICVInfo = OMPInfoCache.ICVs[ICV];
2384 auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2385 if (Getter.Declaration == getAssociatedFunction()) {
2386 AssociatedICV = ICVInfo.Kind;
2387 return;
2388 }
2389 }
2390
2391 /// Unknown ICV.
2392 indicatePessimisticFixpoint();
2393 }
2394
2395 ChangeStatus manifest(Attributor &A) override {
2396 if (!ReplVal.hasValue() || !ReplVal.getValue())
2397 return ChangeStatus::UNCHANGED;
2398
2399 A.changeValueAfterManifest(*getCtxI(), **ReplVal);
2400 A.deleteAfterManifest(*getCtxI());
2401
2402 return ChangeStatus::CHANGED;
2403 }
2404
2405 // FIXME: come up with better string.
2406 const std::string getAsStr() const override { return "ICVTrackerCallSite"; }
2407
2408 // FIXME: come up with some stats.
2409 void trackStatistics() const override {}
2410
2411 InternalControlVar AssociatedICV;
2412 Optional<Value *> ReplVal;
2413
2414 ChangeStatus updateImpl(Attributor &A) override {
2415 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2416 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2417
2418 // We don't have any information, so we assume it changes the ICV.
2419 if (!ICVTrackingAA.isAssumedTracked())
2420 return indicatePessimisticFixpoint();
2421
2422 Optional<Value *> NewReplVal =
2423 ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A);
2424
2425 if (ReplVal == NewReplVal)
2426 return ChangeStatus::UNCHANGED;
2427
2428 ReplVal = NewReplVal;
2429 return ChangeStatus::CHANGED;
2430 }
2431
2432 // Return the value with which associated value can be replaced for specific
2433 // \p ICV.
2434 Optional<Value *>
2435 getUniqueReplacementValue(InternalControlVar ICV) const override {
2436 return ReplVal;
2437 }
2438};
2439
2440struct AAICVTrackerCallSiteReturned : AAICVTracker {
2441 AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2442 : AAICVTracker(IRP, A) {}
2443
2444 // FIXME: come up with better string.
2445 const std::string getAsStr() const override {
2446 return "ICVTrackerCallSiteReturned";
2447 }
2448
2449 // FIXME: come up with some stats.
2450 void trackStatistics() const override {}
2451
2452 /// We don't manifest anything for this AA.
2453 ChangeStatus manifest(Attributor &A) override {
2454 return ChangeStatus::UNCHANGED;
2455 }
2456
2457 // Map of ICV to their values at specific program point.
2458 EnumeratedArray<Optional<Value *>, InternalControlVar,
2459 InternalControlVar::ICV___last>
2460 ICVReplacementValuesMap;
2461
2462 /// Return the value with which associated value can be replaced for specific
2463 /// \p ICV.
2464 Optional<Value *>
2465 getUniqueReplacementValue(InternalControlVar ICV) const override {
2466 return ICVReplacementValuesMap[ICV];
2467 }
2468
2469 ChangeStatus updateImpl(Attributor &A) override {
2470 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2471 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2472 *this, IRPosition::returned(*getAssociatedFunction()),
2473 DepClassTy::REQUIRED);
2474
2475 // We don't have any information, so we assume it changes the ICV.
2476 if (!ICVTrackingAA.isAssumedTracked())
2477 return indicatePessimisticFixpoint();
2478
2479 for (InternalControlVar ICV : TrackableICVs) {
2480 Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2481 Optional<Value *> NewReplVal =
2482 ICVTrackingAA.getUniqueReplacementValue(ICV);
2483
2484 if (ReplVal == NewReplVal)
2485 continue;
2486
2487 ReplVal = NewReplVal;
2488 Changed = ChangeStatus::CHANGED;
2489 }
2490 return Changed;
2491 }
2492};
2493
2494struct AAExecutionDomainFunction : public AAExecutionDomain {
2495 AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2496 : AAExecutionDomain(IRP, A) {}
2497
2498 const std::string getAsStr() const override {
2499 return "[AAExecutionDomain] " + std::to_string(SingleThreadedBBs.size()) +
2500 "/" + std::to_string(NumBBs) + " BBs thread 0 only.";
2501 }
2502
2503 /// See AbstractAttribute::trackStatistics().
2504 void trackStatistics() const override {}
2505
2506 void initialize(Attributor &A) override {
2507 Function *F = getAnchorScope();
2508 for (const auto &BB : *F)
2509 SingleThreadedBBs.insert(&BB);
2510 NumBBs = SingleThreadedBBs.size();
2511 }
2512
2513 ChangeStatus manifest(Attributor &A) override {
2514 LLVM_DEBUG({do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { { for (const BasicBlock *BB : SingleThreadedBBs
) dbgs() << TAG << " Basic block @" << getAnchorScope
()->getName() << " " << BB->getName() <<
" is executed by a single thread.\n"; }; } } while (false)
2515 for (const BasicBlock *BB : SingleThreadedBBs)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { { for (const BasicBlock *BB : SingleThreadedBBs
) dbgs() << TAG << " Basic block @" << getAnchorScope
()->getName() << " " << BB->getName() <<
" is executed by a single thread.\n"; }; } } while (false)
2516 dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { { for (const BasicBlock *BB : SingleThreadedBBs
) dbgs() << TAG << " Basic block @" << getAnchorScope
()->getName() << " " << BB->getName() <<
" is executed by a single thread.\n"; }; } } while (false)
2517 << BB->getName() << " is executed by a single thread.\n";do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { { for (const BasicBlock *BB : SingleThreadedBBs
) dbgs() << TAG << " Basic block @" << getAnchorScope
()->getName() << " " << BB->getName() <<
" is executed by a single thread.\n"; }; } } while (false)
2518 })do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { { for (const BasicBlock *BB : SingleThreadedBBs
) dbgs() << TAG << " Basic block @" << getAnchorScope
()->getName() << " " << BB->getName() <<
" is executed by a single thread.\n"; }; } } while (false)
;
2519 return ChangeStatus::UNCHANGED;
2520 }
2521
2522 ChangeStatus updateImpl(Attributor &A) override;
2523
2524 /// Check if an instruction is executed by a single thread.
2525 bool isExecutedByInitialThreadOnly(const Instruction &I) const override {
2526 return isExecutedByInitialThreadOnly(*I.getParent());
2527 }
2528
2529 bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2530 return isValidState() && SingleThreadedBBs.contains(&BB);
2531 }
2532
2533 /// Set of basic blocks that are executed by a single thread.
2534 DenseSet<const BasicBlock *> SingleThreadedBBs;
2535
2536 /// Total number of basic blocks in this function.
2537 long unsigned NumBBs;
2538};
2539
2540ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
2541 Function *F = getAnchorScope();
2542 ReversePostOrderTraversal<Function *> RPOT(F);
2543 auto NumSingleThreadedBBs = SingleThreadedBBs.size();
2544
2545 bool AllCallSitesKnown;
2546 auto PredForCallSite = [&](AbstractCallSite ACS) {
2547 const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
2548 *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
2549 DepClassTy::REQUIRED);
2550 return ACS.isDirectCall() &&
2551 ExecutionDomainAA.isExecutedByInitialThreadOnly(
2552 *ACS.getInstruction());
2553 };
2554
2555 if (!A.checkForAllCallSites(PredForCallSite, *this,
2556 /* RequiresAllCallSites */ true,
2557 AllCallSitesKnown))
2558 SingleThreadedBBs.erase(&F->getEntryBlock());
2559
2560 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2561 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2562
2563 // Check if the edge into the successor block contains a condition that only
2564 // lets the main thread execute it.
2565 auto IsInitialThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) {
2566 if (!Edge || !Edge->isConditional())
2567 return false;
2568 if (Edge->getSuccessor(0) != SuccessorBB)
2569 return false;
2570
2571 auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2572 if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2573 return false;
2574
2575 ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2576 if (!C)
2577 return false;
2578
2579 // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
2580 if (C->isAllOnesValue()) {
2581 auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2582 CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2583 if (!CB)
2584 return false;
2585 const int InitModeArgNo = 1;
2586 auto *ModeCI = dyn_cast<ConstantInt>(CB->getOperand(InitModeArgNo));
2587 return ModeCI && (ModeCI->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC);
2588 }
2589
2590 if (C->isZero()) {
2591 // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x()
2592 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2593 if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2594 return true;
2595
2596 // Match: 0 == llvm.amdgcn.workitem.id.x()
2597 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2598 if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2599 return true;
2600 }
2601
2602 return false;
2603 };
2604
2605 // Merge all the predecessor states into the current basic block. A basic
2606 // block is executed by a single thread if all of its predecessors are.
2607 auto MergePredecessorStates = [&](BasicBlock *BB) {
2608 if (pred_begin(BB) == pred_end(BB))
2609 return SingleThreadedBBs.contains(BB);
2610
2611 bool IsInitialThread = true;
2612 for (auto PredBB = pred_begin(BB), PredEndBB = pred_end(BB);
2613 PredBB != PredEndBB; ++PredBB) {
2614 if (!IsInitialThreadOnly(dyn_cast<BranchInst>((*PredBB)->getTerminator()),
2615 BB))
2616 IsInitialThread &= SingleThreadedBBs.contains(*PredBB);
2617 }
2618
2619 return IsInitialThread;
2620 };
2621
2622 for (auto *BB : RPOT) {
2623 if (!MergePredecessorStates(BB))
2624 SingleThreadedBBs.erase(BB);
2625 }
2626
2627 return (NumSingleThreadedBBs == SingleThreadedBBs.size())
2628 ? ChangeStatus::UNCHANGED
2629 : ChangeStatus::CHANGED;
2630}
2631
2632/// Try to replace memory allocation calls called by a single thread with a
2633/// static buffer of shared memory.
2634struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
2635 using Base = StateWrapper<BooleanState, AbstractAttribute>;
2636 AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2637
2638 /// Create an abstract attribute view for the position \p IRP.
2639 static AAHeapToShared &createForPosition(const IRPosition &IRP,
2640 Attributor &A);
2641
2642 /// Returns true if HeapToShared conversion is assumed to be possible.
2643 virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
2644
2645 /// Returns true if HeapToShared conversion is assumed and the CB is a
2646 /// callsite to a free operation to be removed.
2647 virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
2648
2649 /// See AbstractAttribute::getName().
2650 const std::string getName() const override { return "AAHeapToShared"; }
2651
2652 /// See AbstractAttribute::getIdAddr().
2653 const char *getIdAddr() const override { return &ID; }
2654
2655 /// This function should return true if the type of the \p AA is
2656 /// AAHeapToShared.
2657 static bool classof(const AbstractAttribute *AA) {
2658 return (AA->getIdAddr() == &ID);
2659 }
2660
2661 /// Unique ID (due to the unique address)
2662 static const char ID;
2663};
2664
2665struct AAHeapToSharedFunction : public AAHeapToShared {
2666 AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
2667 : AAHeapToShared(IRP, A) {}
2668
2669 const std::string getAsStr() const override {
2670 return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
2671 " malloc calls eligible.";
2672 }
2673
2674 /// See AbstractAttribute::trackStatistics().
2675 void trackStatistics() const override {}
2676
2677 /// This functions finds free calls that will be removed by the
2678 /// HeapToShared transformation.
2679 void findPotentialRemovedFreeCalls(Attributor &A) {
2680 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2681 auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2682
2683 PotentialRemovedFreeCalls.clear();
2684 // Update free call users of found malloc calls.
2685 for (CallBase *CB : MallocCalls) {
2686 SmallVector<CallBase *, 4> FreeCalls;
2687 for (auto *U : CB->users()) {
2688 CallBase *C = dyn_cast<CallBase>(U);
2689 if (C && C->getCalledFunction() == FreeRFI.Declaration)
2690 FreeCalls.push_back(C);
2691 }
2692
2693 if (FreeCalls.size() != 1)
2694 continue;
2695
2696 PotentialRemovedFreeCalls.insert(FreeCalls.front());
2697 }
2698 }
2699
2700 void initialize(Attributor &A) override {
2701 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2702 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
2703
2704 for (User *U : RFI.Declaration->users())
2705 if (CallBase *CB = dyn_cast<CallBase>(U))
2706 MallocCalls.insert(CB);
2707
2708 findPotentialRemovedFreeCalls(A);
2709 }
2710
2711 bool isAssumedHeapToShared(CallBase &CB) const override {
2712 return isValidState() && MallocCalls.count(&CB);
2713 }
2714
2715 bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
2716 return isValidState() && PotentialRemovedFreeCalls.count(&CB);
2717 }
2718
2719 ChangeStatus manifest(Attributor &A) override {
2720 if (MallocCalls.empty())
2721 return ChangeStatus::UNCHANGED;
2722
2723 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2724 auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2725
2726 Function *F = getAnchorScope();
2727 auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
2728 DepClassTy::OPTIONAL);
2729
2730 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2731 for (CallBase *CB : MallocCalls) {
2732 // Skip replacing this if HeapToStack has already claimed it.
2733 if (HS && HS->isAssumedHeapToStack(*CB))
2734 continue;
2735
2736 // Find the unique free call to remove it.
2737 SmallVector<CallBase *, 4> FreeCalls;
2738 for (auto *U : CB->users()) {
2739 CallBase *C = dyn_cast<CallBase>(U);
2740 if (C && C->getCalledFunction() == FreeCall.Declaration)
2741 FreeCalls.push_back(C);
2742 }
2743 if (FreeCalls.size() != 1)
2744 continue;
2745
2746 ConstantInt *AllocSize = dyn_cast<ConstantInt>(CB->getArgOperand(0));
2747
2748 LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CBdo { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "Replace globalization call "
<< *CB << " with " << AllocSize->getZExtValue
() << " bytes of shared memory\n"; } } while (false)
2749 << " with " << AllocSize->getZExtValue()do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "Replace globalization call "
<< *CB << " with " << AllocSize->getZExtValue
() << " bytes of shared memory\n"; } } while (false)
2750 << " bytes of shared memory\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "Replace globalization call "
<< *CB << " with " << AllocSize->getZExtValue
() << " bytes of shared memory\n"; } } while (false)
;
2751
2752 // Create a new shared memory buffer of the same size as the allocation
2753 // and replace all the uses of the original allocation with it.
2754 Module *M = CB->getModule();
2755 Type *Int8Ty = Type::getInt8Ty(M->getContext());
2756 Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
2757 auto *SharedMem = new GlobalVariable(
2758 *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
2759 UndefValue::get(Int8ArrTy), CB->getName(), nullptr,
2760 GlobalValue::NotThreadLocal,
2761 static_cast<unsigned>(AddressSpace::Shared));
2762 auto *NewBuffer =
2763 ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo());
2764
2765 auto Remark = [&](OptimizationRemark OR) {
2766 return OR << "Replaced globalized variable with "
2767 << ore::NV("SharedMemory", AllocSize->getZExtValue())
2768 << ((AllocSize->getZExtValue() != 1) ? " bytes " : " byte ")
2769 << "of shared memory.";
2770 };
2771 A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
2772
2773 SharedMem->setAlignment(MaybeAlign(32));
2774
2775 A.changeValueAfterManifest(*CB, *NewBuffer);
2776 A.deleteAfterManifest(*CB);
2777 A.deleteAfterManifest(*FreeCalls.front());
2778
2779 NumBytesMovedToSharedMemory += AllocSize->getZExtValue();
2780 Changed = ChangeStatus::CHANGED;
2781 }
2782
2783 return Changed;
2784 }
2785
2786 ChangeStatus updateImpl(Attributor &A) override {
2787 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2788 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
2789 Function *F = getAnchorScope();
2790
2791 auto NumMallocCalls = MallocCalls.size();
2792
2793 // Only consider malloc calls executed by a single thread with a constant.
2794 for (User *U : RFI.Declaration->users()) {
2795 const auto &ED = A.getAAFor<AAExecutionDomain>(
2796 *this, IRPosition::function(*F), DepClassTy::REQUIRED);
2797 if (CallBase *CB = dyn_cast<CallBase>(U))
2798 if (!dyn_cast<ConstantInt>(CB->getArgOperand(0)) ||
2799 !ED.isExecutedByInitialThreadOnly(*CB))
2800 MallocCalls.erase(CB);
2801 }
2802
2803 findPotentialRemovedFreeCalls(A);
2804
2805 if (NumMallocCalls != MallocCalls.size())
2806 return ChangeStatus::CHANGED;
2807
2808 return ChangeStatus::UNCHANGED;
2809 }
2810
2811 /// Collection of all malloc calls in a function.
2812 SmallPtrSet<CallBase *, 4> MallocCalls;
2813 /// Collection of potentially removed free calls in a function.
2814 SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
2815};
2816
2817struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
2818 using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
2819 AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2820
2821 /// Statistics are tracked as part of manifest for now.
2822 void trackStatistics() const override {}
2823
2824 /// See AbstractAttribute::getAsStr()
2825 const std::string getAsStr() const override {
2826 if (!isValidState())
2827 return "<invalid>";
2828 return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
2829 : "generic") +
2830 std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
2831 : "") +
2832 std::string(" #PRs: ") +
2833 (ReachedKnownParallelRegions.isValidState()
2834 ? std::to_string(ReachedKnownParallelRegions.size())
2835 : "<invalid>") +
2836 ", #Unknown PRs: " +
2837 (ReachedUnknownParallelRegions.isValidState()
2838 ? std::to_string(ReachedUnknownParallelRegions.size())
2839 : "<invalid>") +
2840 ", #Reaching Kernels: " +
2841 (ReachingKernelEntries.isValidState()
2842 ? std::to_string(ReachingKernelEntries.size())
2843 : "<invalid>");
2844 }
2845
2846 /// Create an abstract attribute biew for the position \p IRP.
2847 static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
2848
2849 /// See AbstractAttribute::getName()
2850 const std::string getName() const override { return "AAKernelInfo"; }
2851
2852 /// See AbstractAttribute::getIdAddr()
2853 const char *getIdAddr() const override { return &ID; }
2854
2855 /// This function should return true if the type of the \p AA is AAKernelInfo
2856 static bool classof(const AbstractAttribute *AA) {
2857 return (AA->getIdAddr() == &ID);
2858 }
2859
2860 static const char ID;
2861};
2862
2863/// The function kernel info abstract attribute, basically, what can we say
2864/// about a function with regards to the KernelInfoState.
2865struct AAKernelInfoFunction : AAKernelInfo {
2866 AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
2867 : AAKernelInfo(IRP, A) {}
2868
2869 SmallPtrSet<Instruction *, 4> GuardedInstructions;
2870
2871 SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
2872 return GuardedInstructions;
2873 }
2874
2875 /// See AbstractAttribute::initialize(...).
2876 void initialize(Attributor &A) override {
2877 // This is a high-level transform that might change the constant arguments
2878 // of the init and dinit calls. We need to tell the Attributor about this
2879 // to avoid other parts using the current constant value for simpliication.
2880 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2881
2882 Function *Fn = getAnchorScope();
2883 if (!OMPInfoCache.Kernels.count(Fn))
2884 return;
2885
2886 // Add itself to the reaching kernel and set IsKernelEntry.
2887 ReachingKernelEntries.insert(Fn);
2888 IsKernelEntry = true;
2889
2890 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
2891 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2892 OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
2893 OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
2894
2895 // For kernels we perform more initialization work, first we find the init
2896 // and deinit calls.
2897 auto StoreCallBase = [](Use &U,
2898 OMPInformationCache::RuntimeFunctionInfo &RFI,
2899 CallBase *&Storage) {
2900 CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
2901 assert(CB &&(static_cast <bool> (CB && "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!"
) ? void (0) : __assert_fail ("CB && \"Unexpected use of __kmpc_target_init or __kmpc_target_deinit!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 2902, __extension__ __PRETTY_FUNCTION__))
2902 "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!")(static_cast <bool> (CB && "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!"
) ? void (0) : __assert_fail ("CB && \"Unexpected use of __kmpc_target_init or __kmpc_target_deinit!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 2902, __extension__ __PRETTY_FUNCTION__))
;
2903 assert(!Storage &&(static_cast <bool> (!Storage && "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!"
) ? void (0) : __assert_fail ("!Storage && \"Multiple uses of __kmpc_target_init or __kmpc_target_deinit!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 2904, __extension__ __PRETTY_FUNCTION__))
2904 "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!")(static_cast <bool> (!Storage && "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!"
) ? void (0) : __assert_fail ("!Storage && \"Multiple uses of __kmpc_target_init or __kmpc_target_deinit!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 2904, __extension__ __PRETTY_FUNCTION__))
;
2905 Storage = CB;
2906 return false;
2907 };
2908 InitRFI.foreachUse(
2909 [&](Use &U, Function &) {
2910 StoreCallBase(U, InitRFI, KernelInitCB);
2911 return false;
2912 },
2913 Fn);
2914 DeinitRFI.foreachUse(
2915 [&](Use &U, Function &) {
2916 StoreCallBase(U, DeinitRFI, KernelDeinitCB);
2917 return false;
2918 },
2919 Fn);
2920
2921 // Ignore kernels without initializers such as global constructors.
2922 if (!KernelInitCB || !KernelDeinitCB) {
2923 indicateOptimisticFixpoint();
2924 return;
2925 }
2926
2927 // For kernels we might need to initialize/finalize the IsSPMD state and
2928 // we need to register a simplification callback so that the Attributor
2929 // knows the constant arguments to __kmpc_target_init and
2930 // __kmpc_target_deinit might actually change.
2931
2932 Attributor::SimplifictionCallbackTy StateMachineSimplifyCB =
2933 [&](const IRPosition &IRP, const AbstractAttribute *AA,
2934 bool &UsedAssumedInformation) -> Optional<Value *> {
2935 // IRP represents the "use generic state machine" argument of an
2936 // __kmpc_target_init call. We will answer this one with the internal
2937 // state. As long as we are not in an invalid state, we will create a
2938 // custom state machine so the value should be a `i1 false`. If we are
2939 // in an invalid state, we won't change the value that is in the IR.
2940 if (!isValidState())
2941 return nullptr;
2942 // If we have disabled state machine rewrites, don't make a custom one.
2943 if (DisableOpenMPOptStateMachineRewrite)
2944 return nullptr;
2945 if (AA)
2946 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2947 UsedAssumedInformation = !isAtFixpoint();
2948 auto *FalseVal =
2949 ConstantInt::getBool(IRP.getAnchorValue().getContext(), 0);
2950 return FalseVal;
2951 };
2952
2953 Attributor::SimplifictionCallbackTy ModeSimplifyCB =
2954 [&](const IRPosition &IRP, const AbstractAttribute *AA,
2955 bool &UsedAssumedInformation) -> Optional<Value *> {
2956 // IRP represents the "SPMDCompatibilityTracker" argument of an
2957 // __kmpc_target_init or
2958 // __kmpc_target_deinit call. We will answer this one with the internal
2959 // state.
2960 if (!SPMDCompatibilityTracker.isValidState())
2961 return nullptr;
2962 if (!SPMDCompatibilityTracker.isAtFixpoint()) {
2963 if (AA)
2964 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2965 UsedAssumedInformation = true;
2966 } else {
2967 UsedAssumedInformation = false;
2968 }
2969 auto *Val = ConstantInt::getSigned(
2970 IntegerType::getInt8Ty(IRP.getAnchorValue().getContext()),
2971 SPMDCompatibilityTracker.isAssumed() ? OMP_TGT_EXEC_MODE_SPMD
2972 : OMP_TGT_EXEC_MODE_GENERIC);
2973 return Val;
2974 };
2975
2976 Attributor::SimplifictionCallbackTy IsGenericModeSimplifyCB =
2977 [&](const IRPosition &IRP, const AbstractAttribute *AA,
2978 bool &UsedAssumedInformation) -> Optional<Value *> {
2979 // IRP represents the "RequiresFullRuntime" argument of an
2980 // __kmpc_target_init or __kmpc_target_deinit call. We will answer this
2981 // one with the internal state of the SPMDCompatibilityTracker, so if
2982 // generic then true, if SPMD then false.
2983 if (!SPMDCompatibilityTracker.isValidState())
2984 return nullptr;
2985 if (!SPMDCompatibilityTracker.isAtFixpoint()) {
2986 if (AA)
2987 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2988 UsedAssumedInformation = true;
2989 } else {
2990 UsedAssumedInformation = false;
2991 }
2992 auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(),
2993 !SPMDCompatibilityTracker.isAssumed());
2994 return Val;
2995 };
2996
2997 constexpr const int InitModeArgNo = 1;
2998 constexpr const int DeinitModeArgNo = 1;
2999 constexpr const int InitUseStateMachineArgNo = 2;
3000 constexpr const int InitRequiresFullRuntimeArgNo = 3;
3001 constexpr const int DeinitRequiresFullRuntimeArgNo = 2;
3002 A.registerSimplificationCallback(
3003 IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo),
3004 StateMachineSimplifyCB);
3005 A.registerSimplificationCallback(
3006 IRPosition::callsite_argument(*KernelInitCB, InitModeArgNo),
3007 ModeSimplifyCB);
3008 A.registerSimplificationCallback(
3009 IRPosition::callsite_argument(*KernelDeinitCB, DeinitModeArgNo),
3010 ModeSimplifyCB);
3011 A.registerSimplificationCallback(
3012 IRPosition::callsite_argument(*KernelInitCB,
3013 InitRequiresFullRuntimeArgNo),
3014 IsGenericModeSimplifyCB);
3015 A.registerSimplificationCallback(
3016 IRPosition::callsite_argument(*KernelDeinitCB,
3017 DeinitRequiresFullRuntimeArgNo),
3018 IsGenericModeSimplifyCB);
3019
3020 // Check if we know we are in SPMD-mode already.
3021 ConstantInt *ModeArg =
3022 dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo));
3023 if (ModeArg && (ModeArg->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD))
3024 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3025 // This is a generic region but SPMDization is disabled so stop tracking.
3026 else if (DisableOpenMPOptSPMDization)
3027 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3028 }
3029
3030 /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
3031 /// finished now.
3032 ChangeStatus manifest(Attributor &A) override {
3033 // If we are not looking at a kernel with __kmpc_target_init and
3034 // __kmpc_target_deinit call we cannot actually manifest the information.
3035 if (!KernelInitCB || !KernelDeinitCB)
3036 return ChangeStatus::UNCHANGED;
3037
3038 // Known SPMD-mode kernels need no manifest changes.
3039 if (SPMDCompatibilityTracker.isKnown())
3040 return ChangeStatus::UNCHANGED;
3041
3042 // If we can we change the execution mode to SPMD-mode otherwise we build a
3043 // custom state machine.
3044 if (!mayContainParallelRegion() || !changeToSPMDMode(A))
3045 return buildCustomStateMachine(A);
3046
3047 return ChangeStatus::CHANGED;
3048 }
3049
3050 bool changeToSPMDMode(Attributor &A) {
3051 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3052
3053 if (!SPMDCompatibilityTracker.isAssumed()) {
3054 for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
3055 if (!NonCompatibleI)
3056 continue;
3057
3058 // Skip diagnostics on calls to known OpenMP runtime functions for now.
3059 if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
3060 if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
3061 continue;
3062
3063 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
3064 ORA << "Value has potential side effects preventing SPMD-mode "
3065 "execution";
3066 if (isa<CallBase>(NonCompatibleI)) {
3067 ORA << ". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to "
3068 "the called function to override";
3069 }
3070 return ORA << ".";
3071 };
3072 A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
3073 Remark);
3074
3075 LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "SPMD-incompatible side-effect: "
<< *NonCompatibleI << "\n"; } } while (false)
3076 << *NonCompatibleI << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "SPMD-incompatible side-effect: "
<< *NonCompatibleI << "\n"; } } while (false)
;
3077 }
3078
3079 return false;
3080 }
3081
3082 auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3083 Instruction *RegionEndI) {
3084 LoopInfo *LI = nullptr;
3085 DominatorTree *DT = nullptr;
3086 MemorySSAUpdater *MSU = nullptr;
3087 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3088
3089 BasicBlock *ParentBB = RegionStartI->getParent();
3090 Function *Fn = ParentBB->getParent();
3091 Module &M = *Fn->getParent();
3092
3093 // Create all the blocks and logic.
3094 // ParentBB:
3095 // goto RegionCheckTidBB
3096 // RegionCheckTidBB:
3097 // Tid = __kmpc_hardware_thread_id()
3098 // if (Tid != 0)
3099 // goto RegionBarrierBB
3100 // RegionStartBB:
3101 // <execute instructions guarded>
3102 // goto RegionEndBB
3103 // RegionEndBB:
3104 // <store escaping values to shared mem>
3105 // goto RegionBarrierBB
3106 // RegionBarrierBB:
3107 // __kmpc_simple_barrier_spmd()
3108 // // second barrier is omitted if lacking escaping values.
3109 // <load escaping values from shared mem>
3110 // __kmpc_simple_barrier_spmd()
3111 // goto RegionExitBB
3112 // RegionExitBB:
3113 // <execute rest of instructions>
3114
3115 BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3116 DT, LI, MSU, "region.guarded.end");
3117 BasicBlock *RegionBarrierBB =
3118 SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3119 MSU, "region.barrier");
3120 BasicBlock *RegionExitBB =
3121 SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3122 DT, LI, MSU, "region.exit");
3123 BasicBlock *RegionStartBB =
3124 SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
3125
3126 assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&(static_cast <bool> (ParentBB->getUniqueSuccessor() ==
RegionStartBB && "Expected a different CFG") ? void (
0) : __assert_fail ("ParentBB->getUniqueSuccessor() == RegionStartBB && \"Expected a different CFG\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 3127, __extension__ __PRETTY_FUNCTION__))
3127 "Expected a different CFG")(static_cast <bool> (ParentBB->getUniqueSuccessor() ==
RegionStartBB && "Expected a different CFG") ? void (
0) : __assert_fail ("ParentBB->getUniqueSuccessor() == RegionStartBB && \"Expected a different CFG\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 3127, __extension__ __PRETTY_FUNCTION__))
;
3128
3129 BasicBlock *RegionCheckTidBB = SplitBlock(
3130 ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
3131
3132 // Register basic blocks with the Attributor.
3133 A.registerManifestAddedBasicBlock(*RegionEndBB);
3134 A.registerManifestAddedBasicBlock(*RegionBarrierBB);
3135 A.registerManifestAddedBasicBlock(*RegionExitBB);
3136 A.registerManifestAddedBasicBlock(*RegionStartBB);
3137 A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
3138
3139 bool HasBroadcastValues = false;
3140 // Find escaping outputs from the guarded region to outside users and
3141 // broadcast their values to them.
3142 for (Instruction &I : *RegionStartBB) {
3143 SmallPtrSet<Instruction *, 4> OutsideUsers;
3144 for (User *Usr : I.users()) {
3145 Instruction &UsrI = *cast<Instruction>(Usr);
3146 if (UsrI.getParent() != RegionStartBB)
3147 OutsideUsers.insert(&UsrI);
3148 }
3149
3150 if (OutsideUsers.empty())
3151 continue;
3152
3153 HasBroadcastValues = true;
3154
3155 // Emit a global variable in shared memory to store the broadcasted
3156 // value.
3157 auto *SharedMem = new GlobalVariable(
3158 M, I.getType(), /* IsConstant */ false,
3159 GlobalValue::InternalLinkage, UndefValue::get(I.getType()),
3160 I.getName() + ".guarded.output.alloc", nullptr,
3161 GlobalValue::NotThreadLocal,
3162 static_cast<unsigned>(AddressSpace::Shared));
3163
3164 // Emit a store instruction to update the value.
3165 new StoreInst(&I, SharedMem, RegionEndBB->getTerminator());
3166
3167 LoadInst *LoadI = new LoadInst(I.getType(), SharedMem,
3168 I.getName() + ".guarded.output.load",
3169 RegionBarrierBB->getTerminator());
3170
3171 // Emit a load instruction and replace uses of the output value.
3172 for (Instruction *UsrI : OutsideUsers) {
3173 assert(UsrI->getParent() == RegionExitBB &&(static_cast <bool> (UsrI->getParent() == RegionExitBB
&& "Expected escaping users in exit region") ? void (
0) : __assert_fail ("UsrI->getParent() == RegionExitBB && \"Expected escaping users in exit region\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 3174, __extension__ __PRETTY_FUNCTION__))
3174 "Expected escaping users in exit region")(static_cast <bool> (UsrI->getParent() == RegionExitBB
&& "Expected escaping users in exit region") ? void (
0) : __assert_fail ("UsrI->getParent() == RegionExitBB && \"Expected escaping users in exit region\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 3174, __extension__ __PRETTY_FUNCTION__))
;
3175 UsrI->replaceUsesOfWith(&I, LoadI);
3176 }
3177 }
3178
3179 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3180
3181 // Go to tid check BB in ParentBB.
3182 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
3183 ParentBB->getTerminator()->eraseFromParent();
3184 OpenMPIRBuilder::LocationDescription Loc(
3185 InsertPointTy(ParentBB, ParentBB->end()), DL);
3186 OMPInfoCache.OMPBuilder.updateToLocation(Loc);
3187 auto *SrcLocStr = OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc);
3188 Value *Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr);
3189 BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
3190
3191 // Add check for Tid in RegionCheckTidBB
3192 RegionCheckTidBB->getTerminator()->eraseFromParent();
3193 OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
3194 InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
3195 OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
3196 FunctionCallee HardwareTidFn =
3197 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3198 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
3199 Value *Tid =
3200 OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
3201 Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
3202 OMPInfoCache.OMPBuilder.Builder
3203 .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
3204 ->setDebugLoc(DL);
3205
3206 // First barrier for synchronization, ensures main thread has updated
3207 // values.
3208 FunctionCallee BarrierFn =
3209 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3210 M, OMPRTL___kmpc_barrier_simple_spmd);
3211 OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
3212 RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
3213 OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid})
3214 ->setDebugLoc(DL);
3215
3216 // Second barrier ensures workers have read broadcast values.
3217 if (HasBroadcastValues)
3218 CallInst::Create(BarrierFn, {Ident, Tid}, "",
3219 RegionBarrierBB->getTerminator())
3220 ->setDebugLoc(DL);
3221 };
3222
3223 auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3224 SmallPtrSet<BasicBlock *, 8> Visited;
3225 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
3226 BasicBlock *BB = GuardedI->getParent();
3227 if (!Visited.insert(BB).second)
3228 continue;
3229
3230 SmallVector<std::pair<Instruction *, Instruction *>> Reorders;
3231 Instruction *LastEffect = nullptr;
3232 BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();
3233 while (++IP != IPEnd) {
3234 if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
3235 continue;
3236 Instruction *I = &*IP;
3237 if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI))
3238 continue;
3239 if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) {
3240 LastEffect = nullptr;
3241 continue;
3242 }
3243 if (LastEffect)
3244 Reorders.push_back({I, LastEffect});
3245 LastEffect = &*IP;
3246 }
3247 for (auto &Reorder : Reorders)
3248 Reorder.first->moveBefore(Reorder.second);
3249 }
3250
3251 SmallVector<std::pair<Instruction *, Instruction *>, 4> GuardedRegions;
3252
3253 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
3254 BasicBlock *BB = GuardedI->getParent();
3255 auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
3256 IRPosition::function(*GuardedI->getFunction()), nullptr,
3257 DepClassTy::NONE);
3258 assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo")(static_cast <bool> (CalleeAA != nullptr && "Expected Callee AAKernelInfo"
) ? void (0) : __assert_fail ("CalleeAA != nullptr && \"Expected Callee AAKernelInfo\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 3258, __extension__ __PRETTY_FUNCTION__))
;
3259 auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
3260 // Continue if instruction is already guarded.
3261 if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
3262 continue;
3263
3264 Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
3265 for (Instruction &I : *BB) {
3266 // If instruction I needs to be guarded update the guarded region
3267 // bounds.
3268 if (SPMDCompatibilityTracker.contains(&I)) {
3269 CalleeAAFunction.getGuardedInstructions().insert(&I);
3270 if (GuardedRegionStart)
3271 GuardedRegionEnd = &I;
3272 else
3273 GuardedRegionStart = GuardedRegionEnd = &I;
3274
3275 continue;
3276 }
3277
3278 // Instruction I does not need guarding, store
3279 // any region found and reset bounds.
3280 if (GuardedRegionStart) {
3281 GuardedRegions.push_back(
3282 std::make_pair(GuardedRegionStart, GuardedRegionEnd));
3283 GuardedRegionStart = nullptr;
3284 GuardedRegionEnd = nullptr;
3285 }
3286 }
3287 }
3288
3289 for (auto &GR : GuardedRegions)
3290 CreateGuardedRegion(GR.first, GR.second);
3291
3292 // Adjust the global exec mode flag that tells the runtime what mode this
3293 // kernel is executed in.
3294 Function *Kernel = getAnchorScope();
3295 GlobalVariable *ExecMode = Kernel->getParent()->getGlobalVariable(
3296 (Kernel->getName() + "_exec_mode").str());
3297 assert(ExecMode && "Kernel without exec mode?")(static_cast <bool> (ExecMode && "Kernel without exec mode?"
) ? void (0) : __assert_fail ("ExecMode && \"Kernel without exec mode?\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 3297, __extension__ __PRETTY_FUNCTION__))
;
3298 assert(ExecMode->getInitializer() && "ExecMode doesn't have initializer!")(static_cast <bool> (ExecMode->getInitializer() &&
"ExecMode doesn't have initializer!") ? void (0) : __assert_fail
("ExecMode->getInitializer() && \"ExecMode doesn't have initializer!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 3298, __extension__ __PRETTY_FUNCTION__))
;
3299
3300 // Set the global exec mode flag to indicate SPMD-Generic mode.
3301 assert(isa<ConstantInt>(ExecMode->getInitializer()) &&(static_cast <bool> (isa<ConstantInt>(ExecMode->
getInitializer()) && "ExecMode is not an integer!") ?
void (0) : __assert_fail ("isa<ConstantInt>(ExecMode->getInitializer()) && \"ExecMode is not an integer!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 3302, __extension__ __PRETTY_FUNCTION__))
3302 "ExecMode is not an integer!")(static_cast <bool> (isa<ConstantInt>(ExecMode->
getInitializer()) && "ExecMode is not an integer!") ?
void (0) : __assert_fail ("isa<ConstantInt>(ExecMode->getInitializer()) && \"ExecMode is not an integer!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 3302, __extension__ __PRETTY_FUNCTION__))
;
3303 const int8_t ExecModeVal =
3304 cast<ConstantInt>(ExecMode->getInitializer())->getSExtValue();
3305 assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&(static_cast <bool> (ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC
&& "Initially non-SPMD kernel has SPMD exec mode!") ?
void (0) : __assert_fail ("ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC && \"Initially non-SPMD kernel has SPMD exec mode!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 3306, __extension__ __PRETTY_FUNCTION__))
3306 "Initially non-SPMD kernel has SPMD exec mode!")(static_cast <bool> (ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC
&& "Initially non-SPMD kernel has SPMD exec mode!") ?
void (0) : __assert_fail ("ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC && \"Initially non-SPMD kernel has SPMD exec mode!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 3306, __extension__ __PRETTY_FUNCTION__))
;
3307 ExecMode->setInitializer(
3308 ConstantInt::get(ExecMode->getInitializer()->getType(),
3309 ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
3310
3311 // Next rewrite the init and deinit calls to indicate we use SPMD-mode now.
3312 const int InitModeArgNo = 1;
3313 const int DeinitModeArgNo = 1;
3314 const int InitUseStateMachineArgNo = 2;
3315 const int InitRequiresFullRuntimeArgNo = 3;
3316 const int DeinitRequiresFullRuntimeArgNo = 2;
3317
3318 auto &Ctx = getAnchorValue().getContext();
3319 A.changeUseAfterManifest(
3320 KernelInitCB->getArgOperandUse(InitModeArgNo),
3321 *ConstantInt::getSigned(IntegerType::getInt8Ty(Ctx),
3322 OMP_TGT_EXEC_MODE_SPMD));
3323 A.changeUseAfterManifest(
3324 KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo),
3325 *ConstantInt::getBool(Ctx, 0));
3326 A.changeUseAfterManifest(
3327 KernelDeinitCB->getArgOperandUse(DeinitModeArgNo),
3328 *ConstantInt::getSigned(IntegerType::getInt8Ty(Ctx),
3329 OMP_TGT_EXEC_MODE_SPMD));
3330 A.changeUseAfterManifest(
3331 KernelInitCB->getArgOperandUse(InitRequiresFullRuntimeArgNo),
3332 *ConstantInt::getBool(Ctx, 0));
3333 A.changeUseAfterManifest(
3334 KernelDeinitCB->getArgOperandUse(DeinitRequiresFullRuntimeArgNo),
3335 *ConstantInt::getBool(Ctx, 0));
3336
3337 ++NumOpenMPTargetRegionKernelsSPMD;
3338
3339 auto Remark = [&](OptimizationRemark OR) {
3340 return OR << "Transformed generic-mode kernel to SPMD-mode.";
3341 };
3342 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
3343 return true;
3344 };
3345
3346 ChangeStatus buildCustomStateMachine(Attributor &A) {
3347 // If we have disabled state machine rewrites, don't make a custom one
3348 if (DisableOpenMPOptStateMachineRewrite)
3349 return ChangeStatus::UNCHANGED;
3350
3351 assert(ReachedKnownParallelRegions.isValidState() &&(static_cast <bool> (ReachedKnownParallelRegions.isValidState
() && "Custom state machine with invalid parallel region states?"
) ? void (0) : __assert_fail ("ReachedKnownParallelRegions.isValidState() && \"Custom state machine with invalid parallel region states?\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 3352, __extension__ __PRETTY_FUNCTION__))
3352 "Custom state machine with invalid parallel region states?")(static_cast <bool> (ReachedKnownParallelRegions.isValidState
() && "Custom state machine with invalid parallel region states?"
) ? void (0) : __assert_fail ("ReachedKnownParallelRegions.isValidState() && \"Custom state machine with invalid parallel region states?\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 3352, __extension__ __PRETTY_FUNCTION__))
;
3353
3354 const int InitModeArgNo = 1;
3355 const int InitUseStateMachineArgNo = 2;
3356
3357 // Check if the current configuration is non-SPMD and generic state machine.
3358 // If we already have SPMD mode or a custom state machine we do not need to
3359 // go any further. If it is anything but a constant something is weird and
3360 // we give up.
3361 ConstantInt *UseStateMachine = dyn_cast<ConstantInt>(
3362 KernelInitCB->getArgOperand(InitUseStateMachineArgNo));
3363 ConstantInt *Mode =
3364 dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo));
3365
3366 // If we are stuck with generic mode, try to create a custom device (=GPU)
3367 // state machine which is specialized for the parallel regions that are
3368 // reachable by the kernel.
3369 if (!UseStateMachine || UseStateMachine->isZero() || !Mode ||
3370 (Mode->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD))
3371 return ChangeStatus::UNCHANGED;
3372
3373 // If not SPMD mode, indicate we use a custom state machine now.
3374 auto &Ctx = getAnchorValue().getContext();
3375 auto *FalseVal = ConstantInt::getBool(Ctx, 0);
3376 A.changeUseAfterManifest(
3377 KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *FalseVal);
3378
3379 // If we don't actually need a state machine we are done here. This can
3380 // happen if there simply are no parallel regions. In the resulting kernel
3381 // all worker threads will simply exit right away, leaving the main thread
3382 // to do the work alone.
3383 if (!mayContainParallelRegion()) {
3384 ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
3385
3386 auto Remark = [&](OptimizationRemark OR) {
3387 return OR << "Removing unused state machine from generic-mode kernel.";
3388 };
3389 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
3390
3391 return ChangeStatus::CHANGED;
3392 }
3393
3394 // Keep track in the statistics of our new shiny custom state machine.
3395 if (ReachedUnknownParallelRegions.empty()) {
3396 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
3397
3398 auto Remark = [&](OptimizationRemark OR) {
3399 return OR << "Rewriting generic-mode kernel with a customized state "
3400 "machine.";
3401 };
3402 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
3403 } else {
3404 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
3405
3406 auto Remark = [&](OptimizationRemarkAnalysis OR) {
3407 return OR << "Generic-mode kernel is executed with a customized state "
3408 "machine that requires a fallback.";
3409 };
3410 A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
3411
3412 // Tell the user why we ended up with a fallback.
3413 for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
3414 if (!UnknownParallelRegionCB)
3415 continue;
3416 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
3417 return ORA << "Call may contain unknown parallel regions. Use "
3418 << "`__attribute__((assume(\"omp_no_parallelism\")))` to "
3419 "override.";
3420 };
3421 A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
3422 "OMP133", Remark);
3423 }
3424 }
3425
3426 // Create all the blocks:
3427 //
3428 // InitCB = __kmpc_target_init(...)
3429 // bool IsWorker = InitCB >= 0;
3430 // if (IsWorker) {
3431 // SMBeginBB: __kmpc_barrier_simple_spmd(...);
3432 // void *WorkFn;
3433 // bool Active = __kmpc_kernel_parallel(&WorkFn);
3434 // if (!WorkFn) return;
3435 // SMIsActiveCheckBB: if (Active) {
3436 // SMIfCascadeCurrentBB: if (WorkFn == <ParFn0>)
3437 // ParFn0(...);
3438 // SMIfCascadeCurrentBB: else if (WorkFn == <ParFn1>)
3439 // ParFn1(...);
3440 // ...
3441 // SMIfCascadeCurrentBB: else
3442 // ((WorkFnTy*)WorkFn)(...);
3443 // SMEndParallelBB: __kmpc_kernel_end_parallel(...);
3444 // }
3445 // SMDoneBB: __kmpc_barrier_simple_spmd(...);
3446 // goto SMBeginBB;
3447 // }
3448 // UserCodeEntryBB: // user code
3449 // __kmpc_target_deinit(...)
3450 //
3451 Function *Kernel = getAssociatedFunction();
3452 assert(Kernel && "Expected an associated function!")(static_cast <bool> (Kernel && "Expected an associated function!"
) ? void (0) : __assert_fail ("Kernel && \"Expected an associated function!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 3452, __extension__ __PRETTY_FUNCTION__))
;
3453
3454 BasicBlock *InitBB = KernelInitCB->getParent();
3455 BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
3456 KernelInitCB->getNextNode(), "thread.user_code.check");
3457 BasicBlock *StateMachineBeginBB = BasicBlock::Create(
3458 Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
3459 BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
3460 Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
3461 BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
3462 Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
3463 BasicBlock *StateMachineIfCascadeCurrentBB =
3464 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3465 Kernel, UserCodeEntryBB);
3466 BasicBlock *StateMachineEndParallelBB =
3467 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
3468 Kernel, UserCodeEntryBB);
3469 BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
3470 Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
3471 A.registerManifestAddedBasicBlock(*InitBB);
3472 A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
3473 A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
3474 A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
3475 A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
3476 A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
3477 A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
3478 A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
3479
3480 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
3481 ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
3482
3483 InitBB->getTerminator()->eraseFromParent();
3484 Instruction *IsWorker =
3485 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
3486 ConstantInt::get(KernelInitCB->getType(), -1),
3487 "thread.is_worker", InitBB);
3488 IsWorker->setDebugLoc(DLoc);
3489 BranchInst::Create(StateMachineBeginBB, UserCodeEntryBB, IsWorker, InitBB);
3490
3491 Module &M = *Kernel->getParent();
3492
3493 // Create local storage for the work function pointer.
3494 const DataLayout &DL = M.getDataLayout();
3495 Type *VoidPtrTy = Type::getInt8PtrTy(Ctx);
3496 Instruction *WorkFnAI =
3497 new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
3498 "worker.work_fn.addr", &Kernel->getEntryBlock().front());
3499 WorkFnAI->setDebugLoc(DLoc);
3500
3501 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3502 OMPInfoCache.OMPBuilder.updateToLocation(
3503 OpenMPIRBuilder::LocationDescription(
3504 IRBuilder<>::InsertPoint(StateMachineBeginBB,
3505 StateMachineBeginBB->end()),
3506 DLoc));
3507
3508 Value *Ident = KernelInitCB->getArgOperand(0);
3509 Value *GTid = KernelInitCB;
3510
3511 FunctionCallee BarrierFn =
3512 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3513 M, OMPRTL___kmpc_barrier_simple_spmd);
3514 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB)
3515 ->setDebugLoc(DLoc);
3516
3517 if (WorkFnAI->getType()->getPointerAddressSpace() !=
3518 (unsigned int)AddressSpace::Generic) {
3519 WorkFnAI = new AddrSpaceCastInst(
3520 WorkFnAI,
3521 PointerType::getWithSamePointeeType(
3522 cast<PointerType>(WorkFnAI->getType()),
3523 (unsigned int)AddressSpace::Generic),
3524 WorkFnAI->getName() + ".generic", StateMachineBeginBB);
3525 WorkFnAI->setDebugLoc(DLoc);
3526 }
3527
3528 FunctionCallee KernelParallelFn =
3529 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3530 M, OMPRTL___kmpc_kernel_parallel);
3531 Instruction *IsActiveWorker = CallInst::Create(
3532 KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
3533 IsActiveWorker->setDebugLoc(DLoc);
3534 Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
3535 StateMachineBeginBB);
3536 WorkFn->setDebugLoc(DLoc);
3537
3538 FunctionType *ParallelRegionFnTy = FunctionType::get(
3539 Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)},
3540 false);
3541 Value *WorkFnCast = BitCastInst::CreatePointerBitCastOrAddrSpaceCast(
3542 WorkFn, ParallelRegionFnTy->getPointerTo(), "worker.work_fn.addr_cast",
3543 StateMachineBeginBB);
3544
3545 Instruction *IsDone =
3546 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
3547 Constant::getNullValue(VoidPtrTy), "worker.is_done",
3548 StateMachineBeginBB);
3549 IsDone->setDebugLoc(DLoc);
3550 BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
3551 IsDone, StateMachineBeginBB)
3552 ->setDebugLoc(DLoc);
3553
3554 BranchInst::Create(StateMachineIfCascadeCurrentBB,
3555 StateMachineDoneBarrierBB, IsActiveWorker,
3556 StateMachineIsActiveCheckBB)
3557 ->setDebugLoc(DLoc);
3558
3559 Value *ZeroArg =
3560 Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
3561
3562 // Now that we have most of the CFG skeleton it is time for the if-cascade
3563 // that checks the function pointer we got from the runtime against the
3564 // parallel regions we expect, if there are any.
3565 for (int i = 0, e = ReachedKnownParallelRegions.size(); i < e; ++i) {
3566 auto *ParallelRegion = ReachedKnownParallelRegions[i];
3567 BasicBlock *PRExecuteBB = BasicBlock::Create(
3568 Ctx, "worker_state_machine.parallel_region.execute", Kernel,
3569 StateMachineEndParallelBB);
3570 CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
3571 ->setDebugLoc(DLoc);
3572 BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
3573 ->setDebugLoc(DLoc);
3574
3575 BasicBlock *PRNextBB =
3576 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3577 Kernel, StateMachineEndParallelBB);
3578
3579 // Check if we need to compare the pointer at all or if we can just
3580 // call the parallel region function.
3581 Value *IsPR;
3582 if (i + 1 < e || !ReachedUnknownParallelRegions.empty()) {
3583 Instruction *CmpI = ICmpInst::Create(
3584 ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFnCast, ParallelRegion,
3585 "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
3586 CmpI->setDebugLoc(DLoc);
3587 IsPR = CmpI;
3588 } else {
3589 IsPR = ConstantInt::getTrue(Ctx);
3590 }
3591
3592 BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
3593 StateMachineIfCascadeCurrentBB)
3594 ->setDebugLoc(DLoc);
3595 StateMachineIfCascadeCurrentBB = PRNextBB;
3596 }
3597
3598 // At the end of the if-cascade we place the indirect function pointer call
3599 // in case we might need it, that is if there can be parallel regions we
3600 // have not handled in the if-cascade above.
3601 if (!ReachedUnknownParallelRegions.empty()) {
3602 StateMachineIfCascadeCurrentBB->setName(
3603 "worker_state_machine.parallel_region.fallback.execute");
3604 CallInst::Create(ParallelRegionFnTy, WorkFnCast, {ZeroArg, GTid}, "",
3605 StateMachineIfCascadeCurrentBB)
3606 ->setDebugLoc(DLoc);
3607 }
3608 BranchInst::Create(StateMachineEndParallelBB,
3609 StateMachineIfCascadeCurrentBB)
3610 ->setDebugLoc(DLoc);
3611
3612 CallInst::Create(OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3613 M, OMPRTL___kmpc_kernel_end_parallel),
3614 {}, "", StateMachineEndParallelBB)
3615 ->setDebugLoc(DLoc);
3616 BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
3617 ->setDebugLoc(DLoc);
3618
3619 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
3620 ->setDebugLoc(DLoc);
3621 BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
3622 ->setDebugLoc(DLoc);
3623
3624 return ChangeStatus::CHANGED;
3625 }
3626
3627 /// Fixpoint iteration update function. Will be called every time a dependence
3628 /// changed its state (and in the beginning).
3629 ChangeStatus updateImpl(Attributor &A) override {
3630 KernelInfoState StateBefore = getState();
3631
3632 // Callback to check a read/write instruction.
3633 auto CheckRWInst = [&](Instruction &I) {
3634 // We handle calls later.
3635 if (isa<CallBase>(I))
3636 return true;
3637 // We only care about write effects.
3638 if (!I.mayWriteToMemory())
3639 return true;
3640 if (auto *SI = dyn_cast<StoreInst>(&I)) {
3641 SmallVector<const Value *> Objects;
3642 getUnderlyingObjects(SI->getPointerOperand(), Objects);
3643 if (llvm::all_of(Objects,
3644 [](const Value *Obj) { return isa<AllocaInst>(Obj); }))
3645 return true;
3646 // Check for AAHeapToStack moved objects which must not be guarded.
3647 auto &HS = A.getAAFor<AAHeapToStack>(
3648 *this, IRPosition::function(*I.getFunction()),
3649 DepClassTy::REQUIRED);
3650 if (llvm::all_of(Objects, [&HS](const Value *Obj) {
3651 auto *CB = dyn_cast<CallBase>(Obj);
3652 if (!CB)
3653 return false;
3654 return HS.isAssumedHeapToStack(*CB);
3655 })) {
3656 return true;
3657 }
3658 }
3659
3660 // Insert instruction that needs guarding.
3661 SPMDCompatibilityTracker.insert(&I);
3662 return true;
3663 };
3664
3665 bool UsedAssumedInformationInCheckRWInst = false;
3666 if (!SPMDCompatibilityTracker.isAtFixpoint())
3667 if (!A.checkForAllReadWriteInstructions(
3668 CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
3669 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3670
3671 if (!IsKernelEntry) {
3672 updateReachingKernelEntries(A);
3673 updateParallelLevels(A);
3674
3675 if (!ParallelLevels.isValidState())
3676 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3677 }
3678
3679 // Callback to check a call instruction.
3680 bool AllSPMDStatesWereFixed = true;
3681 auto CheckCallInst = [&](Instruction &I) {
3682 auto &CB = cast<CallBase>(I);
3683 auto &CBAA = A.getAAFor<AAKernelInfo>(
3684 *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
3685 getState() ^= CBAA.getState();
3686 AllSPMDStatesWereFixed &= CBAA.SPMDCompatibilityTracker.isAtFixpoint();
3687 return true;
3688 };
3689
3690 bool UsedAssumedInformationInCheckCallInst = false;
3691 if (!A.checkForAllCallLikeInstructions(
3692 CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) {
3693 LLVM_DEBUG(dbgs() << TAG << "Failed to visit all call-like instructions!\n";)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "Failed to visit all call-like instructions!\n"
;; } } while (false)
;
3694 return indicatePessimisticFixpoint();
3695 }
3696
3697 // If we haven't used any assumed information for the SPMD state we can fix
3698 // it.
3699 if (!UsedAssumedInformationInCheckRWInst &&
3700 !UsedAssumedInformationInCheckCallInst && AllSPMDStatesWereFixed)
3701 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3702
3703 return StateBefore == getState() ? ChangeStatus::UNCHANGED
3704 : ChangeStatus::CHANGED;
3705 }
3706
3707private:
3708 /// Update info regarding reaching kernels.
3709 void updateReachingKernelEntries(Attributor &A) {
3710 auto PredCallSite = [&](AbstractCallSite ACS) {
3711 Function *Caller = ACS.getInstruction()->getFunction();
3712
3713 assert(Caller && "Caller is nullptr")(static_cast <bool> (Caller && "Caller is nullptr"
) ? void (0) : __assert_fail ("Caller && \"Caller is nullptr\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 3713, __extension__ __PRETTY_FUNCTION__))
;
1
Assuming 'Caller' is non-null
2
'?' condition is true
3714
3715 auto &CAA = A.getOrCreateAAFor<AAKernelInfo>(
3716 IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
3717 if (CAA.ReachingKernelEntries.isValidState()) {
3
Calling 'IntegerStateBase::isValidState'
6
Returning from 'IntegerStateBase::isValidState'
7
Taking true branch
3718 ReachingKernelEntries ^= CAA.ReachingKernelEntries;
8
Called C++ object pointer is null
3719 return true;
3720 }
3721
3722 // We lost track of the caller of the associated function, any kernel
3723 // could reach now.
3724 ReachingKernelEntries.indicatePessimisticFixpoint();
3725
3726 return true;
3727 };
3728
3729 bool AllCallSitesKnown;
3730 if (!A.checkForAllCallSites(PredCallSite, *this,
3731 true /* RequireAllCallSites */,
3732 AllCallSitesKnown))
3733 ReachingKernelEntries.indicatePessimisticFixpoint();
3734 }
3735
3736 /// Update info regarding parallel levels.
3737 void updateParallelLevels(Attributor &A) {
3738 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3739 OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
3740 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
3741
3742 auto PredCallSite = [&](AbstractCallSite ACS) {
3743 Function *Caller = ACS.getInstruction()->getFunction();
3744
3745 assert(Caller && "Caller is nullptr")(static_cast <bool> (Caller && "Caller is nullptr"
) ? void (0) : __assert_fail ("Caller && \"Caller is nullptr\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 3745, __extension__ __PRETTY_FUNCTION__))
;
3746
3747 auto &CAA =
3748 A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
3749 if (CAA.ParallelLevels.isValidState()) {
3750 // Any function that is called by `__kmpc_parallel_51` will not be
3751 // folded as the parallel level in the function is updated. In order to
3752 // get it right, all the analysis would depend on the implentation. That
3753 // said, if in the future any change to the implementation, the analysis
3754 // could be wrong. As a consequence, we are just conservative here.
3755 if (Caller == Parallel51RFI.Declaration) {
3756 ParallelLevels.indicatePessimisticFixpoint();
3757 return true;
3758 }
3759
3760 ParallelLevels ^= CAA.ParallelLevels;
3761
3762 return true;
3763 }
3764
3765 // We lost track of the caller of the associated function, any kernel
3766 // could reach now.
3767 ParallelLevels.indicatePessimisticFixpoint();
3768
3769 return true;
3770 };
3771
3772 bool AllCallSitesKnown = true;
3773 if (!A.checkForAllCallSites(PredCallSite, *this,
3774 true /* RequireAllCallSites */,
3775 AllCallSitesKnown))
3776 ParallelLevels.indicatePessimisticFixpoint();
3777 }
3778};
3779
3780/// The call site kernel info abstract attribute, basically, what can we say
3781/// about a call site with regards to the KernelInfoState. For now this simply
3782/// forwards the information from the callee.
3783struct AAKernelInfoCallSite : AAKernelInfo {
3784 AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
3785 : AAKernelInfo(IRP, A) {}
3786
3787 /// See AbstractAttribute::initialize(...).
3788 void initialize(Attributor &A) override {
3789 AAKernelInfo::initialize(A);
3790
3791 CallBase &CB = cast<CallBase>(getAssociatedValue());
3792 Function *Callee = getAssociatedFunction();
3793
3794 // Helper to lookup an assumption string.
3795 auto HasAssumption = [](CallBase &CB, StringRef AssumptionStr) {
3796 return hasAssumption(CB, AssumptionStr);
3797 };
3798
3799 // Check for SPMD-mode assumptions.
3800 if (HasAssumption(CB, "ompx_spmd_amenable")) {
3801 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3802 indicateOptimisticFixpoint();
3803 }
3804
3805 // First weed out calls we do not care about, that is readonly/readnone
3806 // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
3807 // parallel region or anything else we are looking for.
3808 if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
3809 indicateOptimisticFixpoint();
3810 return;
3811 }
3812
3813 // Next we check if we know the callee. If it is a known OpenMP function
3814 // we will handle them explicitly in the switch below. If it is not, we
3815 // will use an AAKernelInfo object on the callee to gather information and
3816 // merge that into the current state. The latter happens in the updateImpl.
3817 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3818 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
3819 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
3820 // Unknown caller or declarations are not analyzable, we give up.
3821 if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
3822
3823 // Unknown callees might contain parallel regions, except if they have
3824 // an appropriate assumption attached.
3825 if (!(HasAssumption(CB, "omp_no_openmp") ||
3826 HasAssumption(CB, "omp_no_parallelism")))
3827 ReachedUnknownParallelRegions.insert(&CB);
3828
3829 // If SPMDCompatibilityTracker is not fixed, we need to give up on the
3830 // idea we can run something unknown in SPMD-mode.
3831 if (!SPMDCompatibilityTracker.isAtFixpoint()) {
3832 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3833 SPMDCompatibilityTracker.insert(&CB);
3834 }
3835
3836 // We have updated the state for this unknown call properly, there won't
3837 // be any change so we indicate a fixpoint.
3838 indicateOptimisticFixpoint();
3839 }
3840 // If the callee is known and can be used in IPO, we will update the state
3841 // based on the callee state in updateImpl.
3842 return;
3843 }
3844
3845 const unsigned int WrapperFunctionArgNo = 6;
3846 RuntimeFunction RF = It->getSecond();
3847 switch (RF) {
3848 // All the functions we know are compatible with SPMD mode.
3849 case OMPRTL___kmpc_is_spmd_exec_mode:
3850 case OMPRTL___kmpc_for_static_fini:
3851 case OMPRTL___kmpc_global_thread_num:
3852 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
3853 case OMPRTL___kmpc_get_hardware_num_blocks:
3854 case OMPRTL___kmpc_single:
3855 case OMPRTL___kmpc_end_single:
3856 case OMPRTL___kmpc_master:
3857 case OMPRTL___kmpc_end_master:
3858 case OMPRTL___kmpc_barrier:
3859 break;
3860 case OMPRTL___kmpc_for_static_init_4:
3861 case OMPRTL___kmpc_for_static_init_4u:
3862 case OMPRTL___kmpc_for_static_init_8:
3863 case OMPRTL___kmpc_for_static_init_8u: {
3864 // Check the schedule and allow static schedule in SPMD mode.
3865 unsigned ScheduleArgOpNo = 2;
3866 auto *ScheduleTypeCI =
3867 dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
3868 unsigned ScheduleTypeVal =
3869 ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
3870 switch (OMPScheduleType(ScheduleTypeVal)) {
3871 case OMPScheduleType::Static:
3872 case OMPScheduleType::StaticChunked:
3873 case OMPScheduleType::Distribute:
3874 case OMPScheduleType::DistributeChunked:
3875 break;
3876 default:
3877 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3878 SPMDCompatibilityTracker.insert(&CB);
3879 break;
3880 };
3881 } break;
3882 case OMPRTL___kmpc_target_init:
3883 KernelInitCB = &CB;
3884 break;
3885 case OMPRTL___kmpc_target_deinit:
3886 KernelDeinitCB = &CB;
3887 break;
3888 case OMPRTL___kmpc_parallel_51:
3889 if (auto *ParallelRegion = dyn_cast<Function>(
3890 CB.getArgOperand(WrapperFunctionArgNo)->stripPointerCasts())) {
3891 ReachedKnownParallelRegions.insert(ParallelRegion);
3892 break;
3893 }
3894 // The condition above should usually get the parallel region function
3895 // pointer and record it. In the off chance it doesn't we assume the
3896 // worst.
3897 ReachedUnknownParallelRegions.insert(&CB);
3898 break;
3899 case OMPRTL___kmpc_omp_task:
3900 // We do not look into tasks right now, just give up.
3901 SPMDCompatibilityTracker.insert(&CB);
3902 ReachedUnknownParallelRegions.insert(&CB);
3903 break;
3904 case OMPRTL___kmpc_alloc_shared:
3905 case OMPRTL___kmpc_free_shared:
3906 // Return without setting a fixpoint, to be resolved in updateImpl.
3907 return;
3908 default:
3909 // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
3910 // generally. However, they do not hide parallel regions.
3911 SPMDCompatibilityTracker.insert(&CB);
3912 break;
3913 }
3914 // All other OpenMP runtime calls will not reach parallel regions so they
3915 // can be safely ignored for now. Since it is a known OpenMP runtime call we
3916 // have now modeled all effects and there is no need for any update.
3917 indicateOptimisticFixpoint();
3918 }
3919
3920 ChangeStatus updateImpl(Attributor &A) override {
3921 // TODO: Once we have call site specific value information we can provide
3922 // call site specific liveness information and then it makes
3923 // sense to specialize attributes for call sites arguments instead of
3924 // redirecting requests to the callee argument.
3925 Function *F = getAssociatedFunction();
3926
3927 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3928 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
3929
3930 // If F is not a runtime function, propagate the AAKernelInfo of the callee.
3931 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
3932 const IRPosition &FnPos = IRPosition::function(*F);
3933 auto &FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
3934 if (getState() == FnAA.getState())
3935 return ChangeStatus::UNCHANGED;
3936 getState() = FnAA.getState();
3937 return ChangeStatus::CHANGED;
3938 }
3939
3940 // F is a runtime function that allocates or frees memory, check
3941 // AAHeapToStack and AAHeapToShared.
3942 KernelInfoState StateBefore = getState();
3943 assert((It->getSecond() == OMPRTL___kmpc_alloc_shared ||(static_cast <bool> ((It->getSecond() == OMPRTL___kmpc_alloc_shared
|| It->getSecond() == OMPRTL___kmpc_free_shared) &&
"Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call"
) ? void (0) : __assert_fail ("(It->getSecond() == OMPRTL___kmpc_alloc_shared || It->getSecond() == OMPRTL___kmpc_free_shared) && \"Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 3945, __extension__ __PRETTY_FUNCTION__))
3944 It->getSecond() == OMPRTL___kmpc_free_shared) &&(static_cast <bool> ((It->getSecond() == OMPRTL___kmpc_alloc_shared
|| It->getSecond() == OMPRTL___kmpc_free_shared) &&
"Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call"
) ? void (0) : __assert_fail ("(It->getSecond() == OMPRTL___kmpc_alloc_shared || It->getSecond() == OMPRTL___kmpc_free_shared) && \"Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 3945, __extension__ __PRETTY_FUNCTION__))
3945 "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call")(static_cast <bool> ((It->getSecond() == OMPRTL___kmpc_alloc_shared
|| It->getSecond() == OMPRTL___kmpc_free_shared) &&
"Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call"
) ? void (0) : __assert_fail ("(It->getSecond() == OMPRTL___kmpc_alloc_shared || It->getSecond() == OMPRTL___kmpc_free_shared) && \"Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 3945, __extension__ __PRETTY_FUNCTION__))
;
3946
3947 CallBase &CB = cast<CallBase>(getAssociatedValue());
3948
3949 auto &HeapToStackAA = A.getAAFor<AAHeapToStack>(
3950 *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
3951 auto &HeapToSharedAA = A.getAAFor<AAHeapToShared>(
3952 *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
3953
3954 RuntimeFunction RF = It->getSecond();
3955
3956 switch (RF) {
3957 // If neither HeapToStack nor HeapToShared assume the call is removed,
3958 // assume SPMD incompatibility.
3959 case OMPRTL___kmpc_alloc_shared:
3960 if (!HeapToStackAA.isAssumedHeapToStack(CB) &&
3961 !HeapToSharedAA.isAssumedHeapToShared(CB))
3962 SPMDCompatibilityTracker.insert(&CB);
3963 break;
3964 case OMPRTL___kmpc_free_shared:
3965 if (!HeapToStackAA.isAssumedHeapToStackRemovedFree(CB) &&
3966 !HeapToSharedAA.isAssumedHeapToSharedRemovedFree(CB))
3967 SPMDCompatibilityTracker.insert(&CB);
3968 break;
3969 default:
3970 SPMDCompatibilityTracker.insert(&CB);
3971 }
3972
3973 return StateBefore == getState() ? ChangeStatus::UNCHANGED
3974 : ChangeStatus::CHANGED;
3975 }
3976};
3977
3978struct AAFoldRuntimeCall
3979 : public StateWrapper<BooleanState, AbstractAttribute> {
3980 using Base = StateWrapper<BooleanState, AbstractAttribute>;
3981
3982 AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3983
3984 /// Statistics are tracked as part of manifest for now.
3985 void trackStatistics() const override {}
3986
3987 /// Create an abstract attribute biew for the position \p IRP.
3988 static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
3989 Attributor &A);
3990
3991 /// See AbstractAttribute::getName()
3992 const std::string getName() const override { return "AAFoldRuntimeCall"; }
3993
3994 /// See AbstractAttribute::getIdAddr()
3995 const char *getIdAddr() const override { return &ID; }
3996
3997 /// This function should return true if the type of the \p AA is
3998 /// AAFoldRuntimeCall
3999 static bool classof(const AbstractAttribute *AA) {
4000 return (AA->getIdAddr() == &ID);
4001 }
4002
4003 static const char ID;
4004};
4005
4006struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
4007 AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
4008 : AAFoldRuntimeCall(IRP, A) {}
4009
4010 /// See AbstractAttribute::getAsStr()
4011 const std::string getAsStr() const override {
4012 if (!isValidState())
4013 return "<invalid>";
4014
4015 std::string Str("simplified value: ");
4016
4017 if (!SimplifiedValue.hasValue())
4018 return Str + std::string("none");
4019
4020 if (!SimplifiedValue.getValue())
4021 return Str + std::string("nullptr");
4022
4023 if (ConstantInt *CI = dyn_cast<ConstantInt>(SimplifiedValue.getValue()))
4024 return Str + std::to_string(CI->getSExtValue());
4025
4026 return Str + std::string("unknown");
4027 }
4028
4029 void initialize(Attributor &A) override {
4030 if (DisableOpenMPOptFolding)
4031 indicatePessimisticFixpoint();
4032
4033 Function *Callee = getAssociatedFunction();
4034
4035 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4036 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4037 assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&(static_cast <bool> (It != OMPInfoCache.RuntimeFunctionIDMap
.end() && "Expected a known OpenMP runtime function")
? void (0) : __assert_fail ("It != OMPInfoCache.RuntimeFunctionIDMap.end() && \"Expected a known OpenMP runtime function\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 4038, __extension__ __PRETTY_FUNCTION__))
4038 "Expected a known OpenMP runtime function")(static_cast <bool> (It != OMPInfoCache.RuntimeFunctionIDMap
.end() && "Expected a known OpenMP runtime function")
? void (0) : __assert_fail ("It != OMPInfoCache.RuntimeFunctionIDMap.end() && \"Expected a known OpenMP runtime function\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 4038, __extension__ __PRETTY_FUNCTION__))
;
4039
4040 RFKind = It->getSecond();
4041
4042 CallBase &CB = cast<CallBase>(getAssociatedValue());
4043 A.registerSimplificationCallback(
4044 IRPosition::callsite_returned(CB),
4045 [&](const IRPosition &IRP, const AbstractAttribute *AA,
4046 bool &UsedAssumedInformation) -> Optional<Value *> {
4047 assert((isValidState() || (SimplifiedValue.hasValue() &&(static_cast <bool> ((isValidState() || (SimplifiedValue
.hasValue() && SimplifiedValue.getValue() == nullptr)
) && "Unexpected invalid state!") ? void (0) : __assert_fail
("(isValidState() || (SimplifiedValue.hasValue() && SimplifiedValue.getValue() == nullptr)) && \"Unexpected invalid state!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 4049, __extension__ __PRETTY_FUNCTION__))
4048 SimplifiedValue.getValue() == nullptr)) &&(static_cast <bool> ((isValidState() || (SimplifiedValue
.hasValue() && SimplifiedValue.getValue() == nullptr)
) && "Unexpected invalid state!") ? void (0) : __assert_fail
("(isValidState() || (SimplifiedValue.hasValue() && SimplifiedValue.getValue() == nullptr)) && \"Unexpected invalid state!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 4049, __extension__ __PRETTY_FUNCTION__))
4049 "Unexpected invalid state!")(static_cast <bool> ((isValidState() || (SimplifiedValue
.hasValue() && SimplifiedValue.getValue() == nullptr)
) && "Unexpected invalid state!") ? void (0) : __assert_fail
("(isValidState() || (SimplifiedValue.hasValue() && SimplifiedValue.getValue() == nullptr)) && \"Unexpected invalid state!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 4049, __extension__ __PRETTY_FUNCTION__))
;
4050
4051 if (!isAtFixpoint()) {
4052 UsedAssumedInformation = true;
4053 if (AA)
4054 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
4055 }
4056 return SimplifiedValue;
4057 });
4058 }
4059
4060 ChangeStatus updateImpl(Attributor &A) override {
4061 ChangeStatus Changed = ChangeStatus::UNCHANGED;
4062 switch (RFKind) {
4063 case OMPRTL___kmpc_is_spmd_exec_mode:
4064 Changed |= foldIsSPMDExecMode(A);
4065 break;
4066 case OMPRTL___kmpc_is_generic_main_thread_id:
4067 Changed |= foldIsGenericMainThread(A);
4068 break;
4069 case OMPRTL___kmpc_parallel_level:
4070 Changed |= foldParallelLevel(A);
4071 break;
4072 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4073 Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
4074 break;
4075 case OMPRTL___kmpc_get_hardware_num_blocks:
4076 Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
4077 break;
4078 default:
4079 llvm_unreachable("Unhandled OpenMP runtime function!")::llvm::llvm_unreachable_internal("Unhandled OpenMP runtime function!"
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 4079)
;
4080 }
4081
4082 return Changed;
4083 }
4084
4085 ChangeStatus manifest(Attributor &A) override {
4086 ChangeStatus Changed = ChangeStatus::UNCHANGED;
4087
4088 if (SimplifiedValue.hasValue() && SimplifiedValue.getValue()) {
4089 Instruction &I = *getCtxI();
4090 A.changeValueAfterManifest(I, **SimplifiedValue);
4091 A.deleteAfterManifest(I);
4092
4093 CallBase *CB = dyn_cast<CallBase>(&I);
4094 auto Remark = [&](OptimizationRemark OR) {
4095 if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue))
4096 return OR << "Replacing OpenMP runtime call "
4097 << CB->getCalledFunction()->getName() << " with "
4098 << ore::NV("FoldedValue", C->getZExtValue()) << ".";
4099 else
4100 return OR << "Replacing OpenMP runtime call "
4101 << CB->getCalledFunction()->getName() << ".";
4102 };
4103
4104 if (CB && EnableVerboseRemarks)
4105 A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark);
4106
4107 LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "Replacing runtime call: "
<< I << " with " << **SimplifiedValue <<
"\n"; } } while (false)
4108 << **SimplifiedValue << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "Replacing runtime call: "
<< I << " with " << **SimplifiedValue <<
"\n"; } } while (false)
;
4109
4110 Changed = ChangeStatus::CHANGED;
4111 }
4112
4113 return Changed;
4114 }
4115
4116 ChangeStatus indicatePessimisticFixpoint() override {
4117 SimplifiedValue = nullptr;
4118 return AAFoldRuntimeCall::indicatePessimisticFixpoint();
4119 }
4120
4121private:
4122 /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
4123 ChangeStatus foldIsSPMDExecMode(Attributor &A) {
4124 Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4125
4126 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
4127 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
4128 auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4129 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4130
4131 if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4132 return indicatePessimisticFixpoint();
4133
4134 for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4135 auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
4136 DepClassTy::REQUIRED);
4137
4138 if (!AA.isValidState()) {
4139 SimplifiedValue = nullptr;
4140 return indicatePessimisticFixpoint();
4141 }
4142
4143 if (AA.SPMDCompatibilityTracker.isAssumed()) {
4144 if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4145 ++KnownSPMDCount;
4146 else
4147 ++AssumedSPMDCount;
4148 } else {
4149 if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4150 ++KnownNonSPMDCount;
4151 else
4152 ++AssumedNonSPMDCount;
4153 }
4154 }
4155
4156 if ((AssumedSPMDCount + KnownSPMDCount) &&
4157 (AssumedNonSPMDCount + KnownNonSPMDCount))
4158 return indicatePessimisticFixpoint();
4159
4160 auto &Ctx = getAnchorValue().getContext();
4161 if (KnownSPMDCount || AssumedSPMDCount) {
4162 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&(static_cast <bool> (KnownNonSPMDCount == 0 && AssumedNonSPMDCount
== 0 && "Expected only SPMD kernels!") ? void (0) : __assert_fail
("KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 && \"Expected only SPMD kernels!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 4163, __extension__ __PRETTY_FUNCTION__))
4163 "Expected only SPMD kernels!")(static_cast <bool> (KnownNonSPMDCount == 0 && AssumedNonSPMDCount
== 0 && "Expected only SPMD kernels!") ? void (0) : __assert_fail
("KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 && \"Expected only SPMD kernels!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 4163, __extension__ __PRETTY_FUNCTION__))
;
4164 // All reaching kernels are in SPMD mode. Update all function calls to
4165 // __kmpc_is_spmd_exec_mode to 1.
4166 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
4167 } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
4168 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&(static_cast <bool> (KnownSPMDCount == 0 && AssumedSPMDCount
== 0 && "Expected only non-SPMD kernels!") ? void (0
) : __assert_fail ("KnownSPMDCount == 0 && AssumedSPMDCount == 0 && \"Expected only non-SPMD kernels!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 4169, __extension__ __PRETTY_FUNCTION__))
4169 "Expected only non-SPMD kernels!")(static_cast <bool> (KnownSPMDCount == 0 && AssumedSPMDCount
== 0 && "Expected only non-SPMD kernels!") ? void (0
) : __assert_fail ("KnownSPMDCount == 0 && AssumedSPMDCount == 0 && \"Expected only non-SPMD kernels!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 4169, __extension__ __PRETTY_FUNCTION__))
;
4170 // All reaching kernels are in non-SPMD mode. Update all function
4171 // calls to __kmpc_is_spmd_exec_mode to 0.
4172 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
4173 } else {
4174 // We have empty reaching kernels, therefore we cannot tell if the
4175 // associated call site can be folded. At this moment, SimplifiedValue
4176 // must be none.
4177 assert(!SimplifiedValue.hasValue() && "SimplifiedValue should be none")(static_cast <bool> (!SimplifiedValue.hasValue() &&
"SimplifiedValue should be none") ? void (0) : __assert_fail
("!SimplifiedValue.hasValue() && \"SimplifiedValue should be none\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 4177, __extension__ __PRETTY_FUNCTION__))
;
4178 }
4179
4180 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4181 : ChangeStatus::CHANGED;
4182 }
4183
4184 /// Fold __kmpc_is_generic_main_thread_id into a constant if possible.
4185 ChangeStatus foldIsGenericMainThread(Attributor &A) {
4186 Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4187
4188 CallBase &CB = cast<CallBase>(getAssociatedValue());
4189 Function *F = CB.getFunction();
4190 const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
4191 *this, IRPosition::function(*F), DepClassTy::REQUIRED);
4192
4193 if (!ExecutionDomainAA.isValidState())
4194 return indicatePessimisticFixpoint();
4195
4196 auto &Ctx = getAnchorValue().getContext();
4197 if (ExecutionDomainAA.isExecutedByInitialThreadOnly(CB))
4198 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
4199 else
4200 return indicatePessimisticFixpoint();
4201
4202 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4203 : ChangeStatus::CHANGED;
4204 }
4205
4206 /// Fold __kmpc_parallel_level into a constant if possible.
4207 ChangeStatus foldParallelLevel(Attributor &A) {
4208 Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4209
4210 auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4211 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4212
4213 if (!CallerKernelInfoAA.ParallelLevels.isValidState())
4214 return indicatePessimisticFixpoint();
4215
4216 if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4217 return indicatePessimisticFixpoint();
4218
4219 if (CallerKernelInfoAA.ReachingKernelEntries.empty()) {
4220 assert(!SimplifiedValue.hasValue() &&(static_cast <bool> (!SimplifiedValue.hasValue() &&
"SimplifiedValue should keep none at this point") ? void (0)
: __assert_fail ("!SimplifiedValue.hasValue() && \"SimplifiedValue should keep none at this point\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 4221, __extension__ __PRETTY_FUNCTION__))
4221 "SimplifiedValue should keep none at this point")(static_cast <bool> (!SimplifiedValue.hasValue() &&
"SimplifiedValue should keep none at this point") ? void (0)
: __assert_fail ("!SimplifiedValue.hasValue() && \"SimplifiedValue should keep none at this point\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 4221, __extension__ __PRETTY_FUNCTION__))
;
4222 return ChangeStatus::UNCHANGED;
4223 }
4224
4225 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
4226 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
4227 for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4228 auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
4229 DepClassTy::REQUIRED);
4230 if (!AA.SPMDCompatibilityTracker.isValidState())
4231 return indicatePessimisticFixpoint();
4232
4233 if (AA.SPMDCompatibilityTracker.isAssumed()) {
4234 if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4235 ++KnownSPMDCount;
4236 else
4237 ++AssumedSPMDCount;
4238 } else {
4239 if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4240 ++KnownNonSPMDCount;
4241 else
4242 ++AssumedNonSPMDCount;
4243 }
4244 }
4245
4246 if ((AssumedSPMDCount + KnownSPMDCount) &&
4247 (AssumedNonSPMDCount + KnownNonSPMDCount))
4248 return indicatePessimisticFixpoint();
4249
4250 auto &Ctx = getAnchorValue().getContext();
4251 // If the caller can only be reached by SPMD kernel entries, the parallel
4252 // level is 1. Similarly, if the caller can only be reached by non-SPMD
4253 // kernel entries, it is 0.
4254 if (AssumedSPMDCount || KnownSPMDCount) {
4255 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&(static_cast <bool> (KnownNonSPMDCount == 0 && AssumedNonSPMDCount
== 0 && "Expected only SPMD kernels!") ? void (0) : __assert_fail
("KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 && \"Expected only SPMD kernels!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 4256, __extension__ __PRETTY_FUNCTION__))
4256 "Expected only SPMD kernels!")(static_cast <bool> (KnownNonSPMDCount == 0 && AssumedNonSPMDCount
== 0 && "Expected only SPMD kernels!") ? void (0) : __assert_fail
("KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 && \"Expected only SPMD kernels!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 4256, __extension__ __PRETTY_FUNCTION__))
;
4257 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
4258 } else {
4259 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&(static_cast <bool> (KnownSPMDCount == 0 && AssumedSPMDCount
== 0 && "Expected only non-SPMD kernels!") ? void (0
) : __assert_fail ("KnownSPMDCount == 0 && AssumedSPMDCount == 0 && \"Expected only non-SPMD kernels!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 4260, __extension__ __PRETTY_FUNCTION__))
4260 "Expected only non-SPMD kernels!")(static_cast <bool> (KnownSPMDCount == 0 && AssumedSPMDCount
== 0 && "Expected only non-SPMD kernels!") ? void (0
) : __assert_fail ("KnownSPMDCount == 0 && AssumedSPMDCount == 0 && \"Expected only non-SPMD kernels!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 4260, __extension__ __PRETTY_FUNCTION__))
;
4261 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
4262 }
4263 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4264 : ChangeStatus::CHANGED;
4265 }
4266
4267 ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
4268 // Specialize only if all the calls agree with the attribute constant value
4269 int32_t CurrentAttrValue = -1;
4270 Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4271
4272 auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4273 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4274
4275 if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4276 return indicatePessimisticFixpoint();
4277
4278 // Iterate over the kernels that reach this function
4279 for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4280 int32_t NextAttrVal = -1;
4281 if (K->hasFnAttribute(Attr))
4282 NextAttrVal =
4283 std::stoi(K->getFnAttribute(Attr).getValueAsString().str());
4284
4285 if (NextAttrVal == -1 ||
4286 (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
4287 return indicatePessimisticFixpoint();
4288 CurrentAttrValue = NextAttrVal;
4289 }
4290
4291 if (CurrentAttrValue != -1) {
4292 auto &Ctx = getAnchorValue().getContext();
4293 SimplifiedValue =
4294 ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
4295 }
4296 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4297 : ChangeStatus::CHANGED;
4298 }
4299
4300 /// An optional value the associated value is assumed to fold to. That is, we
4301 /// assume the associated value (which is a call) can be replaced by this
4302 /// simplified value.
4303 Optional<Value *> SimplifiedValue;
4304
4305 /// The runtime function kind of the callee of the associated call site.
4306 RuntimeFunction RFKind;
4307};
4308
4309} // namespace
4310
4311/// Register folding callsite
4312void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
4313 auto &RFI = OMPInfoCache.RFIs[RF];
4314 RFI.foreachUse(SCC, [&](Use &U, Function &F) {
4315 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
4316 if (!CI)
4317 return false;
4318 A.getOrCreateAAFor<AAFoldRuntimeCall>(
4319 IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
4320 DepClassTy::NONE, /* ForceUpdate */ false,
4321 /* UpdateAfterInit */ false);
4322 return false;
4323 });
4324}
4325
4326void OpenMPOpt::registerAAs(bool IsModulePass) {
4327 if (SCC.empty())
4328
4329 return;
4330 if (IsModulePass) {
4331 // Ensure we create the AAKernelInfo AAs first and without triggering an
4332 // update. This will make sure we register all value simplification
4333 // callbacks before any other AA has the chance to create an AAValueSimplify
4334 // or similar.
4335 for (Function *Kernel : OMPInfoCache.Kernels)
4336 A.getOrCreateAAFor<AAKernelInfo>(
4337 IRPosition::function(*Kernel), /* QueryingAA */ nullptr,
4338 DepClassTy::NONE, /* ForceUpdate */ false,
4339 /* UpdateAfterInit */ false);
4340
4341 registerFoldRuntimeCall(OMPRTL___kmpc_is_generic_main_thread_id);
4342 registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
4343 registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
4344 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
4345 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
4346 }
4347
4348 // Create CallSite AA for all Getters.
4349 for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
4350 auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
4351
4352 auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
4353
4354 auto CreateAA = [&](Use &U, Function &Caller) {
4355 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
4356 if (!CI)
4357 return false;
4358
4359 auto &CB = cast<CallBase>(*CI);
4360
4361 IRPosition CBPos = IRPosition::callsite_function(CB);
4362 A.getOrCreateAAFor<AAICVTracker>(CBPos);
4363 return false;
4364 };
4365
4366 GetterRFI.foreachUse(SCC, CreateAA);
4367 }
4368 auto &GlobalizationRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4369 auto CreateAA = [&](Use &U, Function &F) {
4370 A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
4371 return false;
4372 };
4373 if (!DisableOpenMPOptDeglobalization)
4374 GlobalizationRFI.foreachUse(SCC, CreateAA);
4375
4376 // Create an ExecutionDomain AA for every function and a HeapToStack AA for
4377 // every function if there is a device kernel.
4378 if (!isOpenMPDevice(M))
4379 return;
4380
4381 for (auto *F : SCC) {
4382 if (F->isDeclaration())
4383 continue;
4384
4385 A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(*F));
4386 if (!DisableOpenMPOptDeglobalization)
4387 A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(*F));
4388
4389 for (auto &I : instructions(*F)) {
4390 if (auto *LI = dyn_cast<LoadInst>(&I)) {
4391 bool UsedAssumedInformation = false;
4392 A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
4393 UsedAssumedInformation);
4394 }
4395 }
4396 }
4397}
4398
4399const char AAICVTracker::ID = 0;
4400const char AAKernelInfo::ID = 0;
4401const char AAExecutionDomain::ID = 0;
4402const char AAHeapToShared::ID = 0;
4403const char AAFoldRuntimeCall::ID = 0;
4404
4405AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
4406 Attributor &A) {
4407 AAICVTracker *AA = nullptr;
4408 switch (IRP.getPositionKind()) {
4409 case IRPosition::IRP_INVALID:
4410 case IRPosition::IRP_FLOAT:
4411 case IRPosition::IRP_ARGUMENT:
4412 case IRPosition::IRP_CALL_SITE_ARGUMENT:
4413 llvm_unreachable("ICVTracker can only be created for function position!")::llvm::llvm_unreachable_internal("ICVTracker can only be created for function position!"
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 4413)
;
4414 case IRPosition::IRP_RETURNED:
4415 AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
4416 break;
4417 case IRPosition::IRP_CALL_SITE_RETURNED:
4418 AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
4419 break;
4420 case IRPosition::IRP_CALL_SITE:
4421 AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
4422 break;
4423 case IRPosition::IRP_FUNCTION:
4424 AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
4425 break;
4426 }
4427
4428 return *AA;
4429}
4430
4431AAExecutionDomain &AAExecutionDomain::createForPosition(const IRPosition &IRP,
4432 Attributor &A) {
4433 AAExecutionDomainFunction *AA = nullptr;
4434 switch (IRP.getPositionKind()) {
4435 case IRPosition::IRP_INVALID:
4436 case IRPosition::IRP_FLOAT:
4437 case IRPosition::IRP_ARGUMENT:
4438 case IRPosition::IRP_CALL_SITE_ARGUMENT:
4439 case IRPosition::IRP_RETURNED:
4440 case IRPosition::IRP_CALL_SITE_RETURNED:
4441 case IRPosition::IRP_CALL_SITE:
4442 llvm_unreachable(::llvm::llvm_unreachable_internal("AAExecutionDomain can only be created for function position!"
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 4443)
4443 "AAExecutionDomain can only be created for function position!")::llvm::llvm_unreachable_internal("AAExecutionDomain can only be created for function position!"
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 4443)
;
4444 case IRPosition::IRP_FUNCTION:
4445 AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
4446 break;
4447 }
4448
4449 return *AA;
4450}
4451
4452AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
4453 Attributor &A) {
4454 AAHeapToSharedFunction *AA = nullptr;
4455 switch (IRP.getPositionKind()) {
4456 case IRPosition::IRP_INVALID:
4457 case IRPosition::IRP_FLOAT:
4458 case IRPosition::IRP_ARGUMENT:
4459 case IRPosition::IRP_CALL_SITE_ARGUMENT:
4460 case IRPosition::IRP_RETURNED:
4461 case IRPosition::IRP_CALL_SITE_RETURNED:
4462 case IRPosition::IRP_CALL_SITE:
4463 llvm_unreachable(::llvm::llvm_unreachable_internal("AAHeapToShared can only be created for function position!"
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 4464)
4464 "AAHeapToShared can only be created for function position!")::llvm::llvm_unreachable_internal("AAHeapToShared can only be created for function position!"
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 4464)
;
4465 case IRPosition::IRP_FUNCTION:
4466 AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
4467 break;
4468 }
4469
4470 return *AA;
4471}
4472
4473AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
4474 Attributor &A) {
4475 AAKernelInfo *AA = nullptr;
4476 switch (IRP.getPositionKind()) {
4477 case IRPosition::IRP_INVALID:
4478 case IRPosition::IRP_FLOAT:
4479 case IRPosition::IRP_ARGUMENT:
4480 case IRPosition::IRP_RETURNED:
4481 case IRPosition::IRP_CALL_SITE_RETURNED:
4482 case IRPosition::IRP_CALL_SITE_ARGUMENT:
4483 llvm_unreachable("KernelInfo can only be created for function position!")::llvm::llvm_unreachable_internal("KernelInfo can only be created for function position!"
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 4483)
;
4484 case IRPosition::IRP_CALL_SITE:
4485 AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
4486 break;
4487 case IRPosition::IRP_FUNCTION:
4488 AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
4489 break;
4490 }
4491
4492 return *AA;
4493}
4494
4495AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
4496 Attributor &A) {
4497 AAFoldRuntimeCall *AA = nullptr;
4498 switch (IRP.getPositionKind()) {
4499 case IRPosition::IRP_INVALID:
4500 case IRPosition::IRP_FLOAT:
4501 case IRPosition::IRP_ARGUMENT:
4502 case IRPosition::IRP_RETURNED:
4503 case IRPosition::IRP_FUNCTION:
4504 case IRPosition::IRP_CALL_SITE:
4505 case IRPosition::IRP_CALL_SITE_ARGUMENT:
4506 llvm_unreachable("KernelInfo can only be created for call site position!")::llvm::llvm_unreachable_internal("KernelInfo can only be created for call site position!"
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/lib/Transforms/IPO/OpenMPOpt.cpp"
, 4506)
;
4507 case IRPosition::IRP_CALL_SITE_RETURNED:
4508 AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
4509 break;
4510 }
4511
4512 return *AA;
4513}
4514
4515PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
4516 if (!containsOpenMP(M))
4517 return PreservedAnalyses::all();
4518 if (DisableOpenMPOptimizations)
4519 return PreservedAnalyses::all();
4520
4521 FunctionAnalysisManager &FAM =
4522 AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
4523 KernelSet Kernels = getDeviceKernels(M);
4524
4525 auto IsCalled = [&](Function &F) {
4526 if (Kernels.contains(&F))
4527 return true;
4528 for (const User *U : F.users())
4529 if (!isa<BlockAddress>(U))
4530 return true;
4531 return false;
4532 };
4533
4534 auto EmitRemark = [&](Function &F) {
4535 auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
4536 ORE.emit([&]() {
4537 OptimizationRemarkAnalysis ORA(DEBUG_TYPE"openmp-opt", "OMP140", &F);
4538 return ORA << "Could not internalize function. "
4539 << "Some optimizations may not be possible. [OMP140]";
4540 });
4541 };
4542
4543 // Create internal copies of each function if this is a kernel Module. This
4544 // allows iterprocedural passes to see every call edge.
4545 DenseMap<Function *, Function *> InternalizedMap;
4546 if (isOpenMPDevice(M)) {
4547 SmallPtrSet<Function *, 16> InternalizeFns;
4548 for (Function &F : M)
4549 if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
4550 !DisableInternalization) {
4551 if (Attributor::isInternalizable(F)) {
4552 InternalizeFns.insert(&F);
4553 } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
4554 EmitRemark(F);
4555 }
4556 }
4557
4558 Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
4559 }
4560
4561 // Look at every function in the Module unless it was internalized.
4562 SmallVector<Function *, 16> SCC;
4563 for (Function &F : M)
4564 if (!F.isDeclaration() && !InternalizedMap.lookup(&F))
4565 SCC.push_back(&F);
4566
4567 if (SCC.empty())
4568 return PreservedAnalyses::all();
4569
4570 AnalysisGetter AG(FAM);
4571
4572 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
4573 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
4574 };
4575
4576 BumpPtrAllocator Allocator;
4577 CallGraphUpdater CGUpdater;
4578
4579 SetVector<Function *> Functions(SCC.begin(), SCC.end());
4580 OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions, Kernels);
4581
4582 unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
4583 Attributor A(Functions, InfoCache, CGUpdater, nullptr, true, false,
4584 MaxFixpointIterations, OREGetter, DEBUG_TYPE"openmp-opt");
4585
4586 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4587 bool Changed = OMPOpt.run(true);
4588
4589 // Optionally inline device functions for potentially better performance.
4590 if (AlwaysInlineDeviceFunctions && isOpenMPDevice(M))
4591 for (Function &F : M)
4592 if (!F.isDeclaration() && !Kernels.contains(&F) &&
4593 !F.hasFnAttribute(Attribute::NoInline))
4594 F.addFnAttr(Attribute::AlwaysInline);
4595
4596 if (PrintModuleAfterOptimizations)
4597 LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n"
<< M; } } while (false)
;
4598
4599 if (Changed)
4600 return PreservedAnalyses::none();
4601
4602 return PreservedAnalyses::all();
4603}
4604
4605PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C,
4606 CGSCCAnalysisManager &AM,
4607 LazyCallGraph &CG,
4608 CGSCCUpdateResult &UR) {
4609 if (!containsOpenMP(*C.begin()->getFunction().getParent()))
4610 return PreservedAnalyses::all();
4611 if (DisableOpenMPOptimizations)
4612 return PreservedAnalyses::all();
4613
4614 SmallVector<Function *, 16> SCC;
4615 // If there are kernels in the module, we have to run on all SCC's.
4616 for (LazyCallGraph::Node &N : C) {
4617 Function *Fn = &N.getFunction();
4618 SCC.push_back(Fn);
4619 }
4620
4621 if (SCC.empty())
4622 return PreservedAnalyses::all();
4623
4624 Module &M = *C.begin()->getFunction().getParent();
4625
4626 KernelSet Kernels = getDeviceKernels(M);
4627
4628 FunctionAnalysisManager &FAM =
4629 AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
4630
4631 AnalysisGetter AG(FAM);
4632
4633 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
4634 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
4635 };
4636
4637 BumpPtrAllocator Allocator;
4638 CallGraphUpdater CGUpdater;
4639 CGUpdater.initialize(CG, C, AM, UR);
4640
4641 SetVector<Function *> Functions(SCC.begin(), SCC.end());
4642 OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
4643 /*CGSCC*/ Functions, Kernels);
4644
4645 unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
4646 Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true,
4647 MaxFixpointIterations, OREGetter, DEBUG_TYPE"openmp-opt");
4648
4649 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4650 bool Changed = OMPOpt.run(false);
4651
4652 if (PrintModuleAfterOptimizations)
4653 LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n"
<< M; } } while (false)
;
4654
4655 if (Changed)
4656 return PreservedAnalyses::none();
4657
4658 return PreservedAnalyses::all();
4659}
4660
4661namespace {
4662
4663struct OpenMPOptCGSCCLegacyPass : public CallGraphSCCPass {
4664 CallGraphUpdater CGUpdater;
4665 static char ID;
4666
4667 OpenMPOptCGSCCLegacyPass() : CallGraphSCCPass(ID) {
4668 initializeOpenMPOptCGSCCLegacyPassPass(*PassRegistry::getPassRegistry());
4669 }
4670
4671 void getAnalysisUsage(AnalysisUsage &AU) const override {
4672 CallGraphSCCPass::getAnalysisUsage(AU);
4673 }
4674
4675 bool runOnSCC(CallGraphSCC &CGSCC) override {
4676 if (!containsOpenMP(CGSCC.getCallGraph().getModule()))
4677 return false;
4678 if (DisableOpenMPOptimizations || skipSCC(CGSCC))
4679 return false;
4680
4681 SmallVector<Function *, 16> SCC;
4682 // If there are kernels in the module, we have to run on all SCC's.
4683 for (CallGraphNode *CGN : CGSCC) {
4684 Function *Fn = CGN->getFunction();
4685 if (!Fn || Fn->isDeclaration())
4686 continue;
4687 SCC.push_back(Fn);
4688 }
4689
4690 if (SCC.empty())
4691 return false;
4692
4693 Module &M = CGSCC.getCallGraph().getModule();
4694 KernelSet Kernels = getDeviceKernels(M);
4695
4696 CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
4697 CGUpdater.initialize(CG, CGSCC);
4698
4699 // Maintain a map of functions to avoid rebuilding the ORE
4700 DenseMap<Function *, std::unique_ptr<OptimizationRemarkEmitter>> OREMap;
4701 auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & {
4702 std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F];
4703 if (!ORE)
4704 ORE = std::make_unique<OptimizationRemarkEmitter>(F);
4705 return *ORE;
4706 };
4707
4708 AnalysisGetter AG;
4709 SetVector<Function *> Functions(SCC.begin(), SCC.end());
4710 BumpPtrAllocator Allocator;
4711 OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG,
4712 Allocator,
4713 /*CGSCC*/ Functions, Kernels);
4714
4715 unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
4716 Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true,
4717 MaxFixpointIterations, OREGetter, DEBUG_TYPE"openmp-opt");
4718
4719 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4720 bool Result = OMPOpt.run(false);
4721
4722 if (PrintModuleAfterOptimizations)
4723 LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("openmp-opt")) { dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n"
<< M; } } while (false)
;
4724
4725 return Result;
4726 }
4727
4728 bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); }
4729};
4730
4731} // end anonymous namespace
4732
4733KernelSet llvm::omp::getDeviceKernels(Module &M) {
4734 // TODO: Create a more cross-platform way of determining device kernels.
4735 NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
4736 KernelSet Kernels;
4737
4738 if (!MD)
4739 return Kernels;
4740
4741 for (auto *Op : MD->operands()) {
4742 if (Op->getNumOperands() < 2)
4743 continue;
4744 MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
4745 if (!KindID || KindID->getString() != "kernel")
4746 continue;
4747
4748 Function *KernelFn =
4749 mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
4750 if (!KernelFn)
4751 continue;
4752
4753 ++NumOpenMPTargetRegionKernels;
4754
4755 Kernels.insert(KernelFn);
4756 }
4757
4758 return Kernels;
4759}
4760
4761bool llvm::omp::containsOpenMP(Module &M) {
4762 Metadata *MD = M.getModuleFlag("openmp");
4763 if (!MD)
4764 return false;
4765
4766 return true;
4767}
4768
4769bool llvm::omp::isOpenMPDevice(Module &M) {
4770 Metadata *MD = M.getModuleFlag("openmp-device");
4771 if (!MD)
4772 return false;
4773
4774 return true;
4775}
4776
4777char OpenMPOptCGSCCLegacyPass::ID = 0;
4778
4779INITIALIZE_PASS_BEGIN(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",static void *initializeOpenMPOptCGSCCLegacyPassPassOnce(PassRegistry
&Registry) {
4780 "OpenMP specific optimizations", false, false)static void *initializeOpenMPOptCGSCCLegacyPassPassOnce(PassRegistry
&Registry) {
4781INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)initializeCallGraphWrapperPassPass(Registry);
4782INITIALIZE_PASS_END(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",PassInfo *PI = new PassInfo( "OpenMP specific optimizations",
"openmp-opt-cgscc", &OpenMPOptCGSCCLegacyPass::ID, PassInfo
::NormalCtor_t(callDefaultCtor<OpenMPOptCGSCCLegacyPass>
), false, false); Registry.registerPass(*PI, true); return PI
; } static llvm::once_flag InitializeOpenMPOptCGSCCLegacyPassPassFlag
; void llvm::initializeOpenMPOptCGSCCLegacyPassPass(PassRegistry
&Registry) { llvm::call_once(InitializeOpenMPOptCGSCCLegacyPassPassFlag
, initializeOpenMPOptCGSCCLegacyPassPassOnce, std::ref(Registry
)); }
4783 "OpenMP specific optimizations", false, false)PassInfo *PI = new PassInfo( "OpenMP specific optimizations",
"openmp-opt-cgscc", &OpenMPOptCGSCCLegacyPass::ID, PassInfo
::NormalCtor_t(callDefaultCtor<OpenMPOptCGSCCLegacyPass>
), false, false); Registry.registerPass(*PI, true); return PI
; } static llvm::once_flag InitializeOpenMPOptCGSCCLegacyPassPassFlag
; void llvm::initializeOpenMPOptCGSCCLegacyPassPass(PassRegistry
&Registry) { llvm::call_once(InitializeOpenMPOptCGSCCLegacyPassPassFlag
, initializeOpenMPOptCGSCCLegacyPassPassOnce, std::ref(Registry
)); }
4784
4785Pass *llvm::createOpenMPOptCGSCCLegacyPass() {
4786 return new OpenMPOptCGSCCLegacyPass();
4787}

/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/include/llvm/Transforms/IPO/Attributor.h

1//===- Attributor.h --- Module-wide attribute deduction ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Attributor: An inter procedural (abstract) "attribute" deduction framework.
10//
11// The Attributor framework is an inter procedural abstract analysis (fixpoint
12// iteration analysis). The goal is to allow easy deduction of new attributes as
13// well as information exchange between abstract attributes in-flight.
14//
15// The Attributor class is the driver and the link between the various abstract
16// attributes. The Attributor will iterate until a fixpoint state is reached by
17// all abstract attributes in-flight, or until it will enforce a pessimistic fix
18// point because an iteration limit is reached.
19//
20// Abstract attributes, derived from the AbstractAttribute class, actually
21// describe properties of the code. They can correspond to actual LLVM-IR
22// attributes, or they can be more general, ultimately unrelated to LLVM-IR
23// attributes. The latter is useful when an abstract attributes provides
24// information to other abstract attributes in-flight but we might not want to
25// manifest the information. The Attributor allows to query in-flight abstract
26// attributes through the `Attributor::getAAFor` method (see the method
27// description for an example). If the method is used by an abstract attribute
28// P, and it results in an abstract attribute Q, the Attributor will
29// automatically capture a potential dependence from Q to P. This dependence
30// will cause P to be reevaluated whenever Q changes in the future.
31//
32// The Attributor will only reevaluate abstract attributes that might have
33// changed since the last iteration. That means that the Attribute will not
34// revisit all instructions/blocks/functions in the module but only query
35// an update from a subset of the abstract attributes.
36//
37// The update method `AbstractAttribute::updateImpl` is implemented by the
38// specific "abstract attribute" subclasses. The method is invoked whenever the
39// currently assumed state (see the AbstractState class) might not be valid
40// anymore. This can, for example, happen if the state was dependent on another
41// abstract attribute that changed. In every invocation, the update method has
42// to adjust the internal state of an abstract attribute to a point that is
43// justifiable by the underlying IR and the current state of abstract attributes
44// in-flight. Since the IR is given and assumed to be valid, the information
45// derived from it can be assumed to hold. However, information derived from
46// other abstract attributes is conditional on various things. If the justifying
47// state changed, the `updateImpl` has to revisit the situation and potentially
48// find another justification or limit the optimistic assumes made.
49//
50// Change is the key in this framework. Until a state of no-change, thus a
51// fixpoint, is reached, the Attributor will query the abstract attributes
52// in-flight to re-evaluate their state. If the (current) state is too
53// optimistic, hence it cannot be justified anymore through other abstract
54// attributes or the state of the IR, the state of the abstract attribute will
55// have to change. Generally, we assume abstract attribute state to be a finite
56// height lattice and the update function to be monotone. However, these
57// conditions are not enforced because the iteration limit will guarantee
58// termination. If an optimistic fixpoint is reached, or a pessimistic fix
59// point is enforced after a timeout, the abstract attributes are tasked to
60// manifest their result in the IR for passes to come.
61//
62// Attribute manifestation is not mandatory. If desired, there is support to
63// generate a single or multiple LLVM-IR attributes already in the helper struct
64// IRAttribute. In the simplest case, a subclass inherits from IRAttribute with
65// a proper Attribute::AttrKind as template parameter. The Attributor
66// manifestation framework will then create and place a new attribute if it is
67// allowed to do so (based on the abstract state). Other use cases can be
68// achieved by overloading AbstractAttribute or IRAttribute methods.
69//
70//
71// The "mechanics" of adding a new "abstract attribute":
72// - Define a class (transitively) inheriting from AbstractAttribute and one
73// (which could be the same) that (transitively) inherits from AbstractState.
74// For the latter, consider the already available BooleanState and
75// {Inc,Dec,Bit}IntegerState if they fit your needs, e.g., you require only a
76// number tracking or bit-encoding.
77// - Implement all pure methods. Also use overloading if the attribute is not
78// conforming with the "default" behavior: A (set of) LLVM-IR attribute(s) for
79// an argument, call site argument, function return value, or function. See
80// the class and method descriptions for more information on the two
81// "Abstract" classes and their respective methods.
82// - Register opportunities for the new abstract attribute in the
83// `Attributor::identifyDefaultAbstractAttributes` method if it should be
84// counted as a 'default' attribute.
85// - Add sufficient tests.
86// - Add a Statistics object for bookkeeping. If it is a simple (set of)
87// attribute(s) manifested through the Attributor manifestation framework, see
88// the bookkeeping function in Attributor.cpp.
89// - If instructions with a certain opcode are interesting to the attribute, add
90// that opcode to the switch in `Attributor::identifyAbstractAttributes`. This
91// will make it possible to query all those instructions through the
92// `InformationCache::getOpcodeInstMapForFunction` interface and eliminate the
93// need to traverse the IR repeatedly.
94//
95//===----------------------------------------------------------------------===//
96
97#ifndef LLVM_TRANSFORMS_IPO_ATTRIBUTOR_H
98#define LLVM_TRANSFORMS_IPO_ATTRIBUTOR_H
99
100#include "llvm/ADT/DenseSet.h"
101#include "llvm/ADT/GraphTraits.h"
102#include "llvm/ADT/MapVector.h"
103#include "llvm/ADT/STLExtras.h"
104#include "llvm/ADT/SetVector.h"
105#include "llvm/ADT/Triple.h"
106#include "llvm/ADT/iterator.h"
107#include "llvm/Analysis/AssumeBundleQueries.h"
108#include "llvm/Analysis/CFG.h"
109#include "llvm/Analysis/CGSCCPassManager.h"
110#include "llvm/Analysis/LazyCallGraph.h"
111#include "llvm/Analysis/LoopInfo.h"
112#include "llvm/Analysis/MustExecute.h"
113#include "llvm/Analysis/OptimizationRemarkEmitter.h"
114#include "llvm/Analysis/PostDominators.h"
115#include "llvm/Analysis/TargetLibraryInfo.h"
116#include "llvm/IR/AbstractCallSite.h"
117#include "llvm/IR/ConstantRange.h"
118#include "llvm/IR/PassManager.h"
119#include "llvm/Support/Allocator.h"
120#include "llvm/Support/Casting.h"
121#include "llvm/Support/GraphWriter.h"
122#include "llvm/Support/TimeProfiler.h"
123#include "llvm/Transforms/Utils/CallGraphUpdater.h"
124
125namespace llvm {
126
127struct AADepGraphNode;
128struct AADepGraph;
129struct Attributor;
130struct AbstractAttribute;
131struct InformationCache;
132struct AAIsDead;
133struct AttributorCallGraph;
134
135class AAManager;
136class AAResults;
137class Function;
138
139/// Abstract Attribute helper functions.
140namespace AA {
141
142/// Return true if \p V is dynamically unique, that is, there are no two
143/// "instances" of \p V at runtime with different values.
144bool isDynamicallyUnique(Attributor &A, const AbstractAttribute &QueryingAA,
145 const Value &V);
146
147/// Return true if \p V is a valid value in \p Scope, that is a constant or an
148/// instruction/argument of \p Scope.
149bool isValidInScope(const Value &V, const Function *Scope);
150
151/// Return true if \p V is a valid value at position \p CtxI, that is a
152/// constant, an argument of the same function as \p CtxI, or an instruction in
153/// that function that dominates \p CtxI.
154bool isValidAtPosition(const Value &V, const Instruction &CtxI,
155 InformationCache &InfoCache);
156
157/// Try to convert \p V to type \p Ty without introducing new instructions. If
158/// this is not possible return `nullptr`. Note: this function basically knows
159/// how to cast various constants.
160Value *getWithType(Value &V, Type &Ty);
161
162/// Return the combination of \p A and \p B such that the result is a possible
163/// value of both. \p B is potentially casted to match the type \p Ty or the
164/// type of \p A if \p Ty is null.
165///
166/// Examples:
167/// X + none => X
168/// not_none + undef => not_none
169/// V1 + V2 => nullptr
170Optional<Value *>
171combineOptionalValuesInAAValueLatice(const Optional<Value *> &A,
172 const Optional<Value *> &B, Type *Ty);
173
174/// Return the initial value of \p Obj with type \p Ty if that is a constant.
175Constant *getInitialValueForObj(Value &Obj, Type &Ty);
176
177/// Collect all potential underlying objects of \p Ptr at position \p CtxI in
178/// \p Objects. Assumed information is used and dependences onto \p QueryingAA
179/// are added appropriately.
180///
181/// \returns True if \p Objects contains all assumed underlying objects, and
182/// false if something went wrong and the objects could not be
183/// determined.
184bool getAssumedUnderlyingObjects(Attributor &A, const Value &Ptr,
185 SmallVectorImpl<Value *> &Objects,
186 const AbstractAttribute &QueryingAA,
187 const Instruction *CtxI);
188
189/// Collect all potential values of the one stored by \p SI into
190/// \p PotentialCopies. That is, the only copies that were made via the
191/// store are assumed to be known and all in \p PotentialCopies. Dependences
192/// onto \p QueryingAA are properly tracked, \p UsedAssumedInformation will
193/// inform the caller if assumed information was used.
194///
195/// \returns True if the assumed potential copies are all in \p PotentialCopies,
196/// false if something went wrong and the copies could not be
197/// determined.
198bool getPotentialCopiesOfStoredValue(
199 Attributor &A, StoreInst &SI, SmallSetVector<Value *, 4> &PotentialCopies,
200 const AbstractAttribute &QueryingAA, bool &UsedAssumedInformation);
201
202} // namespace AA
203
204/// The value passed to the line option that defines the maximal initialization
205/// chain length.
206extern unsigned MaxInitializationChainLength;
207
208///{
209enum class ChangeStatus {
210 CHANGED,
211 UNCHANGED,
212};
213
214ChangeStatus operator|(ChangeStatus l, ChangeStatus r);
215ChangeStatus &operator|=(ChangeStatus &l, ChangeStatus r);
216ChangeStatus operator&(ChangeStatus l, ChangeStatus r);
217ChangeStatus &operator&=(ChangeStatus &l, ChangeStatus r);
218
219enum class DepClassTy {
220 REQUIRED, ///< The target cannot be valid if the source is not.
221 OPTIONAL, ///< The target may be valid if the source is not.
222 NONE, ///< Do not track a dependence between source and target.
223};
224///}
225
226/// The data structure for the nodes of a dependency graph
227struct AADepGraphNode {
228public:
229 virtual ~AADepGraphNode(){};
230 using DepTy = PointerIntPair<AADepGraphNode *, 1>;
231
232protected:
233 /// Set of dependency graph nodes which should be updated if this one
234 /// is updated. The bit encodes if it is optional.
235 TinyPtrVector<DepTy> Deps;
236
237 static AADepGraphNode *DepGetVal(DepTy &DT) { return DT.getPointer(); }
238 static AbstractAttribute *DepGetValAA(DepTy &DT) {
239 return cast<AbstractAttribute>(DT.getPointer());
240 }
241
242 operator AbstractAttribute *() { return cast<AbstractAttribute>(this); }
243
244public:
245 using iterator =
246 mapped_iterator<TinyPtrVector<DepTy>::iterator, decltype(&DepGetVal)>;
247 using aaiterator =
248 mapped_iterator<TinyPtrVector<DepTy>::iterator, decltype(&DepGetValAA)>;
249
250 aaiterator begin() { return aaiterator(Deps.begin(), &DepGetValAA); }
251 aaiterator end() { return aaiterator(Deps.end(), &DepGetValAA); }
252 iterator child_begin() { return iterator(Deps.begin(), &DepGetVal); }
253 iterator child_end() { return iterator(Deps.end(), &DepGetVal); }
254
255 virtual void print(raw_ostream &OS) const { OS << "AADepNode Impl\n"; }
256 TinyPtrVector<DepTy> &getDeps() { return Deps; }
257
258 friend struct Attributor;
259 friend struct AADepGraph;
260};
261
262/// The data structure for the dependency graph
263///
264/// Note that in this graph if there is an edge from A to B (A -> B),
265/// then it means that B depends on A, and when the state of A is
266/// updated, node B should also be updated
267struct AADepGraph {
268 AADepGraph() {}
269 ~AADepGraph() {}
270
271 using DepTy = AADepGraphNode::DepTy;
272 static AADepGraphNode *DepGetVal(DepTy &DT) { return DT.getPointer(); }
273 using iterator =
274 mapped_iterator<TinyPtrVector<DepTy>::iterator, decltype(&DepGetVal)>;
275
276 /// There is no root node for the dependency graph. But the SCCIterator
277 /// requires a single entry point, so we maintain a fake("synthetic") root
278 /// node that depends on every node.
279 AADepGraphNode SyntheticRoot;
280 AADepGraphNode *GetEntryNode() { return &SyntheticRoot; }
281
282 iterator begin() { return SyntheticRoot.child_begin(); }
283 iterator end() { return SyntheticRoot.child_end(); }
284
285 void viewGraph();
286
287 /// Dump graph to file
288 void dumpGraph();
289
290 /// Print dependency graph
291 void print();
292};
293
294/// Helper to describe and deal with positions in the LLVM-IR.
295///
296/// A position in the IR is described by an anchor value and an "offset" that
297/// could be the argument number, for call sites and arguments, or an indicator
298/// of the "position kind". The kinds, specified in the Kind enum below, include
299/// the locations in the attribute list, i.a., function scope and return value,
300/// as well as a distinction between call sites and functions. Finally, there
301/// are floating values that do not have a corresponding attribute list
302/// position.
303struct IRPosition {
304 // NOTE: In the future this definition can be changed to support recursive
305 // functions.
306 using CallBaseContext = CallBase;
307
308 /// The positions we distinguish in the IR.
309 enum Kind : char {
310 IRP_INVALID, ///< An invalid position.
311 IRP_FLOAT, ///< A position that is not associated with a spot suitable
312 ///< for attributes. This could be any value or instruction.
313 IRP_RETURNED, ///< An attribute for the function return value.
314 IRP_CALL_SITE_RETURNED, ///< An attribute for a call site return value.
315 IRP_FUNCTION, ///< An attribute for a function (scope).
316 IRP_CALL_SITE, ///< An attribute for a call site (function scope).
317 IRP_ARGUMENT, ///< An attribute for a function argument.
318 IRP_CALL_SITE_ARGUMENT, ///< An attribute for a call site argument.
319 };
320
321 /// Default constructor available to create invalid positions implicitly. All
322 /// other positions need to be created explicitly through the appropriate
323 /// static member function.
324 IRPosition() : Enc(nullptr, ENC_VALUE) { verify(); }
325
326 /// Create a position describing the value of \p V.
327 static const IRPosition value(const Value &V,
328 const CallBaseContext *CBContext = nullptr) {
329 if (auto *Arg = dyn_cast<Argument>(&V))
330 return IRPosition::argument(*Arg, CBContext);
331 if (auto *CB = dyn_cast<CallBase>(&V))
332 return IRPosition::callsite_returned(*CB);
333 return IRPosition(const_cast<Value &>(V), IRP_FLOAT, CBContext);
334 }
335
336 /// Create a position describing the function scope of \p F.
337 /// \p CBContext is used for call base specific analysis.
338 static const IRPosition function(const Function &F,
339 const CallBaseContext *CBContext = nullptr) {
340 return IRPosition(const_cast<Function &>(F), IRP_FUNCTION, CBContext);
341 }
342
343 /// Create a position describing the returned value of \p F.
344 /// \p CBContext is used for call base specific analysis.
345 static const IRPosition returned(const Function &F,
346 const CallBaseContext *CBContext = nullptr) {
347 return IRPosition(const_cast<Function &>(F), IRP_RETURNED, CBContext);
348 }
349
350 /// Create a position describing the argument \p Arg.
351 /// \p CBContext is used for call base specific analysis.
352 static const IRPosition argument(const Argument &Arg,
353 const CallBaseContext *CBContext = nullptr) {
354 return IRPosition(const_cast<Argument &>(Arg), IRP_ARGUMENT, CBContext);
355 }
356
357 /// Create a position describing the function scope of \p CB.
358 static const IRPosition callsite_function(const CallBase &CB) {
359 return IRPosition(const_cast<CallBase &>(CB), IRP_CALL_SITE);
360 }
361
362 /// Create a position describing the returned value of \p CB.
363 static const IRPosition callsite_returned(const CallBase &CB) {
364 return IRPosition(const_cast<CallBase &>(CB), IRP_CALL_SITE_RETURNED);
365 }
366
367 /// Create a position describing the argument of \p CB at position \p ArgNo.
368 static const IRPosition callsite_argument(const CallBase &CB,
369 unsigned ArgNo) {
370 return IRPosition(const_cast<Use &>(CB.getArgOperandUse(ArgNo)),
371 IRP_CALL_SITE_ARGUMENT);
372 }
373
374 /// Create a position describing the argument of \p ACS at position \p ArgNo.
375 static const IRPosition callsite_argument(AbstractCallSite ACS,
376 unsigned ArgNo) {
377 if (ACS.getNumArgOperands() <= ArgNo)
378 return IRPosition();
379 int CSArgNo = ACS.getCallArgOperandNo(ArgNo);
380 if (CSArgNo >= 0)
381 return IRPosition::callsite_argument(
382 cast<CallBase>(*ACS.getInstruction()), CSArgNo);
383 return IRPosition();
384 }
385
386 /// Create a position with function scope matching the "context" of \p IRP.
387 /// If \p IRP is a call site (see isAnyCallSitePosition()) then the result
388 /// will be a call site position, otherwise the function position of the
389 /// associated function.
390 static const IRPosition
391 function_scope(const IRPosition &IRP,
392 const CallBaseContext *CBContext = nullptr) {
393 if (IRP.isAnyCallSitePosition()) {
394 return IRPosition::callsite_function(
395 cast<CallBase>(IRP.getAnchorValue()));
396 }
397 assert(IRP.getAssociatedFunction())(static_cast <bool> (IRP.getAssociatedFunction()) ? void
(0) : __assert_fail ("IRP.getAssociatedFunction()", "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/include/llvm/Transforms/IPO/Attributor.h"
, 397, __extension__ __PRETTY_FUNCTION__))
;
398 return IRPosition::function(*IRP.getAssociatedFunction(), CBContext);
399 }
400
401 bool operator==(const IRPosition &RHS) const {
402 return Enc == RHS.Enc && RHS.CBContext == CBContext;
403 }
404 bool operator!=(const IRPosition &RHS) const { return !(*this == RHS); }
405
406 /// Return the value this abstract attribute is anchored with.
407 ///
408 /// The anchor value might not be the associated value if the latter is not
409 /// sufficient to determine where arguments will be manifested. This is, so
410 /// far, only the case for call site arguments as the value is not sufficient
411 /// to pinpoint them. Instead, we can use the call site as an anchor.
412 Value &getAnchorValue() const {
413 switch (getEncodingBits()) {
414 case ENC_VALUE:
415 case ENC_RETURNED_VALUE:
416 case ENC_FLOATING_FUNCTION:
417 return *getAsValuePtr();
418 case ENC_CALL_SITE_ARGUMENT_USE:
419 return *(getAsUsePtr()->getUser());
420 default:
421 llvm_unreachable("Unkown encoding!")::llvm::llvm_unreachable_internal("Unkown encoding!", "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/include/llvm/Transforms/IPO/Attributor.h"
, 421)
;
422 };
423 }
424
425 /// Return the associated function, if any.
426 Function *getAssociatedFunction() const {
427 if (auto *CB = dyn_cast<CallBase>(&getAnchorValue())) {
428 // We reuse the logic that associates callback calles to arguments of a
429 // call site here to identify the callback callee as the associated
430 // function.
431 if (Argument *Arg = getAssociatedArgument())
432 return Arg->getParent();
433 return CB->getCalledFunction();
434 }
435 return getAnchorScope();
436 }
437
438 /// Return the associated argument, if any.
439 Argument *getAssociatedArgument() const;
440
441 /// Return true if the position refers to a function interface, that is the
442 /// function scope, the function return, or an argument.
443 bool isFnInterfaceKind() const {
444 switch (getPositionKind()) {
445 case IRPosition::IRP_FUNCTION:
446 case IRPosition::IRP_RETURNED:
447 case IRPosition::IRP_ARGUMENT:
448 return true;
449 default:
450 return false;
451 }
452 }
453
454 /// Return the Function surrounding the anchor value.
455 Function *getAnchorScope() const {
456 Value &V = getAnchorValue();
457 if (isa<Function>(V))
458 return &cast<Function>(V);
459 if (isa<Argument>(V))
460 return cast<Argument>(V).getParent();
461 if (isa<Instruction>(V))
462 return cast<Instruction>(V).getFunction();
463 return nullptr;
464 }
465
466 /// Return the context instruction, if any.
467 Instruction *getCtxI() const {
468 Value &V = getAnchorValue();
469 if (auto *I = dyn_cast<Instruction>(&V))
470 return I;
471 if (auto *Arg = dyn_cast<Argument>(&V))
472 if (!Arg->getParent()->isDeclaration())
473 return &Arg->getParent()->getEntryBlock().front();
474 if (auto *F = dyn_cast<Function>(&V))
475 if (!F->isDeclaration())
476 return &(F->getEntryBlock().front());
477 return nullptr;
478 }
479
480 /// Return the value this abstract attribute is associated with.
481 Value &getAssociatedValue() const {
482 if (getCallSiteArgNo() < 0 || isa<Argument>(&getAnchorValue()))
483 return getAnchorValue();
484 assert(isa<CallBase>(&getAnchorValue()) && "Expected a call base!")(static_cast <bool> (isa<CallBase>(&getAnchorValue
()) && "Expected a call base!") ? void (0) : __assert_fail
("isa<CallBase>(&getAnchorValue()) && \"Expected a call base!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/include/llvm/Transforms/IPO/Attributor.h"
, 484, __extension__ __PRETTY_FUNCTION__))
;
485 return *cast<CallBase>(&getAnchorValue())
486 ->getArgOperand(getCallSiteArgNo());
487 }
488
489 /// Return the type this abstract attribute is associated with.
490 Type *getAssociatedType() const {
491 if (getPositionKind() == IRPosition::IRP_RETURNED)
492 return getAssociatedFunction()->getReturnType();
493 return getAssociatedValue().getType();
494 }
495
496 /// Return the callee argument number of the associated value if it is an
497 /// argument or call site argument, otherwise a negative value. In contrast to
498 /// `getCallSiteArgNo` this method will always return the "argument number"
499 /// from the perspective of the callee. This may not the same as the call site
500 /// if this is a callback call.
501 int getCalleeArgNo() const {
502 return getArgNo(/* CallbackCalleeArgIfApplicable */ true);
503 }
504
505 /// Return the call site argument number of the associated value if it is an
506 /// argument or call site argument, otherwise a negative value. In contrast to
507 /// `getCalleArgNo` this method will always return the "operand number" from
508 /// the perspective of the call site. This may not the same as the callee
509 /// perspective if this is a callback call.
510 int getCallSiteArgNo() const {
511 return getArgNo(/* CallbackCalleeArgIfApplicable */ false);
512 }
513
514 /// Return the index in the attribute list for this position.
515 unsigned getAttrIdx() const {
516 switch (getPositionKind()) {
517 case IRPosition::IRP_INVALID:
518 case IRPosition::IRP_FLOAT:
519 break;
520 case IRPosition::IRP_FUNCTION:
521 case IRPosition::IRP_CALL_SITE:
522 return AttributeList::FunctionIndex;
523 case IRPosition::IRP_RETURNED:
524 case IRPosition::IRP_CALL_SITE_RETURNED:
525 return AttributeList::ReturnIndex;
526 case IRPosition::IRP_ARGUMENT:
527 case IRPosition::IRP_CALL_SITE_ARGUMENT:
528 return getCallSiteArgNo() + AttributeList::FirstArgIndex;
529 }
530 llvm_unreachable(::llvm::llvm_unreachable_internal("There is no attribute index for a floating or invalid position!"
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/include/llvm/Transforms/IPO/Attributor.h"
, 531)
531 "There is no attribute index for a floating or invalid position!")::llvm::llvm_unreachable_internal("There is no attribute index for a floating or invalid position!"
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/include/llvm/Transforms/IPO/Attributor.h"
, 531)
;
532 }
533
534 /// Return the associated position kind.
535 Kind getPositionKind() const {
536 char EncodingBits = getEncodingBits();
537 if (EncodingBits == ENC_CALL_SITE_ARGUMENT_USE)
538 return IRP_CALL_SITE_ARGUMENT;
539 if (EncodingBits == ENC_FLOATING_FUNCTION)
540 return IRP_FLOAT;
541
542 Value *V = getAsValuePtr();
543 if (!V)
544 return IRP_INVALID;
545 if (isa<Argument>(V))
546 return IRP_ARGUMENT;
547 if (isa<Function>(V))
548 return isReturnPosition(EncodingBits) ? IRP_RETURNED : IRP_FUNCTION;
549 if (isa<CallBase>(V))
550 return isReturnPosition(EncodingBits) ? IRP_CALL_SITE_RETURNED
551 : IRP_CALL_SITE;
552 return IRP_FLOAT;
553 }
554
555 /// TODO: Figure out if the attribute related helper functions should live
556 /// here or somewhere else.
557
558 /// Return true if any kind in \p AKs existing in the IR at a position that
559 /// will affect this one. See also getAttrs(...).
560 /// \param IgnoreSubsumingPositions Flag to determine if subsuming positions,
561 /// e.g., the function position if this is an
562 /// argument position, should be ignored.
563 bool hasAttr(ArrayRef<Attribute::AttrKind> AKs,
564 bool IgnoreSubsumingPositions = false,
565 Attributor *A = nullptr) const;
566
567 /// Return the attributes of any kind in \p AKs existing in the IR at a
568 /// position that will affect this one. While each position can only have a
569 /// single attribute of any kind in \p AKs, there are "subsuming" positions
570 /// that could have an attribute as well. This method returns all attributes
571 /// found in \p Attrs.
572 /// \param IgnoreSubsumingPositions Flag to determine if subsuming positions,
573 /// e.g., the function position if this is an
574 /// argument position, should be ignored.
575 void getAttrs(ArrayRef<Attribute::AttrKind> AKs,
576 SmallVectorImpl<Attribute> &Attrs,
577 bool IgnoreSubsumingPositions = false,
578 Attributor *A = nullptr) const;
579
580 /// Remove the attribute of kind \p AKs existing in the IR at this position.
581 void removeAttrs(ArrayRef<Attribute::AttrKind> AKs) const {
582 if (getPositionKind() == IRP_INVALID || getPositionKind() == IRP_FLOAT)
583 return;
584
585 AttributeList AttrList;
586 auto *CB = dyn_cast<CallBase>(&getAnchorValue());
587 if (CB)
588 AttrList = CB->getAttributes();
589 else
590 AttrList = getAssociatedFunction()->getAttributes();
591
592 LLVMContext &Ctx = getAnchorValue().getContext();
593 for (Attribute::AttrKind AK : AKs)
594 AttrList = AttrList.removeAttributeAtIndex(Ctx, getAttrIdx(), AK);
595
596 if (CB)
597 CB->setAttributes(AttrList);
598 else
599 getAssociatedFunction()->setAttributes(AttrList);
600 }
601
602 bool isAnyCallSitePosition() const {
603 switch (getPositionKind()) {
604 case IRPosition::IRP_CALL_SITE:
605 case IRPosition::IRP_CALL_SITE_RETURNED:
606 case IRPosition::IRP_CALL_SITE_ARGUMENT:
607 return true;
608 default:
609 return false;
610 }
611 }
612
613 /// Return true if the position is an argument or call site argument.
614 bool isArgumentPosition() const {
615 switch (getPositionKind()) {
616 case IRPosition::IRP_ARGUMENT:
617 case IRPosition::IRP_CALL_SITE_ARGUMENT:
618 return true;
619 default:
620 return false;
621 }
622 }
623
624 /// Return the same position without the call base context.
625 IRPosition stripCallBaseContext() const {
626 IRPosition Result = *this;
627 Result.CBContext = nullptr;
628 return Result;
629 }
630
631 /// Get the call base context from the position.
632 const CallBaseContext *getCallBaseContext() const { return CBContext; }
633
634 /// Check if the position has any call base context.
635 bool hasCallBaseContext() const { return CBContext != nullptr; }
636
637 /// Special DenseMap key values.
638 ///
639 ///{
640 static const IRPosition EmptyKey;
641 static const IRPosition TombstoneKey;
642 ///}
643
644 /// Conversion into a void * to allow reuse of pointer hashing.
645 operator void *() const { return Enc.getOpaqueValue(); }
646
647private:
648 /// Private constructor for special values only!
649 explicit IRPosition(void *Ptr, const CallBaseContext *CBContext = nullptr)
650 : CBContext(CBContext) {
651 Enc.setFromOpaqueValue(Ptr);
652 }
653
654 /// IRPosition anchored at \p AnchorVal with kind/argument numbet \p PK.
655 explicit IRPosition(Value &AnchorVal, Kind PK,
656 const CallBaseContext *CBContext = nullptr)
657 : CBContext(CBContext) {
658 switch (PK) {
659 case IRPosition::IRP_INVALID:
660 llvm_unreachable("Cannot create invalid IRP with an anchor value!")::llvm::llvm_unreachable_internal("Cannot create invalid IRP with an anchor value!"
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/include/llvm/Transforms/IPO/Attributor.h"
, 660)
;
661 break;
662 case IRPosition::IRP_FLOAT:
663 // Special case for floating functions.
664 if (isa<Function>(AnchorVal))
665 Enc = {&AnchorVal, ENC_FLOATING_FUNCTION};
666 else
667 Enc = {&AnchorVal, ENC_VALUE};
668 break;
669 case IRPosition::IRP_FUNCTION:
670 case IRPosition::IRP_CALL_SITE:
671 Enc = {&AnchorVal, ENC_VALUE};
672 break;
673 case IRPosition::IRP_RETURNED:
674 case IRPosition::IRP_CALL_SITE_RETURNED:
675 Enc = {&AnchorVal, ENC_RETURNED_VALUE};
676 break;
677 case IRPosition::IRP_ARGUMENT:
678 Enc = {&AnchorVal, ENC_VALUE};
679 break;
680 case IRPosition::IRP_CALL_SITE_ARGUMENT:
681 llvm_unreachable(::llvm::llvm_unreachable_internal("Cannot create call site argument IRP with an anchor value!"
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/include/llvm/Transforms/IPO/Attributor.h"
, 682)
682 "Cannot create call site argument IRP with an anchor value!")::llvm::llvm_unreachable_internal("Cannot create call site argument IRP with an anchor value!"
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/include/llvm/Transforms/IPO/Attributor.h"
, 682)
;
683 break;
684 }
685 verify();
686 }
687
688 /// Return the callee argument number of the associated value if it is an
689 /// argument or call site argument. See also `getCalleeArgNo` and
690 /// `getCallSiteArgNo`.
691 int getArgNo(bool CallbackCalleeArgIfApplicable) const {
692 if (CallbackCalleeArgIfApplicable)
693 if (Argument *Arg = getAssociatedArgument())
694 return Arg->getArgNo();
695 switch (getPositionKind()) {
696 case IRPosition::IRP_ARGUMENT:
697 return cast<Argument>(getAsValuePtr())->getArgNo();
698 case IRPosition::IRP_CALL_SITE_ARGUMENT: {
699 Use &U = *getAsUsePtr();
700 return cast<CallBase>(U.getUser())->getArgOperandNo(&U);
701 }
702 default:
703 return -1;
704 }
705 }
706
707 /// IRPosition for the use \p U. The position kind \p PK needs to be
708 /// IRP_CALL_SITE_ARGUMENT, the anchor value is the user, the associated value
709 /// the used value.
710 explicit IRPosition(Use &U, Kind PK) {
711 assert(PK == IRP_CALL_SITE_ARGUMENT &&(static_cast <bool> (PK == IRP_CALL_SITE_ARGUMENT &&
"Use constructor is for call site arguments only!") ? void (
0) : __assert_fail ("PK == IRP_CALL_SITE_ARGUMENT && \"Use constructor is for call site arguments only!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/include/llvm/Transforms/IPO/Attributor.h"
, 712, __extension__ __PRETTY_FUNCTION__))
712 "Use constructor is for call site arguments only!")(static_cast <bool> (PK == IRP_CALL_SITE_ARGUMENT &&
"Use constructor is for call site arguments only!") ? void (
0) : __assert_fail ("PK == IRP_CALL_SITE_ARGUMENT && \"Use constructor is for call site arguments only!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/include/llvm/Transforms/IPO/Attributor.h"
, 712, __extension__ __PRETTY_FUNCTION__))
;
713 Enc = {&U, ENC_CALL_SITE_ARGUMENT_USE};
714 verify();
715 }
716
717 /// Verify internal invariants.
718 void verify();
719
720 /// Return the attributes of kind \p AK existing in the IR as attribute.
721 bool getAttrsFromIRAttr(Attribute::AttrKind AK,
722 SmallVectorImpl<Attribute> &Attrs) const;
723
724 /// Return the attributes of kind \p AK existing in the IR as operand bundles
725 /// of an llvm.assume.
726 bool getAttrsFromAssumes(Attribute::AttrKind AK,
727 SmallVectorImpl<Attribute> &Attrs,
728 Attributor &A) const;
729
730 /// Return the underlying pointer as Value *, valid for all positions but
731 /// IRP_CALL_SITE_ARGUMENT.
732 Value *getAsValuePtr() const {
733 assert(getEncodingBits() != ENC_CALL_SITE_ARGUMENT_USE &&(static_cast <bool> (getEncodingBits() != ENC_CALL_SITE_ARGUMENT_USE
&& "Not a value pointer!") ? void (0) : __assert_fail
("getEncodingBits() != ENC_CALL_SITE_ARGUMENT_USE && \"Not a value pointer!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/include/llvm/Transforms/IPO/Attributor.h"
, 734, __extension__ __PRETTY_FUNCTION__))
734 "Not a value pointer!")(static_cast <bool> (getEncodingBits() != ENC_CALL_SITE_ARGUMENT_USE
&& "Not a value pointer!") ? void (0) : __assert_fail
("getEncodingBits() != ENC_CALL_SITE_ARGUMENT_USE && \"Not a value pointer!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/include/llvm/Transforms/IPO/Attributor.h"
, 734, __extension__ __PRETTY_FUNCTION__))
;
735 return reinterpret_cast<Value *>(Enc.getPointer());
736 }
737
738 /// Return the underlying pointer as Use *, valid only for
739 /// IRP_CALL_SITE_ARGUMENT positions.
740 Use *getAsUsePtr() const {
741 assert(getEncodingBits() == ENC_CALL_SITE_ARGUMENT_USE &&(static_cast <bool> (getEncodingBits() == ENC_CALL_SITE_ARGUMENT_USE
&& "Not a value pointer!") ? void (0) : __assert_fail
("getEncodingBits() == ENC_CALL_SITE_ARGUMENT_USE && \"Not a value pointer!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/include/llvm/Transforms/IPO/Attributor.h"
, 742, __extension__ __PRETTY_FUNCTION__))
742 "Not a value pointer!")(static_cast <bool> (getEncodingBits() == ENC_CALL_SITE_ARGUMENT_USE
&& "Not a value pointer!") ? void (0) : __assert_fail
("getEncodingBits() == ENC_CALL_SITE_ARGUMENT_USE && \"Not a value pointer!\""
, "/build/llvm-toolchain-snapshot-14~++20210926122410+d23fd8ae8906/llvm/include/llvm/Transforms/IPO/Attributor.h"
, 742, __extension__ __PRETTY_FUNCTION__))
;
743 return reinterpret_cast<Use *>(Enc.getPointer());
744 }
745
746 /// Return true if \p EncodingBits describe a returned or call site returned
747 /// position.
748 static bool isReturnPosition(char EncodingBits) {
749 return EncodingBits == ENC_RETURNED_VALUE;
750 }
751
752 /// Return true if the encoding bits describe a returned or call site returned
753 /// position.
754 bool isReturnPosition() const { return isReturnPosition(getEncodingBits()); }
755
756 /// The encoding of the IRPosition is a combination of a pointer and two
757 /// encoding bits. The values of the encoding bits are defined in the enum
758 /// below. The pointer is either a Value* (for the first three encoding bit
759 /// combinations) or Use* (for ENC_CALL_SITE_ARGUMENT_USE).
760 ///
761 ///{
762 enum {
763 ENC_VALUE = 0b00,
764 ENC_RETURNED_VALUE = 0b01,
765 ENC_FLOATING_FUNCTION = 0b10,
766 ENC_CALL_SITE_ARGUMENT_USE = 0b11,
767 };
768
769 // Reserve the maximal amount of bits so there is no need to mask out the
770 // remaining ones. We will not encode anything else in the pointer anyway.
771 static constexpr int NumEncodingBits =
772 PointerLikeTypeTraits<void *>::NumLowBitsAvailable;
773 static_assert(NumEncodingBits >= 2, "At least two bits are required!");
774
775 /// The pointer with the encoding bits.
776 PointerIntPair<void *, NumEncodingBits, char> Enc;
777 ///}
778
779 /// Call base context. Used for callsite specific analysis.
780 const CallBaseContext *CBContext = nullptr;
781
782 /// Return the encoding bits.
783 char getEncodingBits() const { return Enc.getInt(); }
784};
785
786/// Helper that allows IRPosition as a key in a DenseMap.
787template <> struct DenseMapInfo<IRPosition> {
788 static inline IRPosition getEmptyKey() { return IRPosition::EmptyKey; }
789 static inline IRPosition getTombstoneKey() {
790 return IRPosition::TombstoneKey;
791 }
792 static unsigned getHashValue(const IRPosition &IRP) {
793 return (DenseMapInfo<void *>::getHashValue(IRP) << 4) ^
794 (DenseMapInfo<Value *>::getHashValue(IRP.getCallBaseContext()));
795 }
796
797 static bool isEqual(const IRPosition &a, const IRPosition &b) {
798 return a == b;
799 }
800};
801
802/// A visitor class for IR positions.
803///
804/// Given a position P, the SubsumingPositionIterator allows to visit "subsuming
805/// positions" wrt. attributes/information. Thus, if a piece of information
806/// holds for a subsuming position, it also holds for the position P.
807///
808/// The subsuming positions always include the initial position and then,
809/// depending on the position kind, additionally the following ones:
810/// - for IRP_RETURNED:
811/// - the function (IRP_FUNCTION)
812/// - for IRP_ARGUMENT:
813/// - the function (IRP_FUNCTION)
814/// - for IRP_CALL_SITE:
815/// - the callee (IRP_FUNCTION), if known
816/// - for IRP_CALL_SITE_RETURNED:
817/// - the callee (IRP_RETURNED), if known
818/// - the call site (IRP_FUNCTION)
819/// - the callee (IRP_FUNCTION), if known
820/// - for IRP_CALL_SITE_ARGUMENT:
821/// - the argument of the callee (IRP_ARGUMENT), if known
822/// - the callee (IRP_FUNCTION), if known
823/// - the position the call site argument is associated with if it is not
824/// anchored to the call site, e.g., if it is an argument then the argument
825/// (IRP_ARGUMENT)
826class SubsumingPositionIterator {
827 SmallVector<IRPosition, 4> IRPositions;
828 using iterator = decltype(IRPositions)::iterator;
829
830public:
831 SubsumingPositionIterator(const IRPosition &IRP);
832 iterator begin() { return IRPositions.begin(); }
833 iterator end() { return IRPositions.end(); }
834};
835
836/// Wrapper for FunctoinAnalysisManager.
837struct AnalysisGetter {
838 template <typename Analysis>
839 typename Analysis::Result *getAnalysis(const Function &F) {
840 if (!FAM || !F.getParent())
841 return nullptr;
842 return &FAM->getResult<Analysis>(const_cast<Function &>(F));
843 }
844
845 AnalysisGetter(FunctionAnalysisManager &FAM) : FAM(&FAM) {}
846 AnalysisGetter() {}
847
848private:
849 FunctionAnalysisManager *FAM = nullptr;
850};
851
852/// Data structure to hold cached (LLVM-IR) information.
853///
854/// All attributes are given an InformationCache object at creation time to
855/// avoid inspection of the IR by all of them individually. This default
856/// InformationCache will hold information required by 'default' attributes,
857/// thus the ones deduced when Attributor::identifyDefaultAbstractAttributes(..)
858/// is called.
859///
860/// If custom abstract attributes, registered manually through
861/// Attributor::registerAA(...), need more information, especially if it is not
862/// reusable, it is advised to inherit from the InformationCache and cast the
863/// instance down in the abstract attributes.
864struct InformationCache {
865 InformationCache(const Module &M, AnalysisGetter &AG,
866 BumpPtrAllocator &Allocator, SetVector<Function *> *CGSCC)
867 : DL(M.getDataLayout()), Allocator(Allocator),
868 Explorer(
869 /* ExploreInterBlock */ true, /* ExploreCFGForward */ true,
870 /* ExploreCFGBackward */ true,
871 /* LIGetter */
872 [&](const Function &F) { return AG.getAnalysis<LoopAnalysis>(F); },
873 /* DTGetter */
874 [&](const Function &F) {
875 return AG.getAnalysis<DominatorTreeAnalysis>(F);
876 },
877 /* PDTGetter */
878 [&](const Function &F) {
879 return AG.getAnalysis<PostDominatorTreeAnalysis>(F);
880 }),
881 AG(AG), CGSCC(CGSCC), TargetTriple(M.getTargetTriple()) {
882 if (CGSCC)
883 initializeModuleSlice(*CGSCC);
884 }
885
886 ~InformationCache() {
887 // The FunctionInfo objects are allocated via a BumpPtrAllocator, we call
888 // the destructor manually.
889 for (auto &It : FuncInfoMap)
890 It.getSecond()->~FunctionInfo();
891 }
892
893 /// Apply \p CB to all uses of \p F. If \p LookThroughConstantExprUses is
894 /// true, constant expression users are not given to \p CB but their uses are
895 /// traversed transitively.
896 template <typename CBTy>
897 static void foreachUse(Function &F, CBTy CB,
898 bool LookThroughConstantExprUses = true) {
899 SmallVector<Use *, 8> Worklist(make_pointer_range(F.uses()));
900
901 for (unsigned Idx = 0; Idx < Worklist.size(); ++Idx) {
902 Use &U = *Worklist[Idx];
903
904 // Allow use in constant bitcasts and simply look through them.
905 if (LookThroughConstantExprUses && isa<ConstantExpr>(U.getUser())) {
906 for (Use &CEU : cast<ConstantExpr>(U.getUser())->uses())
907 Worklist.push_back(&CEU);
908 continue;
909 }
910
911 CB(U);
912 }
913 }
914
915 /// Initialize the ModuleSlice member based on \p SCC. ModuleSlices contains
916 /// (a subset of) all functions that we can look at during this SCC traversal.
917 /// This includes functions (transitively) called from the SCC and the
918 /// (transitive) callers of SCC functions. We also can look at a function if
919 /// there is a "reference edge", i.a., if the function somehow uses (!=calls)
920 /// a function in the SCC or a caller of a function in the SCC.
921 void initializeModuleSlice(SetVector<Function *> &SCC) {
922 ModuleSlice.insert(SCC.begin(), SCC.end());
923
924 SmallPtrSet<Function *, 16> Seen;
925 SmallVector<Function *, 16> Worklist(SCC.begin(), SCC.end());
926 while (!Worklist.empty()) {
927 Function *F = Worklist.pop_back_val();
928 ModuleSlice.insert(F);
929
930 for (Instruction &I : instructions(*F))
931 if (auto *CB = dyn_cast<CallBase>(&I))