File: | build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/mlir/lib/Rewrite/PatternApplicator.cpp |
Warning: | line 194, column 7 Called C++ object pointer is null |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | // | |||
9 | // This file implements an applicator that applies pattern rewrites based upon a | |||
10 | // user defined cost model. | |||
11 | // | |||
12 | //===----------------------------------------------------------------------===// | |||
13 | ||||
14 | #include "mlir/Rewrite/PatternApplicator.h" | |||
15 | #include "ByteCode.h" | |||
16 | #include "llvm/Support/Debug.h" | |||
17 | ||||
18 | #define DEBUG_TYPE"pattern-application" "pattern-application" | |||
19 | ||||
20 | using namespace mlir; | |||
21 | using namespace mlir::detail; | |||
22 | ||||
23 | PatternApplicator::PatternApplicator( | |||
24 | const FrozenRewritePatternSet &frozenPatternList) | |||
25 | : frozenPatternList(frozenPatternList) { | |||
26 | if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { | |||
27 | mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>(); | |||
28 | bytecode->initializeMutableState(*mutableByteCodeState); | |||
29 | } | |||
30 | } | |||
31 | PatternApplicator::~PatternApplicator() = default; | |||
32 | ||||
33 | #ifndef NDEBUG | |||
34 | /// Log a message for a pattern that is impossible to match. | |||
35 | static void logImpossibleToMatch(const Pattern &pattern) { | |||
36 | llvm::dbgs() << "Ignoring pattern '" << pattern.getRootKind() | |||
37 | << "' because it is impossible to match or cannot lead " | |||
38 | "to legal IR (by cost model)\n"; | |||
39 | } | |||
40 | ||||
41 | /// Log IR after pattern application. | |||
42 | static Operation *getDumpRootOp(Operation *op) { | |||
43 | return op->getParentWithTrait<mlir::OpTrait::IsIsolatedFromAbove>(); | |||
44 | } | |||
45 | static void logSucessfulPatternApplication(Operation *op) { | |||
46 | llvm::dbgs() << "// *** IR Dump After Pattern Application ***\n"; | |||
47 | op->dump(); | |||
48 | llvm::dbgs() << "\n\n"; | |||
49 | } | |||
50 | #endif | |||
51 | ||||
52 | void PatternApplicator::applyCostModel(CostModel model) { | |||
53 | // Apply the cost model to the bytecode patterns first, and then the native | |||
54 | // patterns. | |||
55 | if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { | |||
56 | for (const auto &it : llvm::enumerate(bytecode->getPatterns())) | |||
57 | mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value())); | |||
58 | } | |||
59 | ||||
60 | // Copy over the patterns so that we can sort by benefit based on the cost | |||
61 | // model. Patterns that are already impossible to match are ignored. | |||
62 | patterns.clear(); | |||
63 | for (const auto &it : frozenPatternList.getOpSpecificNativePatterns()) { | |||
64 | for (const RewritePattern *pattern : it.second) { | |||
65 | if (pattern->getBenefit().isImpossibleToMatch()) | |||
66 | LLVM_DEBUG(logImpossibleToMatch(*pattern))do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("pattern-application")) { logImpossibleToMatch(*pattern); } } while (false); | |||
67 | else | |||
68 | patterns[it.first].push_back(pattern); | |||
69 | } | |||
70 | } | |||
71 | anyOpPatterns.clear(); | |||
72 | for (const RewritePattern &pattern : | |||
73 | frozenPatternList.getMatchAnyOpNativePatterns()) { | |||
74 | if (pattern.getBenefit().isImpossibleToMatch()) | |||
75 | LLVM_DEBUG(logImpossibleToMatch(pattern))do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("pattern-application")) { logImpossibleToMatch(pattern); } } while (false); | |||
76 | else | |||
77 | anyOpPatterns.push_back(&pattern); | |||
78 | } | |||
79 | ||||
80 | // Sort the patterns using the provided cost model. | |||
81 | llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits; | |||
82 | auto cmp = [&benefits](const Pattern *lhs, const Pattern *rhs) { | |||
83 | return benefits[lhs] > benefits[rhs]; | |||
84 | }; | |||
85 | auto processPatternList = [&](SmallVectorImpl<const RewritePattern *> &list) { | |||
86 | // Special case for one pattern in the list, which is the most common case. | |||
87 | if (list.size() == 1) { | |||
88 | if (model(*list.front()).isImpossibleToMatch()) { | |||
89 | LLVM_DEBUG(logImpossibleToMatch(*list.front()))do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("pattern-application")) { logImpossibleToMatch(*list.front() ); } } while (false); | |||
90 | list.clear(); | |||
91 | } | |||
92 | return; | |||
93 | } | |||
94 | ||||
95 | // Collect the dynamic benefits for the current pattern list. | |||
96 | benefits.clear(); | |||
97 | for (const Pattern *pat : list) | |||
98 | benefits.try_emplace(pat, model(*pat)); | |||
99 | ||||
100 | // Sort patterns with highest benefit first, and remove those that are | |||
101 | // impossible to match. | |||
102 | std::stable_sort(list.begin(), list.end(), cmp); | |||
103 | while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) { | |||
104 | LLVM_DEBUG(logImpossibleToMatch(*list.back()))do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("pattern-application")) { logImpossibleToMatch(*list.back()) ; } } while (false); | |||
105 | list.pop_back(); | |||
106 | } | |||
107 | }; | |||
108 | for (auto &it : patterns) | |||
109 | processPatternList(it.second); | |||
110 | processPatternList(anyOpPatterns); | |||
111 | } | |||
112 | ||||
113 | void PatternApplicator::walkAllPatterns( | |||
114 | function_ref<void(const Pattern &)> walk) { | |||
115 | for (const auto &it : frozenPatternList.getOpSpecificNativePatterns()) | |||
116 | for (const auto &pattern : it.second) | |||
117 | walk(*pattern); | |||
118 | for (const Pattern &it : frozenPatternList.getMatchAnyOpNativePatterns()) | |||
119 | walk(it); | |||
120 | if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { | |||
121 | for (const Pattern &it : bytecode->getPatterns()) | |||
122 | walk(it); | |||
123 | } | |||
124 | } | |||
125 | ||||
126 | LogicalResult PatternApplicator::matchAndRewrite( | |||
127 | Operation *op, PatternRewriter &rewriter, | |||
128 | function_ref<bool(const Pattern &)> canApply, | |||
129 | function_ref<void(const Pattern &)> onFailure, | |||
130 | function_ref<LogicalResult(const Pattern &)> onSuccess) { | |||
131 | // Before checking native patterns, first match against the bytecode. This | |||
132 | // won't automatically perform any rewrites so there is no need to worry about | |||
133 | // conflicts. | |||
134 | SmallVector<PDLByteCode::MatchResult, 4> pdlMatches; | |||
135 | const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode(); | |||
| ||||
136 | if (bytecode) | |||
137 | bytecode->match(op, rewriter, pdlMatches, *mutableByteCodeState); | |||
138 | ||||
139 | // Check to see if there are patterns matching this specific operation type. | |||
140 | MutableArrayRef<const RewritePattern *> opPatterns; | |||
141 | auto patternIt = patterns.find(op->getName()); | |||
142 | if (patternIt != patterns.end()) | |||
143 | opPatterns = patternIt->second; | |||
144 | ||||
145 | // Process the patterns for that match the specific operation type, and any | |||
146 | // operation type in an interleaved fashion. | |||
147 | unsigned opIt = 0, opE = opPatterns.size(); | |||
148 | unsigned anyIt = 0, anyE = anyOpPatterns.size(); | |||
149 | unsigned pdlIt = 0, pdlE = pdlMatches.size(); | |||
150 | LogicalResult result = failure(); | |||
151 | do { | |||
152 | // Find the next pattern with the highest benefit. | |||
153 | const Pattern *bestPattern = nullptr; | |||
154 | unsigned *bestPatternIt = &opIt; | |||
155 | const PDLByteCode::MatchResult *pdlMatch = nullptr; | |||
156 | ||||
157 | /// Operation specific patterns. | |||
158 | if (opIt < opE) | |||
159 | bestPattern = opPatterns[opIt]; | |||
160 | /// Operation agnostic patterns. | |||
161 | if (anyIt < anyE && | |||
162 | (!bestPattern || | |||
163 | bestPattern->getBenefit() < anyOpPatterns[anyIt]->getBenefit())) { | |||
164 | bestPatternIt = &anyIt; | |||
165 | bestPattern = anyOpPatterns[anyIt]; | |||
166 | } | |||
167 | /// PDL patterns. | |||
168 | if (pdlIt < pdlE && (!bestPattern || bestPattern->getBenefit() < | |||
169 | pdlMatches[pdlIt].benefit)) { | |||
170 | bestPatternIt = &pdlIt; | |||
171 | pdlMatch = &pdlMatches[pdlIt]; | |||
172 | bestPattern = pdlMatch->pattern; | |||
173 | } | |||
174 | if (!bestPattern) | |||
175 | break; | |||
176 | ||||
177 | // Update the pattern iterator on failure so that this pattern isn't | |||
178 | // attempted again. | |||
179 | ++(*bestPatternIt); | |||
180 | ||||
181 | // Check that the pattern can be applied. | |||
182 | if (canApply && !canApply(*bestPattern)) | |||
183 | continue; | |||
184 | ||||
185 | // Try to match and rewrite this pattern. The patterns are sorted by | |||
186 | // benefit, so if we match we can immediately rewrite. For PDL patterns, the | |||
187 | // match has already been performed, we just need to rewrite. | |||
188 | rewriter.setInsertionPoint(op); | |||
189 | #ifndef NDEBUG | |||
190 | // Operation `op` may be invalidated after applying the rewrite pattern. | |||
191 | Operation *dumpRootOp = getDumpRootOp(op); | |||
192 | #endif | |||
193 | if (pdlMatch
| |||
194 | bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState); | |||
| ||||
195 | result = success(!onSuccess || succeeded(onSuccess(*bestPattern))); | |||
196 | } else { | |||
197 | const auto *pattern = static_cast<const RewritePattern *>(bestPattern); | |||
198 | ||||
199 | LLVM_DEBUG(llvm::dbgs()do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("pattern-application")) { llvm::dbgs() << "Trying to match \"" << pattern->getDebugName() << "\"\n"; } } while (false) | |||
200 | << "Trying to match \"" << pattern->getDebugName() << "\"\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("pattern-application")) { llvm::dbgs() << "Trying to match \"" << pattern->getDebugName() << "\"\n"; } } while (false); | |||
201 | result = pattern->matchAndRewrite(op, rewriter); | |||
202 | LLVM_DEBUG(llvm::dbgs() << "\"" << pattern->getDebugName() << "\" result "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("pattern-application")) { llvm::dbgs() << "\"" << pattern->getDebugName() << "\" result " << succeeded (result) << "\n"; } } while (false) | |||
203 | << succeeded(result) << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("pattern-application")) { llvm::dbgs() << "\"" << pattern->getDebugName() << "\" result " << succeeded (result) << "\n"; } } while (false); | |||
204 | ||||
205 | if (succeeded(result) && onSuccess && failed(onSuccess(*pattern))) | |||
206 | result = failure(); | |||
207 | } | |||
208 | if (succeeded(result)) { | |||
209 | LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp))do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("pattern-application")) { logSucessfulPatternApplication(dumpRootOp ); } } while (false); | |||
210 | break; | |||
211 | } | |||
212 | ||||
213 | // Perform any necessary cleanups. | |||
214 | if (onFailure) | |||
215 | onFailure(*bestPattern); | |||
216 | } while (true); | |||
217 | ||||
218 | if (mutableByteCodeState) | |||
219 | mutableByteCodeState->cleanupAfterMatchAndRewrite(); | |||
220 | return result; | |||
221 | } |