File: | build/source/mlir/lib/IR/Operation.cpp |
Warning: | line 94, column 3 Potential leak of memory pointed to by 'mallocMem' |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===- Operation.cpp - Operation support code -----------------------------===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | ||||
9 | #include "mlir/IR/Operation.h" | |||
10 | #include "mlir/IR/BuiltinTypes.h" | |||
11 | #include "mlir/IR/Dialect.h" | |||
12 | #include "mlir/IR/IRMapping.h" | |||
13 | #include "mlir/IR/OpImplementation.h" | |||
14 | #include "mlir/IR/PatternMatch.h" | |||
15 | #include "mlir/IR/TypeUtilities.h" | |||
16 | #include "mlir/Interfaces/FoldInterfaces.h" | |||
17 | #include "llvm/ADT/StringExtras.h" | |||
18 | #include <numeric> | |||
19 | ||||
20 | using namespace mlir; | |||
21 | ||||
22 | //===----------------------------------------------------------------------===// | |||
23 | // Operation | |||
24 | //===----------------------------------------------------------------------===// | |||
25 | ||||
26 | /// Create a new Operation from operation state. | |||
27 | Operation *Operation::create(const OperationState &state) { | |||
28 | return create(state.location, state.name, state.types, state.operands, | |||
29 | state.attributes.getDictionary(state.getContext()), | |||
30 | state.successors, state.regions); | |||
31 | } | |||
32 | ||||
33 | /// Create a new Operation with the specific fields. | |||
34 | Operation *Operation::create(Location location, OperationName name, | |||
35 | TypeRange resultTypes, ValueRange operands, | |||
36 | NamedAttrList &&attributes, BlockRange successors, | |||
37 | RegionRange regions) { | |||
38 | unsigned numRegions = regions.size(); | |||
39 | Operation *op = create(location, name, resultTypes, operands, | |||
40 | std::move(attributes), successors, numRegions); | |||
41 | for (unsigned i = 0; i < numRegions; ++i) | |||
42 | if (regions[i]) | |||
43 | op->getRegion(i).takeBody(*regions[i]); | |||
44 | return op; | |||
45 | } | |||
46 | ||||
47 | /// Create a new Operation with the specific fields. | |||
48 | Operation *Operation::create(Location location, OperationName name, | |||
49 | TypeRange resultTypes, ValueRange operands, | |||
50 | NamedAttrList &&attributes, BlockRange successors, | |||
51 | unsigned numRegions) { | |||
52 | // Populate default attributes. | |||
53 | name.populateDefaultAttrs(attributes); | |||
54 | ||||
55 | return create(location, name, resultTypes, operands, | |||
56 | attributes.getDictionary(location.getContext()), successors, | |||
57 | numRegions); | |||
58 | } | |||
59 | ||||
60 | /// Overload of create that takes an existing DictionaryAttr to avoid | |||
61 | /// unnecessarily uniquing a list of attributes. | |||
62 | Operation *Operation::create(Location location, OperationName name, | |||
63 | TypeRange resultTypes, ValueRange operands, | |||
64 | DictionaryAttr attributes, BlockRange successors, | |||
65 | unsigned numRegions) { | |||
66 | assert(llvm::all_of(resultTypes, [](Type t) { return t; }) &&(static_cast <bool> (llvm::all_of(resultTypes, [](Type t ) { return t; }) && "unexpected null result type") ? void (0) : __assert_fail ("llvm::all_of(resultTypes, [](Type t) { return t; }) && \"unexpected null result type\"" , "mlir/lib/IR/Operation.cpp", 67, __extension__ __PRETTY_FUNCTION__ )) | |||
67 | "unexpected null result type")(static_cast <bool> (llvm::all_of(resultTypes, [](Type t ) { return t; }) && "unexpected null result type") ? void (0) : __assert_fail ("llvm::all_of(resultTypes, [](Type t) { return t; }) && \"unexpected null result type\"" , "mlir/lib/IR/Operation.cpp", 67, __extension__ __PRETTY_FUNCTION__ )); | |||
68 | ||||
69 | // We only need to allocate additional memory for a subset of results. | |||
70 | unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size()); | |||
71 | unsigned numInlineResults = OpResult::getNumInline(resultTypes.size()); | |||
72 | unsigned numSuccessors = successors.size(); | |||
73 | unsigned numOperands = operands.size(); | |||
74 | unsigned numResults = resultTypes.size(); | |||
75 | ||||
76 | // If the operation is known to have no operands, don't allocate an operand | |||
77 | // storage. | |||
78 | bool needsOperandStorage = | |||
79 | operands.empty() ? !name.hasTrait<OpTrait::ZeroOperands>() : true; | |||
80 | ||||
81 | // Compute the byte size for the operation and the operand storage. This takes | |||
82 | // into account the size of the operation, its trailing objects, and its | |||
83 | // prefixed objects. | |||
84 | size_t byteSize = | |||
85 | totalSizeToAlloc<detail::OperandStorage, BlockOperand, Region, OpOperand>( | |||
86 | needsOperandStorage
| |||
87 | size_t prefixByteSize = llvm::alignTo( | |||
88 | Operation::prefixAllocSize(numTrailingResults, numInlineResults), | |||
89 | alignof(Operation)); | |||
90 | char *mallocMem = reinterpret_cast<char *>(malloc(byteSize + prefixByteSize)); | |||
91 | void *rawMem = mallocMem + prefixByteSize; | |||
92 | ||||
93 | // Create the new Operation. | |||
94 | Operation *op = | |||
| ||||
95 | ::new (rawMem) Operation(location, name, numResults, numSuccessors, | |||
96 | numRegions, attributes, needsOperandStorage); | |||
97 | ||||
98 | assert((numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) &&(static_cast <bool> ((numSuccessors == 0 || op->mightHaveTrait <OpTrait::IsTerminator>()) && "unexpected successors in a non-terminator operation" ) ? void (0) : __assert_fail ("(numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) && \"unexpected successors in a non-terminator operation\"" , "mlir/lib/IR/Operation.cpp", 99, __extension__ __PRETTY_FUNCTION__ )) | |||
99 | "unexpected successors in a non-terminator operation")(static_cast <bool> ((numSuccessors == 0 || op->mightHaveTrait <OpTrait::IsTerminator>()) && "unexpected successors in a non-terminator operation" ) ? void (0) : __assert_fail ("(numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) && \"unexpected successors in a non-terminator operation\"" , "mlir/lib/IR/Operation.cpp", 99, __extension__ __PRETTY_FUNCTION__ )); | |||
100 | ||||
101 | // Initialize the results. | |||
102 | auto resultTypeIt = resultTypes.begin(); | |||
103 | for (unsigned i = 0; i < numInlineResults; ++i, ++resultTypeIt) | |||
104 | new (op->getInlineOpResult(i)) detail::InlineOpResult(*resultTypeIt, i); | |||
105 | for (unsigned i = 0; i < numTrailingResults; ++i, ++resultTypeIt) { | |||
106 | new (op->getOutOfLineOpResult(i)) | |||
107 | detail::OutOfLineOpResult(*resultTypeIt, i); | |||
108 | } | |||
109 | ||||
110 | // Initialize the regions. | |||
111 | for (unsigned i = 0; i != numRegions; ++i) | |||
112 | new (&op->getRegion(i)) Region(op); | |||
113 | ||||
114 | // Initialize the operands. | |||
115 | if (needsOperandStorage) { | |||
116 | new (&op->getOperandStorage()) detail::OperandStorage( | |||
117 | op, op->getTrailingObjects<OpOperand>(), operands); | |||
118 | } | |||
119 | ||||
120 | // Initialize the successors. | |||
121 | auto blockOperands = op->getBlockOperands(); | |||
122 | for (unsigned i = 0; i != numSuccessors; ++i) | |||
123 | new (&blockOperands[i]) BlockOperand(op, successors[i]); | |||
124 | ||||
125 | return op; | |||
126 | } | |||
127 | ||||
128 | Operation::Operation(Location location, OperationName name, unsigned numResults, | |||
129 | unsigned numSuccessors, unsigned numRegions, | |||
130 | DictionaryAttr attributes, bool hasOperandStorage) | |||
131 | : location(location), numResults(numResults), numSuccs(numSuccessors), | |||
132 | numRegions(numRegions), hasOperandStorage(hasOperandStorage), name(name), | |||
133 | attrs(attributes) { | |||
134 | assert(attributes && "unexpected null attribute dictionary")(static_cast <bool> (attributes && "unexpected null attribute dictionary" ) ? void (0) : __assert_fail ("attributes && \"unexpected null attribute dictionary\"" , "mlir/lib/IR/Operation.cpp", 134, __extension__ __PRETTY_FUNCTION__ )); | |||
135 | #ifndef NDEBUG | |||
136 | if (!getDialect() && !getContext()->allowsUnregisteredDialects()) | |||
137 | llvm::report_fatal_error( | |||
138 | name.getStringRef() + | |||
139 | " created with unregistered dialect. If this is intended, please call " | |||
140 | "allowUnregisteredDialects() on the MLIRContext, or use " | |||
141 | "-allow-unregistered-dialect with the MLIR tool used."); | |||
142 | #endif | |||
143 | } | |||
144 | ||||
145 | // Operations are deleted through the destroy() member because they are | |||
146 | // allocated via malloc. | |||
147 | Operation::~Operation() { | |||
148 | assert(block == nullptr && "operation destroyed but still in a block")(static_cast <bool> (block == nullptr && "operation destroyed but still in a block" ) ? void (0) : __assert_fail ("block == nullptr && \"operation destroyed but still in a block\"" , "mlir/lib/IR/Operation.cpp", 148, __extension__ __PRETTY_FUNCTION__ )); | |||
149 | #ifndef NDEBUG | |||
150 | if (!use_empty()) { | |||
151 | { | |||
152 | InFlightDiagnostic diag = | |||
153 | emitOpError("operation destroyed but still has uses"); | |||
154 | for (Operation *user : getUsers()) | |||
155 | diag.attachNote(user->getLoc()) << "- use: " << *user << "\n"; | |||
156 | } | |||
157 | llvm::report_fatal_error("operation destroyed but still has uses"); | |||
158 | } | |||
159 | #endif | |||
160 | // Explicitly run the destructors for the operands. | |||
161 | if (hasOperandStorage) | |||
162 | getOperandStorage().~OperandStorage(); | |||
163 | ||||
164 | // Explicitly run the destructors for the successors. | |||
165 | for (auto &successor : getBlockOperands()) | |||
166 | successor.~BlockOperand(); | |||
167 | ||||
168 | // Explicitly destroy the regions. | |||
169 | for (auto ®ion : getRegions()) | |||
170 | region.~Region(); | |||
171 | } | |||
172 | ||||
173 | /// Destroy this operation or one of its subclasses. | |||
174 | void Operation::destroy() { | |||
175 | // Operations may have additional prefixed allocation, which needs to be | |||
176 | // accounted for here when computing the address to free. | |||
177 | char *rawMem = reinterpret_cast<char *>(this) - | |||
178 | llvm::alignTo(prefixAllocSize(), alignof(Operation)); | |||
179 | this->~Operation(); | |||
180 | free(rawMem); | |||
181 | } | |||
182 | ||||
183 | /// Return true if this operation is a proper ancestor of the `other` | |||
184 | /// operation. | |||
185 | bool Operation::isProperAncestor(Operation *other) { | |||
186 | while ((other = other->getParentOp())) | |||
187 | if (this == other) | |||
188 | return true; | |||
189 | return false; | |||
190 | } | |||
191 | ||||
192 | /// Replace any uses of 'from' with 'to' within this operation. | |||
193 | void Operation::replaceUsesOfWith(Value from, Value to) { | |||
194 | if (from == to) | |||
195 | return; | |||
196 | for (auto &operand : getOpOperands()) | |||
197 | if (operand.get() == from) | |||
198 | operand.set(to); | |||
199 | } | |||
200 | ||||
201 | /// Replace the current operands of this operation with the ones provided in | |||
202 | /// 'operands'. | |||
203 | void Operation::setOperands(ValueRange operands) { | |||
204 | if (LLVM_LIKELY(hasOperandStorage)__builtin_expect((bool)(hasOperandStorage), true)) | |||
205 | return getOperandStorage().setOperands(this, operands); | |||
206 | assert(operands.empty() && "setting operands without an operand storage")(static_cast <bool> (operands.empty() && "setting operands without an operand storage" ) ? void (0) : __assert_fail ("operands.empty() && \"setting operands without an operand storage\"" , "mlir/lib/IR/Operation.cpp", 206, __extension__ __PRETTY_FUNCTION__ )); | |||
207 | } | |||
208 | ||||
209 | /// Replace the operands beginning at 'start' and ending at 'start' + 'length' | |||
210 | /// with the ones provided in 'operands'. 'operands' may be smaller or larger | |||
211 | /// than the range pointed to by 'start'+'length'. | |||
212 | void Operation::setOperands(unsigned start, unsigned length, | |||
213 | ValueRange operands) { | |||
214 | assert((start + length) <= getNumOperands() &&(static_cast <bool> ((start + length) <= getNumOperands () && "invalid operand range specified") ? void (0) : __assert_fail ("(start + length) <= getNumOperands() && \"invalid operand range specified\"" , "mlir/lib/IR/Operation.cpp", 215, __extension__ __PRETTY_FUNCTION__ )) | |||
215 | "invalid operand range specified")(static_cast <bool> ((start + length) <= getNumOperands () && "invalid operand range specified") ? void (0) : __assert_fail ("(start + length) <= getNumOperands() && \"invalid operand range specified\"" , "mlir/lib/IR/Operation.cpp", 215, __extension__ __PRETTY_FUNCTION__ )); | |||
216 | if (LLVM_LIKELY(hasOperandStorage)__builtin_expect((bool)(hasOperandStorage), true)) | |||
217 | return getOperandStorage().setOperands(this, start, length, operands); | |||
218 | assert(operands.empty() && "setting operands without an operand storage")(static_cast <bool> (operands.empty() && "setting operands without an operand storage" ) ? void (0) : __assert_fail ("operands.empty() && \"setting operands without an operand storage\"" , "mlir/lib/IR/Operation.cpp", 218, __extension__ __PRETTY_FUNCTION__ )); | |||
219 | } | |||
220 | ||||
221 | /// Insert the given operands into the operand list at the given 'index'. | |||
222 | void Operation::insertOperands(unsigned index, ValueRange operands) { | |||
223 | if (LLVM_LIKELY(hasOperandStorage)__builtin_expect((bool)(hasOperandStorage), true)) | |||
224 | return setOperands(index, /*length=*/0, operands); | |||
225 | assert(operands.empty() && "inserting operands without an operand storage")(static_cast <bool> (operands.empty() && "inserting operands without an operand storage" ) ? void (0) : __assert_fail ("operands.empty() && \"inserting operands without an operand storage\"" , "mlir/lib/IR/Operation.cpp", 225, __extension__ __PRETTY_FUNCTION__ )); | |||
226 | } | |||
227 | ||||
228 | //===----------------------------------------------------------------------===// | |||
229 | // Diagnostics | |||
230 | //===----------------------------------------------------------------------===// | |||
231 | ||||
232 | /// Emit an error about fatal conditions with this operation, reporting up to | |||
233 | /// any diagnostic handlers that may be listening. | |||
234 | InFlightDiagnostic Operation::emitError(const Twine &message) { | |||
235 | InFlightDiagnostic diag = mlir::emitError(getLoc(), message); | |||
236 | if (getContext()->shouldPrintOpOnDiagnostic()) { | |||
237 | diag.attachNote(getLoc()) | |||
238 | .append("see current operation: ") | |||
239 | .appendOp(*this, OpPrintingFlags().printGenericOpForm()); | |||
240 | } | |||
241 | return diag; | |||
242 | } | |||
243 | ||||
244 | /// Emit a warning about this operation, reporting up to any diagnostic | |||
245 | /// handlers that may be listening. | |||
246 | InFlightDiagnostic Operation::emitWarning(const Twine &message) { | |||
247 | InFlightDiagnostic diag = mlir::emitWarning(getLoc(), message); | |||
248 | if (getContext()->shouldPrintOpOnDiagnostic()) | |||
249 | diag.attachNote(getLoc()) << "see current operation: " << *this; | |||
250 | return diag; | |||
251 | } | |||
252 | ||||
253 | /// Emit a remark about this operation, reporting up to any diagnostic | |||
254 | /// handlers that may be listening. | |||
255 | InFlightDiagnostic Operation::emitRemark(const Twine &message) { | |||
256 | InFlightDiagnostic diag = mlir::emitRemark(getLoc(), message); | |||
257 | if (getContext()->shouldPrintOpOnDiagnostic()) | |||
258 | diag.attachNote(getLoc()) << "see current operation: " << *this; | |||
259 | return diag; | |||
260 | } | |||
261 | ||||
262 | //===----------------------------------------------------------------------===// | |||
263 | // Operation Ordering | |||
264 | //===----------------------------------------------------------------------===// | |||
265 | ||||
266 | constexpr unsigned Operation::kInvalidOrderIdx; | |||
267 | constexpr unsigned Operation::kOrderStride; | |||
268 | ||||
269 | /// Given an operation 'other' that is within the same parent block, return | |||
270 | /// whether the current operation is before 'other' in the operation list | |||
271 | /// of the parent block. | |||
272 | /// Note: This function has an average complexity of O(1), but worst case may | |||
273 | /// take O(N) where N is the number of operations within the parent block. | |||
274 | bool Operation::isBeforeInBlock(Operation *other) { | |||
275 | assert(block && "Operations without parent blocks have no order.")(static_cast <bool> (block && "Operations without parent blocks have no order." ) ? void (0) : __assert_fail ("block && \"Operations without parent blocks have no order.\"" , "mlir/lib/IR/Operation.cpp", 275, __extension__ __PRETTY_FUNCTION__ )); | |||
276 | assert(other && other->block == block &&(static_cast <bool> (other && other->block == block && "Expected other operation to have the same parent block." ) ? void (0) : __assert_fail ("other && other->block == block && \"Expected other operation to have the same parent block.\"" , "mlir/lib/IR/Operation.cpp", 277, __extension__ __PRETTY_FUNCTION__ )) | |||
277 | "Expected other operation to have the same parent block.")(static_cast <bool> (other && other->block == block && "Expected other operation to have the same parent block." ) ? void (0) : __assert_fail ("other && other->block == block && \"Expected other operation to have the same parent block.\"" , "mlir/lib/IR/Operation.cpp", 277, __extension__ __PRETTY_FUNCTION__ )); | |||
278 | // If the order of the block is already invalid, directly recompute the | |||
279 | // parent. | |||
280 | if (!block->isOpOrderValid()) { | |||
281 | block->recomputeOpOrder(); | |||
282 | } else { | |||
283 | // Update the order either operation if necessary. | |||
284 | updateOrderIfNecessary(); | |||
285 | other->updateOrderIfNecessary(); | |||
286 | } | |||
287 | ||||
288 | return orderIndex < other->orderIndex; | |||
289 | } | |||
290 | ||||
291 | /// Update the order index of this operation of this operation if necessary, | |||
292 | /// potentially recomputing the order of the parent block. | |||
293 | void Operation::updateOrderIfNecessary() { | |||
294 | assert(block && "expected valid parent")(static_cast <bool> (block && "expected valid parent" ) ? void (0) : __assert_fail ("block && \"expected valid parent\"" , "mlir/lib/IR/Operation.cpp", 294, __extension__ __PRETTY_FUNCTION__ )); | |||
295 | ||||
296 | // If the order is valid for this operation there is nothing to do. | |||
297 | if (hasValidOrder()) | |||
298 | return; | |||
299 | Operation *blockFront = &block->front(); | |||
300 | Operation *blockBack = &block->back(); | |||
301 | ||||
302 | // This method is expected to only be invoked on blocks with more than one | |||
303 | // operation. | |||
304 | assert(blockFront != blockBack && "expected more than one operation")(static_cast <bool> (blockFront != blockBack && "expected more than one operation") ? void (0) : __assert_fail ("blockFront != blockBack && \"expected more than one operation\"" , "mlir/lib/IR/Operation.cpp", 304, __extension__ __PRETTY_FUNCTION__ )); | |||
305 | ||||
306 | // If the operation is at the end of the block. | |||
307 | if (this == blockBack) { | |||
308 | Operation *prevNode = getPrevNode(); | |||
309 | if (!prevNode->hasValidOrder()) | |||
310 | return block->recomputeOpOrder(); | |||
311 | ||||
312 | // Add the stride to the previous operation. | |||
313 | orderIndex = prevNode->orderIndex + kOrderStride; | |||
314 | return; | |||
315 | } | |||
316 | ||||
317 | // If this is the first operation try to use the next operation to compute the | |||
318 | // ordering. | |||
319 | if (this == blockFront) { | |||
320 | Operation *nextNode = getNextNode(); | |||
321 | if (!nextNode->hasValidOrder()) | |||
322 | return block->recomputeOpOrder(); | |||
323 | // There is no order to give this operation. | |||
324 | if (nextNode->orderIndex == 0) | |||
325 | return block->recomputeOpOrder(); | |||
326 | ||||
327 | // If we can't use the stride, just take the middle value left. This is safe | |||
328 | // because we know there is at least one valid index to assign to. | |||
329 | if (nextNode->orderIndex <= kOrderStride) | |||
330 | orderIndex = (nextNode->orderIndex / 2); | |||
331 | else | |||
332 | orderIndex = kOrderStride; | |||
333 | return; | |||
334 | } | |||
335 | ||||
336 | // Otherwise, this operation is between two others. Place this operation in | |||
337 | // the middle of the previous and next if possible. | |||
338 | Operation *prevNode = getPrevNode(), *nextNode = getNextNode(); | |||
339 | if (!prevNode->hasValidOrder() || !nextNode->hasValidOrder()) | |||
340 | return block->recomputeOpOrder(); | |||
341 | unsigned prevOrder = prevNode->orderIndex, nextOrder = nextNode->orderIndex; | |||
342 | ||||
343 | // Check to see if there is a valid order between the two. | |||
344 | if (prevOrder + 1 == nextOrder) | |||
345 | return block->recomputeOpOrder(); | |||
346 | orderIndex = prevOrder + ((nextOrder - prevOrder) / 2); | |||
347 | } | |||
348 | ||||
349 | //===----------------------------------------------------------------------===// | |||
350 | // ilist_traits for Operation | |||
351 | //===----------------------------------------------------------------------===// | |||
352 | ||||
353 | auto llvm::ilist_detail::SpecificNodeAccess< | |||
354 | typename llvm::ilist_detail::compute_node_options< | |||
355 | ::mlir::Operation>::type>::getNodePtr(pointer n) -> node_type * { | |||
356 | return NodeAccess::getNodePtr<OptionsT>(n); | |||
357 | } | |||
358 | ||||
359 | auto llvm::ilist_detail::SpecificNodeAccess< | |||
360 | typename llvm::ilist_detail::compute_node_options< | |||
361 | ::mlir::Operation>::type>::getNodePtr(const_pointer n) | |||
362 | -> const node_type * { | |||
363 | return NodeAccess::getNodePtr<OptionsT>(n); | |||
364 | } | |||
365 | ||||
366 | auto llvm::ilist_detail::SpecificNodeAccess< | |||
367 | typename llvm::ilist_detail::compute_node_options< | |||
368 | ::mlir::Operation>::type>::getValuePtr(node_type *n) -> pointer { | |||
369 | return NodeAccess::getValuePtr<OptionsT>(n); | |||
370 | } | |||
371 | ||||
372 | auto llvm::ilist_detail::SpecificNodeAccess< | |||
373 | typename llvm::ilist_detail::compute_node_options< | |||
374 | ::mlir::Operation>::type>::getValuePtr(const node_type *n) | |||
375 | -> const_pointer { | |||
376 | return NodeAccess::getValuePtr<OptionsT>(n); | |||
377 | } | |||
378 | ||||
379 | void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) { | |||
380 | op->destroy(); | |||
381 | } | |||
382 | ||||
383 | Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() { | |||
384 | size_t offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); | |||
385 | iplist<Operation> *anchor(static_cast<iplist<Operation> *>(this)); | |||
386 | return reinterpret_cast<Block *>(reinterpret_cast<char *>(anchor) - offset); | |||
387 | } | |||
388 | ||||
389 | /// This is a trait method invoked when an operation is added to a block. We | |||
390 | /// keep the block pointer up to date. | |||
391 | void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) { | |||
392 | assert(!op->getBlock() && "already in an operation block!")(static_cast <bool> (!op->getBlock() && "already in an operation block!" ) ? void (0) : __assert_fail ("!op->getBlock() && \"already in an operation block!\"" , "mlir/lib/IR/Operation.cpp", 392, __extension__ __PRETTY_FUNCTION__ )); | |||
393 | op->block = getContainingBlock(); | |||
394 | ||||
395 | // Invalidate the order on the operation. | |||
396 | op->orderIndex = Operation::kInvalidOrderIdx; | |||
397 | } | |||
398 | ||||
399 | /// This is a trait method invoked when an operation is removed from a block. | |||
400 | /// We keep the block pointer up to date. | |||
401 | void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) { | |||
402 | assert(op->block && "not already in an operation block!")(static_cast <bool> (op->block && "not already in an operation block!" ) ? void (0) : __assert_fail ("op->block && \"not already in an operation block!\"" , "mlir/lib/IR/Operation.cpp", 402, __extension__ __PRETTY_FUNCTION__ )); | |||
403 | op->block = nullptr; | |||
404 | } | |||
405 | ||||
406 | /// This is a trait method invoked when an operation is moved from one block | |||
407 | /// to another. We keep the block pointer up to date. | |||
408 | void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList( | |||
409 | ilist_traits<Operation> &otherList, op_iterator first, op_iterator last) { | |||
410 | Block *curParent = getContainingBlock(); | |||
411 | ||||
412 | // Invalidate the ordering of the parent block. | |||
413 | curParent->invalidateOpOrder(); | |||
414 | ||||
415 | // If we are transferring operations within the same block, the block | |||
416 | // pointer doesn't need to be updated. | |||
417 | if (curParent == otherList.getContainingBlock()) | |||
418 | return; | |||
419 | ||||
420 | // Update the 'block' member of each operation. | |||
421 | for (; first != last; ++first) | |||
422 | first->block = curParent; | |||
423 | } | |||
424 | ||||
425 | /// Remove this operation (and its descendants) from its Block and delete | |||
426 | /// all of them. | |||
427 | void Operation::erase() { | |||
428 | if (auto *parent = getBlock()) | |||
429 | parent->getOperations().erase(this); | |||
430 | else | |||
431 | destroy(); | |||
432 | } | |||
433 | ||||
434 | /// Remove the operation from its parent block, but don't delete it. | |||
435 | void Operation::remove() { | |||
436 | if (Block *parent = getBlock()) | |||
437 | parent->getOperations().remove(this); | |||
438 | } | |||
439 | ||||
440 | /// Unlink this operation from its current block and insert it right before | |||
441 | /// `existingOp` which may be in the same or another block in the same | |||
442 | /// function. | |||
443 | void Operation::moveBefore(Operation *existingOp) { | |||
444 | moveBefore(existingOp->getBlock(), existingOp->getIterator()); | |||
445 | } | |||
446 | ||||
447 | /// Unlink this operation from its current basic block and insert it right | |||
448 | /// before `iterator` in the specified basic block. | |||
449 | void Operation::moveBefore(Block *block, | |||
450 | llvm::iplist<Operation>::iterator iterator) { | |||
451 | block->getOperations().splice(iterator, getBlock()->getOperations(), | |||
452 | getIterator()); | |||
453 | } | |||
454 | ||||
455 | /// Unlink this operation from its current block and insert it right after | |||
456 | /// `existingOp` which may be in the same or another block in the same function. | |||
457 | void Operation::moveAfter(Operation *existingOp) { | |||
458 | moveAfter(existingOp->getBlock(), existingOp->getIterator()); | |||
459 | } | |||
460 | ||||
461 | /// Unlink this operation from its current block and insert it right after | |||
462 | /// `iterator` in the specified block. | |||
463 | void Operation::moveAfter(Block *block, | |||
464 | llvm::iplist<Operation>::iterator iterator) { | |||
465 | assert(iterator != block->end() && "cannot move after end of block")(static_cast <bool> (iterator != block->end() && "cannot move after end of block") ? void (0) : __assert_fail ("iterator != block->end() && \"cannot move after end of block\"" , "mlir/lib/IR/Operation.cpp", 465, __extension__ __PRETTY_FUNCTION__ )); | |||
466 | moveBefore(block, std::next(iterator)); | |||
467 | } | |||
468 | ||||
469 | /// This drops all operand uses from this operation, which is an essential | |||
470 | /// step in breaking cyclic dependences between references when they are to | |||
471 | /// be deleted. | |||
472 | void Operation::dropAllReferences() { | |||
473 | for (auto &op : getOpOperands()) | |||
474 | op.drop(); | |||
475 | ||||
476 | for (auto ®ion : getRegions()) | |||
477 | region.dropAllReferences(); | |||
478 | ||||
479 | for (auto &dest : getBlockOperands()) | |||
480 | dest.drop(); | |||
481 | } | |||
482 | ||||
483 | /// This drops all uses of any values defined by this operation or its nested | |||
484 | /// regions, wherever they are located. | |||
485 | void Operation::dropAllDefinedValueUses() { | |||
486 | dropAllUses(); | |||
487 | ||||
488 | for (auto ®ion : getRegions()) | |||
489 | for (auto &block : region) | |||
490 | block.dropAllDefinedValueUses(); | |||
491 | } | |||
492 | ||||
493 | void Operation::setSuccessor(Block *block, unsigned index) { | |||
494 | assert(index < getNumSuccessors())(static_cast <bool> (index < getNumSuccessors()) ? void (0) : __assert_fail ("index < getNumSuccessors()", "mlir/lib/IR/Operation.cpp" , 494, __extension__ __PRETTY_FUNCTION__)); | |||
495 | getBlockOperands()[index].set(block); | |||
496 | } | |||
497 | ||||
498 | /// Attempt to fold this operation using the Op's registered foldHook. | |||
499 | LogicalResult Operation::fold(ArrayRef<Attribute> operands, | |||
500 | SmallVectorImpl<OpFoldResult> &results) { | |||
501 | // If we have a registered operation definition matching this one, use it to | |||
502 | // try to constant fold the operation. | |||
503 | if (succeeded(name.foldHook(this, operands, results))) | |||
504 | return success(); | |||
505 | ||||
506 | // Otherwise, fall back on the dialect hook to handle it. | |||
507 | Dialect *dialect = getDialect(); | |||
508 | if (!dialect) | |||
509 | return failure(); | |||
510 | ||||
511 | auto *interface = dyn_cast<DialectFoldInterface>(dialect); | |||
512 | if (!interface) | |||
513 | return failure(); | |||
514 | ||||
515 | return interface->fold(this, operands, results); | |||
516 | } | |||
517 | ||||
518 | /// Emit an error with the op name prefixed, like "'dim' op " which is | |||
519 | /// convenient for verifiers. | |||
520 | InFlightDiagnostic Operation::emitOpError(const Twine &message) { | |||
521 | return emitError() << "'" << getName() << "' op " << message; | |||
522 | } | |||
523 | ||||
524 | //===----------------------------------------------------------------------===// | |||
525 | // Operation Cloning | |||
526 | //===----------------------------------------------------------------------===// | |||
527 | ||||
528 | Operation::CloneOptions::CloneOptions() | |||
529 | : cloneRegionsFlag(false), cloneOperandsFlag(false) {} | |||
530 | ||||
531 | Operation::CloneOptions::CloneOptions(bool cloneRegions, bool cloneOperands) | |||
532 | : cloneRegionsFlag(cloneRegions), cloneOperandsFlag(cloneOperands) {} | |||
533 | ||||
534 | Operation::CloneOptions Operation::CloneOptions::all() { | |||
535 | return CloneOptions().cloneRegions().cloneOperands(); | |||
536 | } | |||
537 | ||||
538 | Operation::CloneOptions &Operation::CloneOptions::cloneRegions(bool enable) { | |||
539 | cloneRegionsFlag = enable; | |||
540 | return *this; | |||
541 | } | |||
542 | ||||
543 | Operation::CloneOptions &Operation::CloneOptions::cloneOperands(bool enable) { | |||
544 | cloneOperandsFlag = enable; | |||
545 | return *this; | |||
546 | } | |||
547 | ||||
548 | /// Create a deep copy of this operation but keep the operation regions empty. | |||
549 | /// Operands are remapped using `mapper` (if present), and `mapper` is updated | |||
550 | /// to contain the results. The `mapResults` flag specifies whether the results | |||
551 | /// of the cloned operation should be added to the map. | |||
552 | Operation *Operation::cloneWithoutRegions(IRMapping &mapper) { | |||
553 | return clone(mapper, CloneOptions::all().cloneRegions(false)); | |||
554 | } | |||
555 | ||||
556 | Operation *Operation::cloneWithoutRegions() { | |||
557 | IRMapping mapper; | |||
558 | return cloneWithoutRegions(mapper); | |||
559 | } | |||
560 | ||||
561 | /// Create a deep copy of this operation, remapping any operands that use | |||
562 | /// values outside of the operation using the map that is provided (leaving | |||
563 | /// them alone if no entry is present). Replaces references to cloned | |||
564 | /// sub-operations to the corresponding operation that is copied, and adds | |||
565 | /// those mappings to the map. | |||
566 | Operation *Operation::clone(IRMapping &mapper, CloneOptions options) { | |||
567 | SmallVector<Value, 8> operands; | |||
568 | SmallVector<Block *, 2> successors; | |||
569 | ||||
570 | // Remap the operands. | |||
571 | if (options.shouldCloneOperands()) { | |||
572 | operands.reserve(getNumOperands()); | |||
573 | for (auto opValue : getOperands()) | |||
574 | operands.push_back(mapper.lookupOrDefault(opValue)); | |||
575 | } | |||
576 | ||||
577 | // Remap the successors. | |||
578 | successors.reserve(getNumSuccessors()); | |||
579 | for (Block *successor : getSuccessors()) | |||
580 | successors.push_back(mapper.lookupOrDefault(successor)); | |||
581 | ||||
582 | // Create the new operation. | |||
583 | auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs, | |||
584 | successors, getNumRegions()); | |||
585 | mapper.map(this, newOp); | |||
586 | ||||
587 | // Clone the regions. | |||
588 | if (options.shouldCloneRegions()) { | |||
589 | for (unsigned i = 0; i != numRegions; ++i) | |||
590 | getRegion(i).cloneInto(&newOp->getRegion(i), mapper); | |||
591 | } | |||
592 | ||||
593 | // Remember the mapping of any results. | |||
594 | for (unsigned i = 0, e = getNumResults(); i != e; ++i) | |||
595 | mapper.map(getResult(i), newOp->getResult(i)); | |||
596 | ||||
597 | return newOp; | |||
598 | } | |||
599 | ||||
600 | Operation *Operation::clone(CloneOptions options) { | |||
601 | IRMapping mapper; | |||
602 | return clone(mapper, options); | |||
| ||||
603 | } | |||
604 | ||||
605 | //===----------------------------------------------------------------------===// | |||
606 | // OpState trait class. | |||
607 | //===----------------------------------------------------------------------===// | |||
608 | ||||
609 | // The fallback for the parser is to try for a dialect operation parser. | |||
610 | // Otherwise, reject the custom assembly form. | |||
611 | ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) { | |||
612 | if (auto parseFn = result.name.getDialect()->getParseOperationHook( | |||
613 | result.name.getStringRef())) | |||
614 | return (*parseFn)(parser, result); | |||
615 | return parser.emitError(parser.getNameLoc(), "has no custom assembly form"); | |||
616 | } | |||
617 | ||||
618 | // The fallback for the printer is to try for a dialect operation printer. | |||
619 | // Otherwise, it prints the generic form. | |||
620 | void OpState::print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect) { | |||
621 | if (auto printFn = op->getDialect()->getOperationPrinter(op)) { | |||
622 | printOpName(op, p, defaultDialect); | |||
623 | printFn(op, p); | |||
624 | } else { | |||
625 | p.printGenericOp(op); | |||
626 | } | |||
627 | } | |||
628 | ||||
629 | /// Print an operation name, eliding the dialect prefix if necessary and doesn't | |||
630 | /// lead to ambiguities. | |||
631 | void OpState::printOpName(Operation *op, OpAsmPrinter &p, | |||
632 | StringRef defaultDialect) { | |||
633 | StringRef name = op->getName().getStringRef(); | |||
634 | if (name.startswith((defaultDialect + ".").str()) && name.count('.') == 1) | |||
635 | name = name.drop_front(defaultDialect.size() + 1); | |||
636 | p.getStream() << name; | |||
637 | } | |||
638 | ||||
639 | /// Emit an error about fatal conditions with this operation, reporting up to | |||
640 | /// any diagnostic handlers that may be listening. | |||
641 | InFlightDiagnostic OpState::emitError(const Twine &message) { | |||
642 | return getOperation()->emitError(message); | |||
643 | } | |||
644 | ||||
645 | /// Emit an error with the op name prefixed, like "'dim' op " which is | |||
646 | /// convenient for verifiers. | |||
647 | InFlightDiagnostic OpState::emitOpError(const Twine &message) { | |||
648 | return getOperation()->emitOpError(message); | |||
649 | } | |||
650 | ||||
651 | /// Emit a warning about this operation, reporting up to any diagnostic | |||
652 | /// handlers that may be listening. | |||
653 | InFlightDiagnostic OpState::emitWarning(const Twine &message) { | |||
654 | return getOperation()->emitWarning(message); | |||
655 | } | |||
656 | ||||
657 | /// Emit a remark about this operation, reporting up to any diagnostic | |||
658 | /// handlers that may be listening. | |||
659 | InFlightDiagnostic OpState::emitRemark(const Twine &message) { | |||
660 | return getOperation()->emitRemark(message); | |||
661 | } | |||
662 | ||||
663 | //===----------------------------------------------------------------------===// | |||
664 | // Op Trait implementations | |||
665 | //===----------------------------------------------------------------------===// | |||
666 | ||||
667 | OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) { | |||
668 | if (op->getNumOperands() == 1) { | |||
669 | auto *argumentOp = op->getOperand(0).getDefiningOp(); | |||
670 | if (argumentOp && op->getName() == argumentOp->getName()) { | |||
671 | // Replace the outer operation output with the inner operation. | |||
672 | return op->getOperand(0); | |||
673 | } | |||
674 | } else if (op->getOperand(0) == op->getOperand(1)) { | |||
675 | return op->getOperand(0); | |||
676 | } | |||
677 | ||||
678 | return {}; | |||
679 | } | |||
680 | ||||
681 | OpFoldResult OpTrait::impl::foldInvolution(Operation *op) { | |||
682 | auto *argumentOp = op->getOperand(0).getDefiningOp(); | |||
683 | if (argumentOp && op->getName() == argumentOp->getName()) { | |||
684 | // Replace the outer involutions output with inner's input. | |||
685 | return argumentOp->getOperand(0); | |||
686 | } | |||
687 | ||||
688 | return {}; | |||
689 | } | |||
690 | ||||
691 | LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) { | |||
692 | if (op->getNumOperands() != 0) | |||
693 | return op->emitOpError() << "requires zero operands"; | |||
694 | return success(); | |||
695 | } | |||
696 | ||||
697 | LogicalResult OpTrait::impl::verifyOneOperand(Operation *op) { | |||
698 | if (op->getNumOperands() != 1) | |||
699 | return op->emitOpError() << "requires a single operand"; | |||
700 | return success(); | |||
701 | } | |||
702 | ||||
703 | LogicalResult OpTrait::impl::verifyNOperands(Operation *op, | |||
704 | unsigned numOperands) { | |||
705 | if (op->getNumOperands() != numOperands) { | |||
706 | return op->emitOpError() << "expected " << numOperands | |||
707 | << " operands, but found " << op->getNumOperands(); | |||
708 | } | |||
709 | return success(); | |||
710 | } | |||
711 | ||||
712 | LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op, | |||
713 | unsigned numOperands) { | |||
714 | if (op->getNumOperands() < numOperands) | |||
715 | return op->emitOpError() | |||
716 | << "expected " << numOperands << " or more operands, but found " | |||
717 | << op->getNumOperands(); | |||
718 | return success(); | |||
719 | } | |||
720 | ||||
721 | /// If this is a vector type, or a tensor type, return the scalar element type | |||
722 | /// that it is built around, otherwise return the type unmodified. | |||
723 | static Type getTensorOrVectorElementType(Type type) { | |||
724 | if (auto vec = type.dyn_cast<VectorType>()) | |||
725 | return vec.getElementType(); | |||
726 | ||||
727 | // Look through tensor<vector<...>> to find the underlying element type. | |||
728 | if (auto tensor = type.dyn_cast<TensorType>()) | |||
729 | return getTensorOrVectorElementType(tensor.getElementType()); | |||
730 | return type; | |||
731 | } | |||
732 | ||||
733 | LogicalResult OpTrait::impl::verifyIsIdempotent(Operation *op) { | |||
734 | // FIXME: Add back check for no side effects on operation. | |||
735 | // Currently adding it would cause the shared library build | |||
736 | // to fail since there would be a dependency of IR on SideEffectInterfaces | |||
737 | // which is cyclical. | |||
738 | return success(); | |||
739 | } | |||
740 | ||||
741 | LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) { | |||
742 | // FIXME: Add back check for no side effects on operation. | |||
743 | // Currently adding it would cause the shared library build | |||
744 | // to fail since there would be a dependency of IR on SideEffectInterfaces | |||
745 | // which is cyclical. | |||
746 | return success(); | |||
747 | } | |||
748 | ||||
749 | LogicalResult | |||
750 | OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) { | |||
751 | for (auto opType : op->getOperandTypes()) { | |||
752 | auto type = getTensorOrVectorElementType(opType); | |||
753 | if (!type.isSignlessIntOrIndex()) | |||
754 | return op->emitOpError() << "requires an integer or index type"; | |||
755 | } | |||
756 | return success(); | |||
757 | } | |||
758 | ||||
759 | LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) { | |||
760 | for (auto opType : op->getOperandTypes()) { | |||
761 | auto type = getTensorOrVectorElementType(opType); | |||
762 | if (!type.isa<FloatType>()) | |||
763 | return op->emitOpError("requires a float type"); | |||
764 | } | |||
765 | return success(); | |||
766 | } | |||
767 | ||||
768 | LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) { | |||
769 | // Zero or one operand always have the "same" type. | |||
770 | unsigned nOperands = op->getNumOperands(); | |||
771 | if (nOperands < 2) | |||
772 | return success(); | |||
773 | ||||
774 | auto type = op->getOperand(0).getType(); | |||
775 | for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) | |||
776 | if (opType != type) | |||
777 | return op->emitOpError() << "requires all operands to have the same type"; | |||
778 | return success(); | |||
779 | } | |||
780 | ||||
781 | LogicalResult OpTrait::impl::verifyZeroRegions(Operation *op) { | |||
782 | if (op->getNumRegions() != 0) | |||
783 | return op->emitOpError() << "requires zero regions"; | |||
784 | return success(); | |||
785 | } | |||
786 | ||||
787 | LogicalResult OpTrait::impl::verifyOneRegion(Operation *op) { | |||
788 | if (op->getNumRegions() != 1) | |||
789 | return op->emitOpError() << "requires one region"; | |||
790 | return success(); | |||
791 | } | |||
792 | ||||
793 | LogicalResult OpTrait::impl::verifyNRegions(Operation *op, | |||
794 | unsigned numRegions) { | |||
795 | if (op->getNumRegions() != numRegions) | |||
796 | return op->emitOpError() << "expected " << numRegions << " regions"; | |||
797 | return success(); | |||
798 | } | |||
799 | ||||
800 | LogicalResult OpTrait::impl::verifyAtLeastNRegions(Operation *op, | |||
801 | unsigned numRegions) { | |||
802 | if (op->getNumRegions() < numRegions) | |||
803 | return op->emitOpError() << "expected " << numRegions << " or more regions"; | |||
804 | return success(); | |||
805 | } | |||
806 | ||||
807 | LogicalResult OpTrait::impl::verifyZeroResults(Operation *op) { | |||
808 | if (op->getNumResults() != 0) | |||
809 | return op->emitOpError() << "requires zero results"; | |||
810 | return success(); | |||
811 | } | |||
812 | ||||
813 | LogicalResult OpTrait::impl::verifyOneResult(Operation *op) { | |||
814 | if (op->getNumResults() != 1) | |||
815 | return op->emitOpError() << "requires one result"; | |||
816 | return success(); | |||
817 | } | |||
818 | ||||
819 | LogicalResult OpTrait::impl::verifyNResults(Operation *op, | |||
820 | unsigned numOperands) { | |||
821 | if (op->getNumResults() != numOperands) | |||
822 | return op->emitOpError() << "expected " << numOperands << " results"; | |||
823 | return success(); | |||
824 | } | |||
825 | ||||
826 | LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op, | |||
827 | unsigned numOperands) { | |||
828 | if (op->getNumResults() < numOperands) | |||
829 | return op->emitOpError() | |||
830 | << "expected " << numOperands << " or more results"; | |||
831 | return success(); | |||
832 | } | |||
833 | ||||
834 | LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) { | |||
835 | if (failed(verifyAtLeastNOperands(op, 1))) | |||
836 | return failure(); | |||
837 | ||||
838 | if (failed(verifyCompatibleShapes(op->getOperandTypes()))) | |||
839 | return op->emitOpError() << "requires the same shape for all operands"; | |||
840 | ||||
841 | return success(); | |||
842 | } | |||
843 | ||||
844 | LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { | |||
845 | if (failed(verifyAtLeastNOperands(op, 1)) || | |||
846 | failed(verifyAtLeastNResults(op, 1))) | |||
847 | return failure(); | |||
848 | ||||
849 | SmallVector<Type, 8> types(op->getOperandTypes()); | |||
850 | types.append(llvm::to_vector<4>(op->getResultTypes())); | |||
851 | ||||
852 | if (failed(verifyCompatibleShapes(types))) | |||
853 | return op->emitOpError() | |||
854 | << "requires the same shape for all operands and results"; | |||
855 | ||||
856 | return success(); | |||
857 | } | |||
858 | ||||
859 | LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) { | |||
860 | if (failed(verifyAtLeastNOperands(op, 1))) | |||
861 | return failure(); | |||
862 | auto elementType = getElementTypeOrSelf(op->getOperand(0)); | |||
863 | ||||
864 | for (auto operand : llvm::drop_begin(op->getOperands(), 1)) { | |||
865 | if (getElementTypeOrSelf(operand) != elementType) | |||
866 | return op->emitOpError("requires the same element type for all operands"); | |||
867 | } | |||
868 | ||||
869 | return success(); | |||
870 | } | |||
871 | ||||
872 | LogicalResult | |||
873 | OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { | |||
874 | if (failed(verifyAtLeastNOperands(op, 1)) || | |||
875 | failed(verifyAtLeastNResults(op, 1))) | |||
876 | return failure(); | |||
877 | ||||
878 | auto elementType = getElementTypeOrSelf(op->getResult(0)); | |||
879 | ||||
880 | // Verify result element type matches first result's element type. | |||
881 | for (auto result : llvm::drop_begin(op->getResults(), 1)) { | |||
882 | if (getElementTypeOrSelf(result) != elementType) | |||
883 | return op->emitOpError( | |||
884 | "requires the same element type for all operands and results"); | |||
885 | } | |||
886 | ||||
887 | // Verify operand's element type matches first result's element type. | |||
888 | for (auto operand : op->getOperands()) { | |||
889 | if (getElementTypeOrSelf(operand) != elementType) | |||
890 | return op->emitOpError( | |||
891 | "requires the same element type for all operands and results"); | |||
892 | } | |||
893 | ||||
894 | return success(); | |||
895 | } | |||
896 | ||||
897 | LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { | |||
898 | if (failed(verifyAtLeastNOperands(op, 1)) || | |||
899 | failed(verifyAtLeastNResults(op, 1))) | |||
900 | return failure(); | |||
901 | ||||
902 | auto type = op->getResult(0).getType(); | |||
903 | auto elementType = getElementTypeOrSelf(type); | |||
904 | Attribute encoding = nullptr; | |||
905 | if (auto rankedType = dyn_cast<RankedTensorType>(type)) | |||
906 | encoding = rankedType.getEncoding(); | |||
907 | for (auto resultType : llvm::drop_begin(op->getResultTypes())) { | |||
908 | if (getElementTypeOrSelf(resultType) != elementType || | |||
909 | failed(verifyCompatibleShape(resultType, type))) | |||
910 | return op->emitOpError() | |||
911 | << "requires the same type for all operands and results"; | |||
912 | if (encoding) | |||
913 | if (auto rankedType = dyn_cast<RankedTensorType>(resultType); | |||
914 | encoding != rankedType.getEncoding()) | |||
915 | return op->emitOpError() | |||
916 | << "requires the same encoding for all operands and results"; | |||
917 | } | |||
918 | for (auto opType : op->getOperandTypes()) { | |||
919 | if (getElementTypeOrSelf(opType) != elementType || | |||
920 | failed(verifyCompatibleShape(opType, type))) | |||
921 | return op->emitOpError() | |||
922 | << "requires the same type for all operands and results"; | |||
923 | if (encoding) | |||
924 | if (auto rankedType = dyn_cast<RankedTensorType>(opType); | |||
925 | encoding != rankedType.getEncoding()) | |||
926 | return op->emitOpError() | |||
927 | << "requires the same encoding for all operands and results"; | |||
928 | } | |||
929 | return success(); | |||
930 | } | |||
931 | ||||
932 | LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { | |||
933 | Block *block = op->getBlock(); | |||
934 | // Verify that the operation is at the end of the respective parent block. | |||
935 | if (!block || &block->back() != op) | |||
936 | return op->emitOpError("must be the last operation in the parent block"); | |||
937 | return success(); | |||
938 | } | |||
939 | ||||
940 | static LogicalResult verifyTerminatorSuccessors(Operation *op) { | |||
941 | auto *parent = op->getParentRegion(); | |||
942 | ||||
943 | // Verify that the operands lines up with the BB arguments in the successor. | |||
944 | for (Block *succ : op->getSuccessors()) | |||
945 | if (succ->getParent() != parent) | |||
946 | return op->emitError("reference to block defined in another region"); | |||
947 | return success(); | |||
948 | } | |||
949 | ||||
950 | LogicalResult OpTrait::impl::verifyZeroSuccessors(Operation *op) { | |||
951 | if (op->getNumSuccessors() != 0) { | |||
952 | return op->emitOpError("requires 0 successors but found ") | |||
953 | << op->getNumSuccessors(); | |||
954 | } | |||
955 | return success(); | |||
956 | } | |||
957 | ||||
958 | LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) { | |||
959 | if (op->getNumSuccessors() != 1) { | |||
960 | return op->emitOpError("requires 1 successor but found ") | |||
961 | << op->getNumSuccessors(); | |||
962 | } | |||
963 | return verifyTerminatorSuccessors(op); | |||
964 | } | |||
965 | LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op, | |||
966 | unsigned numSuccessors) { | |||
967 | if (op->getNumSuccessors() != numSuccessors) { | |||
968 | return op->emitOpError("requires ") | |||
969 | << numSuccessors << " successors but found " | |||
970 | << op->getNumSuccessors(); | |||
971 | } | |||
972 | return verifyTerminatorSuccessors(op); | |||
973 | } | |||
974 | LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op, | |||
975 | unsigned numSuccessors) { | |||
976 | if (op->getNumSuccessors() < numSuccessors) { | |||
977 | return op->emitOpError("requires at least ") | |||
978 | << numSuccessors << " successors but found " | |||
979 | << op->getNumSuccessors(); | |||
980 | } | |||
981 | return verifyTerminatorSuccessors(op); | |||
982 | } | |||
983 | ||||
984 | LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) { | |||
985 | for (auto resultType : op->getResultTypes()) { | |||
986 | auto elementType = getTensorOrVectorElementType(resultType); | |||
987 | bool isBoolType = elementType.isInteger(1); | |||
988 | if (!isBoolType) | |||
989 | return op->emitOpError() << "requires a bool result type"; | |||
990 | } | |||
991 | ||||
992 | return success(); | |||
993 | } | |||
994 | ||||
995 | LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) { | |||
996 | for (auto resultType : op->getResultTypes()) | |||
997 | if (!getTensorOrVectorElementType(resultType).isa<FloatType>()) | |||
998 | return op->emitOpError() << "requires a floating point type"; | |||
999 | ||||
1000 | return success(); | |||
1001 | } | |||
1002 | ||||
1003 | LogicalResult | |||
1004 | OpTrait::impl::verifyResultsAreSignlessIntegerLike(Operation *op) { | |||
1005 | for (auto resultType : op->getResultTypes()) | |||
1006 | if (!getTensorOrVectorElementType(resultType).isSignlessIntOrIndex()) | |||
1007 | return op->emitOpError() << "requires an integer or index type"; | |||
1008 | return success(); | |||
1009 | } | |||
1010 | ||||
1011 | LogicalResult OpTrait::impl::verifyValueSizeAttr(Operation *op, | |||
1012 | StringRef attrName, | |||
1013 | StringRef valueGroupName, | |||
1014 | size_t expectedCount) { | |||
1015 | auto sizeAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrName); | |||
1016 | if (!sizeAttr) | |||
1017 | return op->emitOpError("requires dense i32 array attribute '") | |||
1018 | << attrName << "'"; | |||
1019 | ||||
1020 | ArrayRef<int32_t> sizes = sizeAttr.asArrayRef(); | |||
1021 | if (llvm::any_of(sizes, [](int32_t element) { return element < 0; })) | |||
1022 | return op->emitOpError("'") | |||
1023 | << attrName << "' attribute cannot have negative elements"; | |||
1024 | ||||
1025 | size_t totalCount = | |||
1026 | std::accumulate(sizes.begin(), sizes.end(), 0, | |||
1027 | [](unsigned all, int32_t one) { return all + one; }); | |||
1028 | ||||
1029 | if (totalCount != expectedCount) | |||
1030 | return op->emitOpError() | |||
1031 | << valueGroupName << " count (" << expectedCount | |||
1032 | << ") does not match with the total size (" << totalCount | |||
1033 | << ") specified in attribute '" << attrName << "'"; | |||
1034 | return success(); | |||
1035 | } | |||
1036 | ||||
1037 | LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op, | |||
1038 | StringRef attrName) { | |||
1039 | return verifyValueSizeAttr(op, attrName, "operand", op->getNumOperands()); | |||
1040 | } | |||
1041 | ||||
1042 | LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op, | |||
1043 | StringRef attrName) { | |||
1044 | return verifyValueSizeAttr(op, attrName, "result", op->getNumResults()); | |||
1045 | } | |||
1046 | ||||
1047 | LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) { | |||
1048 | for (Region ®ion : op->getRegions()) { | |||
1049 | if (region.empty()) | |||
1050 | continue; | |||
1051 | ||||
1052 | if (region.getNumArguments() != 0) { | |||
1053 | if (op->getNumRegions() > 1) | |||
1054 | return op->emitOpError("region #") | |||
1055 | << region.getRegionNumber() << " should have no arguments"; | |||
1056 | return op->emitOpError("region should have no arguments"); | |||
1057 | } | |||
1058 | } | |||
1059 | return success(); | |||
1060 | } | |||
1061 | ||||
1062 | LogicalResult OpTrait::impl::verifyElementwise(Operation *op) { | |||
1063 | auto isMappableType = [](Type type) { | |||
1064 | return type.isa<VectorType, TensorType>(); | |||
1065 | }; | |||
1066 | auto resultMappableTypes = llvm::to_vector<1>( | |||
1067 | llvm::make_filter_range(op->getResultTypes(), isMappableType)); | |||
1068 | auto operandMappableTypes = llvm::to_vector<2>( | |||
1069 | llvm::make_filter_range(op->getOperandTypes(), isMappableType)); | |||
1070 | ||||
1071 | // If the op only has scalar operand/result types, then we have nothing to | |||
1072 | // check. | |||
1073 | if (resultMappableTypes.empty() && operandMappableTypes.empty()) | |||
1074 | return success(); | |||
1075 | ||||
1076 | if (!resultMappableTypes.empty() && operandMappableTypes.empty()) | |||
1077 | return op->emitOpError("if a result is non-scalar, then at least one " | |||
1078 | "operand must be non-scalar"); | |||
1079 | ||||
1080 | assert(!operandMappableTypes.empty())(static_cast <bool> (!operandMappableTypes.empty()) ? void (0) : __assert_fail ("!operandMappableTypes.empty()", "mlir/lib/IR/Operation.cpp" , 1080, __extension__ __PRETTY_FUNCTION__)); | |||
1081 | ||||
1082 | if (resultMappableTypes.empty()) | |||
1083 | return op->emitOpError("if an operand is non-scalar, then there must be at " | |||
1084 | "least one non-scalar result"); | |||
1085 | ||||
1086 | if (resultMappableTypes.size() != op->getNumResults()) | |||
1087 | return op->emitOpError( | |||
1088 | "if an operand is non-scalar, then all results must be non-scalar"); | |||
1089 | ||||
1090 | SmallVector<Type, 4> types = llvm::to_vector<2>( | |||
1091 | llvm::concat<Type>(operandMappableTypes, resultMappableTypes)); | |||
1092 | TypeID expectedBaseTy = types.front().getTypeID(); | |||
1093 | if (!llvm::all_of(types, | |||
1094 | [&](Type t) { return t.getTypeID() == expectedBaseTy; }) || | |||
1095 | failed(verifyCompatibleShapes(types))) { | |||
1096 | return op->emitOpError() << "all non-scalar operands/results must have the " | |||
1097 | "same shape and base type"; | |||
1098 | } | |||
1099 | ||||
1100 | return success(); | |||
1101 | } | |||
1102 | ||||
1103 | /// Check for any values used by operations regions attached to the | |||
1104 | /// specified "IsIsolatedFromAbove" operation defined outside of it. | |||
1105 | LogicalResult OpTrait::impl::verifyIsIsolatedFromAbove(Operation *isolatedOp) { | |||
1106 | assert(isolatedOp->hasTrait<OpTrait::IsIsolatedFromAbove>() &&(static_cast <bool> (isolatedOp->hasTrait<OpTrait ::IsIsolatedFromAbove>() && "Intended to check IsolatedFromAbove ops" ) ? void (0) : __assert_fail ("isolatedOp->hasTrait<OpTrait::IsIsolatedFromAbove>() && \"Intended to check IsolatedFromAbove ops\"" , "mlir/lib/IR/Operation.cpp", 1107, __extension__ __PRETTY_FUNCTION__ )) | |||
1107 | "Intended to check IsolatedFromAbove ops")(static_cast <bool> (isolatedOp->hasTrait<OpTrait ::IsIsolatedFromAbove>() && "Intended to check IsolatedFromAbove ops" ) ? void (0) : __assert_fail ("isolatedOp->hasTrait<OpTrait::IsIsolatedFromAbove>() && \"Intended to check IsolatedFromAbove ops\"" , "mlir/lib/IR/Operation.cpp", 1107, __extension__ __PRETTY_FUNCTION__ )); | |||
1108 | ||||
1109 | // List of regions to analyze. Each region is processed independently, with | |||
1110 | // respect to the common `limit` region, so we can look at them in any order. | |||
1111 | // Therefore, use a simple vector and push/pop back the current region. | |||
1112 | SmallVector<Region *, 8> pendingRegions; | |||
1113 | for (auto ®ion : isolatedOp->getRegions()) { | |||
1114 | pendingRegions.push_back(®ion); | |||
1115 | ||||
1116 | // Traverse all operations in the region. | |||
1117 | while (!pendingRegions.empty()) { | |||
1118 | for (Operation &op : pendingRegions.pop_back_val()->getOps()) { | |||
1119 | for (Value operand : op.getOperands()) { | |||
1120 | // Check that any value that is used by an operation is defined in the | |||
1121 | // same region as either an operation result. | |||
1122 | auto *operandRegion = operand.getParentRegion(); | |||
1123 | if (!operandRegion) | |||
1124 | return op.emitError("operation's operand is unlinked"); | |||
1125 | if (!region.isAncestor(operandRegion)) { | |||
1126 | return op.emitOpError("using value defined outside the region") | |||
1127 | .attachNote(isolatedOp->getLoc()) | |||
1128 | << "required by region isolation constraints"; | |||
1129 | } | |||
1130 | } | |||
1131 | ||||
1132 | // Schedule any regions in the operation for further checking. Don't | |||
1133 | // recurse into other IsolatedFromAbove ops, because they will check | |||
1134 | // themselves. | |||
1135 | if (op.getNumRegions() && | |||
1136 | !op.hasTrait<OpTrait::IsIsolatedFromAbove>()) { | |||
1137 | for (Region &subRegion : op.getRegions()) | |||
1138 | pendingRegions.push_back(&subRegion); | |||
1139 | } | |||
1140 | } | |||
1141 | } | |||
1142 | } | |||
1143 | ||||
1144 | return success(); | |||
1145 | } | |||
1146 | ||||
1147 | bool OpTrait::hasElementwiseMappableTraits(Operation *op) { | |||
1148 | return op->hasTrait<Elementwise>() && op->hasTrait<Scalarizable>() && | |||
1149 | op->hasTrait<Vectorizable>() && op->hasTrait<Tensorizable>(); | |||
1150 | } | |||
1151 | ||||
1152 | //===----------------------------------------------------------------------===// | |||
1153 | // CastOpInterface | |||
1154 | //===----------------------------------------------------------------------===// | |||
1155 | ||||
1156 | /// Attempt to fold the given cast operation. | |||
1157 | LogicalResult | |||
1158 | impl::foldCastInterfaceOp(Operation *op, ArrayRef<Attribute> attrOperands, | |||
1159 | SmallVectorImpl<OpFoldResult> &foldResults) { | |||
1160 | OperandRange operands = op->getOperands(); | |||
1161 | if (operands.empty()) | |||
1162 | return failure(); | |||
1163 | ResultRange results = op->getResults(); | |||
1164 | ||||
1165 | // Check for the case where the input and output types match 1-1. | |||
1166 | if (operands.getTypes() == results.getTypes()) { | |||
1167 | foldResults.append(operands.begin(), operands.end()); | |||
1168 | return success(); | |||
1169 | } | |||
1170 | ||||
1171 | return failure(); | |||
1172 | } | |||
1173 | ||||
1174 | /// Attempt to verify the given cast operation. | |||
1175 | LogicalResult impl::verifyCastInterfaceOp( | |||
1176 | Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible) { | |||
1177 | auto resultTypes = op->getResultTypes(); | |||
1178 | if (resultTypes.empty()) | |||
1179 | return op->emitOpError() | |||
1180 | << "expected at least one result for cast operation"; | |||
1181 | ||||
1182 | auto operandTypes = op->getOperandTypes(); | |||
1183 | if (!areCastCompatible(operandTypes, resultTypes)) { | |||
1184 | InFlightDiagnostic diag = op->emitOpError("operand type"); | |||
1185 | if (operandTypes.empty()) | |||
1186 | diag << "s []"; | |||
1187 | else if (llvm::size(operandTypes) == 1) | |||
1188 | diag << " " << *operandTypes.begin(); | |||
1189 | else | |||
1190 | diag << "s " << operandTypes; | |||
1191 | return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ") | |||
1192 | << resultTypes << " are cast incompatible"; | |||
1193 | } | |||
1194 | ||||
1195 | return success(); | |||
1196 | } | |||
1197 | ||||
1198 | //===----------------------------------------------------------------------===// | |||
1199 | // Misc. utils | |||
1200 | //===----------------------------------------------------------------------===// | |||
1201 | ||||
1202 | /// Insert an operation, generated by `buildTerminatorOp`, at the end of the | |||
1203 | /// region's only block if it does not have a terminator already. If the region | |||
1204 | /// is empty, insert a new block first. `buildTerminatorOp` should return the | |||
1205 | /// terminator operation to insert. | |||
1206 | void impl::ensureRegionTerminator( | |||
1207 | Region ®ion, OpBuilder &builder, Location loc, | |||
1208 | function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { | |||
1209 | OpBuilder::InsertionGuard guard(builder); | |||
1210 | if (region.empty()) | |||
1211 | builder.createBlock(®ion); | |||
1212 | ||||
1213 | Block &block = region.back(); | |||
1214 | if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>()) | |||
1215 | return; | |||
1216 | ||||
1217 | builder.setInsertionPointToEnd(&block); | |||
1218 | builder.insert(buildTerminatorOp(builder, loc)); | |||
1219 | } | |||
1220 | ||||
1221 | /// Create a simple OpBuilder and forward to the OpBuilder version of this | |||
1222 | /// function. | |||
1223 | void impl::ensureRegionTerminator( | |||
1224 | Region ®ion, Builder &builder, Location loc, | |||
1225 | function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp) { | |||
1226 | OpBuilder opBuilder(builder.getContext()); | |||
1227 | ensureRegionTerminator(region, opBuilder, loc, buildTerminatorOp); | |||
1228 | } |