File: | build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/mlir/lib/IR/Operation.cpp |
Warning: | line 91, column 3 Potential leak of memory pointed to by 'mallocMem' |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===- Operation.cpp - Operation support code -----------------------------===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | ||||
9 | #include "mlir/IR/Operation.h" | |||
10 | #include "mlir/IR/BlockAndValueMapping.h" | |||
11 | #include "mlir/IR/BuiltinTypes.h" | |||
12 | #include "mlir/IR/Dialect.h" | |||
13 | #include "mlir/IR/OpImplementation.h" | |||
14 | #include "mlir/IR/PatternMatch.h" | |||
15 | #include "mlir/IR/TypeUtilities.h" | |||
16 | #include "mlir/Interfaces/FoldInterfaces.h" | |||
17 | #include "llvm/ADT/StringExtras.h" | |||
18 | #include <numeric> | |||
19 | ||||
20 | using namespace mlir; | |||
21 | ||||
22 | //===----------------------------------------------------------------------===// | |||
23 | // Operation | |||
24 | //===----------------------------------------------------------------------===// | |||
25 | ||||
26 | /// Create a new Operation with the specific fields. | |||
27 | Operation *Operation::create(Location location, OperationName name, | |||
28 | TypeRange resultTypes, ValueRange operands, | |||
29 | ArrayRef<NamedAttribute> attributes, | |||
30 | BlockRange successors, unsigned numRegions) { | |||
31 | return create(location, name, resultTypes, operands, | |||
32 | DictionaryAttr::get(location.getContext(), attributes), | |||
33 | successors, numRegions); | |||
34 | } | |||
35 | ||||
36 | /// Create a new Operation from operation state. | |||
37 | Operation *Operation::create(const OperationState &state) { | |||
38 | return create(state.location, state.name, state.types, state.operands, | |||
39 | state.attributes.getDictionary(state.getContext()), | |||
40 | state.successors, state.regions); | |||
41 | } | |||
42 | ||||
43 | /// Create a new Operation with the specific fields. | |||
44 | Operation *Operation::create(Location location, OperationName name, | |||
45 | TypeRange resultTypes, ValueRange operands, | |||
46 | DictionaryAttr attributes, BlockRange successors, | |||
47 | RegionRange regions) { | |||
48 | unsigned numRegions = regions.size(); | |||
49 | Operation *op = create(location, name, resultTypes, operands, attributes, | |||
50 | successors, numRegions); | |||
51 | for (unsigned i = 0; i < numRegions; ++i) | |||
52 | if (regions[i]) | |||
53 | op->getRegion(i).takeBody(*regions[i]); | |||
54 | return op; | |||
55 | } | |||
56 | ||||
57 | /// Overload of create that takes an existing DictionaryAttr to avoid | |||
58 | /// unnecessarily uniquing a list of attributes. | |||
59 | Operation *Operation::create(Location location, OperationName name, | |||
60 | TypeRange resultTypes, ValueRange operands, | |||
61 | DictionaryAttr attributes, BlockRange successors, | |||
62 | unsigned numRegions) { | |||
63 | assert(llvm::all_of(resultTypes, [](Type t) { return t; }) &&(static_cast <bool> (llvm::all_of(resultTypes, [](Type t ) { return t; }) && "unexpected null result type") ? void (0) : __assert_fail ("llvm::all_of(resultTypes, [](Type t) { return t; }) && \"unexpected null result type\"" , "mlir/lib/IR/Operation.cpp", 64, __extension__ __PRETTY_FUNCTION__ )) | |||
64 | "unexpected null result type")(static_cast <bool> (llvm::all_of(resultTypes, [](Type t ) { return t; }) && "unexpected null result type") ? void (0) : __assert_fail ("llvm::all_of(resultTypes, [](Type t) { return t; }) && \"unexpected null result type\"" , "mlir/lib/IR/Operation.cpp", 64, __extension__ __PRETTY_FUNCTION__ )); | |||
65 | ||||
66 | // We only need to allocate additional memory for a subset of results. | |||
67 | unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size()); | |||
68 | unsigned numInlineResults = OpResult::getNumInline(resultTypes.size()); | |||
69 | unsigned numSuccessors = successors.size(); | |||
70 | unsigned numOperands = operands.size(); | |||
71 | unsigned numResults = resultTypes.size(); | |||
72 | ||||
73 | // If the operation is known to have no operands, don't allocate an operand | |||
74 | // storage. | |||
75 | bool needsOperandStorage = | |||
76 | operands.empty() ? !name.hasTrait<OpTrait::ZeroOperands>() : true; | |||
77 | ||||
78 | // Compute the byte size for the operation and the operand storage. This takes | |||
79 | // into account the size of the operation, its trailing objects, and its | |||
80 | // prefixed objects. | |||
81 | size_t byteSize = | |||
82 | totalSizeToAlloc<detail::OperandStorage, BlockOperand, Region, OpOperand>( | |||
83 | needsOperandStorage
| |||
84 | size_t prefixByteSize = llvm::alignTo( | |||
85 | Operation::prefixAllocSize(numTrailingResults, numInlineResults), | |||
86 | alignof(Operation)); | |||
87 | char *mallocMem = reinterpret_cast<char *>(malloc(byteSize + prefixByteSize)); | |||
88 | void *rawMem = mallocMem + prefixByteSize; | |||
89 | ||||
90 | // Create the new Operation. | |||
91 | Operation *op = | |||
| ||||
92 | ::new (rawMem) Operation(location, name, numResults, numSuccessors, | |||
93 | numRegions, attributes, needsOperandStorage); | |||
94 | ||||
95 | assert((numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) &&(static_cast <bool> ((numSuccessors == 0 || op->mightHaveTrait <OpTrait::IsTerminator>()) && "unexpected successors in a non-terminator operation" ) ? void (0) : __assert_fail ("(numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) && \"unexpected successors in a non-terminator operation\"" , "mlir/lib/IR/Operation.cpp", 96, __extension__ __PRETTY_FUNCTION__ )) | |||
96 | "unexpected successors in a non-terminator operation")(static_cast <bool> ((numSuccessors == 0 || op->mightHaveTrait <OpTrait::IsTerminator>()) && "unexpected successors in a non-terminator operation" ) ? void (0) : __assert_fail ("(numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) && \"unexpected successors in a non-terminator operation\"" , "mlir/lib/IR/Operation.cpp", 96, __extension__ __PRETTY_FUNCTION__ )); | |||
97 | ||||
98 | // Initialize the results. | |||
99 | auto resultTypeIt = resultTypes.begin(); | |||
100 | for (unsigned i = 0; i < numInlineResults; ++i, ++resultTypeIt) | |||
101 | new (op->getInlineOpResult(i)) detail::InlineOpResult(*resultTypeIt, i); | |||
102 | for (unsigned i = 0; i < numTrailingResults; ++i, ++resultTypeIt) { | |||
103 | new (op->getOutOfLineOpResult(i)) | |||
104 | detail::OutOfLineOpResult(*resultTypeIt, i); | |||
105 | } | |||
106 | ||||
107 | // Initialize the regions. | |||
108 | for (unsigned i = 0; i != numRegions; ++i) | |||
109 | new (&op->getRegion(i)) Region(op); | |||
110 | ||||
111 | // Initialize the operands. | |||
112 | if (needsOperandStorage) { | |||
113 | new (&op->getOperandStorage()) detail::OperandStorage( | |||
114 | op, op->getTrailingObjects<OpOperand>(), operands); | |||
115 | } | |||
116 | ||||
117 | // Initialize the successors. | |||
118 | auto blockOperands = op->getBlockOperands(); | |||
119 | for (unsigned i = 0; i != numSuccessors; ++i) | |||
120 | new (&blockOperands[i]) BlockOperand(op, successors[i]); | |||
121 | ||||
122 | return op; | |||
123 | } | |||
124 | ||||
125 | Operation::Operation(Location location, OperationName name, unsigned numResults, | |||
126 | unsigned numSuccessors, unsigned numRegions, | |||
127 | DictionaryAttr attributes, bool hasOperandStorage) | |||
128 | : location(location), numResults(numResults), numSuccs(numSuccessors), | |||
129 | numRegions(numRegions), hasOperandStorage(hasOperandStorage), name(name), | |||
130 | attrs(attributes) { | |||
131 | assert(attributes && "unexpected null attribute dictionary")(static_cast <bool> (attributes && "unexpected null attribute dictionary" ) ? void (0) : __assert_fail ("attributes && \"unexpected null attribute dictionary\"" , "mlir/lib/IR/Operation.cpp", 131, __extension__ __PRETTY_FUNCTION__ )); | |||
132 | #ifndef NDEBUG | |||
133 | if (!getDialect() && !getContext()->allowsUnregisteredDialects()) | |||
134 | llvm::report_fatal_error( | |||
135 | name.getStringRef() + | |||
136 | " created with unregistered dialect. If this is intended, please call " | |||
137 | "allowUnregisteredDialects() on the MLIRContext, or use " | |||
138 | "-allow-unregistered-dialect with the MLIR tool used."); | |||
139 | #endif | |||
140 | } | |||
141 | ||||
142 | // Operations are deleted through the destroy() member because they are | |||
143 | // allocated via malloc. | |||
144 | Operation::~Operation() { | |||
145 | assert(block == nullptr && "operation destroyed but still in a block")(static_cast <bool> (block == nullptr && "operation destroyed but still in a block" ) ? void (0) : __assert_fail ("block == nullptr && \"operation destroyed but still in a block\"" , "mlir/lib/IR/Operation.cpp", 145, __extension__ __PRETTY_FUNCTION__ )); | |||
146 | #ifndef NDEBUG | |||
147 | if (!use_empty()) { | |||
148 | { | |||
149 | InFlightDiagnostic diag = | |||
150 | emitOpError("operation destroyed but still has uses"); | |||
151 | for (Operation *user : getUsers()) | |||
152 | diag.attachNote(user->getLoc()) << "- use: " << *user << "\n"; | |||
153 | } | |||
154 | llvm::report_fatal_error("operation destroyed but still has uses"); | |||
155 | } | |||
156 | #endif | |||
157 | // Explicitly run the destructors for the operands. | |||
158 | if (hasOperandStorage) | |||
159 | getOperandStorage().~OperandStorage(); | |||
160 | ||||
161 | // Explicitly run the destructors for the successors. | |||
162 | for (auto &successor : getBlockOperands()) | |||
163 | successor.~BlockOperand(); | |||
164 | ||||
165 | // Explicitly destroy the regions. | |||
166 | for (auto ®ion : getRegions()) | |||
167 | region.~Region(); | |||
168 | } | |||
169 | ||||
170 | /// Destroy this operation or one of its subclasses. | |||
171 | void Operation::destroy() { | |||
172 | // Operations may have additional prefixed allocation, which needs to be | |||
173 | // accounted for here when computing the address to free. | |||
174 | char *rawMem = reinterpret_cast<char *>(this) - | |||
175 | llvm::alignTo(prefixAllocSize(), alignof(Operation)); | |||
176 | this->~Operation(); | |||
177 | free(rawMem); | |||
178 | } | |||
179 | ||||
180 | /// Return true if this operation is a proper ancestor of the `other` | |||
181 | /// operation. | |||
182 | bool Operation::isProperAncestor(Operation *other) { | |||
183 | while ((other = other->getParentOp())) | |||
184 | if (this == other) | |||
185 | return true; | |||
186 | return false; | |||
187 | } | |||
188 | ||||
189 | /// Replace any uses of 'from' with 'to' within this operation. | |||
190 | void Operation::replaceUsesOfWith(Value from, Value to) { | |||
191 | if (from == to) | |||
192 | return; | |||
193 | for (auto &operand : getOpOperands()) | |||
194 | if (operand.get() == from) | |||
195 | operand.set(to); | |||
196 | } | |||
197 | ||||
198 | /// Replace the current operands of this operation with the ones provided in | |||
199 | /// 'operands'. | |||
200 | void Operation::setOperands(ValueRange operands) { | |||
201 | if (LLVM_LIKELY(hasOperandStorage)__builtin_expect((bool)(hasOperandStorage), true)) | |||
202 | return getOperandStorage().setOperands(this, operands); | |||
203 | assert(operands.empty() && "setting operands without an operand storage")(static_cast <bool> (operands.empty() && "setting operands without an operand storage" ) ? void (0) : __assert_fail ("operands.empty() && \"setting operands without an operand storage\"" , "mlir/lib/IR/Operation.cpp", 203, __extension__ __PRETTY_FUNCTION__ )); | |||
204 | } | |||
205 | ||||
206 | /// Replace the operands beginning at 'start' and ending at 'start' + 'length' | |||
207 | /// with the ones provided in 'operands'. 'operands' may be smaller or larger | |||
208 | /// than the range pointed to by 'start'+'length'. | |||
209 | void Operation::setOperands(unsigned start, unsigned length, | |||
210 | ValueRange operands) { | |||
211 | assert((start + length) <= getNumOperands() &&(static_cast <bool> ((start + length) <= getNumOperands () && "invalid operand range specified") ? void (0) : __assert_fail ("(start + length) <= getNumOperands() && \"invalid operand range specified\"" , "mlir/lib/IR/Operation.cpp", 212, __extension__ __PRETTY_FUNCTION__ )) | |||
212 | "invalid operand range specified")(static_cast <bool> ((start + length) <= getNumOperands () && "invalid operand range specified") ? void (0) : __assert_fail ("(start + length) <= getNumOperands() && \"invalid operand range specified\"" , "mlir/lib/IR/Operation.cpp", 212, __extension__ __PRETTY_FUNCTION__ )); | |||
213 | if (LLVM_LIKELY(hasOperandStorage)__builtin_expect((bool)(hasOperandStorage), true)) | |||
214 | return getOperandStorage().setOperands(this, start, length, operands); | |||
215 | assert(operands.empty() && "setting operands without an operand storage")(static_cast <bool> (operands.empty() && "setting operands without an operand storage" ) ? void (0) : __assert_fail ("operands.empty() && \"setting operands without an operand storage\"" , "mlir/lib/IR/Operation.cpp", 215, __extension__ __PRETTY_FUNCTION__ )); | |||
216 | } | |||
217 | ||||
218 | /// Insert the given operands into the operand list at the given 'index'. | |||
219 | void Operation::insertOperands(unsigned index, ValueRange operands) { | |||
220 | if (LLVM_LIKELY(hasOperandStorage)__builtin_expect((bool)(hasOperandStorage), true)) | |||
221 | return setOperands(index, /*length=*/0, operands); | |||
222 | assert(operands.empty() && "inserting operands without an operand storage")(static_cast <bool> (operands.empty() && "inserting operands without an operand storage" ) ? void (0) : __assert_fail ("operands.empty() && \"inserting operands without an operand storage\"" , "mlir/lib/IR/Operation.cpp", 222, __extension__ __PRETTY_FUNCTION__ )); | |||
223 | } | |||
224 | ||||
225 | //===----------------------------------------------------------------------===// | |||
226 | // Diagnostics | |||
227 | //===----------------------------------------------------------------------===// | |||
228 | ||||
229 | /// Emit an error about fatal conditions with this operation, reporting up to | |||
230 | /// any diagnostic handlers that may be listening. | |||
231 | InFlightDiagnostic Operation::emitError(const Twine &message) { | |||
232 | InFlightDiagnostic diag = mlir::emitError(getLoc(), message); | |||
233 | if (getContext()->shouldPrintOpOnDiagnostic()) { | |||
234 | diag.attachNote(getLoc()) | |||
235 | .append("see current operation: ") | |||
236 | .appendOp(*this, OpPrintingFlags().printGenericOpForm()); | |||
237 | } | |||
238 | return diag; | |||
239 | } | |||
240 | ||||
241 | /// Emit a warning about this operation, reporting up to any diagnostic | |||
242 | /// handlers that may be listening. | |||
243 | InFlightDiagnostic Operation::emitWarning(const Twine &message) { | |||
244 | InFlightDiagnostic diag = mlir::emitWarning(getLoc(), message); | |||
245 | if (getContext()->shouldPrintOpOnDiagnostic()) | |||
246 | diag.attachNote(getLoc()) << "see current operation: " << *this; | |||
247 | return diag; | |||
248 | } | |||
249 | ||||
250 | /// Emit a remark about this operation, reporting up to any diagnostic | |||
251 | /// handlers that may be listening. | |||
252 | InFlightDiagnostic Operation::emitRemark(const Twine &message) { | |||
253 | InFlightDiagnostic diag = mlir::emitRemark(getLoc(), message); | |||
254 | if (getContext()->shouldPrintOpOnDiagnostic()) | |||
255 | diag.attachNote(getLoc()) << "see current operation: " << *this; | |||
256 | return diag; | |||
257 | } | |||
258 | ||||
259 | //===----------------------------------------------------------------------===// | |||
260 | // Operation Ordering | |||
261 | //===----------------------------------------------------------------------===// | |||
262 | ||||
263 | constexpr unsigned Operation::kInvalidOrderIdx; | |||
264 | constexpr unsigned Operation::kOrderStride; | |||
265 | ||||
266 | /// Given an operation 'other' that is within the same parent block, return | |||
267 | /// whether the current operation is before 'other' in the operation list | |||
268 | /// of the parent block. | |||
269 | /// Note: This function has an average complexity of O(1), but worst case may | |||
270 | /// take O(N) where N is the number of operations within the parent block. | |||
271 | bool Operation::isBeforeInBlock(Operation *other) { | |||
272 | assert(block && "Operations without parent blocks have no order.")(static_cast <bool> (block && "Operations without parent blocks have no order." ) ? void (0) : __assert_fail ("block && \"Operations without parent blocks have no order.\"" , "mlir/lib/IR/Operation.cpp", 272, __extension__ __PRETTY_FUNCTION__ )); | |||
273 | assert(other && other->block == block &&(static_cast <bool> (other && other->block == block && "Expected other operation to have the same parent block." ) ? void (0) : __assert_fail ("other && other->block == block && \"Expected other operation to have the same parent block.\"" , "mlir/lib/IR/Operation.cpp", 274, __extension__ __PRETTY_FUNCTION__ )) | |||
274 | "Expected other operation to have the same parent block.")(static_cast <bool> (other && other->block == block && "Expected other operation to have the same parent block." ) ? void (0) : __assert_fail ("other && other->block == block && \"Expected other operation to have the same parent block.\"" , "mlir/lib/IR/Operation.cpp", 274, __extension__ __PRETTY_FUNCTION__ )); | |||
275 | // If the order of the block is already invalid, directly recompute the | |||
276 | // parent. | |||
277 | if (!block->isOpOrderValid()) { | |||
278 | block->recomputeOpOrder(); | |||
279 | } else { | |||
280 | // Update the order either operation if necessary. | |||
281 | updateOrderIfNecessary(); | |||
282 | other->updateOrderIfNecessary(); | |||
283 | } | |||
284 | ||||
285 | return orderIndex < other->orderIndex; | |||
286 | } | |||
287 | ||||
288 | /// Update the order index of this operation of this operation if necessary, | |||
289 | /// potentially recomputing the order of the parent block. | |||
290 | void Operation::updateOrderIfNecessary() { | |||
291 | assert(block && "expected valid parent")(static_cast <bool> (block && "expected valid parent" ) ? void (0) : __assert_fail ("block && \"expected valid parent\"" , "mlir/lib/IR/Operation.cpp", 291, __extension__ __PRETTY_FUNCTION__ )); | |||
292 | ||||
293 | // If the order is valid for this operation there is nothing to do. | |||
294 | if (hasValidOrder()) | |||
295 | return; | |||
296 | Operation *blockFront = &block->front(); | |||
297 | Operation *blockBack = &block->back(); | |||
298 | ||||
299 | // This method is expected to only be invoked on blocks with more than one | |||
300 | // operation. | |||
301 | assert(blockFront != blockBack && "expected more than one operation")(static_cast <bool> (blockFront != blockBack && "expected more than one operation") ? void (0) : __assert_fail ("blockFront != blockBack && \"expected more than one operation\"" , "mlir/lib/IR/Operation.cpp", 301, __extension__ __PRETTY_FUNCTION__ )); | |||
302 | ||||
303 | // If the operation is at the end of the block. | |||
304 | if (this == blockBack) { | |||
305 | Operation *prevNode = getPrevNode(); | |||
306 | if (!prevNode->hasValidOrder()) | |||
307 | return block->recomputeOpOrder(); | |||
308 | ||||
309 | // Add the stride to the previous operation. | |||
310 | orderIndex = prevNode->orderIndex + kOrderStride; | |||
311 | return; | |||
312 | } | |||
313 | ||||
314 | // If this is the first operation try to use the next operation to compute the | |||
315 | // ordering. | |||
316 | if (this == blockFront) { | |||
317 | Operation *nextNode = getNextNode(); | |||
318 | if (!nextNode->hasValidOrder()) | |||
319 | return block->recomputeOpOrder(); | |||
320 | // There is no order to give this operation. | |||
321 | if (nextNode->orderIndex == 0) | |||
322 | return block->recomputeOpOrder(); | |||
323 | ||||
324 | // If we can't use the stride, just take the middle value left. This is safe | |||
325 | // because we know there is at least one valid index to assign to. | |||
326 | if (nextNode->orderIndex <= kOrderStride) | |||
327 | orderIndex = (nextNode->orderIndex / 2); | |||
328 | else | |||
329 | orderIndex = kOrderStride; | |||
330 | return; | |||
331 | } | |||
332 | ||||
333 | // Otherwise, this operation is between two others. Place this operation in | |||
334 | // the middle of the previous and next if possible. | |||
335 | Operation *prevNode = getPrevNode(), *nextNode = getNextNode(); | |||
336 | if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder()) | |||
337 | return block->recomputeOpOrder(); | |||
338 | unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex; | |||
339 | ||||
340 | // Check to see if there is a valid order between the two. | |||
341 | if (prevOrder + 1 == nextOrder) | |||
342 | return block->recomputeOpOrder(); | |||
343 | orderIndex = prevOrder + ((nextOrder - prevOrder) / 2); | |||
344 | } | |||
345 | ||||
346 | //===----------------------------------------------------------------------===// | |||
347 | // ilist_traits for Operation | |||
348 | //===----------------------------------------------------------------------===// | |||
349 | ||||
350 | auto llvm::ilist_detail::SpecificNodeAccess< | |||
351 | typename llvm::ilist_detail::compute_node_options< | |||
352 | ::mlir::Operation>::type>::getNodePtr(pointer n) -> node_type * { | |||
353 | return NodeAccess::getNodePtr<OptionsT>(n); | |||
354 | } | |||
355 | ||||
356 | auto llvm::ilist_detail::SpecificNodeAccess< | |||
357 | typename llvm::ilist_detail::compute_node_options< | |||
358 | ::mlir::Operation>::type>::getNodePtr(const_pointer n) | |||
359 | -> const node_type * { | |||
360 | return NodeAccess::getNodePtr<OptionsT>(n); | |||
361 | } | |||
362 | ||||
363 | auto llvm::ilist_detail::SpecificNodeAccess< | |||
364 | typename llvm::ilist_detail::compute_node_options< | |||
365 | ::mlir::Operation>::type>::getValuePtr(node_type *n) -> pointer { | |||
366 | return NodeAccess::getValuePtr<OptionsT>(n); | |||
367 | } | |||
368 | ||||
369 | auto llvm::ilist_detail::SpecificNodeAccess< | |||
370 | typename llvm::ilist_detail::compute_node_options< | |||
371 | ::mlir::Operation>::type>::getValuePtr(const node_type *n) | |||
372 | -> const_pointer { | |||
373 | return NodeAccess::getValuePtr<OptionsT>(n); | |||
374 | } | |||
375 | ||||
376 | void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) { | |||
377 | op->destroy(); | |||
378 | } | |||
379 | ||||
380 | Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() { | |||
381 | size_t offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); | |||
382 | iplist<Operation> *anchor(static_cast<iplist<Operation> *>(this)); | |||
383 | return reinterpret_cast<Block *>(reinterpret_cast<char *>(anchor) - offset); | |||
384 | } | |||
385 | ||||
386 | /// This is a trait method invoked when an operation is added to a block. We | |||
387 | /// keep the block pointer up to date. | |||
388 | void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) { | |||
389 | assert(!op->getBlock() && "already in an operation block!")(static_cast <bool> (!op->getBlock() && "already in an operation block!" ) ? void (0) : __assert_fail ("!op->getBlock() && \"already in an operation block!\"" , "mlir/lib/IR/Operation.cpp", 389, __extension__ __PRETTY_FUNCTION__ )); | |||
390 | op->block = getContainingBlock(); | |||
391 | ||||
392 | // Invalidate the order on the operation. | |||
393 | op->orderIndex = Operation::kInvalidOrderIdx; | |||
394 | } | |||
395 | ||||
396 | /// This is a trait method invoked when an operation is removed from a block. | |||
397 | /// We keep the block pointer up to date. | |||
398 | void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) { | |||
399 | assert(op->block && "not already in an operation block!")(static_cast <bool> (op->block && "not already in an operation block!" ) ? void (0) : __assert_fail ("op->block && \"not already in an operation block!\"" , "mlir/lib/IR/Operation.cpp", 399, __extension__ __PRETTY_FUNCTION__ )); | |||
400 | op->block = nullptr; | |||
401 | } | |||
402 | ||||
403 | /// This is a trait method invoked when an operation is moved from one block | |||
404 | /// to another. We keep the block pointer up to date. | |||
405 | void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList( | |||
406 | ilist_traits<Operation> &otherList, op_iterator first, op_iterator last) { | |||
407 | Block *curParent = getContainingBlock(); | |||
408 | ||||
409 | // Invalidate the ordering of the parent block. | |||
410 | curParent->invalidateOpOrder(); | |||
411 | ||||
412 | // If we are transferring operations within the same block, the block | |||
413 | // pointer doesn't need to be updated. | |||
414 | if (curParent == otherList.getContainingBlock()) | |||
415 | return; | |||
416 | ||||
417 | // Update the 'block' member of each operation. | |||
418 | for (; first != last; ++first) | |||
419 | first->block = curParent; | |||
420 | } | |||
421 | ||||
422 | /// Remove this operation (and its descendants) from its Block and delete | |||
423 | /// all of them. | |||
424 | void Operation::erase() { | |||
425 | if (auto *parent = getBlock()) | |||
426 | parent->getOperations().erase(this); | |||
427 | else | |||
428 | destroy(); | |||
429 | } | |||
430 | ||||
431 | /// Remove the operation from its parent block, but don't delete it. | |||
432 | void Operation::remove() { | |||
433 | if (Block *parent = getBlock()) | |||
434 | parent->getOperations().remove(this); | |||
435 | } | |||
436 | ||||
437 | /// Unlink this operation from its current block and insert it right before | |||
438 | /// `existingOp` which may be in the same or another block in the same | |||
439 | /// function. | |||
440 | void Operation::moveBefore(Operation *existingOp) { | |||
441 | moveBefore(existingOp->getBlock(), existingOp->getIterator()); | |||
442 | } | |||
443 | ||||
444 | /// Unlink this operation from its current basic block and insert it right | |||
445 | /// before `iterator` in the specified basic block. | |||
446 | void Operation::moveBefore(Block *block, | |||
447 | llvm::iplist<Operation>::iterator iterator) { | |||
448 | block->getOperations().splice(iterator, getBlock()->getOperations(), | |||
449 | getIterator()); | |||
450 | } | |||
451 | ||||
452 | /// Unlink this operation from its current block and insert it right after | |||
453 | /// `existingOp` which may be in the same or another block in the same function. | |||
454 | void Operation::moveAfter(Operation *existingOp) { | |||
455 | moveAfter(existingOp->getBlock(), existingOp->getIterator()); | |||
456 | } | |||
457 | ||||
458 | /// Unlink this operation from its current block and insert it right after | |||
459 | /// `iterator` in the specified block. | |||
460 | void Operation::moveAfter(Block *block, | |||
461 | llvm::iplist<Operation>::iterator iterator) { | |||
462 | assert(iterator != block->end() && "cannot move after end of block")(static_cast <bool> (iterator != block->end() && "cannot move after end of block") ? void (0) : __assert_fail ("iterator != block->end() && \"cannot move after end of block\"" , "mlir/lib/IR/Operation.cpp", 462, __extension__ __PRETTY_FUNCTION__ )); | |||
463 | moveBefore(block, std::next(iterator)); | |||
464 | } | |||
465 | ||||
466 | /// This drops all operand uses from this operation, which is an essential | |||
467 | /// step in breaking cyclic dependences between references when they are to | |||
468 | /// be deleted. | |||
469 | void Operation::dropAllReferences() { | |||
470 | for (auto &op : getOpOperands()) | |||
471 | op.drop(); | |||
472 | ||||
473 | for (auto ®ion : getRegions()) | |||
474 | region.dropAllReferences(); | |||
475 | ||||
476 | for (auto &dest : getBlockOperands()) | |||
477 | dest.drop(); | |||
478 | } | |||
479 | ||||
480 | /// This drops all uses of any values defined by this operation or its nested | |||
481 | /// regions, wherever they are located. | |||
482 | void Operation::dropAllDefinedValueUses() { | |||
483 | dropAllUses(); | |||
484 | ||||
485 | for (auto ®ion : getRegions()) | |||
486 | for (auto &block : region) | |||
487 | block.dropAllDefinedValueUses(); | |||
488 | } | |||
489 | ||||
490 | void Operation::setSuccessor(Block *block, unsigned index) { | |||
491 | assert(index < getNumSuccessors())(static_cast <bool> (index < getNumSuccessors()) ? void (0) : __assert_fail ("index < getNumSuccessors()", "mlir/lib/IR/Operation.cpp" , 491, __extension__ __PRETTY_FUNCTION__)); | |||
492 | getBlockOperands()[index].set(block); | |||
493 | } | |||
494 | ||||
495 | /// Attempt to fold this operation using the Op's registered foldHook. | |||
496 | LogicalResult Operation::fold(ArrayRef<Attribute> operands, | |||
497 | SmallVectorImpl<OpFoldResult> &results) { | |||
498 | // If we have a registered operation definition matching this one, use it to | |||
499 | // try to constant fold the operation. | |||
500 | Optional<RegisteredOperationName> info = getRegisteredInfo(); | |||
501 | if (info && succeeded(info->foldHook(this, operands, results))) | |||
502 | return success(); | |||
503 | ||||
504 | // Otherwise, fall back on the dialect hook to handle it. | |||
505 | Dialect *dialect = getDialect(); | |||
506 | if (!dialect) | |||
507 | return failure(); | |||
508 | ||||
509 | auto *interface = dyn_cast<DialectFoldInterface>(dialect); | |||
510 | if (!interface) | |||
511 | return failure(); | |||
512 | ||||
513 | return interface->fold(this, operands, results); | |||
514 | } | |||
515 | ||||
516 | /// Emit an error with the op name prefixed, like "'dim' op " which is | |||
517 | /// convenient for verifiers. | |||
518 | InFlightDiagnostic Operation::emitOpError(const Twine &message) { | |||
519 | return emitError() << "'" << getName() << "' op " << message; | |||
520 | } | |||
521 | ||||
522 | //===----------------------------------------------------------------------===// | |||
523 | // Operation Cloning | |||
524 | //===----------------------------------------------------------------------===// | |||
525 | ||||
526 | /// Create a deep copy of this operation but keep the operation regions empty. | |||
527 | /// Operands are remapped using `mapper` (if present), and `mapper` is updated | |||
528 | /// to contain the results. The `mapResults` flag specifies whether the results | |||
529 | /// of the cloned operation should be added to the map. | |||
530 | Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper, | |||
531 | bool mapResults) { | |||
532 | SmallVector<Value, 8> operands; | |||
533 | SmallVector<Block *, 2> successors; | |||
534 | ||||
535 | // Remap the operands. | |||
536 | operands.reserve(getNumOperands()); | |||
537 | for (auto opValue : getOperands()) | |||
538 | operands.push_back(mapper.lookupOrDefault(opValue)); | |||
539 | ||||
540 | // Remap the successors. | |||
541 | successors.reserve(getNumSuccessors()); | |||
542 | for (Block *successor : getSuccessors()) | |||
543 | successors.push_back(mapper.lookupOrDefault(successor)); | |||
544 | ||||
545 | // Create the new operation. | |||
546 | auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs, | |||
547 | successors, getNumRegions()); | |||
548 | ||||
549 | // Remember the mapping of any results. | |||
550 | if (mapResults) { | |||
551 | for (unsigned i = 0, e = getNumResults(); i != e; ++i) | |||
552 | mapper.map(getResult(i), newOp->getResult(i)); | |||
553 | } | |||
554 | ||||
555 | return newOp; | |||
556 | } | |||
557 | ||||
558 | Operation *Operation::cloneWithoutRegions() { | |||
559 | BlockAndValueMapping mapper; | |||
560 | return cloneWithoutRegions(mapper); | |||
561 | } | |||
562 | ||||
563 | /// Create a deep copy of this operation, remapping any operands that use | |||
564 | /// values outside of the operation using the map that is provided (leaving | |||
565 | /// them alone if no entry is present). Replaces references to cloned | |||
566 | /// sub-operations to the corresponding operation that is copied, and adds | |||
567 | /// those mappings to the map. | |||
568 | Operation *Operation::clone(BlockAndValueMapping &mapper) { | |||
569 | auto *newOp = cloneWithoutRegions(mapper, /*mapResults=*/false); | |||
570 | ||||
571 | // Clone the regions. | |||
572 | for (unsigned i = 0; i != numRegions; ++i) | |||
573 | getRegion(i).cloneInto(&newOp->getRegion(i), mapper); | |||
574 | ||||
575 | for (unsigned i = 0, e = getNumResults(); i != e; ++i) | |||
576 | mapper.map(getResult(i), newOp->getResult(i)); | |||
577 | ||||
578 | return newOp; | |||
579 | } | |||
580 | ||||
581 | Operation *Operation::clone() { | |||
582 | BlockAndValueMapping mapper; | |||
583 | return clone(mapper); | |||
| ||||
584 | } | |||
585 | ||||
586 | //===----------------------------------------------------------------------===// | |||
587 | // OpState trait class. | |||
588 | //===----------------------------------------------------------------------===// | |||
589 | ||||
590 | // The fallback for the parser is to try for a dialect operation parser. | |||
591 | // Otherwise, reject the custom assembly form. | |||
592 | ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) { | |||
593 | if (auto parseFn = result.name.getDialect()->getParseOperationHook( | |||
594 | result.name.getStringRef())) | |||
595 | return (*parseFn)(parser, result); | |||
596 | return parser.emitError(parser.getNameLoc(), "has no custom assembly form"); | |||
597 | } | |||
598 | ||||
599 | // The fallback for the printer is to try for a dialect operation printer. | |||
600 | // Otherwise, it prints the generic form. | |||
601 | void OpState::print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect) { | |||
602 | if (auto printFn = op->getDialect()->getOperationPrinter(op)) { | |||
603 | printOpName(op, p, defaultDialect); | |||
604 | printFn(op, p); | |||
605 | } else { | |||
606 | p.printGenericOp(op); | |||
607 | } | |||
608 | } | |||
609 | ||||
610 | /// Print an operation name, eliding the dialect prefix if necessary. | |||
611 | void OpState::printOpName(Operation *op, OpAsmPrinter &p, | |||
612 | StringRef defaultDialect) { | |||
613 | StringRef name = op->getName().getStringRef(); | |||
614 | if (name.startswith((defaultDialect + ".").str())) | |||
615 | name = name.drop_front(defaultDialect.size() + 1); | |||
616 | // TODO: remove this special case (and update test/IR/parser.mlir) | |||
617 | else if ((defaultDialect.empty() || defaultDialect == "builtin") && | |||
618 | name.startswith("func.")) | |||
619 | name = name.drop_front(5); | |||
620 | p.getStream() << name; | |||
621 | } | |||
622 | ||||
623 | /// Emit an error about fatal conditions with this operation, reporting up to | |||
624 | /// any diagnostic handlers that may be listening. | |||
625 | InFlightDiagnostic OpState::emitError(const Twine &message) { | |||
626 | return getOperation()->emitError(message); | |||
627 | } | |||
628 | ||||
629 | /// Emit an error with the op name prefixed, like "'dim' op " which is | |||
630 | /// convenient for verifiers. | |||
631 | InFlightDiagnostic OpState::emitOpError(const Twine &message) { | |||
632 | return getOperation()->emitOpError(message); | |||
633 | } | |||
634 | ||||
635 | /// Emit a warning about this operation, reporting up to any diagnostic | |||
636 | /// handlers that may be listening. | |||
637 | InFlightDiagnostic OpState::emitWarning(const Twine &message) { | |||
638 | return getOperation()->emitWarning(message); | |||
639 | } | |||
640 | ||||
641 | /// Emit a remark about this operation, reporting up to any diagnostic | |||
642 | /// handlers that may be listening. | |||
643 | InFlightDiagnostic OpState::emitRemark(const Twine &message) { | |||
644 | return getOperation()->emitRemark(message); | |||
645 | } | |||
646 | ||||
647 | //===----------------------------------------------------------------------===// | |||
648 | // Op Trait implementations | |||
649 | //===----------------------------------------------------------------------===// | |||
650 | ||||
651 | OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) { | |||
652 | if (op->getNumOperands() == 1) { | |||
653 | auto *argumentOp = op->getOperand(0).getDefiningOp(); | |||
654 | if (argumentOp && op->getName() == argumentOp->getName()) { | |||
655 | // Replace the outer operation output with the inner operation. | |||
656 | return op->getOperand(0); | |||
657 | } | |||
658 | } else if (op->getOperand(0) == op->getOperand(1)) { | |||
659 | return op->getOperand(0); | |||
660 | } | |||
661 | ||||
662 | return {}; | |||
663 | } | |||
664 | ||||
665 | OpFoldResult OpTrait::impl::foldInvolution(Operation *op) { | |||
666 | auto *argumentOp = op->getOperand(0).getDefiningOp(); | |||
667 | if (argumentOp && op->getName() == argumentOp->getName()) { | |||
668 | // Replace the outer involutions output with inner's input. | |||
669 | return argumentOp->getOperand(0); | |||
670 | } | |||
671 | ||||
672 | return {}; | |||
673 | } | |||
674 | ||||
675 | LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) { | |||
676 | if (op->getNumOperands() != 0) | |||
677 | return op->emitOpError() << "requires zero operands"; | |||
678 | return success(); | |||
679 | } | |||
680 | ||||
681 | LogicalResult OpTrait::impl::verifyOneOperand(Operation *op) { | |||
682 | if (op->getNumOperands() != 1) | |||
683 | return op->emitOpError() << "requires a single operand"; | |||
684 | return success(); | |||
685 | } | |||
686 | ||||
687 | LogicalResult OpTrait::impl::verifyNOperands(Operation *op, | |||
688 | unsigned numOperands) { | |||
689 | if (op->getNumOperands() != numOperands) { | |||
690 | return op->emitOpError() << "expected " << numOperands | |||
691 | << " operands, but found " << op->getNumOperands(); | |||
692 | } | |||
693 | return success(); | |||
694 | } | |||
695 | ||||
696 | LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op, | |||
697 | unsigned numOperands) { | |||
698 | if (op->getNumOperands() < numOperands) | |||
699 | return op->emitOpError() | |||
700 | << "expected " << numOperands << " or more operands, but found " | |||
701 | << op->getNumOperands(); | |||
702 | return success(); | |||
703 | } | |||
704 | ||||
705 | /// If this is a vector type, or a tensor type, return the scalar element type | |||
706 | /// that it is built around, otherwise return the type unmodified. | |||
707 | static Type getTensorOrVectorElementType(Type type) { | |||
708 | if (auto vec = type.dyn_cast<VectorType>()) | |||
709 | return vec.getElementType(); | |||
710 | ||||
711 | // Look through tensor<vector<...>> to find the underlying element type. | |||
712 | if (auto tensor = type.dyn_cast<TensorType>()) | |||
713 | return getTensorOrVectorElementType(tensor.getElementType()); | |||
714 | return type; | |||
715 | } | |||
716 | ||||
717 | LogicalResult OpTrait::impl::verifyIsIdempotent(Operation *op) { | |||
718 | // FIXME: Add back check for no side effects on operation. | |||
719 | // Currently adding it would cause the shared library build | |||
720 | // to fail since there would be a dependency of IR on SideEffectInterfaces | |||
721 | // which is cyclical. | |||
722 | return success(); | |||
723 | } | |||
724 | ||||
725 | LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) { | |||
726 | // FIXME: Add back check for no side effects on operation. | |||
727 | // Currently adding it would cause the shared library build | |||
728 | // to fail since there would be a dependency of IR on SideEffectInterfaces | |||
729 | // which is cyclical. | |||
730 | return success(); | |||
731 | } | |||
732 | ||||
733 | LogicalResult | |||
734 | OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) { | |||
735 | for (auto opType : op->getOperandTypes()) { | |||
736 | auto type = getTensorOrVectorElementType(opType); | |||
737 | if (!type.isSignlessIntOrIndex()) | |||
738 | return op->emitOpError() << "requires an integer or index type"; | |||
739 | } | |||
740 | return success(); | |||
741 | } | |||
742 | ||||
743 | LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) { | |||
744 | for (auto opType : op->getOperandTypes()) { | |||
745 | auto type = getTensorOrVectorElementType(opType); | |||
746 | if (!type.isa<FloatType>()) | |||
747 | return op->emitOpError("requires a float type"); | |||
748 | } | |||
749 | return success(); | |||
750 | } | |||
751 | ||||
752 | LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) { | |||
753 | // Zero or one operand always have the "same" type. | |||
754 | unsigned nOperands = op->getNumOperands(); | |||
755 | if (nOperands < 2) | |||
756 | return success(); | |||
757 | ||||
758 | auto type = op->getOperand(0).getType(); | |||
759 | for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) | |||
760 | if (opType != type) | |||
761 | return op->emitOpError() << "requires all operands to have the same type"; | |||
762 | return success(); | |||
763 | } | |||
764 | ||||
765 | LogicalResult OpTrait::impl::verifyZeroRegion(Operation *op) { | |||
766 | if (op->getNumRegions() != 0) | |||
767 | return op->emitOpError() << "requires zero regions"; | |||
768 | return success(); | |||
769 | } | |||
770 | ||||
771 | LogicalResult OpTrait::impl::verifyOneRegion(Operation *op) { | |||
772 | if (op->getNumRegions() != 1) | |||
773 | return op->emitOpError() << "requires one region"; | |||
774 | return success(); | |||
775 | } | |||
776 | ||||
777 | LogicalResult OpTrait::impl::verifyNRegions(Operation *op, | |||
778 | unsigned numRegions) { | |||
779 | if (op->getNumRegions() != numRegions) | |||
780 | return op->emitOpError() << "expected " << numRegions << " regions"; | |||
781 | return success(); | |||
782 | } | |||
783 | ||||
784 | LogicalResult OpTrait::impl::verifyAtLeastNRegions(Operation *op, | |||
785 | unsigned numRegions) { | |||
786 | if (op->getNumRegions() < numRegions) | |||
787 | return op->emitOpError() << "expected " << numRegions << " or more regions"; | |||
788 | return success(); | |||
789 | } | |||
790 | ||||
791 | LogicalResult OpTrait::impl::verifyZeroResult(Operation *op) { | |||
792 | if (op->getNumResults() != 0) | |||
793 | return op->emitOpError() << "requires zero results"; | |||
794 | return success(); | |||
795 | } | |||
796 | ||||
797 | LogicalResult OpTrait::impl::verifyOneResult(Operation *op) { | |||
798 | if (op->getNumResults() != 1) | |||
799 | return op->emitOpError() << "requires one result"; | |||
800 | return success(); | |||
801 | } | |||
802 | ||||
803 | LogicalResult OpTrait::impl::verifyNResults(Operation *op, | |||
804 | unsigned numOperands) { | |||
805 | if (op->getNumResults() != numOperands) | |||
806 | return op->emitOpError() << "expected " << numOperands << " results"; | |||
807 | return success(); | |||
808 | } | |||
809 | ||||
810 | LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op, | |||
811 | unsigned numOperands) { | |||
812 | if (op->getNumResults() < numOperands) | |||
813 | return op->emitOpError() | |||
814 | << "expected " << numOperands << " or more results"; | |||
815 | return success(); | |||
816 | } | |||
817 | ||||
818 | LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) { | |||
819 | if (failed(verifyAtLeastNOperands(op, 1))) | |||
820 | return failure(); | |||
821 | ||||
822 | if (failed(verifyCompatibleShapes(op->getOperandTypes()))) | |||
823 | return op->emitOpError() << "requires the same shape for all operands"; | |||
824 | ||||
825 | return success(); | |||
826 | } | |||
827 | ||||
828 | LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { | |||
829 | if (failed(verifyAtLeastNOperands(op, 1)) || | |||
830 | failed(verifyAtLeastNResults(op, 1))) | |||
831 | return failure(); | |||
832 | ||||
833 | SmallVector<Type, 8> types(op->getOperandTypes()); | |||
834 | types.append(llvm::to_vector<4>(op->getResultTypes())); | |||
835 | ||||
836 | if (failed(verifyCompatibleShapes(types))) | |||
837 | return op->emitOpError() | |||
838 | << "requires the same shape for all operands and results"; | |||
839 | ||||
840 | return success(); | |||
841 | } | |||
842 | ||||
843 | LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) { | |||
844 | if (failed(verifyAtLeastNOperands(op, 1))) | |||
845 | return failure(); | |||
846 | auto elementType = getElementTypeOrSelf(op->getOperand(0)); | |||
847 | ||||
848 | for (auto operand : llvm::drop_begin(op->getOperands(), 1)) { | |||
849 | if (getElementTypeOrSelf(operand) != elementType) | |||
850 | return op->emitOpError("requires the same element type for all operands"); | |||
851 | } | |||
852 | ||||
853 | return success(); | |||
854 | } | |||
855 | ||||
856 | LogicalResult | |||
857 | OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { | |||
858 | if (failed(verifyAtLeastNOperands(op, 1)) || | |||
859 | failed(verifyAtLeastNResults(op, 1))) | |||
860 | return failure(); | |||
861 | ||||
862 | auto elementType = getElementTypeOrSelf(op->getResult(0)); | |||
863 | ||||
864 | // Verify result element type matches first result's element type. | |||
865 | for (auto result : llvm::drop_begin(op->getResults(), 1)) { | |||
866 | if (getElementTypeOrSelf(result) != elementType) | |||
867 | return op->emitOpError( | |||
868 | "requires the same element type for all operands and results"); | |||
869 | } | |||
870 | ||||
871 | // Verify operand's element type matches first result's element type. | |||
872 | for (auto operand : op->getOperands()) { | |||
873 | if (getElementTypeOrSelf(operand) != elementType) | |||
874 | return op->emitOpError( | |||
875 | "requires the same element type for all operands and results"); | |||
876 | } | |||
877 | ||||
878 | return success(); | |||
879 | } | |||
880 | ||||
881 | LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { | |||
882 | if (failed(verifyAtLeastNOperands(op, 1)) || | |||
883 | failed(verifyAtLeastNResults(op, 1))) | |||
884 | return failure(); | |||
885 | ||||
886 | auto type = op->getResult(0).getType(); | |||
887 | auto elementType = getElementTypeOrSelf(type); | |||
888 | for (auto resultType : llvm::drop_begin(op->getResultTypes())) { | |||
889 | if (getElementTypeOrSelf(resultType) != elementType || | |||
890 | failed(verifyCompatibleShape(resultType, type))) | |||
891 | return op->emitOpError() | |||
892 | << "requires the same type for all operands and results"; | |||
893 | } | |||
894 | for (auto opType : op->getOperandTypes()) { | |||
895 | if (getElementTypeOrSelf(opType) != elementType || | |||
896 | failed(verifyCompatibleShape(opType, type))) | |||
897 | return op->emitOpError() | |||
898 | << "requires the same type for all operands and results"; | |||
899 | } | |||
900 | return success(); | |||
901 | } | |||
902 | ||||
903 | LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { | |||
904 | Block *block = op->getBlock(); | |||
905 | // Verify that the operation is at the end of the respective parent block. | |||
906 | if (!block || &block->back() != op) | |||
907 | return op->emitOpError("must be the last operation in the parent block"); | |||
908 | return success(); | |||
909 | } | |||
910 | ||||
911 | static LogicalResult verifyTerminatorSuccessors(Operation *op) { | |||
912 | auto *parent = op->getParentRegion(); | |||
913 | ||||
914 | // Verify that the operands lines up with the BB arguments in the successor. | |||
915 | for (Block *succ : op->getSuccessors()) | |||
916 | if (succ->getParent() != parent) | |||
917 | return op->emitError("reference to block defined in another region"); | |||
918 | return success(); | |||
919 | } | |||
920 | ||||
921 | LogicalResult OpTrait::impl::verifyZeroSuccessor(Operation *op) { | |||
922 | if (op->getNumSuccessors() != 0) { | |||
923 | return op->emitOpError("requires 0 successors but found ") | |||
924 | << op->getNumSuccessors(); | |||
925 | } | |||
926 | return success(); | |||
927 | } | |||
928 | ||||
929 | LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) { | |||
930 | if (op->getNumSuccessors() != 1) { | |||
931 | return op->emitOpError("requires 1 successor but found ") | |||
932 | << op->getNumSuccessors(); | |||
933 | } | |||
934 | return verifyTerminatorSuccessors(op); | |||
935 | } | |||
936 | LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op, | |||
937 | unsigned numSuccessors) { | |||
938 | if (op->getNumSuccessors() != numSuccessors) { | |||
939 | return op->emitOpError("requires ") | |||
940 | << numSuccessors << " successors but found " | |||
941 | << op->getNumSuccessors(); | |||
942 | } | |||
943 | return verifyTerminatorSuccessors(op); | |||
944 | } | |||
945 | LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op, | |||
946 | unsigned numSuccessors) { | |||
947 | if (op->getNumSuccessors() < numSuccessors) { | |||
948 | return op->emitOpError("requires at least ") | |||
949 | << numSuccessors << " successors but found " | |||
950 | << op->getNumSuccessors(); | |||
951 | } | |||
952 | return verifyTerminatorSuccessors(op); | |||
953 | } | |||
954 | ||||
955 | LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) { | |||
956 | for (auto resultType : op->getResultTypes()) { | |||
957 | auto elementType = getTensorOrVectorElementType(resultType); | |||
958 | bool isBoolType = elementType.isInteger(1); | |||
959 | if (!isBoolType) | |||
960 | return op->emitOpError() << "requires a bool result type"; | |||
961 | } | |||
962 | ||||
963 | return success(); | |||
964 | } | |||
965 | ||||
966 | LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) { | |||
967 | for (auto resultType : op->getResultTypes()) | |||
968 | if (!getTensorOrVectorElementType(resultType).isa<FloatType>()) | |||
969 | return op->emitOpError() << "requires a floating point type"; | |||
970 | ||||
971 | return success(); | |||
972 | } | |||
973 | ||||
974 | LogicalResult | |||
975 | OpTrait::impl::verifyResultsAreSignlessIntegerLike(Operation *op) { | |||
976 | for (auto resultType : op->getResultTypes()) | |||
977 | if (!getTensorOrVectorElementType(resultType).isSignlessIntOrIndex()) | |||
978 | return op->emitOpError() << "requires an integer or index type"; | |||
979 | return success(); | |||
980 | } | |||
981 | ||||
982 | LogicalResult OpTrait::impl::verifyValueSizeAttr(Operation *op, | |||
983 | StringRef attrName, | |||
984 | StringRef valueGroupName, | |||
985 | size_t expectedCount) { | |||
986 | auto sizeAttr = op->getAttrOfType<DenseIntElementsAttr>(attrName); | |||
987 | if (!sizeAttr) | |||
988 | return op->emitOpError("requires 1D i32 elements attribute '") | |||
989 | << attrName << "'"; | |||
990 | ||||
991 | auto sizeAttrType = sizeAttr.getType(); | |||
992 | if (sizeAttrType.getRank() != 1 || | |||
993 | !sizeAttrType.getElementType().isInteger(32)) | |||
994 | return op->emitOpError("requires 1D i32 elements attribute '") | |||
995 | << attrName << "'"; | |||
996 | ||||
997 | if (llvm::any_of(sizeAttr.getValues<APInt>(), [](const APInt &element) { | |||
998 | return !element.isNonNegative(); | |||
999 | })) | |||
1000 | return op->emitOpError("'") | |||
1001 | << attrName << "' attribute cannot have negative elements"; | |||
1002 | ||||
1003 | size_t totalCount = std::accumulate( | |||
1004 | sizeAttr.begin(), sizeAttr.end(), 0, | |||
1005 | [](unsigned all, const APInt &one) { return all + one.getZExtValue(); }); | |||
1006 | ||||
1007 | if (totalCount != expectedCount) | |||
1008 | return op->emitOpError() | |||
1009 | << valueGroupName << " count (" << expectedCount | |||
1010 | << ") does not match with the total size (" << totalCount | |||
1011 | << ") specified in attribute '" << attrName << "'"; | |||
1012 | return success(); | |||
1013 | } | |||
1014 | ||||
1015 | LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op, | |||
1016 | StringRef attrName) { | |||
1017 | return verifyValueSizeAttr(op, attrName, "operand", op->getNumOperands()); | |||
1018 | } | |||
1019 | ||||
1020 | LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op, | |||
1021 | StringRef attrName) { | |||
1022 | return verifyValueSizeAttr(op, attrName, "result", op->getNumResults()); | |||
1023 | } | |||
1024 | ||||
1025 | LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) { | |||
1026 | for (Region ®ion : op->getRegions()) { | |||
1027 | if (region.empty()) | |||
1028 | continue; | |||
1029 | ||||
1030 | if (region.getNumArguments() != 0) { | |||
1031 | if (op->getNumRegions() > 1) | |||
1032 | return op->emitOpError("region #") | |||
1033 | << region.getRegionNumber() << " should have no arguments"; | |||
1034 | return op->emitOpError("region should have no arguments"); | |||
1035 | } | |||
1036 | } | |||
1037 | return success(); | |||
1038 | } | |||
1039 | ||||
1040 | LogicalResult OpTrait::impl::verifyElementwise(Operation *op) { | |||
1041 | auto isMappableType = [](Type type) { | |||
1042 | return type.isa<VectorType, TensorType>(); | |||
1043 | }; | |||
1044 | auto resultMappableTypes = llvm::to_vector<1>( | |||
1045 | llvm::make_filter_range(op->getResultTypes(), isMappableType)); | |||
1046 | auto operandMappableTypes = llvm::to_vector<2>( | |||
1047 | llvm::make_filter_range(op->getOperandTypes(), isMappableType)); | |||
1048 | ||||
1049 | // If the op only has scalar operand/result types, then we have nothing to | |||
1050 | // check. | |||
1051 | if (resultMappableTypes.empty() && operandMappableTypes.empty()) | |||
1052 | return success(); | |||
1053 | ||||
1054 | if (!resultMappableTypes.empty() && operandMappableTypes.empty()) | |||
1055 | return op->emitOpError("if a result is non-scalar, then at least one " | |||
1056 | "operand must be non-scalar"); | |||
1057 | ||||
1058 | assert(!operandMappableTypes.empty())(static_cast <bool> (!operandMappableTypes.empty()) ? void (0) : __assert_fail ("!operandMappableTypes.empty()", "mlir/lib/IR/Operation.cpp" , 1058, __extension__ __PRETTY_FUNCTION__)); | |||
1059 | ||||
1060 | if (resultMappableTypes.empty()) | |||
1061 | return op->emitOpError("if an operand is non-scalar, then there must be at " | |||
1062 | "least one non-scalar result"); | |||
1063 | ||||
1064 | if (resultMappableTypes.size() != op->getNumResults()) | |||
1065 | return op->emitOpError( | |||
1066 | "if an operand is non-scalar, then all results must be non-scalar"); | |||
1067 | ||||
1068 | SmallVector<Type, 4> types = llvm::to_vector<2>( | |||
1069 | llvm::concat<Type>(operandMappableTypes, resultMappableTypes)); | |||
1070 | TypeID expectedBaseTy = types.front().getTypeID(); | |||
1071 | if (!llvm::all_of(types, | |||
1072 | [&](Type t) { return t.getTypeID() == expectedBaseTy; }) || | |||
1073 | failed(verifyCompatibleShapes(types))) { | |||
1074 | return op->emitOpError() << "all non-scalar operands/results must have the " | |||
1075 | "same shape and base type"; | |||
1076 | } | |||
1077 | ||||
1078 | return success(); | |||
1079 | } | |||
1080 | ||||
1081 | /// Check for any values used by operations regions attached to the | |||
1082 | /// specified "IsIsolatedFromAbove" operation defined outside of it. | |||
1083 | LogicalResult OpTrait::impl::verifyIsIsolatedFromAbove(Operation *isolatedOp) { | |||
1084 | assert(isolatedOp->hasTrait<OpTrait::IsIsolatedFromAbove>() &&(static_cast <bool> (isolatedOp->hasTrait<OpTrait ::IsIsolatedFromAbove>() && "Intended to check IsolatedFromAbove ops" ) ? void (0) : __assert_fail ("isolatedOp->hasTrait<OpTrait::IsIsolatedFromAbove>() && \"Intended to check IsolatedFromAbove ops\"" , "mlir/lib/IR/Operation.cpp", 1085, __extension__ __PRETTY_FUNCTION__ )) | |||
1085 | "Intended to check IsolatedFromAbove ops")(static_cast <bool> (isolatedOp->hasTrait<OpTrait ::IsIsolatedFromAbove>() && "Intended to check IsolatedFromAbove ops" ) ? void (0) : __assert_fail ("isolatedOp->hasTrait<OpTrait::IsIsolatedFromAbove>() && \"Intended to check IsolatedFromAbove ops\"" , "mlir/lib/IR/Operation.cpp", 1085, __extension__ __PRETTY_FUNCTION__ )); | |||
1086 | ||||
1087 | // List of regions to analyze. Each region is processed independently, with | |||
1088 | // respect to the common `limit` region, so we can look at them in any order. | |||
1089 | // Therefore, use a simple vector and push/pop back the current region. | |||
1090 | SmallVector<Region *, 8> pendingRegions; | |||
1091 | for (auto ®ion : isolatedOp->getRegions()) { | |||
1092 | pendingRegions.push_back(®ion); | |||
1093 | ||||
1094 | // Traverse all operations in the region. | |||
1095 | while (!pendingRegions.empty()) { | |||
1096 | for (Operation &op : pendingRegions.pop_back_val()->getOps()) { | |||
1097 | for (Value operand : op.getOperands()) { | |||
1098 | // Check that any value that is used by an operation is defined in the | |||
1099 | // same region as either an operation result. | |||
1100 | auto *operandRegion = operand.getParentRegion(); | |||
1101 | if (!operandRegion) | |||
1102 | return op.emitError("operation's operand is unlinked"); | |||
1103 | if (!region.isAncestor(operandRegion)) { | |||
1104 | return op.emitOpError("using value defined outside the region") | |||
1105 | .attachNote(isolatedOp->getLoc()) | |||
1106 | << "required by region isolation constraints"; | |||
1107 | } | |||
1108 | } | |||
1109 | ||||
1110 | // Schedule any regions in the operation for further checking. Don't | |||
1111 | // recurse into other IsolatedFromAbove ops, because they will check | |||
1112 | // themselves. | |||
1113 | if (op.getNumRegions() && | |||
1114 | !op.hasTrait<OpTrait::IsIsolatedFromAbove>()) { | |||
1115 | for (Region &subRegion : op.getRegions()) | |||
1116 | pendingRegions.push_back(&subRegion); | |||
1117 | } | |||
1118 | } | |||
1119 | } | |||
1120 | } | |||
1121 | ||||
1122 | return success(); | |||
1123 | } | |||
1124 | ||||
1125 | bool OpTrait::hasElementwiseMappableTraits(Operation *op) { | |||
1126 | return op->hasTrait<Elementwise>() && op->hasTrait<Scalarizable>() && | |||
1127 | op->hasTrait<Vectorizable>() && op->hasTrait<Tensorizable>(); | |||
1128 | } | |||
1129 | ||||
1130 | //===----------------------------------------------------------------------===// | |||
1131 | // CastOpInterface | |||
1132 | //===----------------------------------------------------------------------===// | |||
1133 | ||||
1134 | /// Attempt to fold the given cast operation. | |||
1135 | LogicalResult | |||
1136 | impl::foldCastInterfaceOp(Operation *op, ArrayRef<Attribute> attrOperands, | |||
1137 | SmallVectorImpl<OpFoldResult> &foldResults) { | |||
1138 | OperandRange operands = op->getOperands(); | |||
1139 | if (operands.empty()) | |||
1140 | return failure(); | |||
1141 | ResultRange results = op->getResults(); | |||
1142 | ||||
1143 | // Check for the case where the input and output types match 1-1. | |||
1144 | if (operands.getTypes() == results.getTypes()) { | |||
1145 | foldResults.append(operands.begin(), operands.end()); | |||
1146 | return success(); | |||
1147 | } | |||
1148 | ||||
1149 | return failure(); | |||
1150 | } | |||
1151 | ||||
1152 | /// Attempt to verify the given cast operation. | |||
1153 | LogicalResult impl::verifyCastInterfaceOp( | |||
1154 | Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible) { | |||
1155 | auto resultTypes = op->getResultTypes(); | |||
1156 | if (llvm::empty(resultTypes)) | |||
1157 | return op->emitOpError() | |||
1158 | << "expected at least one result for cast operation"; | |||
1159 | ||||
1160 | auto operandTypes = op->getOperandTypes(); | |||
1161 | if (!areCastCompatible(operandTypes, resultTypes)) { | |||
1162 | InFlightDiagnostic diag = op->emitOpError("operand type"); | |||
1163 | if (llvm::empty(operandTypes)) | |||
1164 | diag << "s []"; | |||
1165 | else if (llvm::size(operandTypes) == 1) | |||
1166 | diag << " " << *operandTypes.begin(); | |||
1167 | else | |||
1168 | diag << "s " << operandTypes; | |||
1169 | return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ") | |||
1170 | << resultTypes << " are cast incompatible"; | |||
1171 | } | |||
1172 | ||||
1173 | return success(); | |||
1174 | } | |||
1175 | ||||
1176 | //===----------------------------------------------------------------------===// | |||
1177 | // Misc. utils | |||
1178 | //===----------------------------------------------------------------------===// | |||
1179 | ||||
1180 | /// Insert an operation, generated by `buildTerminatorOp`, at the end of the | |||
1181 | /// region's only block if it does not have a terminator already. If the region | |||
1182 | /// is empty, insert a new block first. `buildTerminatorOp` should return the | |||
1183 | /// terminator operation to insert. | |||
1184 | void impl::ensureRegionTerminator( | |||
1185 | Region ®ion, OpBuilder &builder, Location loc, | |||
1186 | function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { | |||
1187 | OpBuilder::InsertionGuard guard(builder); | |||
1188 | if (region.empty()) | |||
1189 | builder.createBlock(®ion); | |||
1190 | ||||
1191 | Block &block = region.back(); | |||
1192 | if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>()) | |||
1193 | return; | |||
1194 | ||||
1195 | builder.setInsertionPointToEnd(&block); | |||
1196 | builder.insert(buildTerminatorOp(builder, loc)); | |||
1197 | } | |||
1198 | ||||
1199 | /// Create a simple OpBuilder and forward to the OpBuilder version of this | |||
1200 | /// function. | |||
1201 | void impl::ensureRegionTerminator( | |||
1202 | Region ®ion, Builder &builder, Location loc, | |||
1203 | function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { | |||
1204 | OpBuilder opBuilder(builder.getContext()); | |||
1205 | ensureRegionTerminator(region, opBuilder, loc, buildTerminatorOp); | |||
1206 | } |