File: | build/source/mlir/lib/IR/Operation.cpp |
Warning: | line 84, column 3 Potential leak of memory pointed to by 'mallocMem' |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===- Operation.cpp - Operation support code -----------------------------===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | ||||
9 | #include "mlir/IR/Operation.h" | |||
10 | #include "mlir/IR/BuiltinTypes.h" | |||
11 | #include "mlir/IR/Dialect.h" | |||
12 | #include "mlir/IR/IRMapping.h" | |||
13 | #include "mlir/IR/OpImplementation.h" | |||
14 | #include "mlir/IR/PatternMatch.h" | |||
15 | #include "mlir/IR/TypeUtilities.h" | |||
16 | #include "mlir/Interfaces/FoldInterfaces.h" | |||
17 | #include "llvm/ADT/StringExtras.h" | |||
18 | #include <numeric> | |||
19 | ||||
20 | using namespace mlir; | |||
21 | ||||
22 | //===----------------------------------------------------------------------===// | |||
23 | // Operation | |||
24 | //===----------------------------------------------------------------------===// | |||
25 | ||||
26 | /// Create a new Operation from operation state. | |||
27 | Operation *Operation::create(const OperationState &state) { | |||
28 | return create(state.location, state.name, state.types, state.operands, | |||
29 | state.attributes.getDictionary(state.getContext()), | |||
30 | state.successors, state.regions); | |||
31 | } | |||
32 | ||||
33 | /// Create a new Operation with the specific fields. | |||
34 | Operation *Operation::create(Location location, OperationName name, | |||
35 | TypeRange resultTypes, ValueRange operands, | |||
36 | NamedAttrList &&attributes, BlockRange successors, | |||
37 | RegionRange regions) { | |||
38 | unsigned numRegions = regions.size(); | |||
39 | Operation *op = create(location, name, resultTypes, operands, | |||
40 | std::move(attributes), successors, numRegions); | |||
41 | for (unsigned i = 0; i < numRegions; ++i) | |||
42 | if (regions[i]) | |||
43 | op->getRegion(i).takeBody(*regions[i]); | |||
44 | return op; | |||
45 | } | |||
46 | ||||
47 | /// Overload of create that takes an existing DictionaryAttr to avoid | |||
48 | /// unnecessarily uniquing a list of attributes. | |||
49 | Operation *Operation::create(Location location, OperationName name, | |||
50 | TypeRange resultTypes, ValueRange operands, | |||
51 | NamedAttrList &&attributes, BlockRange successors, | |||
52 | unsigned numRegions) { | |||
53 | assert(llvm::all_of(resultTypes, [](Type t) { return t; }) &&(static_cast <bool> (llvm::all_of(resultTypes, [](Type t ) { return t; }) && "unexpected null result type") ? void (0) : __assert_fail ("llvm::all_of(resultTypes, [](Type t) { return t; }) && \"unexpected null result type\"" , "mlir/lib/IR/Operation.cpp", 54, __extension__ __PRETTY_FUNCTION__ )) | |||
54 | "unexpected null result type")(static_cast <bool> (llvm::all_of(resultTypes, [](Type t ) { return t; }) && "unexpected null result type") ? void (0) : __assert_fail ("llvm::all_of(resultTypes, [](Type t) { return t; }) && \"unexpected null result type\"" , "mlir/lib/IR/Operation.cpp", 54, __extension__ __PRETTY_FUNCTION__ )); | |||
55 | ||||
56 | // We only need to allocate additional memory for a subset of results. | |||
57 | unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size()); | |||
58 | unsigned numInlineResults = OpResult::getNumInline(resultTypes.size()); | |||
59 | unsigned numSuccessors = successors.size(); | |||
60 | unsigned numOperands = operands.size(); | |||
61 | unsigned numResults = resultTypes.size(); | |||
62 | ||||
63 | // If the operation is known to have no operands, don't allocate an operand | |||
64 | // storage. | |||
65 | bool needsOperandStorage = | |||
66 | operands.empty() ? !name.hasTrait<OpTrait::ZeroOperands>() : true; | |||
67 | ||||
68 | // Compute the byte size for the operation and the operand storage. This takes | |||
69 | // into account the size of the operation, its trailing objects, and its | |||
70 | // prefixed objects. | |||
71 | size_t byteSize = | |||
72 | totalSizeToAlloc<detail::OperandStorage, BlockOperand, Region, OpOperand>( | |||
73 | needsOperandStorage
| |||
74 | size_t prefixByteSize = llvm::alignTo( | |||
75 | Operation::prefixAllocSize(numTrailingResults, numInlineResults), | |||
76 | alignof(Operation)); | |||
77 | char *mallocMem = reinterpret_cast<char *>(malloc(byteSize + prefixByteSize)); | |||
78 | void *rawMem = mallocMem + prefixByteSize; | |||
79 | ||||
80 | // Populate default attributes. | |||
81 | name.populateDefaultAttrs(attributes); | |||
82 | ||||
83 | // Create the new Operation. | |||
84 | Operation *op = ::new (rawMem) Operation( | |||
| ||||
85 | location, name, numResults, numSuccessors, numRegions, | |||
86 | attributes.getDictionary(location.getContext()), needsOperandStorage); | |||
87 | ||||
88 | assert((numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) &&(static_cast <bool> ((numSuccessors == 0 || op->mightHaveTrait <OpTrait::IsTerminator>()) && "unexpected successors in a non-terminator operation" ) ? void (0) : __assert_fail ("(numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) && \"unexpected successors in a non-terminator operation\"" , "mlir/lib/IR/Operation.cpp", 89, __extension__ __PRETTY_FUNCTION__ )) | |||
89 | "unexpected successors in a non-terminator operation")(static_cast <bool> ((numSuccessors == 0 || op->mightHaveTrait <OpTrait::IsTerminator>()) && "unexpected successors in a non-terminator operation" ) ? void (0) : __assert_fail ("(numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) && \"unexpected successors in a non-terminator operation\"" , "mlir/lib/IR/Operation.cpp", 89, __extension__ __PRETTY_FUNCTION__ )); | |||
90 | ||||
91 | // Initialize the results. | |||
92 | auto resultTypeIt = resultTypes.begin(); | |||
93 | for (unsigned i = 0; i < numInlineResults; ++i, ++resultTypeIt) | |||
94 | new (op->getInlineOpResult(i)) detail::InlineOpResult(*resultTypeIt, i); | |||
95 | for (unsigned i = 0; i < numTrailingResults; ++i, ++resultTypeIt) { | |||
96 | new (op->getOutOfLineOpResult(i)) | |||
97 | detail::OutOfLineOpResult(*resultTypeIt, i); | |||
98 | } | |||
99 | ||||
100 | // Initialize the regions. | |||
101 | for (unsigned i = 0; i != numRegions; ++i) | |||
102 | new (&op->getRegion(i)) Region(op); | |||
103 | ||||
104 | // Initialize the operands. | |||
105 | if (needsOperandStorage) { | |||
106 | new (&op->getOperandStorage()) detail::OperandStorage( | |||
107 | op, op->getTrailingObjects<OpOperand>(), operands); | |||
108 | } | |||
109 | ||||
110 | // Initialize the successors. | |||
111 | auto blockOperands = op->getBlockOperands(); | |||
112 | for (unsigned i = 0; i != numSuccessors; ++i) | |||
113 | new (&blockOperands[i]) BlockOperand(op, successors[i]); | |||
114 | ||||
115 | return op; | |||
116 | } | |||
117 | ||||
118 | Operation::Operation(Location location, OperationName name, unsigned numResults, | |||
119 | unsigned numSuccessors, unsigned numRegions, | |||
120 | DictionaryAttr attributes, bool hasOperandStorage) | |||
121 | : location(location), numResults(numResults), numSuccs(numSuccessors), | |||
122 | numRegions(numRegions), hasOperandStorage(hasOperandStorage), name(name), | |||
123 | attrs(attributes) { | |||
124 | assert(attributes && "unexpected null attribute dictionary")(static_cast <bool> (attributes && "unexpected null attribute dictionary" ) ? void (0) : __assert_fail ("attributes && \"unexpected null attribute dictionary\"" , "mlir/lib/IR/Operation.cpp", 124, __extension__ __PRETTY_FUNCTION__ )); | |||
125 | #ifndef NDEBUG | |||
126 | if (!getDialect() && !getContext()->allowsUnregisteredDialects()) | |||
127 | llvm::report_fatal_error( | |||
128 | name.getStringRef() + | |||
129 | " created with unregistered dialect. If this is intended, please call " | |||
130 | "allowUnregisteredDialects() on the MLIRContext, or use " | |||
131 | "-allow-unregistered-dialect with the MLIR tool used."); | |||
132 | #endif | |||
133 | } | |||
134 | ||||
135 | // Operations are deleted through the destroy() member because they are | |||
136 | // allocated via malloc. | |||
137 | Operation::~Operation() { | |||
138 | assert(block == nullptr && "operation destroyed but still in a block")(static_cast <bool> (block == nullptr && "operation destroyed but still in a block" ) ? void (0) : __assert_fail ("block == nullptr && \"operation destroyed but still in a block\"" , "mlir/lib/IR/Operation.cpp", 138, __extension__ __PRETTY_FUNCTION__ )); | |||
139 | #ifndef NDEBUG | |||
140 | if (!use_empty()) { | |||
141 | { | |||
142 | InFlightDiagnostic diag = | |||
143 | emitOpError("operation destroyed but still has uses"); | |||
144 | for (Operation *user : getUsers()) | |||
145 | diag.attachNote(user->getLoc()) << "- use: " << *user << "\n"; | |||
146 | } | |||
147 | llvm::report_fatal_error("operation destroyed but still has uses"); | |||
148 | } | |||
149 | #endif | |||
150 | // Explicitly run the destructors for the operands. | |||
151 | if (hasOperandStorage) | |||
152 | getOperandStorage().~OperandStorage(); | |||
153 | ||||
154 | // Explicitly run the destructors for the successors. | |||
155 | for (auto &successor : getBlockOperands()) | |||
156 | successor.~BlockOperand(); | |||
157 | ||||
158 | // Explicitly destroy the regions. | |||
159 | for (auto ®ion : getRegions()) | |||
160 | region.~Region(); | |||
161 | } | |||
162 | ||||
163 | /// Destroy this operation or one of its subclasses. | |||
164 | void Operation::destroy() { | |||
165 | // Operations may have additional prefixed allocation, which needs to be | |||
166 | // accounted for here when computing the address to free. | |||
167 | char *rawMem = reinterpret_cast<char *>(this) - | |||
168 | llvm::alignTo(prefixAllocSize(), alignof(Operation)); | |||
169 | this->~Operation(); | |||
170 | free(rawMem); | |||
171 | } | |||
172 | ||||
173 | /// Return true if this operation is a proper ancestor of the `other` | |||
174 | /// operation. | |||
175 | bool Operation::isProperAncestor(Operation *other) { | |||
176 | while ((other = other->getParentOp())) | |||
177 | if (this == other) | |||
178 | return true; | |||
179 | return false; | |||
180 | } | |||
181 | ||||
182 | /// Replace any uses of 'from' with 'to' within this operation. | |||
183 | void Operation::replaceUsesOfWith(Value from, Value to) { | |||
184 | if (from == to) | |||
185 | return; | |||
186 | for (auto &operand : getOpOperands()) | |||
187 | if (operand.get() == from) | |||
188 | operand.set(to); | |||
189 | } | |||
190 | ||||
191 | /// Replace the current operands of this operation with the ones provided in | |||
192 | /// 'operands'. | |||
193 | void Operation::setOperands(ValueRange operands) { | |||
194 | if (LLVM_LIKELY(hasOperandStorage)__builtin_expect((bool)(hasOperandStorage), true)) | |||
195 | return getOperandStorage().setOperands(this, operands); | |||
196 | assert(operands.empty() && "setting operands without an operand storage")(static_cast <bool> (operands.empty() && "setting operands without an operand storage" ) ? void (0) : __assert_fail ("operands.empty() && \"setting operands without an operand storage\"" , "mlir/lib/IR/Operation.cpp", 196, __extension__ __PRETTY_FUNCTION__ )); | |||
197 | } | |||
198 | ||||
199 | /// Replace the operands beginning at 'start' and ending at 'start' + 'length' | |||
200 | /// with the ones provided in 'operands'. 'operands' may be smaller or larger | |||
201 | /// than the range pointed to by 'start'+'length'. | |||
202 | void Operation::setOperands(unsigned start, unsigned length, | |||
203 | ValueRange operands) { | |||
204 | assert((start + length) <= getNumOperands() &&(static_cast <bool> ((start + length) <= getNumOperands () && "invalid operand range specified") ? void (0) : __assert_fail ("(start + length) <= getNumOperands() && \"invalid operand range specified\"" , "mlir/lib/IR/Operation.cpp", 205, __extension__ __PRETTY_FUNCTION__ )) | |||
205 | "invalid operand range specified")(static_cast <bool> ((start + length) <= getNumOperands () && "invalid operand range specified") ? void (0) : __assert_fail ("(start + length) <= getNumOperands() && \"invalid operand range specified\"" , "mlir/lib/IR/Operation.cpp", 205, __extension__ __PRETTY_FUNCTION__ )); | |||
206 | if (LLVM_LIKELY(hasOperandStorage)__builtin_expect((bool)(hasOperandStorage), true)) | |||
207 | return getOperandStorage().setOperands(this, start, length, operands); | |||
208 | assert(operands.empty() && "setting operands without an operand storage")(static_cast <bool> (operands.empty() && "setting operands without an operand storage" ) ? void (0) : __assert_fail ("operands.empty() && \"setting operands without an operand storage\"" , "mlir/lib/IR/Operation.cpp", 208, __extension__ __PRETTY_FUNCTION__ )); | |||
209 | } | |||
210 | ||||
211 | /// Insert the given operands into the operand list at the given 'index'. | |||
212 | void Operation::insertOperands(unsigned index, ValueRange operands) { | |||
213 | if (LLVM_LIKELY(hasOperandStorage)__builtin_expect((bool)(hasOperandStorage), true)) | |||
214 | return setOperands(index, /*length=*/0, operands); | |||
215 | assert(operands.empty() && "inserting operands without an operand storage")(static_cast <bool> (operands.empty() && "inserting operands without an operand storage" ) ? void (0) : __assert_fail ("operands.empty() && \"inserting operands without an operand storage\"" , "mlir/lib/IR/Operation.cpp", 215, __extension__ __PRETTY_FUNCTION__ )); | |||
216 | } | |||
217 | ||||
218 | //===----------------------------------------------------------------------===// | |||
219 | // Diagnostics | |||
220 | //===----------------------------------------------------------------------===// | |||
221 | ||||
222 | /// Emit an error about fatal conditions with this operation, reporting up to | |||
223 | /// any diagnostic handlers that may be listening. | |||
224 | InFlightDiagnostic Operation::emitError(const Twine &message) { | |||
225 | InFlightDiagnostic diag = mlir::emitError(getLoc(), message); | |||
226 | if (getContext()->shouldPrintOpOnDiagnostic()) { | |||
227 | diag.attachNote(getLoc()) | |||
228 | .append("see current operation: ") | |||
229 | .appendOp(*this, OpPrintingFlags().printGenericOpForm()); | |||
230 | } | |||
231 | return diag; | |||
232 | } | |||
233 | ||||
234 | /// Emit a warning about this operation, reporting up to any diagnostic | |||
235 | /// handlers that may be listening. | |||
236 | InFlightDiagnostic Operation::emitWarning(const Twine &message) { | |||
237 | InFlightDiagnostic diag = mlir::emitWarning(getLoc(), message); | |||
238 | if (getContext()->shouldPrintOpOnDiagnostic()) | |||
239 | diag.attachNote(getLoc()) << "see current operation: " << *this; | |||
240 | return diag; | |||
241 | } | |||
242 | ||||
243 | /// Emit a remark about this operation, reporting up to any diagnostic | |||
244 | /// handlers that may be listening. | |||
245 | InFlightDiagnostic Operation::emitRemark(const Twine &message) { | |||
246 | InFlightDiagnostic diag = mlir::emitRemark(getLoc(), message); | |||
247 | if (getContext()->shouldPrintOpOnDiagnostic()) | |||
248 | diag.attachNote(getLoc()) << "see current operation: " << *this; | |||
249 | return diag; | |||
250 | } | |||
251 | ||||
252 | //===----------------------------------------------------------------------===// | |||
253 | // Operation Ordering | |||
254 | //===----------------------------------------------------------------------===// | |||
255 | ||||
256 | constexpr unsigned Operation::kInvalidOrderIdx; | |||
257 | constexpr unsigned Operation::kOrderStride; | |||
258 | ||||
259 | /// Given an operation 'other' that is within the same parent block, return | |||
260 | /// whether the current operation is before 'other' in the operation list | |||
261 | /// of the parent block. | |||
262 | /// Note: This function has an average complexity of O(1), but worst case may | |||
263 | /// take O(N) where N is the number of operations within the parent block. | |||
264 | bool Operation::isBeforeInBlock(Operation *other) { | |||
265 | assert(block && "Operations without parent blocks have no order.")(static_cast <bool> (block && "Operations without parent blocks have no order." ) ? void (0) : __assert_fail ("block && \"Operations without parent blocks have no order.\"" , "mlir/lib/IR/Operation.cpp", 265, __extension__ __PRETTY_FUNCTION__ )); | |||
266 | assert(other && other->block == block &&(static_cast <bool> (other && other->block == block && "Expected other operation to have the same parent block." ) ? void (0) : __assert_fail ("other && other->block == block && \"Expected other operation to have the same parent block.\"" , "mlir/lib/IR/Operation.cpp", 267, __extension__ __PRETTY_FUNCTION__ )) | |||
267 | "Expected other operation to have the same parent block.")(static_cast <bool> (other && other->block == block && "Expected other operation to have the same parent block." ) ? void (0) : __assert_fail ("other && other->block == block && \"Expected other operation to have the same parent block.\"" , "mlir/lib/IR/Operation.cpp", 267, __extension__ __PRETTY_FUNCTION__ )); | |||
268 | // If the order of the block is already invalid, directly recompute the | |||
269 | // parent. | |||
270 | if (!block->isOpOrderValid()) { | |||
271 | block->recomputeOpOrder(); | |||
272 | } else { | |||
273 | // Update the order either operation if necessary. | |||
274 | updateOrderIfNecessary(); | |||
275 | other->updateOrderIfNecessary(); | |||
276 | } | |||
277 | ||||
278 | return orderIndex < other->orderIndex; | |||
279 | } | |||
280 | ||||
281 | /// Update the order index of this operation of this operation if necessary, | |||
282 | /// potentially recomputing the order of the parent block. | |||
283 | void Operation::updateOrderIfNecessary() { | |||
284 | assert(block && "expected valid parent")(static_cast <bool> (block && "expected valid parent" ) ? void (0) : __assert_fail ("block && \"expected valid parent\"" , "mlir/lib/IR/Operation.cpp", 284, __extension__ __PRETTY_FUNCTION__ )); | |||
285 | ||||
286 | // If the order is valid for this operation there is nothing to do. | |||
287 | if (hasValidOrder()) | |||
288 | return; | |||
289 | Operation *blockFront = &block->front(); | |||
290 | Operation *blockBack = &block->back(); | |||
291 | ||||
292 | // This method is expected to only be invoked on blocks with more than one | |||
293 | // operation. | |||
294 | assert(blockFront != blockBack && "expected more than one operation")(static_cast <bool> (blockFront != blockBack && "expected more than one operation") ? void (0) : __assert_fail ("blockFront != blockBack && \"expected more than one operation\"" , "mlir/lib/IR/Operation.cpp", 294, __extension__ __PRETTY_FUNCTION__ )); | |||
295 | ||||
296 | // If the operation is at the end of the block. | |||
297 | if (this == blockBack) { | |||
298 | Operation *prevNode = getPrevNode(); | |||
299 | if (!prevNode->hasValidOrder()) | |||
300 | return block->recomputeOpOrder(); | |||
301 | ||||
302 | // Add the stride to the previous operation. | |||
303 | orderIndex = prevNode->orderIndex + kOrderStride; | |||
304 | return; | |||
305 | } | |||
306 | ||||
307 | // If this is the first operation try to use the next operation to compute the | |||
308 | // ordering. | |||
309 | if (this == blockFront) { | |||
310 | Operation *nextNode = getNextNode(); | |||
311 | if (!nextNode->hasValidOrder()) | |||
312 | return block->recomputeOpOrder(); | |||
313 | // There is no order to give this operation. | |||
314 | if (nextNode->orderIndex == 0) | |||
315 | return block->recomputeOpOrder(); | |||
316 | ||||
317 | // If we can't use the stride, just take the middle value left. This is safe | |||
318 | // because we know there is at least one valid index to assign to. | |||
319 | if (nextNode->orderIndex <= kOrderStride) | |||
320 | orderIndex = (nextNode->orderIndex / 2); | |||
321 | else | |||
322 | orderIndex = kOrderStride; | |||
323 | return; | |||
324 | } | |||
325 | ||||
326 | // Otherwise, this operation is between two others. Place this operation in | |||
327 | // the middle of the previous and next if possible. | |||
328 | Operation *prevNode = getPrevNode(), *nextNode = getNextNode(); | |||
329 | if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder()) | |||
330 | return block->recomputeOpOrder(); | |||
331 | unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex; | |||
332 | ||||
333 | // Check to see if there is a valid order between the two. | |||
334 | if (prevOrder + 1 == nextOrder) | |||
335 | return block->recomputeOpOrder(); | |||
336 | orderIndex = prevOrder + ((nextOrder - prevOrder) / 2); | |||
337 | } | |||
338 | ||||
339 | //===----------------------------------------------------------------------===// | |||
340 | // ilist_traits for Operation | |||
341 | //===----------------------------------------------------------------------===// | |||
342 | ||||
343 | auto llvm::ilist_detail::SpecificNodeAccess< | |||
344 | typename llvm::ilist_detail::compute_node_options< | |||
345 | ::mlir::Operation>::type>::getNodePtr(pointer n) -> node_type * { | |||
346 | return NodeAccess::getNodePtr<OptionsT>(n); | |||
347 | } | |||
348 | ||||
349 | auto llvm::ilist_detail::SpecificNodeAccess< | |||
350 | typename llvm::ilist_detail::compute_node_options< | |||
351 | ::mlir::Operation>::type>::getNodePtr(const_pointer n) | |||
352 | -> const node_type * { | |||
353 | return NodeAccess::getNodePtr<OptionsT>(n); | |||
354 | } | |||
355 | ||||
356 | auto llvm::ilist_detail::SpecificNodeAccess< | |||
357 | typename llvm::ilist_detail::compute_node_options< | |||
358 | ::mlir::Operation>::type>::getValuePtr(node_type *n) -> pointer { | |||
359 | return NodeAccess::getValuePtr<OptionsT>(n); | |||
360 | } | |||
361 | ||||
362 | auto llvm::ilist_detail::SpecificNodeAccess< | |||
363 | typename llvm::ilist_detail::compute_node_options< | |||
364 | ::mlir::Operation>::type>::getValuePtr(const node_type *n) | |||
365 | -> const_pointer { | |||
366 | return NodeAccess::getValuePtr<OptionsT>(n); | |||
367 | } | |||
368 | ||||
369 | void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) { | |||
370 | op->destroy(); | |||
371 | } | |||
372 | ||||
373 | Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() { | |||
374 | size_t offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); | |||
375 | iplist<Operation> *anchor(static_cast<iplist<Operation> *>(this)); | |||
376 | return reinterpret_cast<Block *>(reinterpret_cast<char *>(anchor) - offset); | |||
377 | } | |||
378 | ||||
379 | /// This is a trait method invoked when an operation is added to a block. We | |||
380 | /// keep the block pointer up to date. | |||
381 | void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) { | |||
382 | assert(!op->getBlock() && "already in an operation block!")(static_cast <bool> (!op->getBlock() && "already in an operation block!" ) ? void (0) : __assert_fail ("!op->getBlock() && \"already in an operation block!\"" , "mlir/lib/IR/Operation.cpp", 382, __extension__ __PRETTY_FUNCTION__ )); | |||
383 | op->block = getContainingBlock(); | |||
384 | ||||
385 | // Invalidate the order on the operation. | |||
386 | op->orderIndex = Operation::kInvalidOrderIdx; | |||
387 | } | |||
388 | ||||
389 | /// This is a trait method invoked when an operation is removed from a block. | |||
390 | /// We keep the block pointer up to date. | |||
391 | void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) { | |||
392 | assert(op->block && "not already in an operation block!")(static_cast <bool> (op->block && "not already in an operation block!" ) ? void (0) : __assert_fail ("op->block && \"not already in an operation block!\"" , "mlir/lib/IR/Operation.cpp", 392, __extension__ __PRETTY_FUNCTION__ )); | |||
393 | op->block = nullptr; | |||
394 | } | |||
395 | ||||
396 | /// This is a trait method invoked when an operation is moved from one block | |||
397 | /// to another. We keep the block pointer up to date. | |||
398 | void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList( | |||
399 | ilist_traits<Operation> &otherList, op_iterator first, op_iterator last) { | |||
400 | Block *curParent = getContainingBlock(); | |||
401 | ||||
402 | // Invalidate the ordering of the parent block. | |||
403 | curParent->invalidateOpOrder(); | |||
404 | ||||
405 | // If we are transferring operations within the same block, the block | |||
406 | // pointer doesn't need to be updated. | |||
407 | if (curParent == otherList.getContainingBlock()) | |||
408 | return; | |||
409 | ||||
410 | // Update the 'block' member of each operation. | |||
411 | for (; first != last; ++first) | |||
412 | first->block = curParent; | |||
413 | } | |||
414 | ||||
415 | /// Remove this operation (and its descendants) from its Block and delete | |||
416 | /// all of them. | |||
417 | void Operation::erase() { | |||
418 | if (auto *parent = getBlock()) | |||
419 | parent->getOperations().erase(this); | |||
420 | else | |||
421 | destroy(); | |||
422 | } | |||
423 | ||||
424 | /// Remove the operation from its parent block, but don't delete it. | |||
425 | void Operation::remove() { | |||
426 | if (Block *parent = getBlock()) | |||
427 | parent->getOperations().remove(this); | |||
428 | } | |||
429 | ||||
430 | /// Unlink this operation from its current block and insert it right before | |||
431 | /// `existingOp` which may be in the same or another block in the same | |||
432 | /// function. | |||
433 | void Operation::moveBefore(Operation *existingOp) { | |||
434 | moveBefore(existingOp->getBlock(), existingOp->getIterator()); | |||
435 | } | |||
436 | ||||
437 | /// Unlink this operation from its current basic block and insert it right | |||
438 | /// before `iterator` in the specified basic block. | |||
439 | void Operation::moveBefore(Block *block, | |||
440 | llvm::iplist<Operation>::iterator iterator) { | |||
441 | block->getOperations().splice(iterator, getBlock()->getOperations(), | |||
442 | getIterator()); | |||
443 | } | |||
444 | ||||
445 | /// Unlink this operation from its current block and insert it right after | |||
446 | /// `existingOp` which may be in the same or another block in the same function. | |||
447 | void Operation::moveAfter(Operation *existingOp) { | |||
448 | moveAfter(existingOp->getBlock(), existingOp->getIterator()); | |||
449 | } | |||
450 | ||||
451 | /// Unlink this operation from its current block and insert it right after | |||
452 | /// `iterator` in the specified block. | |||
453 | void Operation::moveAfter(Block *block, | |||
454 | llvm::iplist<Operation>::iterator iterator) { | |||
455 | assert(iterator != block->end() && "cannot move after end of block")(static_cast <bool> (iterator != block->end() && "cannot move after end of block") ? void (0) : __assert_fail ("iterator != block->end() && \"cannot move after end of block\"" , "mlir/lib/IR/Operation.cpp", 455, __extension__ __PRETTY_FUNCTION__ )); | |||
456 | moveBefore(block, std::next(iterator)); | |||
457 | } | |||
458 | ||||
459 | /// This drops all operand uses from this operation, which is an essential | |||
460 | /// step in breaking cyclic dependences between references when they are to | |||
461 | /// be deleted. | |||
462 | void Operation::dropAllReferences() { | |||
463 | for (auto &op : getOpOperands()) | |||
464 | op.drop(); | |||
465 | ||||
466 | for (auto ®ion : getRegions()) | |||
467 | region.dropAllReferences(); | |||
468 | ||||
469 | for (auto &dest : getBlockOperands()) | |||
470 | dest.drop(); | |||
471 | } | |||
472 | ||||
473 | /// This drops all uses of any values defined by this operation or its nested | |||
474 | /// regions, wherever they are located. | |||
475 | void Operation::dropAllDefinedValueUses() { | |||
476 | dropAllUses(); | |||
477 | ||||
478 | for (auto ®ion : getRegions()) | |||
479 | for (auto &block : region) | |||
480 | block.dropAllDefinedValueUses(); | |||
481 | } | |||
482 | ||||
483 | void Operation::setSuccessor(Block *block, unsigned index) { | |||
484 | assert(index < getNumSuccessors())(static_cast <bool> (index < getNumSuccessors()) ? void (0) : __assert_fail ("index < getNumSuccessors()", "mlir/lib/IR/Operation.cpp" , 484, __extension__ __PRETTY_FUNCTION__)); | |||
485 | getBlockOperands()[index].set(block); | |||
486 | } | |||
487 | ||||
488 | /// Attempt to fold this operation using the Op's registered foldHook. | |||
489 | LogicalResult Operation::fold(ArrayRef<Attribute> operands, | |||
490 | SmallVectorImpl<OpFoldResult> &results) { | |||
491 | // If we have a registered operation definition matching this one, use it to | |||
492 | // try to constant fold the operation. | |||
493 | if (succeeded(name.foldHook(this, operands, results))) | |||
494 | return success(); | |||
495 | ||||
496 | // Otherwise, fall back on the dialect hook to handle it. | |||
497 | Dialect *dialect = getDialect(); | |||
498 | if (!dialect) | |||
499 | return failure(); | |||
500 | ||||
501 | auto *interface = dyn_cast<DialectFoldInterface>(dialect); | |||
502 | if (!interface) | |||
503 | return failure(); | |||
504 | ||||
505 | return interface->fold(this, operands, results); | |||
506 | } | |||
507 | ||||
508 | /// Emit an error with the op name prefixed, like "'dim' op " which is | |||
509 | /// convenient for verifiers. | |||
510 | InFlightDiagnostic Operation::emitOpError(const Twine &message) { | |||
511 | return emitError() << "'" << getName() << "' op " << message; | |||
512 | } | |||
513 | ||||
514 | //===----------------------------------------------------------------------===// | |||
515 | // Operation Cloning | |||
516 | //===----------------------------------------------------------------------===// | |||
517 | ||||
518 | Operation::CloneOptions::CloneOptions() | |||
519 | : cloneRegionsFlag(false), cloneOperandsFlag(false) {} | |||
520 | ||||
521 | Operation::CloneOptions::CloneOptions(bool cloneRegions, bool cloneOperands) | |||
522 | : cloneRegionsFlag(cloneRegions), cloneOperandsFlag(cloneOperands) {} | |||
523 | ||||
524 | Operation::CloneOptions Operation::CloneOptions::all() { | |||
525 | return CloneOptions().cloneRegions().cloneOperands(); | |||
526 | } | |||
527 | ||||
528 | Operation::CloneOptions &Operation::CloneOptions::cloneRegions(bool enable) { | |||
529 | cloneRegionsFlag = enable; | |||
530 | return *this; | |||
531 | } | |||
532 | ||||
533 | Operation::CloneOptions &Operation::CloneOptions::cloneOperands(bool enable) { | |||
534 | cloneOperandsFlag = enable; | |||
535 | return *this; | |||
536 | } | |||
537 | ||||
538 | /// Create a deep copy of this operation but keep the operation regions empty. | |||
539 | /// Operands are remapped using `mapper` (if present), and `mapper` is updated | |||
540 | /// to contain the results. The `mapResults` flag specifies whether the results | |||
541 | /// of the cloned operation should be added to the map. | |||
542 | Operation *Operation::cloneWithoutRegions(IRMapping &mapper) { | |||
543 | return clone(mapper, CloneOptions::all().cloneRegions(false)); | |||
544 | } | |||
545 | ||||
546 | Operation *Operation::cloneWithoutRegions() { | |||
547 | IRMapping mapper; | |||
548 | return cloneWithoutRegions(mapper); | |||
549 | } | |||
550 | ||||
551 | /// Create a deep copy of this operation, remapping any operands that use | |||
552 | /// values outside of the operation using the map that is provided (leaving | |||
553 | /// them alone if no entry is present). Replaces references to cloned | |||
554 | /// sub-operations to the corresponding operation that is copied, and adds | |||
555 | /// those mappings to the map. | |||
556 | Operation *Operation::clone(IRMapping &mapper, CloneOptions options) { | |||
557 | SmallVector<Value, 8> operands; | |||
558 | SmallVector<Block *, 2> successors; | |||
559 | ||||
560 | // Remap the operands. | |||
561 | if (options.shouldCloneOperands()) { | |||
562 | operands.reserve(getNumOperands()); | |||
563 | for (auto opValue : getOperands()) | |||
564 | operands.push_back(mapper.lookupOrDefault(opValue)); | |||
565 | } | |||
566 | ||||
567 | // Remap the successors. | |||
568 | successors.reserve(getNumSuccessors()); | |||
569 | for (Block *successor : getSuccessors()) | |||
570 | successors.push_back(mapper.lookupOrDefault(successor)); | |||
571 | ||||
572 | // Create the new operation. | |||
573 | auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs, | |||
574 | successors, getNumRegions()); | |||
575 | mapper.map(this, newOp); | |||
576 | ||||
577 | // Clone the regions. | |||
578 | if (options.shouldCloneRegions()) { | |||
579 | for (unsigned i = 0; i != numRegions; ++i) | |||
580 | getRegion(i).cloneInto(&newOp->getRegion(i), mapper); | |||
581 | } | |||
582 | ||||
583 | // Remember the mapping of any results. | |||
584 | for (unsigned i = 0, e = getNumResults(); i != e; ++i) | |||
585 | mapper.map(getResult(i), newOp->getResult(i)); | |||
586 | ||||
587 | return newOp; | |||
588 | } | |||
589 | ||||
590 | Operation *Operation::clone(CloneOptions options) { | |||
591 | IRMapping mapper; | |||
592 | return clone(mapper, options); | |||
| ||||
593 | } | |||
594 | ||||
595 | //===----------------------------------------------------------------------===// | |||
596 | // OpState trait class. | |||
597 | //===----------------------------------------------------------------------===// | |||
598 | ||||
599 | // The fallback for the parser is to try for a dialect operation parser. | |||
600 | // Otherwise, reject the custom assembly form. | |||
601 | ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) { | |||
602 | if (auto parseFn = result.name.getDialect()->getParseOperationHook( | |||
603 | result.name.getStringRef())) | |||
604 | return (*parseFn)(parser, result); | |||
605 | return parser.emitError(parser.getNameLoc(), "has no custom assembly form"); | |||
606 | } | |||
607 | ||||
608 | // The fallback for the printer is to try for a dialect operation printer. | |||
609 | // Otherwise, it prints the generic form. | |||
610 | void OpState::print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect) { | |||
611 | if (auto printFn = op->getDialect()->getOperationPrinter(op)) { | |||
612 | printOpName(op, p, defaultDialect); | |||
613 | printFn(op, p); | |||
614 | } else { | |||
615 | p.printGenericOp(op); | |||
616 | } | |||
617 | } | |||
618 | ||||
619 | /// Print an operation name, eliding the dialect prefix if necessary and doesn't | |||
620 | /// lead to ambiguities. | |||
621 | void OpState::printOpName(Operation *op, OpAsmPrinter &p, | |||
622 | StringRef defaultDialect) { | |||
623 | StringRef name = op->getName().getStringRef(); | |||
624 | if (name.startswith((defaultDialect + ".").str()) && name.count('.') == 1) | |||
625 | name = name.drop_front(defaultDialect.size() + 1); | |||
626 | p.getStream() << name; | |||
627 | } | |||
628 | ||||
629 | /// Emit an error about fatal conditions with this operation, reporting up to | |||
630 | /// any diagnostic handlers that may be listening. | |||
631 | InFlightDiagnostic OpState::emitError(const Twine &message) { | |||
632 | return getOperation()->emitError(message); | |||
633 | } | |||
634 | ||||
635 | /// Emit an error with the op name prefixed, like "'dim' op " which is | |||
636 | /// convenient for verifiers. | |||
637 | InFlightDiagnostic OpState::emitOpError(const Twine &message) { | |||
638 | return getOperation()->emitOpError(message); | |||
639 | } | |||
640 | ||||
641 | /// Emit a warning about this operation, reporting up to any diagnostic | |||
642 | /// handlers that may be listening. | |||
643 | InFlightDiagnostic OpState::emitWarning(const Twine &message) { | |||
644 | return getOperation()->emitWarning(message); | |||
645 | } | |||
646 | ||||
647 | /// Emit a remark about this operation, reporting up to any diagnostic | |||
648 | /// handlers that may be listening. | |||
649 | InFlightDiagnostic OpState::emitRemark(const Twine &message) { | |||
650 | return getOperation()->emitRemark(message); | |||
651 | } | |||
652 | ||||
653 | //===----------------------------------------------------------------------===// | |||
654 | // Op Trait implementations | |||
655 | //===----------------------------------------------------------------------===// | |||
656 | ||||
657 | OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) { | |||
658 | if (op->getNumOperands() == 1) { | |||
659 | auto *argumentOp = op->getOperand(0).getDefiningOp(); | |||
660 | if (argumentOp && op->getName() == argumentOp->getName()) { | |||
661 | // Replace the outer operation output with the inner operation. | |||
662 | return op->getOperand(0); | |||
663 | } | |||
664 | } else if (op->getOperand(0) == op->getOperand(1)) { | |||
665 | return op->getOperand(0); | |||
666 | } | |||
667 | ||||
668 | return {}; | |||
669 | } | |||
670 | ||||
671 | OpFoldResult OpTrait::impl::foldInvolution(Operation *op) { | |||
672 | auto *argumentOp = op->getOperand(0).getDefiningOp(); | |||
673 | if (argumentOp && op->getName() == argumentOp->getName()) { | |||
674 | // Replace the outer involutions output with inner's input. | |||
675 | return argumentOp->getOperand(0); | |||
676 | } | |||
677 | ||||
678 | return {}; | |||
679 | } | |||
680 | ||||
681 | LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) { | |||
682 | if (op->getNumOperands() != 0) | |||
683 | return op->emitOpError() << "requires zero operands"; | |||
684 | return success(); | |||
685 | } | |||
686 | ||||
687 | LogicalResult OpTrait::impl::verifyOneOperand(Operation *op) { | |||
688 | if (op->getNumOperands() != 1) | |||
689 | return op->emitOpError() << "requires a single operand"; | |||
690 | return success(); | |||
691 | } | |||
692 | ||||
693 | LogicalResult OpTrait::impl::verifyNOperands(Operation *op, | |||
694 | unsigned numOperands) { | |||
695 | if (op->getNumOperands() != numOperands) { | |||
696 | return op->emitOpError() << "expected " << numOperands | |||
697 | << " operands, but found " << op->getNumOperands(); | |||
698 | } | |||
699 | return success(); | |||
700 | } | |||
701 | ||||
702 | LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op, | |||
703 | unsigned numOperands) { | |||
704 | if (op->getNumOperands() < numOperands) | |||
705 | return op->emitOpError() | |||
706 | << "expected " << numOperands << " or more operands, but found " | |||
707 | << op->getNumOperands(); | |||
708 | return success(); | |||
709 | } | |||
710 | ||||
711 | /// If this is a vector type, or a tensor type, return the scalar element type | |||
712 | /// that it is built around, otherwise return the type unmodified. | |||
713 | static Type getTensorOrVectorElementType(Type type) { | |||
714 | if (auto vec = type.dyn_cast<VectorType>()) | |||
715 | return vec.getElementType(); | |||
716 | ||||
717 | // Look through tensor<vector<...>> to find the underlying element type. | |||
718 | if (auto tensor = type.dyn_cast<TensorType>()) | |||
719 | return getTensorOrVectorElementType(tensor.getElementType()); | |||
720 | return type; | |||
721 | } | |||
722 | ||||
723 | LogicalResult OpTrait::impl::verifyIsIdempotent(Operation *op) { | |||
724 | // FIXME: Add back check for no side effects on operation. | |||
725 | // Currently adding it would cause the shared library build | |||
726 | // to fail since there would be a dependency of IR on SideEffectInterfaces | |||
727 | // which is cyclical. | |||
728 | return success(); | |||
729 | } | |||
730 | ||||
731 | LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) { | |||
732 | // FIXME: Add back check for no side effects on operation. | |||
733 | // Currently adding it would cause the shared library build | |||
734 | // to fail since there would be a dependency of IR on SideEffectInterfaces | |||
735 | // which is cyclical. | |||
736 | return success(); | |||
737 | } | |||
738 | ||||
739 | LogicalResult | |||
740 | OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) { | |||
741 | for (auto opType : op->getOperandTypes()) { | |||
742 | auto type = getTensorOrVectorElementType(opType); | |||
743 | if (!type.isSignlessIntOrIndex()) | |||
744 | return op->emitOpError() << "requires an integer or index type"; | |||
745 | } | |||
746 | return success(); | |||
747 | } | |||
748 | ||||
749 | LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) { | |||
750 | for (auto opType : op->getOperandTypes()) { | |||
751 | auto type = getTensorOrVectorElementType(opType); | |||
752 | if (!type.isa<FloatType>()) | |||
753 | return op->emitOpError("requires a float type"); | |||
754 | } | |||
755 | return success(); | |||
756 | } | |||
757 | ||||
758 | LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) { | |||
759 | // Zero or one operand always have the "same" type. | |||
760 | unsigned nOperands = op->getNumOperands(); | |||
761 | if (nOperands < 2) | |||
762 | return success(); | |||
763 | ||||
764 | auto type = op->getOperand(0).getType(); | |||
765 | for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) | |||
766 | if (opType != type) | |||
767 | return op->emitOpError() << "requires all operands to have the same type"; | |||
768 | return success(); | |||
769 | } | |||
770 | ||||
771 | LogicalResult OpTrait::impl::verifyZeroRegions(Operation *op) { | |||
772 | if (op->getNumRegions() != 0) | |||
773 | return op->emitOpError() << "requires zero regions"; | |||
774 | return success(); | |||
775 | } | |||
776 | ||||
777 | LogicalResult OpTrait::impl::verifyOneRegion(Operation *op) { | |||
778 | if (op->getNumRegions() != 1) | |||
779 | return op->emitOpError() << "requires one region"; | |||
780 | return success(); | |||
781 | } | |||
782 | ||||
783 | LogicalResult OpTrait::impl::verifyNRegions(Operation *op, | |||
784 | unsigned numRegions) { | |||
785 | if (op->getNumRegions() != numRegions) | |||
786 | return op->emitOpError() << "expected " << numRegions << " regions"; | |||
787 | return success(); | |||
788 | } | |||
789 | ||||
790 | LogicalResult OpTrait::impl::verifyAtLeastNRegions(Operation *op, | |||
791 | unsigned numRegions) { | |||
792 | if (op->getNumRegions() < numRegions) | |||
793 | return op->emitOpError() << "expected " << numRegions << " or more regions"; | |||
794 | return success(); | |||
795 | } | |||
796 | ||||
797 | LogicalResult OpTrait::impl::verifyZeroResults(Operation *op) { | |||
798 | if (op->getNumResults() != 0) | |||
799 | return op->emitOpError() << "requires zero results"; | |||
800 | return success(); | |||
801 | } | |||
802 | ||||
803 | LogicalResult OpTrait::impl::verifyOneResult(Operation *op) { | |||
804 | if (op->getNumResults() != 1) | |||
805 | return op->emitOpError() << "requires one result"; | |||
806 | return success(); | |||
807 | } | |||
808 | ||||
809 | LogicalResult OpTrait::impl::verifyNResults(Operation *op, | |||
810 | unsigned numOperands) { | |||
811 | if (op->getNumResults() != numOperands) | |||
812 | return op->emitOpError() << "expected " << numOperands << " results"; | |||
813 | return success(); | |||
814 | } | |||
815 | ||||
816 | LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op, | |||
817 | unsigned numOperands) { | |||
818 | if (op->getNumResults() < numOperands) | |||
819 | return op->emitOpError() | |||
820 | << "expected " << numOperands << " or more results"; | |||
821 | return success(); | |||
822 | } | |||
823 | ||||
824 | LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) { | |||
825 | if (failed(verifyAtLeastNOperands(op, 1))) | |||
826 | return failure(); | |||
827 | ||||
828 | if (failed(verifyCompatibleShapes(op->getOperandTypes()))) | |||
829 | return op->emitOpError() << "requires the same shape for all operands"; | |||
830 | ||||
831 | return success(); | |||
832 | } | |||
833 | ||||
834 | LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { | |||
835 | if (failed(verifyAtLeastNOperands(op, 1)) || | |||
836 | failed(verifyAtLeastNResults(op, 1))) | |||
837 | return failure(); | |||
838 | ||||
839 | SmallVector<Type, 8> types(op->getOperandTypes()); | |||
840 | types.append(llvm::to_vector<4>(op->getResultTypes())); | |||
841 | ||||
842 | if (failed(verifyCompatibleShapes(types))) | |||
843 | return op->emitOpError() | |||
844 | << "requires the same shape for all operands and results"; | |||
845 | ||||
846 | return success(); | |||
847 | } | |||
848 | ||||
849 | LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) { | |||
850 | if (failed(verifyAtLeastNOperands(op, 1))) | |||
851 | return failure(); | |||
852 | auto elementType = getElementTypeOrSelf(op->getOperand(0)); | |||
853 | ||||
854 | for (auto operand : llvm::drop_begin(op->getOperands(), 1)) { | |||
855 | if (getElementTypeOrSelf(operand) != elementType) | |||
856 | return op->emitOpError("requires the same element type for all operands"); | |||
857 | } | |||
858 | ||||
859 | return success(); | |||
860 | } | |||
861 | ||||
862 | LogicalResult | |||
863 | OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { | |||
864 | if (failed(verifyAtLeastNOperands(op, 1)) || | |||
865 | failed(verifyAtLeastNResults(op, 1))) | |||
866 | return failure(); | |||
867 | ||||
868 | auto elementType = getElementTypeOrSelf(op->getResult(0)); | |||
869 | ||||
870 | // Verify result element type matches first result's element type. | |||
871 | for (auto result : llvm::drop_begin(op->getResults(), 1)) { | |||
872 | if (getElementTypeOrSelf(result) != elementType) | |||
873 | return op->emitOpError( | |||
874 | "requires the same element type for all operands and results"); | |||
875 | } | |||
876 | ||||
877 | // Verify operand's element type matches first result's element type. | |||
878 | for (auto operand : op->getOperands()) { | |||
879 | if (getElementTypeOrSelf(operand) != elementType) | |||
880 | return op->emitOpError( | |||
881 | "requires the same element type for all operands and results"); | |||
882 | } | |||
883 | ||||
884 | return success(); | |||
885 | } | |||
886 | ||||
887 | LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { | |||
888 | if (failed(verifyAtLeastNOperands(op, 1)) || | |||
889 | failed(verifyAtLeastNResults(op, 1))) | |||
890 | return failure(); | |||
891 | ||||
892 | auto type = op->getResult(0).getType(); | |||
893 | auto elementType = getElementTypeOrSelf(type); | |||
894 | Attribute encoding = nullptr; | |||
895 | if (auto rankedType = dyn_cast<RankedTensorType>(type)) | |||
896 | encoding = rankedType.getEncoding(); | |||
897 | for (auto resultType : llvm::drop_begin(op->getResultTypes())) { | |||
898 | if (getElementTypeOrSelf(resultType) != elementType || | |||
899 | failed(verifyCompatibleShape(resultType, type))) | |||
900 | return op->emitOpError() | |||
901 | << "requires the same type for all operands and results"; | |||
902 | if (encoding) | |||
903 | if (auto rankedType = dyn_cast<RankedTensorType>(resultType); | |||
904 | encoding != rankedType.getEncoding()) | |||
905 | return op->emitOpError() | |||
906 | << "requires the same encoding for all operands and results"; | |||
907 | } | |||
908 | for (auto opType : op->getOperandTypes()) { | |||
909 | if (getElementTypeOrSelf(opType) != elementType || | |||
910 | failed(verifyCompatibleShape(opType, type))) | |||
911 | return op->emitOpError() | |||
912 | << "requires the same type for all operands and results"; | |||
913 | if (encoding) | |||
914 | if (auto rankedType = dyn_cast<RankedTensorType>(opType); | |||
915 | encoding != rankedType.getEncoding()) | |||
916 | return op->emitOpError() | |||
917 | << "requires the same encoding for all operands and results"; | |||
918 | } | |||
919 | return success(); | |||
920 | } | |||
921 | ||||
922 | LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { | |||
923 | Block *block = op->getBlock(); | |||
924 | // Verify that the operation is at the end of the respective parent block. | |||
925 | if (!block || &block->back() != op) | |||
926 | return op->emitOpError("must be the last operation in the parent block"); | |||
927 | return success(); | |||
928 | } | |||
929 | ||||
930 | static LogicalResult verifyTerminatorSuccessors(Operation *op) { | |||
931 | auto *parent = op->getParentRegion(); | |||
932 | ||||
933 | // Verify that the operands lines up with the BB arguments in the successor. | |||
934 | for (Block *succ : op->getSuccessors()) | |||
935 | if (succ->getParent() != parent) | |||
936 | return op->emitError("reference to block defined in another region"); | |||
937 | return success(); | |||
938 | } | |||
939 | ||||
940 | LogicalResult OpTrait::impl::verifyZeroSuccessors(Operation *op) { | |||
941 | if (op->getNumSuccessors() != 0) { | |||
942 | return op->emitOpError("requires 0 successors but found ") | |||
943 | << op->getNumSuccessors(); | |||
944 | } | |||
945 | return success(); | |||
946 | } | |||
947 | ||||
948 | LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) { | |||
949 | if (op->getNumSuccessors() != 1) { | |||
950 | return op->emitOpError("requires 1 successor but found ") | |||
951 | << op->getNumSuccessors(); | |||
952 | } | |||
953 | return verifyTerminatorSuccessors(op); | |||
954 | } | |||
955 | LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op, | |||
956 | unsigned numSuccessors) { | |||
957 | if (op->getNumSuccessors() != numSuccessors) { | |||
958 | return op->emitOpError("requires ") | |||
959 | << numSuccessors << " successors but found " | |||
960 | << op->getNumSuccessors(); | |||
961 | } | |||
962 | return verifyTerminatorSuccessors(op); | |||
963 | } | |||
964 | LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op, | |||
965 | unsigned numSuccessors) { | |||
966 | if (op->getNumSuccessors() < numSuccessors) { | |||
967 | return op->emitOpError("requires at least ") | |||
968 | << numSuccessors << " successors but found " | |||
969 | << op->getNumSuccessors(); | |||
970 | } | |||
971 | return verifyTerminatorSuccessors(op); | |||
972 | } | |||
973 | ||||
974 | LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) { | |||
975 | for (auto resultType : op->getResultTypes()) { | |||
976 | auto elementType = getTensorOrVectorElementType(resultType); | |||
977 | bool isBoolType = elementType.isInteger(1); | |||
978 | if (!isBoolType) | |||
979 | return op->emitOpError() << "requires a bool result type"; | |||
980 | } | |||
981 | ||||
982 | return success(); | |||
983 | } | |||
984 | ||||
985 | LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) { | |||
986 | for (auto resultType : op->getResultTypes()) | |||
987 | if (!getTensorOrVectorElementType(resultType).isa<FloatType>()) | |||
988 | return op->emitOpError() << "requires a floating point type"; | |||
989 | ||||
990 | return success(); | |||
991 | } | |||
992 | ||||
993 | LogicalResult | |||
994 | OpTrait::impl::verifyResultsAreSignlessIntegerLike(Operation *op) { | |||
995 | for (auto resultType : op->getResultTypes()) | |||
996 | if (!getTensorOrVectorElementType(resultType).isSignlessIntOrIndex()) | |||
997 | return op->emitOpError() << "requires an integer or index type"; | |||
998 | return success(); | |||
999 | } | |||
1000 | ||||
1001 | LogicalResult OpTrait::impl::verifyValueSizeAttr(Operation *op, | |||
1002 | StringRef attrName, | |||
1003 | StringRef valueGroupName, | |||
1004 | size_t expectedCount) { | |||
1005 | auto sizeAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrName); | |||
1006 | if (!sizeAttr) | |||
1007 | return op->emitOpError("requires dense i32 array attribute '") | |||
1008 | << attrName << "'"; | |||
1009 | ||||
1010 | ArrayRef<int32_t> sizes = sizeAttr.asArrayRef(); | |||
1011 | if (llvm::any_of(sizes, [](int32_t element) { return element < 0; })) | |||
1012 | return op->emitOpError("'") | |||
1013 | << attrName << "' attribute cannot have negative elements"; | |||
1014 | ||||
1015 | size_t totalCount = | |||
1016 | std::accumulate(sizes.begin(), sizes.end(), 0, | |||
1017 | [](unsigned all, int32_t one) { return all + one; }); | |||
1018 | ||||
1019 | if (totalCount != expectedCount) | |||
1020 | return op->emitOpError() | |||
1021 | << valueGroupName << " count (" << expectedCount | |||
1022 | << ") does not match with the total size (" << totalCount | |||
1023 | << ") specified in attribute '" << attrName << "'"; | |||
1024 | return success(); | |||
1025 | } | |||
1026 | ||||
1027 | LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op, | |||
1028 | StringRef attrName) { | |||
1029 | return verifyValueSizeAttr(op, attrName, "operand", op->getNumOperands()); | |||
1030 | } | |||
1031 | ||||
1032 | LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op, | |||
1033 | StringRef attrName) { | |||
1034 | return verifyValueSizeAttr(op, attrName, "result", op->getNumResults()); | |||
1035 | } | |||
1036 | ||||
1037 | LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) { | |||
1038 | for (Region ®ion : op->getRegions()) { | |||
1039 | if (region.empty()) | |||
1040 | continue; | |||
1041 | ||||
1042 | if (region.getNumArguments() != 0) { | |||
1043 | if (op->getNumRegions() > 1) | |||
1044 | return op->emitOpError("region #") | |||
1045 | << region.getRegionNumber() << " should have no arguments"; | |||
1046 | return op->emitOpError("region should have no arguments"); | |||
1047 | } | |||
1048 | } | |||
1049 | return success(); | |||
1050 | } | |||
1051 | ||||
1052 | LogicalResult OpTrait::impl::verifyElementwise(Operation *op) { | |||
1053 | auto isMappableType = [](Type type) { | |||
1054 | return type.isa<VectorType, TensorType>(); | |||
1055 | }; | |||
1056 | auto resultMappableTypes = llvm::to_vector<1>( | |||
1057 | llvm::make_filter_range(op->getResultTypes(), isMappableType)); | |||
1058 | auto operandMappableTypes = llvm::to_vector<2>( | |||
1059 | llvm::make_filter_range(op->getOperandTypes(), isMappableType)); | |||
1060 | ||||
1061 | // If the op only has scalar operand/result types, then we have nothing to | |||
1062 | // check. | |||
1063 | if (resultMappableTypes.empty() && operandMappableTypes.empty()) | |||
1064 | return success(); | |||
1065 | ||||
1066 | if (!resultMappableTypes.empty() && operandMappableTypes.empty()) | |||
1067 | return op->emitOpError("if a result is non-scalar, then at least one " | |||
1068 | "operand must be non-scalar"); | |||
1069 | ||||
1070 | assert(!operandMappableTypes.empty())(static_cast <bool> (!operandMappableTypes.empty()) ? void (0) : __assert_fail ("!operandMappableTypes.empty()", "mlir/lib/IR/Operation.cpp" , 1070, __extension__ __PRETTY_FUNCTION__)); | |||
1071 | ||||
1072 | if (resultMappableTypes.empty()) | |||
1073 | return op->emitOpError("if an operand is non-scalar, then there must be at " | |||
1074 | "least one non-scalar result"); | |||
1075 | ||||
1076 | if (resultMappableTypes.size() != op->getNumResults()) | |||
1077 | return op->emitOpError( | |||
1078 | "if an operand is non-scalar, then all results must be non-scalar"); | |||
1079 | ||||
1080 | SmallVector<Type, 4> types = llvm::to_vector<2>( | |||
1081 | llvm::concat<Type>(operandMappableTypes, resultMappableTypes)); | |||
1082 | TypeID expectedBaseTy = types.front().getTypeID(); | |||
1083 | if (!llvm::all_of(types, | |||
1084 | [&](Type t) { return t.getTypeID() == expectedBaseTy; }) || | |||
1085 | failed(verifyCompatibleShapes(types))) { | |||
1086 | return op->emitOpError() << "all non-scalar operands/results must have the " | |||
1087 | "same shape and base type"; | |||
1088 | } | |||
1089 | ||||
1090 | return success(); | |||
1091 | } | |||
1092 | ||||
1093 | /// Check for any values used by operations regions attached to the | |||
1094 | /// specified "IsIsolatedFromAbove" operation defined outside of it. | |||
1095 | LogicalResult OpTrait::impl::verifyIsIsolatedFromAbove(Operation *isolatedOp) { | |||
1096 | assert(isolatedOp->hasTrait<OpTrait::IsIsolatedFromAbove>() &&(static_cast <bool> (isolatedOp->hasTrait<OpTrait ::IsIsolatedFromAbove>() && "Intended to check IsolatedFromAbove ops" ) ? void (0) : __assert_fail ("isolatedOp->hasTrait<OpTrait::IsIsolatedFromAbove>() && \"Intended to check IsolatedFromAbove ops\"" , "mlir/lib/IR/Operation.cpp", 1097, __extension__ __PRETTY_FUNCTION__ )) | |||
1097 | "Intended to check IsolatedFromAbove ops")(static_cast <bool> (isolatedOp->hasTrait<OpTrait ::IsIsolatedFromAbove>() && "Intended to check IsolatedFromAbove ops" ) ? void (0) : __assert_fail ("isolatedOp->hasTrait<OpTrait::IsIsolatedFromAbove>() && \"Intended to check IsolatedFromAbove ops\"" , "mlir/lib/IR/Operation.cpp", 1097, __extension__ __PRETTY_FUNCTION__ )); | |||
1098 | ||||
1099 | // List of regions to analyze. Each region is processed independently, with | |||
1100 | // respect to the common `limit` region, so we can look at them in any order. | |||
1101 | // Therefore, use a simple vector and push/pop back the current region. | |||
1102 | SmallVector<Region *, 8> pendingRegions; | |||
1103 | for (auto ®ion : isolatedOp->getRegions()) { | |||
1104 | pendingRegions.push_back(®ion); | |||
1105 | ||||
1106 | // Traverse all operations in the region. | |||
1107 | while (!pendingRegions.empty()) { | |||
1108 | for (Operation &op : pendingRegions.pop_back_val()->getOps()) { | |||
1109 | for (Value operand : op.getOperands()) { | |||
1110 | // Check that any value that is used by an operation is defined in the | |||
1111 | // same region as either an operation result. | |||
1112 | auto *operandRegion = operand.getParentRegion(); | |||
1113 | if (!operandRegion) | |||
1114 | return op.emitError("operation's operand is unlinked"); | |||
1115 | if (!region.isAncestor(operandRegion)) { | |||
1116 | return op.emitOpError("using value defined outside the region") | |||
1117 | .attachNote(isolatedOp->getLoc()) | |||
1118 | << "required by region isolation constraints"; | |||
1119 | } | |||
1120 | } | |||
1121 | ||||
1122 | // Schedule any regions in the operation for further checking. Don't | |||
1123 | // recurse into other IsolatedFromAbove ops, because they will check | |||
1124 | // themselves. | |||
1125 | if (op.getNumRegions() && | |||
1126 | !op.hasTrait<OpTrait::IsIsolatedFromAbove>()) { | |||
1127 | for (Region &subRegion : op.getRegions()) | |||
1128 | pendingRegions.push_back(&subRegion); | |||
1129 | } | |||
1130 | } | |||
1131 | } | |||
1132 | } | |||
1133 | ||||
1134 | return success(); | |||
1135 | } | |||
1136 | ||||
1137 | bool OpTrait::hasElementwiseMappableTraits(Operation *op) { | |||
1138 | return op->hasTrait<Elementwise>() && op->hasTrait<Scalarizable>() && | |||
1139 | op->hasTrait<Vectorizable>() && op->hasTrait<Tensorizable>(); | |||
1140 | } | |||
1141 | ||||
1142 | //===----------------------------------------------------------------------===// | |||
1143 | // CastOpInterface | |||
1144 | //===----------------------------------------------------------------------===// | |||
1145 | ||||
1146 | /// Attempt to fold the given cast operation. | |||
1147 | LogicalResult | |||
1148 | impl::foldCastInterfaceOp(Operation *op, ArrayRef<Attribute> attrOperands, | |||
1149 | SmallVectorImpl<OpFoldResult> &foldResults) { | |||
1150 | OperandRange operands = op->getOperands(); | |||
1151 | if (operands.empty()) | |||
1152 | return failure(); | |||
1153 | ResultRange results = op->getResults(); | |||
1154 | ||||
1155 | // Check for the case where the input and output types match 1-1. | |||
1156 | if (operands.getTypes() == results.getTypes()) { | |||
1157 | foldResults.append(operands.begin(), operands.end()); | |||
1158 | return success(); | |||
1159 | } | |||
1160 | ||||
1161 | return failure(); | |||
1162 | } | |||
1163 | ||||
1164 | /// Attempt to verify the given cast operation. | |||
1165 | LogicalResult impl::verifyCastInterfaceOp( | |||
1166 | Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible) { | |||
1167 | auto resultTypes = op->getResultTypes(); | |||
1168 | if (resultTypes.empty()) | |||
1169 | return op->emitOpError() | |||
1170 | << "expected at least one result for cast operation"; | |||
1171 | ||||
1172 | auto operandTypes = op->getOperandTypes(); | |||
1173 | if (!areCastCompatible(operandTypes, resultTypes)) { | |||
1174 | InFlightDiagnostic diag = op->emitOpError("operand type"); | |||
1175 | if (operandTypes.empty()) | |||
1176 | diag << "s []"; | |||
1177 | else if (llvm::size(operandTypes) == 1) | |||
1178 | diag << " " << *operandTypes.begin(); | |||
1179 | else | |||
1180 | diag << "s " << operandTypes; | |||
1181 | return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ") | |||
1182 | << resultTypes << " are cast incompatible"; | |||
1183 | } | |||
1184 | ||||
1185 | return success(); | |||
1186 | } | |||
1187 | ||||
1188 | //===----------------------------------------------------------------------===// | |||
1189 | // Misc. utils | |||
1190 | //===----------------------------------------------------------------------===// | |||
1191 | ||||
1192 | /// Insert an operation, generated by `buildTerminatorOp`, at the end of the | |||
1193 | /// region's only block if it does not have a terminator already. If the region | |||
1194 | /// is empty, insert a new block first. `buildTerminatorOp` should return the | |||
1195 | /// terminator operation to insert. | |||
1196 | void impl::ensureRegionTerminator( | |||
1197 | Region ®ion, OpBuilder &builder, Location loc, | |||
1198 | function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { | |||
1199 | OpBuilder::InsertionGuard guard(builder); | |||
1200 | if (region.empty()) | |||
1201 | builder.createBlock(®ion); | |||
1202 | ||||
1203 | Block &block = region.back(); | |||
1204 | if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>()) | |||
1205 | return; | |||
1206 | ||||
1207 | builder.setInsertionPointToEnd(&block); | |||
1208 | builder.insert(buildTerminatorOp(builder, loc)); | |||
1209 | } | |||
1210 | ||||
1211 | /// Create a simple OpBuilder and forward to the OpBuilder version of this | |||
1212 | /// function. | |||
1213 | void impl::ensureRegionTerminator( | |||
1214 | Region ®ion, Builder &builder, Location loc, | |||
1215 | function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { | |||
1216 | OpBuilder opBuilder(builder.getContext()); | |||
1217 | ensureRegionTerminator(region, opBuilder, loc, buildTerminatorOp); | |||
1218 | } |