Line data Source code
1 : //===- Standard pass instrumentations handling ----------------*- C++ -*--===//
2 : //
3 : // The LLVM Compiler Infrastructure
4 : //
5 : // This file is distributed under the University of Illinois Open Source
6 : // License. See LICENSE.TXT for details.
7 : //
8 : //===----------------------------------------------------------------------===//
9 : /// \file
10 : ///
11 : /// This file defines IR-printing pass instrumentation callbacks as well as
12 : /// StandardInstrumentations class that manages standard pass instrumentations.
13 : ///
14 : //===----------------------------------------------------------------------===//
15 :
16 : #include "llvm/Passes/StandardInstrumentations.h"
17 : #include "llvm/Analysis/CallGraphSCCPass.h"
18 : #include "llvm/Analysis/LazyCallGraph.h"
19 : #include "llvm/Analysis/LoopInfo.h"
20 : #include "llvm/IR/Function.h"
21 : #include "llvm/IR/IRPrintingPasses.h"
22 : #include "llvm/IR/Module.h"
23 : #include "llvm/IR/PassInstrumentation.h"
24 : #include "llvm/Support/Debug.h"
25 : #include "llvm/Support/FormatVariadic.h"
26 : #include "llvm/Support/raw_ostream.h"
27 :
28 : using namespace llvm;
29 :
30 : namespace {
31 : namespace PrintIR {
32 :
33 : //===----------------------------------------------------------------------===//
34 : // IR-printing instrumentation
35 : //===----------------------------------------------------------------------===//
36 :
37 : /// Generic IR-printing helper that unpacks a pointer to IRUnit wrapped into
38 : /// llvm::Any and does actual print job.
39 50 : void unwrapAndPrint(StringRef Banner, Any IR) {
40 : SmallString<40> Extra{"\n"};
41 : const Module *M = nullptr;
42 50 : if (any_isa<const Module *>(IR)) {
43 : M = any_cast<const Module *>(IR);
44 34 : } else if (any_isa<const Function *>(IR)) {
45 : const Function *F = any_cast<const Function *>(IR);
46 24 : if (!llvm::isFunctionInPrintList(F->getName()))
47 : return;
48 17 : if (!llvm::forcePrintModuleIR()) {
49 11 : dbgs() << Banner << Extra << static_cast<const Value &>(*F);
50 11 : return;
51 : }
52 6 : M = F->getParent();
53 12 : Extra = formatv(" (function: {0})\n", F->getName());
54 10 : } else if (any_isa<const LazyCallGraph::SCC *>(IR)) {
55 : const LazyCallGraph::SCC *C = any_cast<const LazyCallGraph::SCC *>(IR);
56 : assert(C);
57 4 : if (!llvm::forcePrintModuleIR()) {
58 4 : Extra = formatv(" (scc: {0})\n", C->getName());
59 : bool BannerPrinted = false;
60 5 : for (const LazyCallGraph::Node &N : *C) {
61 3 : const Function &F = N.getFunction();
62 3 : if (!F.isDeclaration() && isFunctionInPrintList(F.getName())) {
63 3 : if (!BannerPrinted) {
64 2 : dbgs() << Banner << Extra;
65 : BannerPrinted = true;
66 : }
67 3 : F.print(dbgs());
68 : }
69 : }
70 : return;
71 : }
72 2 : for (const LazyCallGraph::Node &N : *C) {
73 2 : const Function &F = N.getFunction();
74 2 : if (!F.isDeclaration() && isFunctionInPrintList(F.getName())) {
75 2 : M = F.getParent();
76 2 : break;
77 : }
78 : }
79 2 : if (!M)
80 : return;
81 4 : Extra = formatv(" (for scc: {0})\n", C->getName());
82 6 : } else if (any_isa<const Loop *>(IR)) {
83 : const Loop *L = any_cast<const Loop *>(IR);
84 6 : const Function *F = L->getHeader()->getParent();
85 6 : if (!isFunctionInPrintList(F->getName()))
86 : return;
87 4 : if (!llvm::forcePrintModuleIR()) {
88 3 : llvm::printLoop(const_cast<Loop &>(*L), dbgs(), Banner);
89 3 : return;
90 : }
91 1 : M = F->getParent();
92 : {
93 : std::string LoopName;
94 1 : raw_string_ostream ss(LoopName);
95 1 : L->getHeader()->printAsOperand(ss, false);
96 1 : Extra = formatv(" (loop: {0})\n", ss.str());
97 : }
98 : }
99 25 : if (M) {
100 25 : dbgs() << Banner << Extra;
101 25 : M->print(dbgs(), nullptr, false);
102 : } else {
103 0 : llvm_unreachable("Unknown wrapped IR type");
104 : }
105 : }
106 :
107 17 : bool printBeforePass(StringRef PassID, Any IR) {
108 17 : if (!llvm::shouldPrintBeforePass(PassID))
109 : return true;
110 :
111 11 : if (PassID.startswith("PassManager<") || PassID.contains("PassAdaptor<"))
112 9 : return true;
113 :
114 8 : SmallString<20> Banner = formatv("*** IR Dump Before {0} ***", PassID);
115 16 : unwrapAndPrint(Banner, IR);
116 : return true;
117 : }
118 :
119 75 : void printAfterPass(StringRef PassID, Any IR) {
120 75 : if (!llvm::shouldPrintAfterPass(PassID))
121 : return;
122 :
123 53 : if (PassID.startswith("PassManager<") || PassID.contains("PassAdaptor<"))
124 33 : return;
125 :
126 42 : SmallString<20> Banner = formatv("*** IR Dump After {0} ***", PassID);
127 84 : unwrapAndPrint(Banner, IR);
128 : return;
129 : }
130 : } // namespace PrintIR
131 : } // namespace
132 :
133 936 : void StandardInstrumentations::registerCallbacks(
134 : PassInstrumentationCallbacks &PIC) {
135 936 : if (llvm::shouldPrintBeforePass())
136 1 : PIC.registerBeforePassCallback(PrintIR::printBeforePass);
137 936 : if (llvm::shouldPrintAfterPass())
138 7 : PIC.registerAfterPassCallback(PrintIR::printAfterPass);
139 936 : TimePasses.registerCallbacks(PIC);
140 936 : }
|