File: | build/source/mlir/lib/IR/Operation.cpp |
Warning: | line 85, column 3 Potential leak of memory pointed to by 'mallocMem' |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===- Operation.cpp - Operation support code -----------------------------===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | ||||
9 | #include "mlir/IR/Operation.h" | |||
10 | #include "mlir/IR/BuiltinTypes.h" | |||
11 | #include "mlir/IR/Dialect.h" | |||
12 | #include "mlir/IR/IRMapping.h" | |||
13 | #include "mlir/IR/OpImplementation.h" | |||
14 | #include "mlir/IR/PatternMatch.h" | |||
15 | #include "mlir/IR/TypeUtilities.h" | |||
16 | #include "mlir/Interfaces/FoldInterfaces.h" | |||
17 | #include "llvm/ADT/StringExtras.h" | |||
18 | #include <numeric> | |||
19 | ||||
20 | using namespace mlir; | |||
21 | ||||
22 | //===----------------------------------------------------------------------===// | |||
23 | // Operation | |||
24 | //===----------------------------------------------------------------------===// | |||
25 | ||||
26 | /// Create a new Operation from operation state. | |||
27 | Operation *Operation::create(const OperationState &state) { | |||
28 | return create(state.location, state.name, state.types, state.operands, | |||
29 | state.attributes.getDictionary(state.getContext()), | |||
30 | state.successors, state.regions); | |||
31 | } | |||
32 | ||||
33 | /// Create a new Operation with the specific fields. | |||
34 | Operation *Operation::create(Location location, OperationName name, | |||
35 | TypeRange resultTypes, ValueRange operands, | |||
36 | NamedAttrList &&attributes, BlockRange successors, | |||
37 | RegionRange regions) { | |||
38 | unsigned numRegions = regions.size(); | |||
39 | Operation *op = create(location, name, resultTypes, operands, | |||
40 | std::move(attributes), successors, numRegions); | |||
41 | for (unsigned i = 0; i < numRegions; ++i) | |||
42 | if (regions[i]) | |||
43 | op->getRegion(i).takeBody(*regions[i]); | |||
44 | return op; | |||
45 | } | |||
46 | ||||
47 | /// Overload of create that takes an existing DictionaryAttr to avoid | |||
48 | /// unnecessarily uniquing a list of attributes. | |||
49 | Operation *Operation::create(Location location, OperationName name, | |||
50 | TypeRange resultTypes, ValueRange operands, | |||
51 | NamedAttrList &&attributes, BlockRange successors, | |||
52 | unsigned numRegions) { | |||
53 | assert(llvm::all_of(resultTypes, [](Type t) { return t; }) &&(static_cast <bool> (llvm::all_of(resultTypes, [](Type t ) { return t; }) && "unexpected null result type") ? void (0) : __assert_fail ("llvm::all_of(resultTypes, [](Type t) { return t; }) && \"unexpected null result type\"" , "mlir/lib/IR/Operation.cpp", 54, __extension__ __PRETTY_FUNCTION__ )) | |||
54 | "unexpected null result type")(static_cast <bool> (llvm::all_of(resultTypes, [](Type t ) { return t; }) && "unexpected null result type") ? void (0) : __assert_fail ("llvm::all_of(resultTypes, [](Type t) { return t; }) && \"unexpected null result type\"" , "mlir/lib/IR/Operation.cpp", 54, __extension__ __PRETTY_FUNCTION__ )); | |||
55 | ||||
56 | // We only need to allocate additional memory for a subset of results. | |||
57 | unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size()); | |||
58 | unsigned numInlineResults = OpResult::getNumInline(resultTypes.size()); | |||
59 | unsigned numSuccessors = successors.size(); | |||
60 | unsigned numOperands = operands.size(); | |||
61 | unsigned numResults = resultTypes.size(); | |||
62 | ||||
63 | // If the operation is known to have no operands, don't allocate an operand | |||
64 | // storage. | |||
65 | bool needsOperandStorage = | |||
66 | operands.empty() ? !name.hasTrait<OpTrait::ZeroOperands>() : true; | |||
67 | ||||
68 | // Compute the byte size for the operation and the operand storage. This takes | |||
69 | // into account the size of the operation, its trailing objects, and its | |||
70 | // prefixed objects. | |||
71 | size_t byteSize = | |||
72 | totalSizeToAlloc<detail::OperandStorage, BlockOperand, Region, OpOperand>( | |||
73 | needsOperandStorage
| |||
74 | size_t prefixByteSize = llvm::alignTo( | |||
75 | Operation::prefixAllocSize(numTrailingResults, numInlineResults), | |||
76 | alignof(Operation)); | |||
77 | char *mallocMem = reinterpret_cast<char *>(malloc(byteSize + prefixByteSize)); | |||
78 | void *rawMem = mallocMem + prefixByteSize; | |||
79 | ||||
80 | // Populate default attributes. | |||
81 | if (Optional<RegisteredOperationName> info = name.getRegisteredInfo()) | |||
82 | info->populateDefaultAttrs(attributes); | |||
83 | ||||
84 | // Create the new Operation. | |||
85 | Operation *op = ::new (rawMem) Operation( | |||
| ||||
86 | location, name, numResults, numSuccessors, numRegions, | |||
87 | attributes.getDictionary(location.getContext()), needsOperandStorage); | |||
88 | ||||
89 | assert((numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) &&(static_cast <bool> ((numSuccessors == 0 || op->mightHaveTrait <OpTrait::IsTerminator>()) && "unexpected successors in a non-terminator operation" ) ? void (0) : __assert_fail ("(numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) && \"unexpected successors in a non-terminator operation\"" , "mlir/lib/IR/Operation.cpp", 90, __extension__ __PRETTY_FUNCTION__ )) | |||
90 | "unexpected successors in a non-terminator operation")(static_cast <bool> ((numSuccessors == 0 || op->mightHaveTrait <OpTrait::IsTerminator>()) && "unexpected successors in a non-terminator operation" ) ? void (0) : __assert_fail ("(numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) && \"unexpected successors in a non-terminator operation\"" , "mlir/lib/IR/Operation.cpp", 90, __extension__ __PRETTY_FUNCTION__ )); | |||
91 | ||||
92 | // Initialize the results. | |||
93 | auto resultTypeIt = resultTypes.begin(); | |||
94 | for (unsigned i = 0; i < numInlineResults; ++i, ++resultTypeIt) | |||
95 | new (op->getInlineOpResult(i)) detail::InlineOpResult(*resultTypeIt, i); | |||
96 | for (unsigned i = 0; i < numTrailingResults; ++i, ++resultTypeIt) { | |||
97 | new (op->getOutOfLineOpResult(i)) | |||
98 | detail::OutOfLineOpResult(*resultTypeIt, i); | |||
99 | } | |||
100 | ||||
101 | // Initialize the regions. | |||
102 | for (unsigned i = 0; i != numRegions; ++i) | |||
103 | new (&op->getRegion(i)) Region(op); | |||
104 | ||||
105 | // Initialize the operands. | |||
106 | if (needsOperandStorage) { | |||
107 | new (&op->getOperandStorage()) detail::OperandStorage( | |||
108 | op, op->getTrailingObjects<OpOperand>(), operands); | |||
109 | } | |||
110 | ||||
111 | // Initialize the successors. | |||
112 | auto blockOperands = op->getBlockOperands(); | |||
113 | for (unsigned i = 0; i != numSuccessors; ++i) | |||
114 | new (&blockOperands[i]) BlockOperand(op, successors[i]); | |||
115 | ||||
116 | return op; | |||
117 | } | |||
118 | ||||
119 | Operation::Operation(Location location, OperationName name, unsigned numResults, | |||
120 | unsigned numSuccessors, unsigned numRegions, | |||
121 | DictionaryAttr attributes, bool hasOperandStorage) | |||
122 | : location(location), numResults(numResults), numSuccs(numSuccessors), | |||
123 | numRegions(numRegions), hasOperandStorage(hasOperandStorage), name(name), | |||
124 | attrs(attributes) { | |||
125 | assert(attributes && "unexpected null attribute dictionary")(static_cast <bool> (attributes && "unexpected null attribute dictionary" ) ? void (0) : __assert_fail ("attributes && \"unexpected null attribute dictionary\"" , "mlir/lib/IR/Operation.cpp", 125, __extension__ __PRETTY_FUNCTION__ )); | |||
126 | #ifndef NDEBUG | |||
127 | if (!getDialect() && !getContext()->allowsUnregisteredDialects()) | |||
128 | llvm::report_fatal_error( | |||
129 | name.getStringRef() + | |||
130 | " created with unregistered dialect. If this is intended, please call " | |||
131 | "allowUnregisteredDialects() on the MLIRContext, or use " | |||
132 | "-allow-unregistered-dialect with the MLIR tool used."); | |||
133 | #endif | |||
134 | } | |||
135 | ||||
136 | // Operations are deleted through the destroy() member because they are | |||
137 | // allocated via malloc. | |||
138 | Operation::~Operation() { | |||
139 | assert(block == nullptr && "operation destroyed but still in a block")(static_cast <bool> (block == nullptr && "operation destroyed but still in a block" ) ? void (0) : __assert_fail ("block == nullptr && \"operation destroyed but still in a block\"" , "mlir/lib/IR/Operation.cpp", 139, __extension__ __PRETTY_FUNCTION__ )); | |||
140 | #ifndef NDEBUG | |||
141 | if (!use_empty()) { | |||
142 | { | |||
143 | InFlightDiagnostic diag = | |||
144 | emitOpError("operation destroyed but still has uses"); | |||
145 | for (Operation *user : getUsers()) | |||
146 | diag.attachNote(user->getLoc()) << "- use: " << *user << "\n"; | |||
147 | } | |||
148 | llvm::report_fatal_error("operation destroyed but still has uses"); | |||
149 | } | |||
150 | #endif | |||
151 | // Explicitly run the destructors for the operands. | |||
152 | if (hasOperandStorage) | |||
153 | getOperandStorage().~OperandStorage(); | |||
154 | ||||
155 | // Explicitly run the destructors for the successors. | |||
156 | for (auto &successor : getBlockOperands()) | |||
157 | successor.~BlockOperand(); | |||
158 | ||||
159 | // Explicitly destroy the regions. | |||
160 | for (auto ®ion : getRegions()) | |||
161 | region.~Region(); | |||
162 | } | |||
163 | ||||
164 | /// Destroy this operation or one of its subclasses. | |||
165 | void Operation::destroy() { | |||
166 | // Operations may have additional prefixed allocation, which needs to be | |||
167 | // accounted for here when computing the address to free. | |||
168 | char *rawMem = reinterpret_cast<char *>(this) - | |||
169 | llvm::alignTo(prefixAllocSize(), alignof(Operation)); | |||
170 | this->~Operation(); | |||
171 | free(rawMem); | |||
172 | } | |||
173 | ||||
174 | /// Return true if this operation is a proper ancestor of the `other` | |||
175 | /// operation. | |||
176 | bool Operation::isProperAncestor(Operation *other) { | |||
177 | while ((other = other->getParentOp())) | |||
178 | if (this == other) | |||
179 | return true; | |||
180 | return false; | |||
181 | } | |||
182 | ||||
183 | /// Replace any uses of 'from' with 'to' within this operation. | |||
184 | void Operation::replaceUsesOfWith(Value from, Value to) { | |||
185 | if (from == to) | |||
186 | return; | |||
187 | for (auto &operand : getOpOperands()) | |||
188 | if (operand.get() == from) | |||
189 | operand.set(to); | |||
190 | } | |||
191 | ||||
192 | /// Replace the current operands of this operation with the ones provided in | |||
193 | /// 'operands'. | |||
194 | void Operation::setOperands(ValueRange operands) { | |||
195 | if (LLVM_LIKELY(hasOperandStorage)__builtin_expect((bool)(hasOperandStorage), true)) | |||
196 | return getOperandStorage().setOperands(this, operands); | |||
197 | assert(operands.empty() && "setting operands without an operand storage")(static_cast <bool> (operands.empty() && "setting operands without an operand storage" ) ? void (0) : __assert_fail ("operands.empty() && \"setting operands without an operand storage\"" , "mlir/lib/IR/Operation.cpp", 197, __extension__ __PRETTY_FUNCTION__ )); | |||
198 | } | |||
199 | ||||
200 | /// Replace the operands beginning at 'start' and ending at 'start' + 'length' | |||
201 | /// with the ones provided in 'operands'. 'operands' may be smaller or larger | |||
202 | /// than the range pointed to by 'start'+'length'. | |||
203 | void Operation::setOperands(unsigned start, unsigned length, | |||
204 | ValueRange operands) { | |||
205 | assert((start + length) <= getNumOperands() &&(static_cast <bool> ((start + length) <= getNumOperands () && "invalid operand range specified") ? void (0) : __assert_fail ("(start + length) <= getNumOperands() && \"invalid operand range specified\"" , "mlir/lib/IR/Operation.cpp", 206, __extension__ __PRETTY_FUNCTION__ )) | |||
206 | "invalid operand range specified")(static_cast <bool> ((start + length) <= getNumOperands () && "invalid operand range specified") ? void (0) : __assert_fail ("(start + length) <= getNumOperands() && \"invalid operand range specified\"" , "mlir/lib/IR/Operation.cpp", 206, __extension__ __PRETTY_FUNCTION__ )); | |||
207 | if (LLVM_LIKELY(hasOperandStorage)__builtin_expect((bool)(hasOperandStorage), true)) | |||
208 | return getOperandStorage().setOperands(this, start, length, operands); | |||
209 | assert(operands.empty() && "setting operands without an operand storage")(static_cast <bool> (operands.empty() && "setting operands without an operand storage" ) ? void (0) : __assert_fail ("operands.empty() && \"setting operands without an operand storage\"" , "mlir/lib/IR/Operation.cpp", 209, __extension__ __PRETTY_FUNCTION__ )); | |||
210 | } | |||
211 | ||||
212 | /// Insert the given operands into the operand list at the given 'index'. | |||
213 | void Operation::insertOperands(unsigned index, ValueRange operands) { | |||
214 | if (LLVM_LIKELY(hasOperandStorage)__builtin_expect((bool)(hasOperandStorage), true)) | |||
215 | return setOperands(index, /*length=*/0, operands); | |||
216 | assert(operands.empty() && "inserting operands without an operand storage")(static_cast <bool> (operands.empty() && "inserting operands without an operand storage" ) ? void (0) : __assert_fail ("operands.empty() && \"inserting operands without an operand storage\"" , "mlir/lib/IR/Operation.cpp", 216, __extension__ __PRETTY_FUNCTION__ )); | |||
217 | } | |||
218 | ||||
219 | //===----------------------------------------------------------------------===// | |||
220 | // Diagnostics | |||
221 | //===----------------------------------------------------------------------===// | |||
222 | ||||
223 | /// Emit an error about fatal conditions with this operation, reporting up to | |||
224 | /// any diagnostic handlers that may be listening. | |||
225 | InFlightDiagnostic Operation::emitError(const Twine &message) { | |||
226 | InFlightDiagnostic diag = mlir::emitError(getLoc(), message); | |||
227 | if (getContext()->shouldPrintOpOnDiagnostic()) { | |||
228 | diag.attachNote(getLoc()) | |||
229 | .append("see current operation: ") | |||
230 | .appendOp(*this, OpPrintingFlags().printGenericOpForm()); | |||
231 | } | |||
232 | return diag; | |||
233 | } | |||
234 | ||||
235 | /// Emit a warning about this operation, reporting up to any diagnostic | |||
236 | /// handlers that may be listening. | |||
237 | InFlightDiagnostic Operation::emitWarning(const Twine &message) { | |||
238 | InFlightDiagnostic diag = mlir::emitWarning(getLoc(), message); | |||
239 | if (getContext()->shouldPrintOpOnDiagnostic()) | |||
240 | diag.attachNote(getLoc()) << "see current operation: " << *this; | |||
241 | return diag; | |||
242 | } | |||
243 | ||||
244 | /// Emit a remark about this operation, reporting up to any diagnostic | |||
245 | /// handlers that may be listening. | |||
246 | InFlightDiagnostic Operation::emitRemark(const Twine &message) { | |||
247 | InFlightDiagnostic diag = mlir::emitRemark(getLoc(), message); | |||
248 | if (getContext()->shouldPrintOpOnDiagnostic()) | |||
249 | diag.attachNote(getLoc()) << "see current operation: " << *this; | |||
250 | return diag; | |||
251 | } | |||
252 | ||||
253 | //===----------------------------------------------------------------------===// | |||
254 | // Operation Ordering | |||
255 | //===----------------------------------------------------------------------===// | |||
256 | ||||
257 | constexpr unsigned Operation::kInvalidOrderIdx; | |||
258 | constexpr unsigned Operation::kOrderStride; | |||
259 | ||||
260 | /// Given an operation 'other' that is within the same parent block, return | |||
261 | /// whether the current operation is before 'other' in the operation list | |||
262 | /// of the parent block. | |||
263 | /// Note: This function has an average complexity of O(1), but worst case may | |||
264 | /// take O(N) where N is the number of operations within the parent block. | |||
265 | bool Operation::isBeforeInBlock(Operation *other) { | |||
266 | assert(block && "Operations without parent blocks have no order.")(static_cast <bool> (block && "Operations without parent blocks have no order." ) ? void (0) : __assert_fail ("block && \"Operations without parent blocks have no order.\"" , "mlir/lib/IR/Operation.cpp", 266, __extension__ __PRETTY_FUNCTION__ )); | |||
267 | assert(other && other->block == block &&(static_cast <bool> (other && other->block == block && "Expected other operation to have the same parent block." ) ? void (0) : __assert_fail ("other && other->block == block && \"Expected other operation to have the same parent block.\"" , "mlir/lib/IR/Operation.cpp", 268, __extension__ __PRETTY_FUNCTION__ )) | |||
268 | "Expected other operation to have the same parent block.")(static_cast <bool> (other && other->block == block && "Expected other operation to have the same parent block." ) ? void (0) : __assert_fail ("other && other->block == block && \"Expected other operation to have the same parent block.\"" , "mlir/lib/IR/Operation.cpp", 268, __extension__ __PRETTY_FUNCTION__ )); | |||
269 | // If the order of the block is already invalid, directly recompute the | |||
270 | // parent. | |||
271 | if (!block->isOpOrderValid()) { | |||
272 | block->recomputeOpOrder(); | |||
273 | } else { | |||
274 | // Update the order either operation if necessary. | |||
275 | updateOrderIfNecessary(); | |||
276 | other->updateOrderIfNecessary(); | |||
277 | } | |||
278 | ||||
279 | return orderIndex < other->orderIndex; | |||
280 | } | |||
281 | ||||
282 | /// Update the order index of this operation of this operation if necessary, | |||
283 | /// potentially recomputing the order of the parent block. | |||
284 | void Operation::updateOrderIfNecessary() { | |||
285 | assert(block && "expected valid parent")(static_cast <bool> (block && "expected valid parent" ) ? void (0) : __assert_fail ("block && \"expected valid parent\"" , "mlir/lib/IR/Operation.cpp", 285, __extension__ __PRETTY_FUNCTION__ )); | |||
286 | ||||
287 | // If the order is valid for this operation there is nothing to do. | |||
288 | if (hasValidOrder()) | |||
289 | return; | |||
290 | Operation *blockFront = &block->front(); | |||
291 | Operation *blockBack = &block->back(); | |||
292 | ||||
293 | // This method is expected to only be invoked on blocks with more than one | |||
294 | // operation. | |||
295 | assert(blockFront != blockBack && "expected more than one operation")(static_cast <bool> (blockFront != blockBack && "expected more than one operation") ? void (0) : __assert_fail ("blockFront != blockBack && \"expected more than one operation\"" , "mlir/lib/IR/Operation.cpp", 295, __extension__ __PRETTY_FUNCTION__ )); | |||
296 | ||||
297 | // If the operation is at the end of the block. | |||
298 | if (this == blockBack) { | |||
299 | Operation *prevNode = getPrevNode(); | |||
300 | if (!prevNode->hasValidOrder()) | |||
301 | return block->recomputeOpOrder(); | |||
302 | ||||
303 | // Add the stride to the previous operation. | |||
304 | orderIndex = prevNode->orderIndex + kOrderStride; | |||
305 | return; | |||
306 | } | |||
307 | ||||
308 | // If this is the first operation try to use the next operation to compute the | |||
309 | // ordering. | |||
310 | if (this == blockFront) { | |||
311 | Operation *nextNode = getNextNode(); | |||
312 | if (!nextNode->hasValidOrder()) | |||
313 | return block->recomputeOpOrder(); | |||
314 | // There is no order to give this operation. | |||
315 | if (nextNode->orderIndex == 0) | |||
316 | return block->recomputeOpOrder(); | |||
317 | ||||
318 | // If we can't use the stride, just take the middle value left. This is safe | |||
319 | // because we know there is at least one valid index to assign to. | |||
320 | if (nextNode->orderIndex <= kOrderStride) | |||
321 | orderIndex = (nextNode->orderIndex / 2); | |||
322 | else | |||
323 | orderIndex = kOrderStride; | |||
324 | return; | |||
325 | } | |||
326 | ||||
327 | // Otherwise, this operation is between two others. Place this operation in | |||
328 | // the middle of the previous and next if possible. | |||
329 | Operation *prevNode = getPrevNode(), *nextNode = getNextNode(); | |||
330 | if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder()) | |||
331 | return block->recomputeOpOrder(); | |||
332 | unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex; | |||
333 | ||||
334 | // Check to see if there is a valid order between the two. | |||
335 | if (prevOrder + 1 == nextOrder) | |||
336 | return block->recomputeOpOrder(); | |||
337 | orderIndex = prevOrder + ((nextOrder - prevOrder) / 2); | |||
338 | } | |||
339 | ||||
340 | //===----------------------------------------------------------------------===// | |||
341 | // ilist_traits for Operation | |||
342 | //===----------------------------------------------------------------------===// | |||
343 | ||||
344 | auto llvm::ilist_detail::SpecificNodeAccess< | |||
345 | typename llvm::ilist_detail::compute_node_options< | |||
346 | ::mlir::Operation>::type>::getNodePtr(pointer n) -> node_type * { | |||
347 | return NodeAccess::getNodePtr<OptionsT>(n); | |||
348 | } | |||
349 | ||||
350 | auto llvm::ilist_detail::SpecificNodeAccess< | |||
351 | typename llvm::ilist_detail::compute_node_options< | |||
352 | ::mlir::Operation>::type>::getNodePtr(const_pointer n) | |||
353 | -> const node_type * { | |||
354 | return NodeAccess::getNodePtr<OptionsT>(n); | |||
355 | } | |||
356 | ||||
357 | auto llvm::ilist_detail::SpecificNodeAccess< | |||
358 | typename llvm::ilist_detail::compute_node_options< | |||
359 | ::mlir::Operation>::type>::getValuePtr(node_type *n) -> pointer { | |||
360 | return NodeAccess::getValuePtr<OptionsT>(n); | |||
361 | } | |||
362 | ||||
363 | auto llvm::ilist_detail::SpecificNodeAccess< | |||
364 | typename llvm::ilist_detail::compute_node_options< | |||
365 | ::mlir::Operation>::type>::getValuePtr(const node_type *n) | |||
366 | -> const_pointer { | |||
367 | return NodeAccess::getValuePtr<OptionsT>(n); | |||
368 | } | |||
369 | ||||
370 | void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) { | |||
371 | op->destroy(); | |||
372 | } | |||
373 | ||||
374 | Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() { | |||
375 | size_t offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); | |||
376 | iplist<Operation> *anchor(static_cast<iplist<Operation> *>(this)); | |||
377 | return reinterpret_cast<Block *>(reinterpret_cast<char *>(anchor) - offset); | |||
378 | } | |||
379 | ||||
380 | /// This is a trait method invoked when an operation is added to a block. We | |||
381 | /// keep the block pointer up to date. | |||
382 | void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) { | |||
383 | assert(!op->getBlock() && "already in an operation block!")(static_cast <bool> (!op->getBlock() && "already in an operation block!" ) ? void (0) : __assert_fail ("!op->getBlock() && \"already in an operation block!\"" , "mlir/lib/IR/Operation.cpp", 383, __extension__ __PRETTY_FUNCTION__ )); | |||
384 | op->block = getContainingBlock(); | |||
385 | ||||
386 | // Invalidate the order on the operation. | |||
387 | op->orderIndex = Operation::kInvalidOrderIdx; | |||
388 | } | |||
389 | ||||
390 | /// This is a trait method invoked when an operation is removed from a block. | |||
391 | /// We keep the block pointer up to date. | |||
392 | void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) { | |||
393 | assert(op->block && "not already in an operation block!")(static_cast <bool> (op->block && "not already in an operation block!" ) ? void (0) : __assert_fail ("op->block && \"not already in an operation block!\"" , "mlir/lib/IR/Operation.cpp", 393, __extension__ __PRETTY_FUNCTION__ )); | |||
394 | op->block = nullptr; | |||
395 | } | |||
396 | ||||
397 | /// This is a trait method invoked when an operation is moved from one block | |||
398 | /// to another. We keep the block pointer up to date. | |||
399 | void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList( | |||
400 | ilist_traits<Operation> &otherList, op_iterator first, op_iterator last) { | |||
401 | Block *curParent = getContainingBlock(); | |||
402 | ||||
403 | // Invalidate the ordering of the parent block. | |||
404 | curParent->invalidateOpOrder(); | |||
405 | ||||
406 | // If we are transferring operations within the same block, the block | |||
407 | // pointer doesn't need to be updated. | |||
408 | if (curParent == otherList.getContainingBlock()) | |||
409 | return; | |||
410 | ||||
411 | // Update the 'block' member of each operation. | |||
412 | for (; first != last; ++first) | |||
413 | first->block = curParent; | |||
414 | } | |||
415 | ||||
416 | /// Remove this operation (and its descendants) from its Block and delete | |||
417 | /// all of them. | |||
418 | void Operation::erase() { | |||
419 | if (auto *parent = getBlock()) | |||
420 | parent->getOperations().erase(this); | |||
421 | else | |||
422 | destroy(); | |||
423 | } | |||
424 | ||||
425 | /// Remove the operation from its parent block, but don't delete it. | |||
426 | void Operation::remove() { | |||
427 | if (Block *parent = getBlock()) | |||
428 | parent->getOperations().remove(this); | |||
429 | } | |||
430 | ||||
431 | /// Unlink this operation from its current block and insert it right before | |||
432 | /// `existingOp` which may be in the same or another block in the same | |||
433 | /// function. | |||
434 | void Operation::moveBefore(Operation *existingOp) { | |||
435 | moveBefore(existingOp->getBlock(), existingOp->getIterator()); | |||
436 | } | |||
437 | ||||
438 | /// Unlink this operation from its current basic block and insert it right | |||
439 | /// before `iterator` in the specified basic block. | |||
440 | void Operation::moveBefore(Block *block, | |||
441 | llvm::iplist<Operation>::iterator iterator) { | |||
442 | block->getOperations().splice(iterator, getBlock()->getOperations(), | |||
443 | getIterator()); | |||
444 | } | |||
445 | ||||
446 | /// Unlink this operation from its current block and insert it right after | |||
447 | /// `existingOp` which may be in the same or another block in the same function. | |||
448 | void Operation::moveAfter(Operation *existingOp) { | |||
449 | moveAfter(existingOp->getBlock(), existingOp->getIterator()); | |||
450 | } | |||
451 | ||||
452 | /// Unlink this operation from its current block and insert it right after | |||
453 | /// `iterator` in the specified block. | |||
454 | void Operation::moveAfter(Block *block, | |||
455 | llvm::iplist<Operation>::iterator iterator) { | |||
456 | assert(iterator != block->end() && "cannot move after end of block")(static_cast <bool> (iterator != block->end() && "cannot move after end of block") ? void (0) : __assert_fail ("iterator != block->end() && \"cannot move after end of block\"" , "mlir/lib/IR/Operation.cpp", 456, __extension__ __PRETTY_FUNCTION__ )); | |||
457 | moveBefore(block, std::next(iterator)); | |||
458 | } | |||
459 | ||||
460 | /// This drops all operand uses from this operation, which is an essential | |||
461 | /// step in breaking cyclic dependences between references when they are to | |||
462 | /// be deleted. | |||
463 | void Operation::dropAllReferences() { | |||
464 | for (auto &op : getOpOperands()) | |||
465 | op.drop(); | |||
466 | ||||
467 | for (auto ®ion : getRegions()) | |||
468 | region.dropAllReferences(); | |||
469 | ||||
470 | for (auto &dest : getBlockOperands()) | |||
471 | dest.drop(); | |||
472 | } | |||
473 | ||||
474 | /// This drops all uses of any values defined by this operation or its nested | |||
475 | /// regions, wherever they are located. | |||
476 | void Operation::dropAllDefinedValueUses() { | |||
477 | dropAllUses(); | |||
478 | ||||
479 | for (auto ®ion : getRegions()) | |||
480 | for (auto &block : region) | |||
481 | block.dropAllDefinedValueUses(); | |||
482 | } | |||
483 | ||||
484 | void Operation::setSuccessor(Block *block, unsigned index) { | |||
485 | assert(index < getNumSuccessors())(static_cast <bool> (index < getNumSuccessors()) ? void (0) : __assert_fail ("index < getNumSuccessors()", "mlir/lib/IR/Operation.cpp" , 485, __extension__ __PRETTY_FUNCTION__)); | |||
486 | getBlockOperands()[index].set(block); | |||
487 | } | |||
488 | ||||
489 | /// Attempt to fold this operation using the Op's registered foldHook. | |||
490 | LogicalResult Operation::fold(ArrayRef<Attribute> operands, | |||
491 | SmallVectorImpl<OpFoldResult> &results) { | |||
492 | // If we have a registered operation definition matching this one, use it to | |||
493 | // try to constant fold the operation. | |||
494 | Optional<RegisteredOperationName> info = getRegisteredInfo(); | |||
495 | if (info && succeeded(info->foldHook(this, operands, results))) | |||
496 | return success(); | |||
497 | ||||
498 | // Otherwise, fall back on the dialect hook to handle it. | |||
499 | Dialect *dialect = getDialect(); | |||
500 | if (!dialect) | |||
501 | return failure(); | |||
502 | ||||
503 | auto *interface = dyn_cast<DialectFoldInterface>(dialect); | |||
504 | if (!interface) | |||
505 | return failure(); | |||
506 | ||||
507 | return interface->fold(this, operands, results); | |||
508 | } | |||
509 | ||||
510 | /// Emit an error with the op name prefixed, like "'dim' op " which is | |||
511 | /// convenient for verifiers. | |||
512 | InFlightDiagnostic Operation::emitOpError(const Twine &message) { | |||
513 | return emitError() << "'" << getName() << "' op " << message; | |||
514 | } | |||
515 | ||||
516 | //===----------------------------------------------------------------------===// | |||
517 | // Operation Cloning | |||
518 | //===----------------------------------------------------------------------===// | |||
519 | ||||
520 | Operation::CloneOptions::CloneOptions() | |||
521 | : cloneRegionsFlag(false), cloneOperandsFlag(false) {} | |||
522 | ||||
523 | Operation::CloneOptions::CloneOptions(bool cloneRegions, bool cloneOperands) | |||
524 | : cloneRegionsFlag(cloneRegions), cloneOperandsFlag(cloneOperands) {} | |||
525 | ||||
526 | Operation::CloneOptions Operation::CloneOptions::all() { | |||
527 | return CloneOptions().cloneRegions().cloneOperands(); | |||
528 | } | |||
529 | ||||
530 | Operation::CloneOptions &Operation::CloneOptions::cloneRegions(bool enable) { | |||
531 | cloneRegionsFlag = enable; | |||
532 | return *this; | |||
533 | } | |||
534 | ||||
535 | Operation::CloneOptions &Operation::CloneOptions::cloneOperands(bool enable) { | |||
536 | cloneOperandsFlag = enable; | |||
537 | return *this; | |||
538 | } | |||
539 | ||||
540 | /// Create a deep copy of this operation but keep the operation regions empty. | |||
541 | /// Operands are remapped using `mapper` (if present), and `mapper` is updated | |||
542 | /// to contain the results. The `mapResults` flag specifies whether the results | |||
543 | /// of the cloned operation should be added to the map. | |||
544 | Operation *Operation::cloneWithoutRegions(IRMapping &mapper) { | |||
545 | return clone(mapper, CloneOptions::all().cloneRegions(false)); | |||
546 | } | |||
547 | ||||
548 | Operation *Operation::cloneWithoutRegions() { | |||
549 | IRMapping mapper; | |||
550 | return cloneWithoutRegions(mapper); | |||
551 | } | |||
552 | ||||
553 | /// Create a deep copy of this operation, remapping any operands that use | |||
554 | /// values outside of the operation using the map that is provided (leaving | |||
555 | /// them alone if no entry is present). Replaces references to cloned | |||
556 | /// sub-operations to the corresponding operation that is copied, and adds | |||
557 | /// those mappings to the map. | |||
558 | Operation *Operation::clone(IRMapping &mapper, CloneOptions options) { | |||
559 | SmallVector<Value, 8> operands; | |||
560 | SmallVector<Block *, 2> successors; | |||
561 | ||||
562 | // Remap the operands. | |||
563 | if (options.shouldCloneOperands()) { | |||
564 | operands.reserve(getNumOperands()); | |||
565 | for (auto opValue : getOperands()) | |||
566 | operands.push_back(mapper.lookupOrDefault(opValue)); | |||
567 | } | |||
568 | ||||
569 | // Remap the successors. | |||
570 | successors.reserve(getNumSuccessors()); | |||
571 | for (Block *successor : getSuccessors()) | |||
572 | successors.push_back(mapper.lookupOrDefault(successor)); | |||
573 | ||||
574 | // Create the new operation. | |||
575 | auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs, | |||
576 | successors, getNumRegions()); | |||
577 | mapper.map(this, newOp); | |||
578 | ||||
579 | // Clone the regions. | |||
580 | if (options.shouldCloneRegions()) { | |||
581 | for (unsigned i = 0; i != numRegions; ++i) | |||
582 | getRegion(i).cloneInto(&newOp->getRegion(i), mapper); | |||
583 | } | |||
584 | ||||
585 | // Remember the mapping of any results. | |||
586 | for (unsigned i = 0, e = getNumResults(); i != e; ++i) | |||
587 | mapper.map(getResult(i), newOp->getResult(i)); | |||
588 | ||||
589 | return newOp; | |||
590 | } | |||
591 | ||||
592 | Operation *Operation::clone(CloneOptions options) { | |||
593 | IRMapping mapper; | |||
594 | return clone(mapper, options); | |||
| ||||
595 | } | |||
596 | ||||
597 | //===----------------------------------------------------------------------===// | |||
598 | // OpState trait class. | |||
599 | //===----------------------------------------------------------------------===// | |||
600 | ||||
601 | // The fallback for the parser is to try for a dialect operation parser. | |||
602 | // Otherwise, reject the custom assembly form. | |||
603 | ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) { | |||
604 | if (auto parseFn = result.name.getDialect()->getParseOperationHook( | |||
605 | result.name.getStringRef())) | |||
606 | return (*parseFn)(parser, result); | |||
607 | return parser.emitError(parser.getNameLoc(), "has no custom assembly form"); | |||
608 | } | |||
609 | ||||
610 | // The fallback for the printer is to try for a dialect operation printer. | |||
611 | // Otherwise, it prints the generic form. | |||
612 | void OpState::print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect) { | |||
613 | if (auto printFn = op->getDialect()->getOperationPrinter(op)) { | |||
614 | printOpName(op, p, defaultDialect); | |||
615 | printFn(op, p); | |||
616 | } else { | |||
617 | p.printGenericOp(op); | |||
618 | } | |||
619 | } | |||
620 | ||||
621 | /// Print an operation name, eliding the dialect prefix if necessary and doesn't | |||
622 | /// lead to ambiguities. | |||
623 | void OpState::printOpName(Operation *op, OpAsmPrinter &p, | |||
624 | StringRef defaultDialect) { | |||
625 | StringRef name = op->getName().getStringRef(); | |||
626 | if (name.startswith((defaultDialect + ".").str()) && name.count('.') == 1) | |||
627 | name = name.drop_front(defaultDialect.size() + 1); | |||
628 | p.getStream() << name; | |||
629 | } | |||
630 | ||||
631 | /// Emit an error about fatal conditions with this operation, reporting up to | |||
632 | /// any diagnostic handlers that may be listening. | |||
633 | InFlightDiagnostic OpState::emitError(const Twine &message) { | |||
634 | return getOperation()->emitError(message); | |||
635 | } | |||
636 | ||||
637 | /// Emit an error with the op name prefixed, like "'dim' op " which is | |||
638 | /// convenient for verifiers. | |||
639 | InFlightDiagnostic OpState::emitOpError(const Twine &message) { | |||
640 | return getOperation()->emitOpError(message); | |||
641 | } | |||
642 | ||||
643 | /// Emit a warning about this operation, reporting up to any diagnostic | |||
644 | /// handlers that may be listening. | |||
645 | InFlightDiagnostic OpState::emitWarning(const Twine &message) { | |||
646 | return getOperation()->emitWarning(message); | |||
647 | } | |||
648 | ||||
649 | /// Emit a remark about this operation, reporting up to any diagnostic | |||
650 | /// handlers that may be listening. | |||
651 | InFlightDiagnostic OpState::emitRemark(const Twine &message) { | |||
652 | return getOperation()->emitRemark(message); | |||
653 | } | |||
654 | ||||
655 | //===----------------------------------------------------------------------===// | |||
656 | // Op Trait implementations | |||
657 | //===----------------------------------------------------------------------===// | |||
658 | ||||
659 | OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) { | |||
660 | if (op->getNumOperands() == 1) { | |||
661 | auto *argumentOp = op->getOperand(0).getDefiningOp(); | |||
662 | if (argumentOp && op->getName() == argumentOp->getName()) { | |||
663 | // Replace the outer operation output with the inner operation. | |||
664 | return op->getOperand(0); | |||
665 | } | |||
666 | } else if (op->getOperand(0) == op->getOperand(1)) { | |||
667 | return op->getOperand(0); | |||
668 | } | |||
669 | ||||
670 | return {}; | |||
671 | } | |||
672 | ||||
673 | OpFoldResult OpTrait::impl::foldInvolution(Operation *op) { | |||
674 | auto *argumentOp = op->getOperand(0).getDefiningOp(); | |||
675 | if (argumentOp && op->getName() == argumentOp->getName()) { | |||
676 | // Replace the outer involutions output with inner's input. | |||
677 | return argumentOp->getOperand(0); | |||
678 | } | |||
679 | ||||
680 | return {}; | |||
681 | } | |||
682 | ||||
683 | LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) { | |||
684 | if (op->getNumOperands() != 0) | |||
685 | return op->emitOpError() << "requires zero operands"; | |||
686 | return success(); | |||
687 | } | |||
688 | ||||
689 | LogicalResult OpTrait::impl::verifyOneOperand(Operation *op) { | |||
690 | if (op->getNumOperands() != 1) | |||
691 | return op->emitOpError() << "requires a single operand"; | |||
692 | return success(); | |||
693 | } | |||
694 | ||||
695 | LogicalResult OpTrait::impl::verifyNOperands(Operation *op, | |||
696 | unsigned numOperands) { | |||
697 | if (op->getNumOperands() != numOperands) { | |||
698 | return op->emitOpError() << "expected " << numOperands | |||
699 | << " operands, but found " << op->getNumOperands(); | |||
700 | } | |||
701 | return success(); | |||
702 | } | |||
703 | ||||
704 | LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op, | |||
705 | unsigned numOperands) { | |||
706 | if (op->getNumOperands() < numOperands) | |||
707 | return op->emitOpError() | |||
708 | << "expected " << numOperands << " or more operands, but found " | |||
709 | << op->getNumOperands(); | |||
710 | return success(); | |||
711 | } | |||
712 | ||||
713 | /// If this is a vector type, or a tensor type, return the scalar element type | |||
714 | /// that it is built around, otherwise return the type unmodified. | |||
715 | static Type getTensorOrVectorElementType(Type type) { | |||
716 | if (auto vec = type.dyn_cast<VectorType>()) | |||
717 | return vec.getElementType(); | |||
718 | ||||
719 | // Look through tensor<vector<...>> to find the underlying element type. | |||
720 | if (auto tensor = type.dyn_cast<TensorType>()) | |||
721 | return getTensorOrVectorElementType(tensor.getElementType()); | |||
722 | return type; | |||
723 | } | |||
724 | ||||
725 | LogicalResult OpTrait::impl::verifyIsIdempotent(Operation *op) { | |||
726 | // FIXME: Add back check for no side effects on operation. | |||
727 | // Currently adding it would cause the shared library build | |||
728 | // to fail since there would be a dependency of IR on SideEffectInterfaces | |||
729 | // which is cyclical. | |||
730 | return success(); | |||
731 | } | |||
732 | ||||
733 | LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) { | |||
734 | // FIXME: Add back check for no side effects on operation. | |||
735 | // Currently adding it would cause the shared library build | |||
736 | // to fail since there would be a dependency of IR on SideEffectInterfaces | |||
737 | // which is cyclical. | |||
738 | return success(); | |||
739 | } | |||
740 | ||||
741 | LogicalResult | |||
742 | OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) { | |||
743 | for (auto opType : op->getOperandTypes()) { | |||
744 | auto type = getTensorOrVectorElementType(opType); | |||
745 | if (!type.isSignlessIntOrIndex()) | |||
746 | return op->emitOpError() << "requires an integer or index type"; | |||
747 | } | |||
748 | return success(); | |||
749 | } | |||
750 | ||||
751 | LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) { | |||
752 | for (auto opType : op->getOperandTypes()) { | |||
753 | auto type = getTensorOrVectorElementType(opType); | |||
754 | if (!type.isa<FloatType>()) | |||
755 | return op->emitOpError("requires a float type"); | |||
756 | } | |||
757 | return success(); | |||
758 | } | |||
759 | ||||
760 | LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) { | |||
761 | // Zero or one operand always have the "same" type. | |||
762 | unsigned nOperands = op->getNumOperands(); | |||
763 | if (nOperands < 2) | |||
764 | return success(); | |||
765 | ||||
766 | auto type = op->getOperand(0).getType(); | |||
767 | for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) | |||
768 | if (opType != type) | |||
769 | return op->emitOpError() << "requires all operands to have the same type"; | |||
770 | return success(); | |||
771 | } | |||
772 | ||||
773 | LogicalResult OpTrait::impl::verifyZeroRegions(Operation *op) { | |||
774 | if (op->getNumRegions() != 0) | |||
775 | return op->emitOpError() << "requires zero regions"; | |||
776 | return success(); | |||
777 | } | |||
778 | ||||
779 | LogicalResult OpTrait::impl::verifyOneRegion(Operation *op) { | |||
780 | if (op->getNumRegions() != 1) | |||
781 | return op->emitOpError() << "requires one region"; | |||
782 | return success(); | |||
783 | } | |||
784 | ||||
785 | LogicalResult OpTrait::impl::verifyNRegions(Operation *op, | |||
786 | unsigned numRegions) { | |||
787 | if (op->getNumRegions() != numRegions) | |||
788 | return op->emitOpError() << "expected " << numRegions << " regions"; | |||
789 | return success(); | |||
790 | } | |||
791 | ||||
792 | LogicalResult OpTrait::impl::verifyAtLeastNRegions(Operation *op, | |||
793 | unsigned numRegions) { | |||
794 | if (op->getNumRegions() < numRegions) | |||
795 | return op->emitOpError() << "expected " << numRegions << " or more regions"; | |||
796 | return success(); | |||
797 | } | |||
798 | ||||
799 | LogicalResult OpTrait::impl::verifyZeroResults(Operation *op) { | |||
800 | if (op->getNumResults() != 0) | |||
801 | return op->emitOpError() << "requires zero results"; | |||
802 | return success(); | |||
803 | } | |||
804 | ||||
805 | LogicalResult OpTrait::impl::verifyOneResult(Operation *op) { | |||
806 | if (op->getNumResults() != 1) | |||
807 | return op->emitOpError() << "requires one result"; | |||
808 | return success(); | |||
809 | } | |||
810 | ||||
811 | LogicalResult OpTrait::impl::verifyNResults(Operation *op, | |||
812 | unsigned numOperands) { | |||
813 | if (op->getNumResults() != numOperands) | |||
814 | return op->emitOpError() << "expected " << numOperands << " results"; | |||
815 | return success(); | |||
816 | } | |||
817 | ||||
818 | LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op, | |||
819 | unsigned numOperands) { | |||
820 | if (op->getNumResults() < numOperands) | |||
821 | return op->emitOpError() | |||
822 | << "expected " << numOperands << " or more results"; | |||
823 | return success(); | |||
824 | } | |||
825 | ||||
826 | LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) { | |||
827 | if (failed(verifyAtLeastNOperands(op, 1))) | |||
828 | return failure(); | |||
829 | ||||
830 | if (failed(verifyCompatibleShapes(op->getOperandTypes()))) | |||
831 | return op->emitOpError() << "requires the same shape for all operands"; | |||
832 | ||||
833 | return success(); | |||
834 | } | |||
835 | ||||
836 | LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { | |||
837 | if (failed(verifyAtLeastNOperands(op, 1)) || | |||
838 | failed(verifyAtLeastNResults(op, 1))) | |||
839 | return failure(); | |||
840 | ||||
841 | SmallVector<Type, 8> types(op->getOperandTypes()); | |||
842 | types.append(llvm::to_vector<4>(op->getResultTypes())); | |||
843 | ||||
844 | if (failed(verifyCompatibleShapes(types))) | |||
845 | return op->emitOpError() | |||
846 | << "requires the same shape for all operands and results"; | |||
847 | ||||
848 | return success(); | |||
849 | } | |||
850 | ||||
851 | LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) { | |||
852 | if (failed(verifyAtLeastNOperands(op, 1))) | |||
853 | return failure(); | |||
854 | auto elementType = getElementTypeOrSelf(op->getOperand(0)); | |||
855 | ||||
856 | for (auto operand : llvm::drop_begin(op->getOperands(), 1)) { | |||
857 | if (getElementTypeOrSelf(operand) != elementType) | |||
858 | return op->emitOpError("requires the same element type for all operands"); | |||
859 | } | |||
860 | ||||
861 | return success(); | |||
862 | } | |||
863 | ||||
864 | LogicalResult | |||
865 | OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { | |||
866 | if (failed(verifyAtLeastNOperands(op, 1)) || | |||
867 | failed(verifyAtLeastNResults(op, 1))) | |||
868 | return failure(); | |||
869 | ||||
870 | auto elementType = getElementTypeOrSelf(op->getResult(0)); | |||
871 | ||||
872 | // Verify result element type matches first result's element type. | |||
873 | for (auto result : llvm::drop_begin(op->getResults(), 1)) { | |||
874 | if (getElementTypeOrSelf(result) != elementType) | |||
875 | return op->emitOpError( | |||
876 | "requires the same element type for all operands and results"); | |||
877 | } | |||
878 | ||||
879 | // Verify operand's element type matches first result's element type. | |||
880 | for (auto operand : op->getOperands()) { | |||
881 | if (getElementTypeOrSelf(operand) != elementType) | |||
882 | return op->emitOpError( | |||
883 | "requires the same element type for all operands and results"); | |||
884 | } | |||
885 | ||||
886 | return success(); | |||
887 | } | |||
888 | ||||
889 | LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { | |||
890 | if (failed(verifyAtLeastNOperands(op, 1)) || | |||
891 | failed(verifyAtLeastNResults(op, 1))) | |||
892 | return failure(); | |||
893 | ||||
894 | auto type = op->getResult(0).getType(); | |||
895 | auto elementType = getElementTypeOrSelf(type); | |||
896 | Attribute encoding = nullptr; | |||
897 | if (auto rankedType = dyn_cast<RankedTensorType>(type)) | |||
898 | encoding = rankedType.getEncoding(); | |||
899 | for (auto resultType : llvm::drop_begin(op->getResultTypes())) { | |||
900 | if (getElementTypeOrSelf(resultType) != elementType || | |||
901 | failed(verifyCompatibleShape(resultType, type))) | |||
902 | return op->emitOpError() | |||
903 | << "requires the same type for all operands and results"; | |||
904 | if (encoding) | |||
905 | if (auto rankedType = dyn_cast<RankedTensorType>(resultType); | |||
906 | encoding != rankedType.getEncoding()) | |||
907 | return op->emitOpError() | |||
908 | << "requires the same encoding for all operands and results"; | |||
909 | } | |||
910 | for (auto opType : op->getOperandTypes()) { | |||
911 | if (getElementTypeOrSelf(opType) != elementType || | |||
912 | failed(verifyCompatibleShape(opType, type))) | |||
913 | return op->emitOpError() | |||
914 | << "requires the same type for all operands and results"; | |||
915 | if (encoding) | |||
916 | if (auto rankedType = dyn_cast<RankedTensorType>(opType); | |||
917 | encoding != rankedType.getEncoding()) | |||
918 | return op->emitOpError() | |||
919 | << "requires the same encoding for all operands and results"; | |||
920 | } | |||
921 | return success(); | |||
922 | } | |||
923 | ||||
924 | LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { | |||
925 | Block *block = op->getBlock(); | |||
926 | // Verify that the operation is at the end of the respective parent block. | |||
927 | if (!block || &block->back() != op) | |||
928 | return op->emitOpError("must be the last operation in the parent block"); | |||
929 | return success(); | |||
930 | } | |||
931 | ||||
932 | static LogicalResult verifyTerminatorSuccessors(Operation *op) { | |||
933 | auto *parent = op->getParentRegion(); | |||
934 | ||||
935 | // Verify that the operands lines up with the BB arguments in the successor. | |||
936 | for (Block *succ : op->getSuccessors()) | |||
937 | if (succ->getParent() != parent) | |||
938 | return op->emitError("reference to block defined in another region"); | |||
939 | return success(); | |||
940 | } | |||
941 | ||||
942 | LogicalResult OpTrait::impl::verifyZeroSuccessors(Operation *op) { | |||
943 | if (op->getNumSuccessors() != 0) { | |||
944 | return op->emitOpError("requires 0 successors but found ") | |||
945 | << op->getNumSuccessors(); | |||
946 | } | |||
947 | return success(); | |||
948 | } | |||
949 | ||||
950 | LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) { | |||
951 | if (op->getNumSuccessors() != 1) { | |||
952 | return op->emitOpError("requires 1 successor but found ") | |||
953 | << op->getNumSuccessors(); | |||
954 | } | |||
955 | return verifyTerminatorSuccessors(op); | |||
956 | } | |||
957 | LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op, | |||
958 | unsigned numSuccessors) { | |||
959 | if (op->getNumSuccessors() != numSuccessors) { | |||
960 | return op->emitOpError("requires ") | |||
961 | << numSuccessors << " successors but found " | |||
962 | << op->getNumSuccessors(); | |||
963 | } | |||
964 | return verifyTerminatorSuccessors(op); | |||
965 | } | |||
966 | LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op, | |||
967 | unsigned numSuccessors) { | |||
968 | if (op->getNumSuccessors() < numSuccessors) { | |||
969 | return op->emitOpError("requires at least ") | |||
970 | << numSuccessors << " successors but found " | |||
971 | << op->getNumSuccessors(); | |||
972 | } | |||
973 | return verifyTerminatorSuccessors(op); | |||
974 | } | |||
975 | ||||
976 | LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) { | |||
977 | for (auto resultType : op->getResultTypes()) { | |||
978 | auto elementType = getTensorOrVectorElementType(resultType); | |||
979 | bool isBoolType = elementType.isInteger(1); | |||
980 | if (!isBoolType) | |||
981 | return op->emitOpError() << "requires a bool result type"; | |||
982 | } | |||
983 | ||||
984 | return success(); | |||
985 | } | |||
986 | ||||
987 | LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) { | |||
988 | for (auto resultType : op->getResultTypes()) | |||
989 | if (!getTensorOrVectorElementType(resultType).isa<FloatType>()) | |||
990 | return op->emitOpError() << "requires a floating point type"; | |||
991 | ||||
992 | return success(); | |||
993 | } | |||
994 | ||||
995 | LogicalResult | |||
996 | OpTrait::impl::verifyResultsAreSignlessIntegerLike(Operation *op) { | |||
997 | for (auto resultType : op->getResultTypes()) | |||
998 | if (!getTensorOrVectorElementType(resultType).isSignlessIntOrIndex()) | |||
999 | return op->emitOpError() << "requires an integer or index type"; | |||
1000 | return success(); | |||
1001 | } | |||
1002 | ||||
1003 | LogicalResult OpTrait::impl::verifyValueSizeAttr(Operation *op, | |||
1004 | StringRef attrName, | |||
1005 | StringRef valueGroupName, | |||
1006 | size_t expectedCount) { | |||
1007 | auto sizeAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrName); | |||
1008 | if (!sizeAttr) | |||
1009 | return op->emitOpError("requires dense i32 array attribute '") | |||
1010 | << attrName << "'"; | |||
1011 | ||||
1012 | ArrayRef<int32_t> sizes = sizeAttr.asArrayRef(); | |||
1013 | if (llvm::any_of(sizes, [](int32_t element) { return element < 0; })) | |||
1014 | return op->emitOpError("'") | |||
1015 | << attrName << "' attribute cannot have negative elements"; | |||
1016 | ||||
1017 | size_t totalCount = | |||
1018 | std::accumulate(sizes.begin(), sizes.end(), 0, | |||
1019 | [](unsigned all, int32_t one) { return all + one; }); | |||
1020 | ||||
1021 | if (totalCount != expectedCount) | |||
1022 | return op->emitOpError() | |||
1023 | << valueGroupName << " count (" << expectedCount | |||
1024 | << ") does not match with the total size (" << totalCount | |||
1025 | << ") specified in attribute '" << attrName << "'"; | |||
1026 | return success(); | |||
1027 | } | |||
1028 | ||||
1029 | LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op, | |||
1030 | StringRef attrName) { | |||
1031 | return verifyValueSizeAttr(op, attrName, "operand", op->getNumOperands()); | |||
1032 | } | |||
1033 | ||||
1034 | LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op, | |||
1035 | StringRef attrName) { | |||
1036 | return verifyValueSizeAttr(op, attrName, "result", op->getNumResults()); | |||
1037 | } | |||
1038 | ||||
1039 | LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) { | |||
1040 | for (Region ®ion : op->getRegions()) { | |||
1041 | if (region.empty()) | |||
1042 | continue; | |||
1043 | ||||
1044 | if (region.getNumArguments() != 0) { | |||
1045 | if (op->getNumRegions() > 1) | |||
1046 | return op->emitOpError("region #") | |||
1047 | << region.getRegionNumber() << " should have no arguments"; | |||
1048 | return op->emitOpError("region should have no arguments"); | |||
1049 | } | |||
1050 | } | |||
1051 | return success(); | |||
1052 | } | |||
1053 | ||||
1054 | LogicalResult OpTrait::impl::verifyElementwise(Operation *op) { | |||
1055 | auto isMappableType = [](Type type) { | |||
1056 | return type.isa<VectorType, TensorType>(); | |||
1057 | }; | |||
1058 | auto resultMappableTypes = llvm::to_vector<1>( | |||
1059 | llvm::make_filter_range(op->getResultTypes(), isMappableType)); | |||
1060 | auto operandMappableTypes = llvm::to_vector<2>( | |||
1061 | llvm::make_filter_range(op->getOperandTypes(), isMappableType)); | |||
1062 | ||||
1063 | // If the op only has scalar operand/result types, then we have nothing to | |||
1064 | // check. | |||
1065 | if (resultMappableTypes.empty() && operandMappableTypes.empty()) | |||
1066 | return success(); | |||
1067 | ||||
1068 | if (!resultMappableTypes.empty() && operandMappableTypes.empty()) | |||
1069 | return op->emitOpError("if a result is non-scalar, then at least one " | |||
1070 | "operand must be non-scalar"); | |||
1071 | ||||
1072 | assert(!operandMappableTypes.empty())(static_cast <bool> (!operandMappableTypes.empty()) ? void (0) : __assert_fail ("!operandMappableTypes.empty()", "mlir/lib/IR/Operation.cpp" , 1072, __extension__ __PRETTY_FUNCTION__)); | |||
1073 | ||||
1074 | if (resultMappableTypes.empty()) | |||
1075 | return op->emitOpError("if an operand is non-scalar, then there must be at " | |||
1076 | "least one non-scalar result"); | |||
1077 | ||||
1078 | if (resultMappableTypes.size() != op->getNumResults()) | |||
1079 | return op->emitOpError( | |||
1080 | "if an operand is non-scalar, then all results must be non-scalar"); | |||
1081 | ||||
1082 | SmallVector<Type, 4> types = llvm::to_vector<2>( | |||
1083 | llvm::concat<Type>(operandMappableTypes, resultMappableTypes)); | |||
1084 | TypeID expectedBaseTy = types.front().getTypeID(); | |||
1085 | if (!llvm::all_of(types, | |||
1086 | [&](Type t) { return t.getTypeID() == expectedBaseTy; }) || | |||
1087 | failed(verifyCompatibleShapes(types))) { | |||
1088 | return op->emitOpError() << "all non-scalar operands/results must have the " | |||
1089 | "same shape and base type"; | |||
1090 | } | |||
1091 | ||||
1092 | return success(); | |||
1093 | } | |||
1094 | ||||
1095 | /// Check for any values used by operations regions attached to the | |||
1096 | /// specified "IsIsolatedFromAbove" operation defined outside of it. | |||
1097 | LogicalResult OpTrait::impl::verifyIsIsolatedFromAbove(Operation *isolatedOp) { | |||
1098 | assert(isolatedOp->hasTrait<OpTrait::IsIsolatedFromAbove>() &&(static_cast <bool> (isolatedOp->hasTrait<OpTrait ::IsIsolatedFromAbove>() && "Intended to check IsolatedFromAbove ops" ) ? void (0) : __assert_fail ("isolatedOp->hasTrait<OpTrait::IsIsolatedFromAbove>() && \"Intended to check IsolatedFromAbove ops\"" , "mlir/lib/IR/Operation.cpp", 1099, __extension__ __PRETTY_FUNCTION__ )) | |||
1099 | "Intended to check IsolatedFromAbove ops")(static_cast <bool> (isolatedOp->hasTrait<OpTrait ::IsIsolatedFromAbove>() && "Intended to check IsolatedFromAbove ops" ) ? void (0) : __assert_fail ("isolatedOp->hasTrait<OpTrait::IsIsolatedFromAbove>() && \"Intended to check IsolatedFromAbove ops\"" , "mlir/lib/IR/Operation.cpp", 1099, __extension__ __PRETTY_FUNCTION__ )); | |||
1100 | ||||
1101 | // List of regions to analyze. Each region is processed independently, with | |||
1102 | // respect to the common `limit` region, so we can look at them in any order. | |||
1103 | // Therefore, use a simple vector and push/pop back the current region. | |||
1104 | SmallVector<Region *, 8> pendingRegions; | |||
1105 | for (auto ®ion : isolatedOp->getRegions()) { | |||
1106 | pendingRegions.push_back(®ion); | |||
1107 | ||||
1108 | // Traverse all operations in the region. | |||
1109 | while (!pendingRegions.empty()) { | |||
1110 | for (Operation &op : pendingRegions.pop_back_val()->getOps()) { | |||
1111 | for (Value operand : op.getOperands()) { | |||
1112 | // Check that any value that is used by an operation is defined in the | |||
1113 | // same region as either an operation result. | |||
1114 | auto *operandRegion = operand.getParentRegion(); | |||
1115 | if (!operandRegion) | |||
1116 | return op.emitError("operation's operand is unlinked"); | |||
1117 | if (!region.isAncestor(operandRegion)) { | |||
1118 | return op.emitOpError("using value defined outside the region") | |||
1119 | .attachNote(isolatedOp->getLoc()) | |||
1120 | << "required by region isolation constraints"; | |||
1121 | } | |||
1122 | } | |||
1123 | ||||
1124 | // Schedule any regions in the operation for further checking. Don't | |||
1125 | // recurse into other IsolatedFromAbove ops, because they will check | |||
1126 | // themselves. | |||
1127 | if (op.getNumRegions() && | |||
1128 | !op.hasTrait<OpTrait::IsIsolatedFromAbove>()) { | |||
1129 | for (Region &subRegion : op.getRegions()) | |||
1130 | pendingRegions.push_back(&subRegion); | |||
1131 | } | |||
1132 | } | |||
1133 | } | |||
1134 | } | |||
1135 | ||||
1136 | return success(); | |||
1137 | } | |||
1138 | ||||
1139 | bool OpTrait::hasElementwiseMappableTraits(Operation *op) { | |||
1140 | return op->hasTrait<Elementwise>() && op->hasTrait<Scalarizable>() && | |||
1141 | op->hasTrait<Vectorizable>() && op->hasTrait<Tensorizable>(); | |||
1142 | } | |||
1143 | ||||
1144 | //===----------------------------------------------------------------------===// | |||
1145 | // CastOpInterface | |||
1146 | //===----------------------------------------------------------------------===// | |||
1147 | ||||
1148 | /// Attempt to fold the given cast operation. | |||
1149 | LogicalResult | |||
1150 | impl::foldCastInterfaceOp(Operation *op, ArrayRef<Attribute> attrOperands, | |||
1151 | SmallVectorImpl<OpFoldResult> &foldResults) { | |||
1152 | OperandRange operands = op->getOperands(); | |||
1153 | if (operands.empty()) | |||
1154 | return failure(); | |||
1155 | ResultRange results = op->getResults(); | |||
1156 | ||||
1157 | // Check for the case where the input and output types match 1-1. | |||
1158 | if (operands.getTypes() == results.getTypes()) { | |||
1159 | foldResults.append(operands.begin(), operands.end()); | |||
1160 | return success(); | |||
1161 | } | |||
1162 | ||||
1163 | return failure(); | |||
1164 | } | |||
1165 | ||||
1166 | /// Attempt to verify the given cast operation. | |||
1167 | LogicalResult impl::verifyCastInterfaceOp( | |||
1168 | Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible) { | |||
1169 | auto resultTypes = op->getResultTypes(); | |||
1170 | if (resultTypes.empty()) | |||
1171 | return op->emitOpError() | |||
1172 | << "expected at least one result for cast operation"; | |||
1173 | ||||
1174 | auto operandTypes = op->getOperandTypes(); | |||
1175 | if (!areCastCompatible(operandTypes, resultTypes)) { | |||
1176 | InFlightDiagnostic diag = op->emitOpError("operand type"); | |||
1177 | if (operandTypes.empty()) | |||
1178 | diag << "s []"; | |||
1179 | else if (llvm::size(operandTypes) == 1) | |||
1180 | diag << " " << *operandTypes.begin(); | |||
1181 | else | |||
1182 | diag << "s " << operandTypes; | |||
1183 | return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ") | |||
1184 | << resultTypes << " are cast incompatible"; | |||
1185 | } | |||
1186 | ||||
1187 | return success(); | |||
1188 | } | |||
1189 | ||||
1190 | //===----------------------------------------------------------------------===// | |||
1191 | // Misc. utils | |||
1192 | //===----------------------------------------------------------------------===// | |||
1193 | ||||
1194 | /// Insert an operation, generated by `buildTerminatorOp`, at the end of the | |||
1195 | /// region's only block if it does not have a terminator already. If the region | |||
1196 | /// is empty, insert a new block first. `buildTerminatorOp` should return the | |||
1197 | /// terminator operation to insert. | |||
1198 | void impl::ensureRegionTerminator( | |||
1199 | Region ®ion, OpBuilder &builder, Location loc, | |||
1200 | function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { | |||
1201 | OpBuilder::InsertionGuard guard(builder); | |||
1202 | if (region.empty()) | |||
1203 | builder.createBlock(®ion); | |||
1204 | ||||
1205 | Block &block = region.back(); | |||
1206 | if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>()) | |||
1207 | return; | |||
1208 | ||||
1209 | builder.setInsertionPointToEnd(&block); | |||
1210 | builder.insert(buildTerminatorOp(builder, loc)); | |||
1211 | } | |||
1212 | ||||
1213 | /// Create a simple OpBuilder and forward to the OpBuilder version of this | |||
1214 | /// function. | |||
1215 | void impl::ensureRegionTerminator( | |||
1216 | Region ®ion, Builder &builder, Location loc, | |||
1217 | function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { | |||
1218 | OpBuilder opBuilder(builder.getContext()); | |||
1219 | ensureRegionTerminator(region, opBuilder, loc, buildTerminatorOp); | |||
1220 | } |