File: | build/llvm-toolchain-snapshot-16~++20220821100726+abce7acebd4c/mlir/lib/IR/Operation.cpp |
Warning: | line 81, column 3 Potential leak of memory pointed to by 'mallocMem' |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===- Operation.cpp - Operation support code -----------------------------===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | ||||
9 | #include "mlir/IR/Operation.h" | |||
10 | #include "mlir/IR/BlockAndValueMapping.h" | |||
11 | #include "mlir/IR/BuiltinTypes.h" | |||
12 | #include "mlir/IR/Dialect.h" | |||
13 | #include "mlir/IR/OpImplementation.h" | |||
14 | #include "mlir/IR/PatternMatch.h" | |||
15 | #include "mlir/IR/TypeUtilities.h" | |||
16 | #include "mlir/Interfaces/FoldInterfaces.h" | |||
17 | #include "llvm/ADT/StringExtras.h" | |||
18 | #include <numeric> | |||
19 | ||||
20 | using namespace mlir; | |||
21 | ||||
22 | //===----------------------------------------------------------------------===// | |||
23 | // Operation | |||
24 | //===----------------------------------------------------------------------===// | |||
25 | ||||
26 | /// Create a new Operation from operation state. | |||
27 | Operation *Operation::create(const OperationState &state) { | |||
28 | return create(state.location, state.name, state.types, state.operands, | |||
29 | state.attributes.getDictionary(state.getContext()), | |||
30 | state.successors, state.regions); | |||
31 | } | |||
32 | ||||
33 | /// Create a new Operation with the specific fields. | |||
34 | Operation *Operation::create(Location location, OperationName name, | |||
35 | TypeRange resultTypes, ValueRange operands, | |||
36 | NamedAttrList &&attributes, BlockRange successors, | |||
37 | RegionRange regions) { | |||
38 | unsigned numRegions = regions.size(); | |||
39 | Operation *op = create(location, name, resultTypes, operands, | |||
40 | std::move(attributes), successors, numRegions); | |||
41 | for (unsigned i = 0; i < numRegions; ++i) | |||
42 | if (regions[i]) | |||
43 | op->getRegion(i).takeBody(*regions[i]); | |||
44 | return op; | |||
45 | } | |||
46 | ||||
47 | /// Overload of create that takes an existing DictionaryAttr to avoid | |||
48 | /// unnecessarily uniquing a list of attributes. | |||
49 | Operation *Operation::create(Location location, OperationName name, | |||
50 | TypeRange resultTypes, ValueRange operands, | |||
51 | NamedAttrList &&attributes, BlockRange successors, | |||
52 | unsigned numRegions) { | |||
53 | assert(llvm::all_of(resultTypes, [](Type t) { return t; }) &&(static_cast <bool> (llvm::all_of(resultTypes, [](Type t ) { return t; }) && "unexpected null result type") ? void (0) : __assert_fail ("llvm::all_of(resultTypes, [](Type t) { return t; }) && \"unexpected null result type\"" , "mlir/lib/IR/Operation.cpp", 54, __extension__ __PRETTY_FUNCTION__ )) | |||
54 | "unexpected null result type")(static_cast <bool> (llvm::all_of(resultTypes, [](Type t ) { return t; }) && "unexpected null result type") ? void (0) : __assert_fail ("llvm::all_of(resultTypes, [](Type t) { return t; }) && \"unexpected null result type\"" , "mlir/lib/IR/Operation.cpp", 54, __extension__ __PRETTY_FUNCTION__ )); | |||
55 | ||||
56 | // We only need to allocate additional memory for a subset of results. | |||
57 | unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size()); | |||
58 | unsigned numInlineResults = OpResult::getNumInline(resultTypes.size()); | |||
59 | unsigned numSuccessors = successors.size(); | |||
60 | unsigned numOperands = operands.size(); | |||
61 | unsigned numResults = resultTypes.size(); | |||
62 | ||||
63 | // If the operation is known to have no operands, don't allocate an operand | |||
64 | // storage. | |||
65 | bool needsOperandStorage = | |||
66 | operands.empty() ? !name.hasTrait<OpTrait::ZeroOperands>() : true; | |||
67 | ||||
68 | // Compute the byte size for the operation and the operand storage. This takes | |||
69 | // into account the size of the operation, its trailing objects, and its | |||
70 | // prefixed objects. | |||
71 | size_t byteSize = | |||
72 | totalSizeToAlloc<detail::OperandStorage, BlockOperand, Region, OpOperand>( | |||
73 | needsOperandStorage
| |||
74 | size_t prefixByteSize = llvm::alignTo( | |||
75 | Operation::prefixAllocSize(numTrailingResults, numInlineResults), | |||
76 | alignof(Operation)); | |||
77 | char *mallocMem = reinterpret_cast<char *>(malloc(byteSize + prefixByteSize)); | |||
78 | void *rawMem = mallocMem + prefixByteSize; | |||
79 | ||||
80 | // Create the new Operation. | |||
81 | Operation *op = ::new (rawMem) Operation( | |||
| ||||
82 | location, name, numResults, numSuccessors, numRegions, | |||
83 | attributes.getDictionary(location.getContext()), needsOperandStorage); | |||
84 | ||||
85 | assert((numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) &&(static_cast <bool> ((numSuccessors == 0 || op->mightHaveTrait <OpTrait::IsTerminator>()) && "unexpected successors in a non-terminator operation" ) ? void (0) : __assert_fail ("(numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) && \"unexpected successors in a non-terminator operation\"" , "mlir/lib/IR/Operation.cpp", 86, __extension__ __PRETTY_FUNCTION__ )) | |||
86 | "unexpected successors in a non-terminator operation")(static_cast <bool> ((numSuccessors == 0 || op->mightHaveTrait <OpTrait::IsTerminator>()) && "unexpected successors in a non-terminator operation" ) ? void (0) : __assert_fail ("(numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) && \"unexpected successors in a non-terminator operation\"" , "mlir/lib/IR/Operation.cpp", 86, __extension__ __PRETTY_FUNCTION__ )); | |||
87 | ||||
88 | // Initialize the results. | |||
89 | auto resultTypeIt = resultTypes.begin(); | |||
90 | for (unsigned i = 0; i < numInlineResults; ++i, ++resultTypeIt) | |||
91 | new (op->getInlineOpResult(i)) detail::InlineOpResult(*resultTypeIt, i); | |||
92 | for (unsigned i = 0; i < numTrailingResults; ++i, ++resultTypeIt) { | |||
93 | new (op->getOutOfLineOpResult(i)) | |||
94 | detail::OutOfLineOpResult(*resultTypeIt, i); | |||
95 | } | |||
96 | ||||
97 | // Initialize the regions. | |||
98 | for (unsigned i = 0; i != numRegions; ++i) | |||
99 | new (&op->getRegion(i)) Region(op); | |||
100 | ||||
101 | // Initialize the operands. | |||
102 | if (needsOperandStorage) { | |||
103 | new (&op->getOperandStorage()) detail::OperandStorage( | |||
104 | op, op->getTrailingObjects<OpOperand>(), operands); | |||
105 | } | |||
106 | ||||
107 | // Initialize the successors. | |||
108 | auto blockOperands = op->getBlockOperands(); | |||
109 | for (unsigned i = 0; i != numSuccessors; ++i) | |||
110 | new (&blockOperands[i]) BlockOperand(op, successors[i]); | |||
111 | ||||
112 | return op; | |||
113 | } | |||
114 | ||||
115 | Operation::Operation(Location location, OperationName name, unsigned numResults, | |||
116 | unsigned numSuccessors, unsigned numRegions, | |||
117 | DictionaryAttr attributes, bool hasOperandStorage) | |||
118 | : location(location), numResults(numResults), numSuccs(numSuccessors), | |||
119 | numRegions(numRegions), hasOperandStorage(hasOperandStorage), name(name), | |||
120 | attrs(attributes) { | |||
121 | assert(attributes && "unexpected null attribute dictionary")(static_cast <bool> (attributes && "unexpected null attribute dictionary" ) ? void (0) : __assert_fail ("attributes && \"unexpected null attribute dictionary\"" , "mlir/lib/IR/Operation.cpp", 121, __extension__ __PRETTY_FUNCTION__ )); | |||
122 | #ifndef NDEBUG | |||
123 | if (!getDialect() && !getContext()->allowsUnregisteredDialects()) | |||
124 | llvm::report_fatal_error( | |||
125 | name.getStringRef() + | |||
126 | " created with unregistered dialect. If this is intended, please call " | |||
127 | "allowUnregisteredDialects() on the MLIRContext, or use " | |||
128 | "-allow-unregistered-dialect with the MLIR tool used."); | |||
129 | #endif | |||
130 | } | |||
131 | ||||
132 | // Operations are deleted through the destroy() member because they are | |||
133 | // allocated via malloc. | |||
134 | Operation::~Operation() { | |||
135 | assert(block == nullptr && "operation destroyed but still in a block")(static_cast <bool> (block == nullptr && "operation destroyed but still in a block" ) ? void (0) : __assert_fail ("block == nullptr && \"operation destroyed but still in a block\"" , "mlir/lib/IR/Operation.cpp", 135, __extension__ __PRETTY_FUNCTION__ )); | |||
136 | #ifndef NDEBUG | |||
137 | if (!use_empty()) { | |||
138 | { | |||
139 | InFlightDiagnostic diag = | |||
140 | emitOpError("operation destroyed but still has uses"); | |||
141 | for (Operation *user : getUsers()) | |||
142 | diag.attachNote(user->getLoc()) << "- use: " << *user << "\n"; | |||
143 | } | |||
144 | llvm::report_fatal_error("operation destroyed but still has uses"); | |||
145 | } | |||
146 | #endif | |||
147 | // Explicitly run the destructors for the operands. | |||
148 | if (hasOperandStorage) | |||
149 | getOperandStorage().~OperandStorage(); | |||
150 | ||||
151 | // Explicitly run the destructors for the successors. | |||
152 | for (auto &successor : getBlockOperands()) | |||
153 | successor.~BlockOperand(); | |||
154 | ||||
155 | // Explicitly destroy the regions. | |||
156 | for (auto ®ion : getRegions()) | |||
157 | region.~Region(); | |||
158 | } | |||
159 | ||||
160 | /// Destroy this operation or one of its subclasses. | |||
161 | void Operation::destroy() { | |||
162 | // Operations may have additional prefixed allocation, which needs to be | |||
163 | // accounted for here when computing the address to free. | |||
164 | char *rawMem = reinterpret_cast<char *>(this) - | |||
165 | llvm::alignTo(prefixAllocSize(), alignof(Operation)); | |||
166 | this->~Operation(); | |||
167 | free(rawMem); | |||
168 | } | |||
169 | ||||
170 | /// Return true if this operation is a proper ancestor of the `other` | |||
171 | /// operation. | |||
172 | bool Operation::isProperAncestor(Operation *other) { | |||
173 | while ((other = other->getParentOp())) | |||
174 | if (this == other) | |||
175 | return true; | |||
176 | return false; | |||
177 | } | |||
178 | ||||
179 | /// Replace any uses of 'from' with 'to' within this operation. | |||
180 | void Operation::replaceUsesOfWith(Value from, Value to) { | |||
181 | if (from == to) | |||
182 | return; | |||
183 | for (auto &operand : getOpOperands()) | |||
184 | if (operand.get() == from) | |||
185 | operand.set(to); | |||
186 | } | |||
187 | ||||
188 | /// Replace the current operands of this operation with the ones provided in | |||
189 | /// 'operands'. | |||
190 | void Operation::setOperands(ValueRange operands) { | |||
191 | if (LLVM_LIKELY(hasOperandStorage)__builtin_expect((bool)(hasOperandStorage), true)) | |||
192 | return getOperandStorage().setOperands(this, operands); | |||
193 | assert(operands.empty() && "setting operands without an operand storage")(static_cast <bool> (operands.empty() && "setting operands without an operand storage" ) ? void (0) : __assert_fail ("operands.empty() && \"setting operands without an operand storage\"" , "mlir/lib/IR/Operation.cpp", 193, __extension__ __PRETTY_FUNCTION__ )); | |||
194 | } | |||
195 | ||||
196 | /// Replace the operands beginning at 'start' and ending at 'start' + 'length' | |||
197 | /// with the ones provided in 'operands'. 'operands' may be smaller or larger | |||
198 | /// than the range pointed to by 'start'+'length'. | |||
199 | void Operation::setOperands(unsigned start, unsigned length, | |||
200 | ValueRange operands) { | |||
201 | assert((start + length) <= getNumOperands() &&(static_cast <bool> ((start + length) <= getNumOperands () && "invalid operand range specified") ? void (0) : __assert_fail ("(start + length) <= getNumOperands() && \"invalid operand range specified\"" , "mlir/lib/IR/Operation.cpp", 202, __extension__ __PRETTY_FUNCTION__ )) | |||
202 | "invalid operand range specified")(static_cast <bool> ((start + length) <= getNumOperands () && "invalid operand range specified") ? void (0) : __assert_fail ("(start + length) <= getNumOperands() && \"invalid operand range specified\"" , "mlir/lib/IR/Operation.cpp", 202, __extension__ __PRETTY_FUNCTION__ )); | |||
203 | if (LLVM_LIKELY(hasOperandStorage)__builtin_expect((bool)(hasOperandStorage), true)) | |||
204 | return getOperandStorage().setOperands(this, start, length, operands); | |||
205 | assert(operands.empty() && "setting operands without an operand storage")(static_cast <bool> (operands.empty() && "setting operands without an operand storage" ) ? void (0) : __assert_fail ("operands.empty() && \"setting operands without an operand storage\"" , "mlir/lib/IR/Operation.cpp", 205, __extension__ __PRETTY_FUNCTION__ )); | |||
206 | } | |||
207 | ||||
208 | /// Insert the given operands into the operand list at the given 'index'. | |||
209 | void Operation::insertOperands(unsigned index, ValueRange operands) { | |||
210 | if (LLVM_LIKELY(hasOperandStorage)__builtin_expect((bool)(hasOperandStorage), true)) | |||
211 | return setOperands(index, /*length=*/0, operands); | |||
212 | assert(operands.empty() && "inserting operands without an operand storage")(static_cast <bool> (operands.empty() && "inserting operands without an operand storage" ) ? void (0) : __assert_fail ("operands.empty() && \"inserting operands without an operand storage\"" , "mlir/lib/IR/Operation.cpp", 212, __extension__ __PRETTY_FUNCTION__ )); | |||
213 | } | |||
214 | ||||
215 | //===----------------------------------------------------------------------===// | |||
216 | // Diagnostics | |||
217 | //===----------------------------------------------------------------------===// | |||
218 | ||||
219 | /// Emit an error about fatal conditions with this operation, reporting up to | |||
220 | /// any diagnostic handlers that may be listening. | |||
221 | InFlightDiagnostic Operation::emitError(const Twine &message) { | |||
222 | InFlightDiagnostic diag = mlir::emitError(getLoc(), message); | |||
223 | if (getContext()->shouldPrintOpOnDiagnostic()) { | |||
224 | diag.attachNote(getLoc()) | |||
225 | .append("see current operation: ") | |||
226 | .appendOp(*this, OpPrintingFlags().printGenericOpForm()); | |||
227 | } | |||
228 | return diag; | |||
229 | } | |||
230 | ||||
231 | /// Emit a warning about this operation, reporting up to any diagnostic | |||
232 | /// handlers that may be listening. | |||
233 | InFlightDiagnostic Operation::emitWarning(const Twine &message) { | |||
234 | InFlightDiagnostic diag = mlir::emitWarning(getLoc(), message); | |||
235 | if (getContext()->shouldPrintOpOnDiagnostic()) | |||
236 | diag.attachNote(getLoc()) << "see current operation: " << *this; | |||
237 | return diag; | |||
238 | } | |||
239 | ||||
240 | /// Emit a remark about this operation, reporting up to any diagnostic | |||
241 | /// handlers that may be listening. | |||
242 | InFlightDiagnostic Operation::emitRemark(const Twine &message) { | |||
243 | InFlightDiagnostic diag = mlir::emitRemark(getLoc(), message); | |||
244 | if (getContext()->shouldPrintOpOnDiagnostic()) | |||
245 | diag.attachNote(getLoc()) << "see current operation: " << *this; | |||
246 | return diag; | |||
247 | } | |||
248 | ||||
249 | //===----------------------------------------------------------------------===// | |||
250 | // Operation Ordering | |||
251 | //===----------------------------------------------------------------------===// | |||
252 | ||||
253 | constexpr unsigned Operation::kInvalidOrderIdx; | |||
254 | constexpr unsigned Operation::kOrderStride; | |||
255 | ||||
256 | /// Given an operation 'other' that is within the same parent block, return | |||
257 | /// whether the current operation is before 'other' in the operation list | |||
258 | /// of the parent block. | |||
259 | /// Note: This function has an average complexity of O(1), but worst case may | |||
260 | /// take O(N) where N is the number of operations within the parent block. | |||
261 | bool Operation::isBeforeInBlock(Operation *other) { | |||
262 | assert(block && "Operations without parent blocks have no order.")(static_cast <bool> (block && "Operations without parent blocks have no order." ) ? void (0) : __assert_fail ("block && \"Operations without parent blocks have no order.\"" , "mlir/lib/IR/Operation.cpp", 262, __extension__ __PRETTY_FUNCTION__ )); | |||
263 | assert(other && other->block == block &&(static_cast <bool> (other && other->block == block && "Expected other operation to have the same parent block." ) ? void (0) : __assert_fail ("other && other->block == block && \"Expected other operation to have the same parent block.\"" , "mlir/lib/IR/Operation.cpp", 264, __extension__ __PRETTY_FUNCTION__ )) | |||
264 | "Expected other operation to have the same parent block.")(static_cast <bool> (other && other->block == block && "Expected other operation to have the same parent block." ) ? void (0) : __assert_fail ("other && other->block == block && \"Expected other operation to have the same parent block.\"" , "mlir/lib/IR/Operation.cpp", 264, __extension__ __PRETTY_FUNCTION__ )); | |||
265 | // If the order of the block is already invalid, directly recompute the | |||
266 | // parent. | |||
267 | if (!block->isOpOrderValid()) { | |||
268 | block->recomputeOpOrder(); | |||
269 | } else { | |||
270 | // Update the order either operation if necessary. | |||
271 | updateOrderIfNecessary(); | |||
272 | other->updateOrderIfNecessary(); | |||
273 | } | |||
274 | ||||
275 | return orderIndex < other->orderIndex; | |||
276 | } | |||
277 | ||||
278 | /// Update the order index of this operation of this operation if necessary, | |||
279 | /// potentially recomputing the order of the parent block. | |||
280 | void Operation::updateOrderIfNecessary() { | |||
281 | assert(block && "expected valid parent")(static_cast <bool> (block && "expected valid parent" ) ? void (0) : __assert_fail ("block && \"expected valid parent\"" , "mlir/lib/IR/Operation.cpp", 281, __extension__ __PRETTY_FUNCTION__ )); | |||
282 | ||||
283 | // If the order is valid for this operation there is nothing to do. | |||
284 | if (hasValidOrder()) | |||
285 | return; | |||
286 | Operation *blockFront = &block->front(); | |||
287 | Operation *blockBack = &block->back(); | |||
288 | ||||
289 | // This method is expected to only be invoked on blocks with more than one | |||
290 | // operation. | |||
291 | assert(blockFront != blockBack && "expected more than one operation")(static_cast <bool> (blockFront != blockBack && "expected more than one operation") ? void (0) : __assert_fail ("blockFront != blockBack && \"expected more than one operation\"" , "mlir/lib/IR/Operation.cpp", 291, __extension__ __PRETTY_FUNCTION__ )); | |||
292 | ||||
293 | // If the operation is at the end of the block. | |||
294 | if (this == blockBack) { | |||
295 | Operation *prevNode = getPrevNode(); | |||
296 | if (!prevNode->hasValidOrder()) | |||
297 | return block->recomputeOpOrder(); | |||
298 | ||||
299 | // Add the stride to the previous operation. | |||
300 | orderIndex = prevNode->orderIndex + kOrderStride; | |||
301 | return; | |||
302 | } | |||
303 | ||||
304 | // If this is the first operation try to use the next operation to compute the | |||
305 | // ordering. | |||
306 | if (this == blockFront) { | |||
307 | Operation *nextNode = getNextNode(); | |||
308 | if (!nextNode->hasValidOrder()) | |||
309 | return block->recomputeOpOrder(); | |||
310 | // There is no order to give this operation. | |||
311 | if (nextNode->orderIndex == 0) | |||
312 | return block->recomputeOpOrder(); | |||
313 | ||||
314 | // If we can't use the stride, just take the middle value left. This is safe | |||
315 | // because we know there is at least one valid index to assign to. | |||
316 | if (nextNode->orderIndex <= kOrderStride) | |||
317 | orderIndex = (nextNode->orderIndex / 2); | |||
318 | else | |||
319 | orderIndex = kOrderStride; | |||
320 | return; | |||
321 | } | |||
322 | ||||
323 | // Otherwise, this operation is between two others. Place this operation in | |||
324 | // the middle of the previous and next if possible. | |||
325 | Operation *prevNode = getPrevNode(), *nextNode = getNextNode(); | |||
326 | if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder()) | |||
327 | return block->recomputeOpOrder(); | |||
328 | unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex; | |||
329 | ||||
330 | // Check to see if there is a valid order between the two. | |||
331 | if (prevOrder + 1 == nextOrder) | |||
332 | return block->recomputeOpOrder(); | |||
333 | orderIndex = prevOrder + ((nextOrder - prevOrder) / 2); | |||
334 | } | |||
335 | ||||
336 | //===----------------------------------------------------------------------===// | |||
337 | // ilist_traits for Operation | |||
338 | //===----------------------------------------------------------------------===// | |||
339 | ||||
340 | auto llvm::ilist_detail::SpecificNodeAccess< | |||
341 | typename llvm::ilist_detail::compute_node_options< | |||
342 | ::mlir::Operation>::type>::getNodePtr(pointer n) -> node_type * { | |||
343 | return NodeAccess::getNodePtr<OptionsT>(n); | |||
344 | } | |||
345 | ||||
346 | auto llvm::ilist_detail::SpecificNodeAccess< | |||
347 | typename llvm::ilist_detail::compute_node_options< | |||
348 | ::mlir::Operation>::type>::getNodePtr(const_pointer n) | |||
349 | -> const node_type * { | |||
350 | return NodeAccess::getNodePtr<OptionsT>(n); | |||
351 | } | |||
352 | ||||
353 | auto llvm::ilist_detail::SpecificNodeAccess< | |||
354 | typename llvm::ilist_detail::compute_node_options< | |||
355 | ::mlir::Operation>::type>::getValuePtr(node_type *n) -> pointer { | |||
356 | return NodeAccess::getValuePtr<OptionsT>(n); | |||
357 | } | |||
358 | ||||
359 | auto llvm::ilist_detail::SpecificNodeAccess< | |||
360 | typename llvm::ilist_detail::compute_node_options< | |||
361 | ::mlir::Operation>::type>::getValuePtr(const node_type *n) | |||
362 | -> const_pointer { | |||
363 | return NodeAccess::getValuePtr<OptionsT>(n); | |||
364 | } | |||
365 | ||||
366 | void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) { | |||
367 | op->destroy(); | |||
368 | } | |||
369 | ||||
370 | Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() { | |||
371 | size_t offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); | |||
372 | iplist<Operation> *anchor(static_cast<iplist<Operation> *>(this)); | |||
373 | return reinterpret_cast<Block *>(reinterpret_cast<char *>(anchor) - offset); | |||
374 | } | |||
375 | ||||
376 | /// This is a trait method invoked when an operation is added to a block. We | |||
377 | /// keep the block pointer up to date. | |||
378 | void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) { | |||
379 | assert(!op->getBlock() && "already in an operation block!")(static_cast <bool> (!op->getBlock() && "already in an operation block!" ) ? void (0) : __assert_fail ("!op->getBlock() && \"already in an operation block!\"" , "mlir/lib/IR/Operation.cpp", 379, __extension__ __PRETTY_FUNCTION__ )); | |||
380 | op->block = getContainingBlock(); | |||
381 | ||||
382 | // Invalidate the order on the operation. | |||
383 | op->orderIndex = Operation::kInvalidOrderIdx; | |||
384 | } | |||
385 | ||||
386 | /// This is a trait method invoked when an operation is removed from a block. | |||
387 | /// We keep the block pointer up to date. | |||
388 | void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) { | |||
389 | assert(op->block && "not already in an operation block!")(static_cast <bool> (op->block && "not already in an operation block!" ) ? void (0) : __assert_fail ("op->block && \"not already in an operation block!\"" , "mlir/lib/IR/Operation.cpp", 389, __extension__ __PRETTY_FUNCTION__ )); | |||
390 | op->block = nullptr; | |||
391 | } | |||
392 | ||||
393 | /// This is a trait method invoked when an operation is moved from one block | |||
394 | /// to another. We keep the block pointer up to date. | |||
395 | void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList( | |||
396 | ilist_traits<Operation> &otherList, op_iterator first, op_iterator last) { | |||
397 | Block *curParent = getContainingBlock(); | |||
398 | ||||
399 | // Invalidate the ordering of the parent block. | |||
400 | curParent->invalidateOpOrder(); | |||
401 | ||||
402 | // If we are transferring operations within the same block, the block | |||
403 | // pointer doesn't need to be updated. | |||
404 | if (curParent == otherList.getContainingBlock()) | |||
405 | return; | |||
406 | ||||
407 | // Update the 'block' member of each operation. | |||
408 | for (; first != last; ++first) | |||
409 | first->block = curParent; | |||
410 | } | |||
411 | ||||
412 | /// Remove this operation (and its descendants) from its Block and delete | |||
413 | /// all of them. | |||
414 | void Operation::erase() { | |||
415 | if (auto *parent = getBlock()) | |||
416 | parent->getOperations().erase(this); | |||
417 | else | |||
418 | destroy(); | |||
419 | } | |||
420 | ||||
421 | /// Remove the operation from its parent block, but don't delete it. | |||
422 | void Operation::remove() { | |||
423 | if (Block *parent = getBlock()) | |||
424 | parent->getOperations().remove(this); | |||
425 | } | |||
426 | ||||
427 | /// Unlink this operation from its current block and insert it right before | |||
428 | /// `existingOp` which may be in the same or another block in the same | |||
429 | /// function. | |||
430 | void Operation::moveBefore(Operation *existingOp) { | |||
431 | moveBefore(existingOp->getBlock(), existingOp->getIterator()); | |||
432 | } | |||
433 | ||||
434 | /// Unlink this operation from its current basic block and insert it right | |||
435 | /// before `iterator` in the specified basic block. | |||
436 | void Operation::moveBefore(Block *block, | |||
437 | llvm::iplist<Operation>::iterator iterator) { | |||
438 | block->getOperations().splice(iterator, getBlock()->getOperations(), | |||
439 | getIterator()); | |||
440 | } | |||
441 | ||||
442 | /// Unlink this operation from its current block and insert it right after | |||
443 | /// `existingOp` which may be in the same or another block in the same function. | |||
444 | void Operation::moveAfter(Operation *existingOp) { | |||
445 | moveAfter(existingOp->getBlock(), existingOp->getIterator()); | |||
446 | } | |||
447 | ||||
448 | /// Unlink this operation from its current block and insert it right after | |||
449 | /// `iterator` in the specified block. | |||
450 | void Operation::moveAfter(Block *block, | |||
451 | llvm::iplist<Operation>::iterator iterator) { | |||
452 | assert(iterator != block->end() && "cannot move after end of block")(static_cast <bool> (iterator != block->end() && "cannot move after end of block") ? void (0) : __assert_fail ("iterator != block->end() && \"cannot move after end of block\"" , "mlir/lib/IR/Operation.cpp", 452, __extension__ __PRETTY_FUNCTION__ )); | |||
453 | moveBefore(block, std::next(iterator)); | |||
454 | } | |||
455 | ||||
456 | /// This drops all operand uses from this operation, which is an essential | |||
457 | /// step in breaking cyclic dependences between references when they are to | |||
458 | /// be deleted. | |||
459 | void Operation::dropAllReferences() { | |||
460 | for (auto &op : getOpOperands()) | |||
461 | op.drop(); | |||
462 | ||||
463 | for (auto ®ion : getRegions()) | |||
464 | region.dropAllReferences(); | |||
465 | ||||
466 | for (auto &dest : getBlockOperands()) | |||
467 | dest.drop(); | |||
468 | } | |||
469 | ||||
470 | /// This drops all uses of any values defined by this operation or its nested | |||
471 | /// regions, wherever they are located. | |||
472 | void Operation::dropAllDefinedValueUses() { | |||
473 | dropAllUses(); | |||
474 | ||||
475 | for (auto ®ion : getRegions()) | |||
476 | for (auto &block : region) | |||
477 | block.dropAllDefinedValueUses(); | |||
478 | } | |||
479 | ||||
480 | void Operation::setSuccessor(Block *block, unsigned index) { | |||
481 | assert(index < getNumSuccessors())(static_cast <bool> (index < getNumSuccessors()) ? void (0) : __assert_fail ("index < getNumSuccessors()", "mlir/lib/IR/Operation.cpp" , 481, __extension__ __PRETTY_FUNCTION__)); | |||
482 | getBlockOperands()[index].set(block); | |||
483 | } | |||
484 | ||||
485 | /// Attempt to fold this operation using the Op's registered foldHook. | |||
486 | LogicalResult Operation::fold(ArrayRef<Attribute> operands, | |||
487 | SmallVectorImpl<OpFoldResult> &results) { | |||
488 | // If we have a registered operation definition matching this one, use it to | |||
489 | // try to constant fold the operation. | |||
490 | Optional<RegisteredOperationName> info = getRegisteredInfo(); | |||
491 | if (info && succeeded(info->foldHook(this, operands, results))) | |||
492 | return success(); | |||
493 | ||||
494 | // Otherwise, fall back on the dialect hook to handle it. | |||
495 | Dialect *dialect = getDialect(); | |||
496 | if (!dialect) | |||
497 | return failure(); | |||
498 | ||||
499 | auto *interface = dyn_cast<DialectFoldInterface>(dialect); | |||
500 | if (!interface) | |||
501 | return failure(); | |||
502 | ||||
503 | return interface->fold(this, operands, results); | |||
504 | } | |||
505 | ||||
506 | /// Emit an error with the op name prefixed, like "'dim' op " which is | |||
507 | /// convenient for verifiers. | |||
508 | InFlightDiagnostic Operation::emitOpError(const Twine &message) { | |||
509 | return emitError() << "'" << getName() << "' op " << message; | |||
510 | } | |||
511 | ||||
512 | //===----------------------------------------------------------------------===// | |||
513 | // Operation Cloning | |||
514 | //===----------------------------------------------------------------------===// | |||
515 | ||||
516 | Operation::CloneOptions::CloneOptions() | |||
517 | : cloneRegionsFlag(false), cloneOperandsFlag(false) {} | |||
518 | ||||
519 | Operation::CloneOptions::CloneOptions(bool cloneRegions, bool cloneOperands) | |||
520 | : cloneRegionsFlag(cloneRegions), cloneOperandsFlag(cloneOperands) {} | |||
521 | ||||
522 | Operation::CloneOptions Operation::CloneOptions::all() { | |||
523 | return CloneOptions().cloneRegions().cloneOperands(); | |||
524 | } | |||
525 | ||||
526 | Operation::CloneOptions &Operation::CloneOptions::cloneRegions(bool enable) { | |||
527 | cloneRegionsFlag = enable; | |||
528 | return *this; | |||
529 | } | |||
530 | ||||
531 | Operation::CloneOptions &Operation::CloneOptions::cloneOperands(bool enable) { | |||
532 | cloneOperandsFlag = enable; | |||
533 | return *this; | |||
534 | } | |||
535 | ||||
536 | /// Create a deep copy of this operation but keep the operation regions empty. | |||
537 | /// Operands are remapped using `mapper` (if present), and `mapper` is updated | |||
538 | /// to contain the results. The `mapResults` flag specifies whether the results | |||
539 | /// of the cloned operation should be added to the map. | |||
540 | Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) { | |||
541 | return clone(mapper, CloneOptions::all().cloneRegions(false)); | |||
542 | } | |||
543 | ||||
544 | Operation *Operation::cloneWithoutRegions() { | |||
545 | BlockAndValueMapping mapper; | |||
546 | return cloneWithoutRegions(mapper); | |||
547 | } | |||
548 | ||||
549 | /// Create a deep copy of this operation, remapping any operands that use | |||
550 | /// values outside of the operation using the map that is provided (leaving | |||
551 | /// them alone if no entry is present). Replaces references to cloned | |||
552 | /// sub-operations to the corresponding operation that is copied, and adds | |||
553 | /// those mappings to the map. | |||
554 | Operation *Operation::clone(BlockAndValueMapping &mapper, | |||
555 | CloneOptions options) { | |||
556 | SmallVector<Value, 8> operands; | |||
557 | SmallVector<Block *, 2> successors; | |||
558 | ||||
559 | // Remap the operands. | |||
560 | if (options.shouldCloneOperands()) { | |||
561 | operands.reserve(getNumOperands()); | |||
562 | for (auto opValue : getOperands()) | |||
563 | operands.push_back(mapper.lookupOrDefault(opValue)); | |||
564 | } | |||
565 | ||||
566 | // Remap the successors. | |||
567 | successors.reserve(getNumSuccessors()); | |||
568 | for (Block *successor : getSuccessors()) | |||
569 | successors.push_back(mapper.lookupOrDefault(successor)); | |||
570 | ||||
571 | // Create the new operation. | |||
572 | auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs, | |||
573 | successors, getNumRegions()); | |||
574 | ||||
575 | // Clone the regions. | |||
576 | if (options.shouldCloneRegions()) { | |||
577 | for (unsigned i = 0; i != numRegions; ++i) | |||
578 | getRegion(i).cloneInto(&newOp->getRegion(i), mapper); | |||
579 | } | |||
580 | ||||
581 | // Remember the mapping of any results. | |||
582 | for (unsigned i = 0, e = getNumResults(); i != e; ++i) | |||
583 | mapper.map(getResult(i), newOp->getResult(i)); | |||
584 | ||||
585 | return newOp; | |||
586 | } | |||
587 | ||||
588 | Operation *Operation::clone(CloneOptions options) { | |||
589 | BlockAndValueMapping mapper; | |||
590 | return clone(mapper, options); | |||
| ||||
591 | } | |||
592 | ||||
593 | //===----------------------------------------------------------------------===// | |||
594 | // OpState trait class. | |||
595 | //===----------------------------------------------------------------------===// | |||
596 | ||||
597 | // The fallback for the parser is to try for a dialect operation parser. | |||
598 | // Otherwise, reject the custom assembly form. | |||
599 | ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) { | |||
600 | if (auto parseFn = result.name.getDialect()->getParseOperationHook( | |||
601 | result.name.getStringRef())) | |||
602 | return (*parseFn)(parser, result); | |||
603 | return parser.emitError(parser.getNameLoc(), "has no custom assembly form"); | |||
604 | } | |||
605 | ||||
606 | // The fallback for the printer is to try for a dialect operation printer. | |||
607 | // Otherwise, it prints the generic form. | |||
608 | void OpState::print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect) { | |||
609 | if (auto printFn = op->getDialect()->getOperationPrinter(op)) { | |||
610 | printOpName(op, p, defaultDialect); | |||
611 | printFn(op, p); | |||
612 | } else { | |||
613 | p.printGenericOp(op); | |||
614 | } | |||
615 | } | |||
616 | ||||
617 | /// Print an operation name, eliding the dialect prefix if necessary and doesn't | |||
618 | /// lead to ambiguities. | |||
619 | void OpState::printOpName(Operation *op, OpAsmPrinter &p, | |||
620 | StringRef defaultDialect) { | |||
621 | StringRef name = op->getName().getStringRef(); | |||
622 | if (name.startswith((defaultDialect + ".").str()) && name.count('.') == 1) | |||
623 | name = name.drop_front(defaultDialect.size() + 1); | |||
624 | p.getStream() << name; | |||
625 | } | |||
626 | ||||
627 | /// Emit an error about fatal conditions with this operation, reporting up to | |||
628 | /// any diagnostic handlers that may be listening. | |||
629 | InFlightDiagnostic OpState::emitError(const Twine &message) { | |||
630 | return getOperation()->emitError(message); | |||
631 | } | |||
632 | ||||
633 | /// Emit an error with the op name prefixed, like "'dim' op " which is | |||
634 | /// convenient for verifiers. | |||
635 | InFlightDiagnostic OpState::emitOpError(const Twine &message) { | |||
636 | return getOperation()->emitOpError(message); | |||
637 | } | |||
638 | ||||
639 | /// Emit a warning about this operation, reporting up to any diagnostic | |||
640 | /// handlers that may be listening. | |||
641 | InFlightDiagnostic OpState::emitWarning(const Twine &message) { | |||
642 | return getOperation()->emitWarning(message); | |||
643 | } | |||
644 | ||||
645 | /// Emit a remark about this operation, reporting up to any diagnostic | |||
646 | /// handlers that may be listening. | |||
647 | InFlightDiagnostic OpState::emitRemark(const Twine &message) { | |||
648 | return getOperation()->emitRemark(message); | |||
649 | } | |||
650 | ||||
651 | //===----------------------------------------------------------------------===// | |||
652 | // Op Trait implementations | |||
653 | //===----------------------------------------------------------------------===// | |||
654 | ||||
655 | OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) { | |||
656 | if (op->getNumOperands() == 1) { | |||
657 | auto *argumentOp = op->getOperand(0).getDefiningOp(); | |||
658 | if (argumentOp && op->getName() == argumentOp->getName()) { | |||
659 | // Replace the outer operation output with the inner operation. | |||
660 | return op->getOperand(0); | |||
661 | } | |||
662 | } else if (op->getOperand(0) == op->getOperand(1)) { | |||
663 | return op->getOperand(0); | |||
664 | } | |||
665 | ||||
666 | return {}; | |||
667 | } | |||
668 | ||||
669 | OpFoldResult OpTrait::impl::foldInvolution(Operation *op) { | |||
670 | auto *argumentOp = op->getOperand(0).getDefiningOp(); | |||
671 | if (argumentOp && op->getName() == argumentOp->getName()) { | |||
672 | // Replace the outer involutions output with inner's input. | |||
673 | return argumentOp->getOperand(0); | |||
674 | } | |||
675 | ||||
676 | return {}; | |||
677 | } | |||
678 | ||||
679 | LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) { | |||
680 | if (op->getNumOperands() != 0) | |||
681 | return op->emitOpError() << "requires zero operands"; | |||
682 | return success(); | |||
683 | } | |||
684 | ||||
685 | LogicalResult OpTrait::impl::verifyOneOperand(Operation *op) { | |||
686 | if (op->getNumOperands() != 1) | |||
687 | return op->emitOpError() << "requires a single operand"; | |||
688 | return success(); | |||
689 | } | |||
690 | ||||
691 | LogicalResult OpTrait::impl::verifyNOperands(Operation *op, | |||
692 | unsigned numOperands) { | |||
693 | if (op->getNumOperands() != numOperands) { | |||
694 | return op->emitOpError() << "expected " << numOperands | |||
695 | << " operands, but found " << op->getNumOperands(); | |||
696 | } | |||
697 | return success(); | |||
698 | } | |||
699 | ||||
700 | LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op, | |||
701 | unsigned numOperands) { | |||
702 | if (op->getNumOperands() < numOperands) | |||
703 | return op->emitOpError() | |||
704 | << "expected " << numOperands << " or more operands, but found " | |||
705 | << op->getNumOperands(); | |||
706 | return success(); | |||
707 | } | |||
708 | ||||
709 | /// If this is a vector type, or a tensor type, return the scalar element type | |||
710 | /// that it is built around, otherwise return the type unmodified. | |||
711 | static Type getTensorOrVectorElementType(Type type) { | |||
712 | if (auto vec = type.dyn_cast<VectorType>()) | |||
713 | return vec.getElementType(); | |||
714 | ||||
715 | // Look through tensor<vector<...>> to find the underlying element type. | |||
716 | if (auto tensor = type.dyn_cast<TensorType>()) | |||
717 | return getTensorOrVectorElementType(tensor.getElementType()); | |||
718 | return type; | |||
719 | } | |||
720 | ||||
721 | LogicalResult OpTrait::impl::verifyIsIdempotent(Operation *op) { | |||
722 | // FIXME: Add back check for no side effects on operation. | |||
723 | // Currently adding it would cause the shared library build | |||
724 | // to fail since there would be a dependency of IR on SideEffectInterfaces | |||
725 | // which is cyclical. | |||
726 | return success(); | |||
727 | } | |||
728 | ||||
729 | LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) { | |||
730 | // FIXME: Add back check for no side effects on operation. | |||
731 | // Currently adding it would cause the shared library build | |||
732 | // to fail since there would be a dependency of IR on SideEffectInterfaces | |||
733 | // which is cyclical. | |||
734 | return success(); | |||
735 | } | |||
736 | ||||
737 | LogicalResult | |||
738 | OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) { | |||
739 | for (auto opType : op->getOperandTypes()) { | |||
740 | auto type = getTensorOrVectorElementType(opType); | |||
741 | if (!type.isSignlessIntOrIndex()) | |||
742 | return op->emitOpError() << "requires an integer or index type"; | |||
743 | } | |||
744 | return success(); | |||
745 | } | |||
746 | ||||
747 | LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) { | |||
748 | for (auto opType : op->getOperandTypes()) { | |||
749 | auto type = getTensorOrVectorElementType(opType); | |||
750 | if (!type.isa<FloatType>()) | |||
751 | return op->emitOpError("requires a float type"); | |||
752 | } | |||
753 | return success(); | |||
754 | } | |||
755 | ||||
756 | LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) { | |||
757 | // Zero or one operand always have the "same" type. | |||
758 | unsigned nOperands = op->getNumOperands(); | |||
759 | if (nOperands < 2) | |||
760 | return success(); | |||
761 | ||||
762 | auto type = op->getOperand(0).getType(); | |||
763 | for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) | |||
764 | if (opType != type) | |||
765 | return op->emitOpError() << "requires all operands to have the same type"; | |||
766 | return success(); | |||
767 | } | |||
768 | ||||
769 | LogicalResult OpTrait::impl::verifyZeroRegions(Operation *op) { | |||
770 | if (op->getNumRegions() != 0) | |||
771 | return op->emitOpError() << "requires zero regions"; | |||
772 | return success(); | |||
773 | } | |||
774 | ||||
775 | LogicalResult OpTrait::impl::verifyOneRegion(Operation *op) { | |||
776 | if (op->getNumRegions() != 1) | |||
777 | return op->emitOpError() << "requires one region"; | |||
778 | return success(); | |||
779 | } | |||
780 | ||||
781 | LogicalResult OpTrait::impl::verifyNRegions(Operation *op, | |||
782 | unsigned numRegions) { | |||
783 | if (op->getNumRegions() != numRegions) | |||
784 | return op->emitOpError() << "expected " << numRegions << " regions"; | |||
785 | return success(); | |||
786 | } | |||
787 | ||||
788 | LogicalResult OpTrait::impl::verifyAtLeastNRegions(Operation *op, | |||
789 | unsigned numRegions) { | |||
790 | if (op->getNumRegions() < numRegions) | |||
791 | return op->emitOpError() << "expected " << numRegions << " or more regions"; | |||
792 | return success(); | |||
793 | } | |||
794 | ||||
795 | LogicalResult OpTrait::impl::verifyZeroResults(Operation *op) { | |||
796 | if (op->getNumResults() != 0) | |||
797 | return op->emitOpError() << "requires zero results"; | |||
798 | return success(); | |||
799 | } | |||
800 | ||||
801 | LogicalResult OpTrait::impl::verifyOneResult(Operation *op) { | |||
802 | if (op->getNumResults() != 1) | |||
803 | return op->emitOpError() << "requires one result"; | |||
804 | return success(); | |||
805 | } | |||
806 | ||||
807 | LogicalResult OpTrait::impl::verifyNResults(Operation *op, | |||
808 | unsigned numOperands) { | |||
809 | if (op->getNumResults() != numOperands) | |||
810 | return op->emitOpError() << "expected " << numOperands << " results"; | |||
811 | return success(); | |||
812 | } | |||
813 | ||||
814 | LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op, | |||
815 | unsigned numOperands) { | |||
816 | if (op->getNumResults() < numOperands) | |||
817 | return op->emitOpError() | |||
818 | << "expected " << numOperands << " or more results"; | |||
819 | return success(); | |||
820 | } | |||
821 | ||||
822 | LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) { | |||
823 | if (failed(verifyAtLeastNOperands(op, 1))) | |||
824 | return failure(); | |||
825 | ||||
826 | if (failed(verifyCompatibleShapes(op->getOperandTypes()))) | |||
827 | return op->emitOpError() << "requires the same shape for all operands"; | |||
828 | ||||
829 | return success(); | |||
830 | } | |||
831 | ||||
832 | LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { | |||
833 | if (failed(verifyAtLeastNOperands(op, 1)) || | |||
834 | failed(verifyAtLeastNResults(op, 1))) | |||
835 | return failure(); | |||
836 | ||||
837 | SmallVector<Type, 8> types(op->getOperandTypes()); | |||
838 | types.append(llvm::to_vector<4>(op->getResultTypes())); | |||
839 | ||||
840 | if (failed(verifyCompatibleShapes(types))) | |||
841 | return op->emitOpError() | |||
842 | << "requires the same shape for all operands and results"; | |||
843 | ||||
844 | return success(); | |||
845 | } | |||
846 | ||||
847 | LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) { | |||
848 | if (failed(verifyAtLeastNOperands(op, 1))) | |||
849 | return failure(); | |||
850 | auto elementType = getElementTypeOrSelf(op->getOperand(0)); | |||
851 | ||||
852 | for (auto operand : llvm::drop_begin(op->getOperands(), 1)) { | |||
853 | if (getElementTypeOrSelf(operand) != elementType) | |||
854 | return op->emitOpError("requires the same element type for all operands"); | |||
855 | } | |||
856 | ||||
857 | return success(); | |||
858 | } | |||
859 | ||||
860 | LogicalResult | |||
861 | OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { | |||
862 | if (failed(verifyAtLeastNOperands(op, 1)) || | |||
863 | failed(verifyAtLeastNResults(op, 1))) | |||
864 | return failure(); | |||
865 | ||||
866 | auto elementType = getElementTypeOrSelf(op->getResult(0)); | |||
867 | ||||
868 | // Verify result element type matches first result's element type. | |||
869 | for (auto result : llvm::drop_begin(op->getResults(), 1)) { | |||
870 | if (getElementTypeOrSelf(result) != elementType) | |||
871 | return op->emitOpError( | |||
872 | "requires the same element type for all operands and results"); | |||
873 | } | |||
874 | ||||
875 | // Verify operand's element type matches first result's element type. | |||
876 | for (auto operand : op->getOperands()) { | |||
877 | if (getElementTypeOrSelf(operand) != elementType) | |||
878 | return op->emitOpError( | |||
879 | "requires the same element type for all operands and results"); | |||
880 | } | |||
881 | ||||
882 | return success(); | |||
883 | } | |||
884 | ||||
885 | LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { | |||
886 | if (failed(verifyAtLeastNOperands(op, 1)) || | |||
887 | failed(verifyAtLeastNResults(op, 1))) | |||
888 | return failure(); | |||
889 | ||||
890 | auto type = op->getResult(0).getType(); | |||
891 | auto elementType = getElementTypeOrSelf(type); | |||
892 | for (auto resultType : llvm::drop_begin(op->getResultTypes())) { | |||
893 | if (getElementTypeOrSelf(resultType) != elementType || | |||
894 | failed(verifyCompatibleShape(resultType, type))) | |||
895 | return op->emitOpError() | |||
896 | << "requires the same type for all operands and results"; | |||
897 | } | |||
898 | for (auto opType : op->getOperandTypes()) { | |||
899 | if (getElementTypeOrSelf(opType) != elementType || | |||
900 | failed(verifyCompatibleShape(opType, type))) | |||
901 | return op->emitOpError() | |||
902 | << "requires the same type for all operands and results"; | |||
903 | } | |||
904 | return success(); | |||
905 | } | |||
906 | ||||
907 | LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { | |||
908 | Block *block = op->getBlock(); | |||
909 | // Verify that the operation is at the end of the respective parent block. | |||
910 | if (!block || &block->back() != op) | |||
911 | return op->emitOpError("must be the last operation in the parent block"); | |||
912 | return success(); | |||
913 | } | |||
914 | ||||
915 | static LogicalResult verifyTerminatorSuccessors(Operation *op) { | |||
916 | auto *parent = op->getParentRegion(); | |||
917 | ||||
918 | // Verify that the operands lines up with the BB arguments in the successor. | |||
919 | for (Block *succ : op->getSuccessors()) | |||
920 | if (succ->getParent() != parent) | |||
921 | return op->emitError("reference to block defined in another region"); | |||
922 | return success(); | |||
923 | } | |||
924 | ||||
925 | LogicalResult OpTrait::impl::verifyZeroSuccessors(Operation *op) { | |||
926 | if (op->getNumSuccessors() != 0) { | |||
927 | return op->emitOpError("requires 0 successors but found ") | |||
928 | << op->getNumSuccessors(); | |||
929 | } | |||
930 | return success(); | |||
931 | } | |||
932 | ||||
933 | LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) { | |||
934 | if (op->getNumSuccessors() != 1) { | |||
935 | return op->emitOpError("requires 1 successor but found ") | |||
936 | << op->getNumSuccessors(); | |||
937 | } | |||
938 | return verifyTerminatorSuccessors(op); | |||
939 | } | |||
940 | LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op, | |||
941 | unsigned numSuccessors) { | |||
942 | if (op->getNumSuccessors() != numSuccessors) { | |||
943 | return op->emitOpError("requires ") | |||
944 | << numSuccessors << " successors but found " | |||
945 | << op->getNumSuccessors(); | |||
946 | } | |||
947 | return verifyTerminatorSuccessors(op); | |||
948 | } | |||
949 | LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op, | |||
950 | unsigned numSuccessors) { | |||
951 | if (op->getNumSuccessors() < numSuccessors) { | |||
952 | return op->emitOpError("requires at least ") | |||
953 | << numSuccessors << " successors but found " | |||
954 | << op->getNumSuccessors(); | |||
955 | } | |||
956 | return verifyTerminatorSuccessors(op); | |||
957 | } | |||
958 | ||||
959 | LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) { | |||
960 | for (auto resultType : op->getResultTypes()) { | |||
961 | auto elementType = getTensorOrVectorElementType(resultType); | |||
962 | bool isBoolType = elementType.isInteger(1); | |||
963 | if (!isBoolType) | |||
964 | return op->emitOpError() << "requires a bool result type"; | |||
965 | } | |||
966 | ||||
967 | return success(); | |||
968 | } | |||
969 | ||||
970 | LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) { | |||
971 | for (auto resultType : op->getResultTypes()) | |||
972 | if (!getTensorOrVectorElementType(resultType).isa<FloatType>()) | |||
973 | return op->emitOpError() << "requires a floating point type"; | |||
974 | ||||
975 | return success(); | |||
976 | } | |||
977 | ||||
978 | LogicalResult | |||
979 | OpTrait::impl::verifyResultsAreSignlessIntegerLike(Operation *op) { | |||
980 | for (auto resultType : op->getResultTypes()) | |||
981 | if (!getTensorOrVectorElementType(resultType).isSignlessIntOrIndex()) | |||
982 | return op->emitOpError() << "requires an integer or index type"; | |||
983 | return success(); | |||
984 | } | |||
985 | ||||
986 | LogicalResult OpTrait::impl::verifyValueSizeAttr(Operation *op, | |||
987 | StringRef attrName, | |||
988 | StringRef valueGroupName, | |||
989 | size_t expectedCount) { | |||
990 | auto sizeAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrName); | |||
991 | if (!sizeAttr) | |||
992 | return op->emitOpError("requires dense i32 array attribute '") | |||
993 | << attrName << "'"; | |||
994 | ||||
995 | ArrayRef<int32_t> sizes = sizeAttr.asArrayRef(); | |||
996 | if (llvm::any_of(sizes, [](int32_t element) { return element < 0; })) | |||
997 | return op->emitOpError("'") | |||
998 | << attrName << "' attribute cannot have negative elements"; | |||
999 | ||||
1000 | size_t totalCount = | |||
1001 | std::accumulate(sizes.begin(), sizes.end(), 0, | |||
1002 | [](unsigned all, int32_t one) { return all + one; }); | |||
1003 | ||||
1004 | if (totalCount != expectedCount) | |||
1005 | return op->emitOpError() | |||
1006 | << valueGroupName << " count (" << expectedCount | |||
1007 | << ") does not match with the total size (" << totalCount | |||
1008 | << ") specified in attribute '" << attrName << "'"; | |||
1009 | return success(); | |||
1010 | } | |||
1011 | ||||
1012 | LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op, | |||
1013 | StringRef attrName) { | |||
1014 | return verifyValueSizeAttr(op, attrName, "operand", op->getNumOperands()); | |||
1015 | } | |||
1016 | ||||
1017 | LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op, | |||
1018 | StringRef attrName) { | |||
1019 | return verifyValueSizeAttr(op, attrName, "result", op->getNumResults()); | |||
1020 | } | |||
1021 | ||||
1022 | LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) { | |||
1023 | for (Region ®ion : op->getRegions()) { | |||
1024 | if (region.empty()) | |||
1025 | continue; | |||
1026 | ||||
1027 | if (region.getNumArguments() != 0) { | |||
1028 | if (op->getNumRegions() > 1) | |||
1029 | return op->emitOpError("region #") | |||
1030 | << region.getRegionNumber() << " should have no arguments"; | |||
1031 | return op->emitOpError("region should have no arguments"); | |||
1032 | } | |||
1033 | } | |||
1034 | return success(); | |||
1035 | } | |||
1036 | ||||
1037 | LogicalResult OpTrait::impl::verifyElementwise(Operation *op) { | |||
1038 | auto isMappableType = [](Type type) { | |||
1039 | return type.isa<VectorType, TensorType>(); | |||
1040 | }; | |||
1041 | auto resultMappableTypes = llvm::to_vector<1>( | |||
1042 | llvm::make_filter_range(op->getResultTypes(), isMappableType)); | |||
1043 | auto operandMappableTypes = llvm::to_vector<2>( | |||
1044 | llvm::make_filter_range(op->getOperandTypes(), isMappableType)); | |||
1045 | ||||
1046 | // If the op only has scalar operand/result types, then we have nothing to | |||
1047 | // check. | |||
1048 | if (resultMappableTypes.empty() && operandMappableTypes.empty()) | |||
1049 | return success(); | |||
1050 | ||||
1051 | if (!resultMappableTypes.empty() && operandMappableTypes.empty()) | |||
1052 | return op->emitOpError("if a result is non-scalar, then at least one " | |||
1053 | "operand must be non-scalar"); | |||
1054 | ||||
1055 | assert(!operandMappableTypes.empty())(static_cast <bool> (!operandMappableTypes.empty()) ? void (0) : __assert_fail ("!operandMappableTypes.empty()", "mlir/lib/IR/Operation.cpp" , 1055, __extension__ __PRETTY_FUNCTION__)); | |||
1056 | ||||
1057 | if (resultMappableTypes.empty()) | |||
1058 | return op->emitOpError("if an operand is non-scalar, then there must be at " | |||
1059 | "least one non-scalar result"); | |||
1060 | ||||
1061 | if (resultMappableTypes.size() != op->getNumResults()) | |||
1062 | return op->emitOpError( | |||
1063 | "if an operand is non-scalar, then all results must be non-scalar"); | |||
1064 | ||||
1065 | SmallVector<Type, 4> types = llvm::to_vector<2>( | |||
1066 | llvm::concat<Type>(operandMappableTypes, resultMappableTypes)); | |||
1067 | TypeID expectedBaseTy = types.front().getTypeID(); | |||
1068 | if (!llvm::all_of(types, | |||
1069 | [&](Type t) { return t.getTypeID() == expectedBaseTy; }) || | |||
1070 | failed(verifyCompatibleShapes(types))) { | |||
1071 | return op->emitOpError() << "all non-scalar operands/results must have the " | |||
1072 | "same shape and base type"; | |||
1073 | } | |||
1074 | ||||
1075 | return success(); | |||
1076 | } | |||
1077 | ||||
1078 | /// Check for any values used by operations regions attached to the | |||
1079 | /// specified "IsIsolatedFromAbove" operation defined outside of it. | |||
1080 | LogicalResult OpTrait::impl::verifyIsIsolatedFromAbove(Operation *isolatedOp) { | |||
1081 | assert(isolatedOp->hasTrait<OpTrait::IsIsolatedFromAbove>() &&(static_cast <bool> (isolatedOp->hasTrait<OpTrait ::IsIsolatedFromAbove>() && "Intended to check IsolatedFromAbove ops" ) ? void (0) : __assert_fail ("isolatedOp->hasTrait<OpTrait::IsIsolatedFromAbove>() && \"Intended to check IsolatedFromAbove ops\"" , "mlir/lib/IR/Operation.cpp", 1082, __extension__ __PRETTY_FUNCTION__ )) | |||
1082 | "Intended to check IsolatedFromAbove ops")(static_cast <bool> (isolatedOp->hasTrait<OpTrait ::IsIsolatedFromAbove>() && "Intended to check IsolatedFromAbove ops" ) ? void (0) : __assert_fail ("isolatedOp->hasTrait<OpTrait::IsIsolatedFromAbove>() && \"Intended to check IsolatedFromAbove ops\"" , "mlir/lib/IR/Operation.cpp", 1082, __extension__ __PRETTY_FUNCTION__ )); | |||
1083 | ||||
1084 | // List of regions to analyze. Each region is processed independently, with | |||
1085 | // respect to the common `limit` region, so we can look at them in any order. | |||
1086 | // Therefore, use a simple vector and push/pop back the current region. | |||
1087 | SmallVector<Region *, 8> pendingRegions; | |||
1088 | for (auto ®ion : isolatedOp->getRegions()) { | |||
1089 | pendingRegions.push_back(®ion); | |||
1090 | ||||
1091 | // Traverse all operations in the region. | |||
1092 | while (!pendingRegions.empty()) { | |||
1093 | for (Operation &op : pendingRegions.pop_back_val()->getOps()) { | |||
1094 | for (Value operand : op.getOperands()) { | |||
1095 | // Check that any value that is used by an operation is defined in the | |||
1096 | // same region as either an operation result. | |||
1097 | auto *operandRegion = operand.getParentRegion(); | |||
1098 | if (!operandRegion) | |||
1099 | return op.emitError("operation's operand is unlinked"); | |||
1100 | if (!region.isAncestor(operandRegion)) { | |||
1101 | return op.emitOpError("using value defined outside the region") | |||
1102 | .attachNote(isolatedOp->getLoc()) | |||
1103 | << "required by region isolation constraints"; | |||
1104 | } | |||
1105 | } | |||
1106 | ||||
1107 | // Schedule any regions in the operation for further checking. Don't | |||
1108 | // recurse into other IsolatedFromAbove ops, because they will check | |||
1109 | // themselves. | |||
1110 | if (op.getNumRegions() && | |||
1111 | !op.hasTrait<OpTrait::IsIsolatedFromAbove>()) { | |||
1112 | for (Region &subRegion : op.getRegions()) | |||
1113 | pendingRegions.push_back(&subRegion); | |||
1114 | } | |||
1115 | } | |||
1116 | } | |||
1117 | } | |||
1118 | ||||
1119 | return success(); | |||
1120 | } | |||
1121 | ||||
1122 | bool OpTrait::hasElementwiseMappableTraits(Operation *op) { | |||
1123 | return op->hasTrait<Elementwise>() && op->hasTrait<Scalarizable>() && | |||
1124 | op->hasTrait<Vectorizable>() && op->hasTrait<Tensorizable>(); | |||
1125 | } | |||
1126 | ||||
1127 | //===----------------------------------------------------------------------===// | |||
1128 | // CastOpInterface | |||
1129 | //===----------------------------------------------------------------------===// | |||
1130 | ||||
1131 | /// Attempt to fold the given cast operation. | |||
1132 | LogicalResult | |||
1133 | impl::foldCastInterfaceOp(Operation *op, ArrayRef<Attribute> attrOperands, | |||
1134 | SmallVectorImpl<OpFoldResult> &foldResults) { | |||
1135 | OperandRange operands = op->getOperands(); | |||
1136 | if (operands.empty()) | |||
1137 | return failure(); | |||
1138 | ResultRange results = op->getResults(); | |||
1139 | ||||
1140 | // Check for the case where the input and output types match 1-1. | |||
1141 | if (operands.getTypes() == results.getTypes()) { | |||
1142 | foldResults.append(operands.begin(), operands.end()); | |||
1143 | return success(); | |||
1144 | } | |||
1145 | ||||
1146 | return failure(); | |||
1147 | } | |||
1148 | ||||
1149 | /// Attempt to verify the given cast operation. | |||
1150 | LogicalResult impl::verifyCastInterfaceOp( | |||
1151 | Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible) { | |||
1152 | auto resultTypes = op->getResultTypes(); | |||
1153 | if (llvm::empty(resultTypes)) | |||
1154 | return op->emitOpError() | |||
1155 | << "expected at least one result for cast operation"; | |||
1156 | ||||
1157 | auto operandTypes = op->getOperandTypes(); | |||
1158 | if (!areCastCompatible(operandTypes, resultTypes)) { | |||
1159 | InFlightDiagnostic diag = op->emitOpError("operand type"); | |||
1160 | if (llvm::empty(operandTypes)) | |||
1161 | diag << "s []"; | |||
1162 | else if (llvm::size(operandTypes) == 1) | |||
1163 | diag << " " << *operandTypes.begin(); | |||
1164 | else | |||
1165 | diag << "s " << operandTypes; | |||
1166 | return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ") | |||
1167 | << resultTypes << " are cast incompatible"; | |||
1168 | } | |||
1169 | ||||
1170 | return success(); | |||
1171 | } | |||
1172 | ||||
1173 | //===----------------------------------------------------------------------===// | |||
1174 | // Misc. utils | |||
1175 | //===----------------------------------------------------------------------===// | |||
1176 | ||||
1177 | /// Insert an operation, generated by `buildTerminatorOp`, at the end of the | |||
1178 | /// region's only block if it does not have a terminator already. If the region | |||
1179 | /// is empty, insert a new block first. `buildTerminatorOp` should return the | |||
1180 | /// terminator operation to insert. | |||
1181 | void impl::ensureRegionTerminator( | |||
1182 | Region ®ion, OpBuilder &builder, Location loc, | |||
1183 | function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { | |||
1184 | OpBuilder::InsertionGuard guard(builder); | |||
1185 | if (region.empty()) | |||
1186 | builder.createBlock(®ion); | |||
1187 | ||||
1188 | Block &block = region.back(); | |||
1189 | if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>()) | |||
1190 | return; | |||
1191 | ||||
1192 | builder.setInsertionPointToEnd(&block); | |||
1193 | builder.insert(buildTerminatorOp(builder, loc)); | |||
1194 | } | |||
1195 | ||||
1196 | /// Create a simple OpBuilder and forward to the OpBuilder version of this | |||
1197 | /// function. | |||
1198 | void impl::ensureRegionTerminator( | |||
1199 | Region ®ion, Builder &builder, Location loc, | |||
1200 | function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { | |||
1201 | OpBuilder opBuilder(builder.getContext()); | |||
1202 | ensureRegionTerminator(region, opBuilder, loc, buildTerminatorOp); | |||
1203 | } |