Bug Summary

File:build/source/llvm/../mlir/include/mlir/IR/Builders.h
Warning:line 490, column 5
4th function call argument is an uninitialized value

Annotated Source Code

Press '?' to see keyboard shortcuts

clang -cc1 -cc1 -triple x86_64-pc-linux-gnu -analyze -disable-free -clear-ast-before-backend -disable-llvm-verifier -discard-value-names -main-file-name CodeGen.cpp -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -setup-static-analyzer -analyzer-config-compatibility-mode=true -mrelocation-model pic -pic-level 2 -mframe-pointer=none -relaxed-aliasing -fmath-errno -ffp-contract=on -fno-rounding-math -mconstructor-aliases -funwind-tables=2 -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -ffunction-sections -fdata-sections -fcoverage-compilation-dir=/build/source/build-llvm/tools/clang/stage2-bins -resource-dir /usr/lib/llvm-17/lib/clang/17 -isystem /build/source/llvm/../mlir/include -isystem tools/mlir/include -isystem tools/clang/include -isystem /build/source/llvm/../clang/include -D FLANG_INCLUDE_TESTS=1 -D FLANG_LITTLE_ENDIAN=1 -D FLANG_VENDOR="Debian " -D _DEBUG -D _GLIBCXX_ASSERTIONS -D _GNU_SOURCE -D _LIBCPP_ENABLE_ASSERTIONS -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -D __STDC_LIMIT_MACROS -I tools/flang/lib/Optimizer/CodeGen -I /build/source/flang/lib/Optimizer/CodeGen -I /build/source/flang/include -I tools/flang/include -I include -I /build/source/llvm/include -D _FORTIFY_SOURCE=2 -D NDEBUG -U NDEBUG -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/x86_64-linux-gnu/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/backward -internal-isystem /usr/lib/llvm-17/lib/clang/17/include -internal-isystem /usr/local/include -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../x86_64-linux-gnu/include -internal-externc-isystem /usr/include/x86_64-linux-gnu -internal-externc-isystem /include -internal-externc-isystem /usr/include -fmacro-prefix-map=/build/source/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fmacro-prefix-map=/build/source/= -fcoverage-prefix-map=/build/source/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fcoverage-prefix-map=/build/source/= -source-date-epoch 1683717183 -O2 -Wno-unused-command-line-argument -Wno-unused-parameter -Wwrite-strings -Wno-missing-field-initializers -Wno-long-long -Wno-maybe-uninitialized -Wno-class-memaccess -Wno-redundant-move -Wno-pessimizing-move -Wno-noexcept-type -Wno-comment -Wno-misleading-indentation -Wno-deprecated-copy -Wno-ctad-maybe-unsupported -std=c++17 -fdeprecated-macro -fdebug-compilation-dir=/build/source/build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/source/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/source/= -ferror-limit 19 -fvisibility-inlines-hidden -stack-protector 2 -fgnuc-version=4.2.1 -fcolor-diagnostics -vectorize-loops -vectorize-slp -analyzer-output=html -analyzer-config stable-report-filename=true -faddrsig -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /tmp/scan-build-2023-05-10-133810-16478-1 -x c++ /build/source/flang/lib/Optimizer/CodeGen/CodeGen.cpp

/build/source/flang/lib/Optimizer/CodeGen/CodeGen.cpp

1//===-- CodeGen.cpp -- bridge to lower to LLVM ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12
13#include "flang/Optimizer/CodeGen/CodeGen.h"
14
15#include "CGOps.h"
16#include "flang/ISO_Fortran_binding.h"
17#include "flang/Optimizer/Dialect/FIRAttr.h"
18#include "flang/Optimizer/Dialect/FIROps.h"
19#include "flang/Optimizer/Dialect/FIRType.h"
20#include "flang/Optimizer/Support/InternalNames.h"
21#include "flang/Optimizer/Support/TypeCode.h"
22#include "flang/Optimizer/Support/Utils.h"
23#include "flang/Semantics/runtime-type-info.h"
24#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
25#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
26#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
27#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
28#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
29#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
30#include "mlir/Conversion/LLVMCommon/Pattern.h"
31#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
32#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
33#include "mlir/Conversion/MathToLibm/MathToLibm.h"
34#include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h"
35#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
36#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
37#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
38#include "mlir/Dialect/OpenACC/OpenACC.h"
39#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
40#include "mlir/IR/BuiltinTypes.h"
41#include "mlir/IR/Matchers.h"
42#include "mlir/Pass/Pass.h"
43#include "mlir/Pass/PassManager.h"
44#include "mlir/Target/LLVMIR/ModuleTranslation.h"
45#include "llvm/ADT/ArrayRef.h"
46#include "llvm/ADT/TypeSwitch.h"
47
48namespace fir {
49#define GEN_PASS_DEF_FIRTOLLVMLOWERING
50#include "flang/Optimizer/CodeGen/CGPasses.h.inc"
51} // namespace fir
52
53#define DEBUG_TYPE"flang-codegen" "flang-codegen"
54
55// fir::LLVMTypeConverter for converting to LLVM IR dialect types.
56#include "flang/Optimizer/CodeGen/TypeConverter.h"
57
58// TODO: This should really be recovered from the specified target.
59static constexpr unsigned defaultAlign = 8;
60
61/// `fir.box` attribute values as defined for CFI_attribute_t in
62/// flang/ISO_Fortran_binding.h.
63static constexpr unsigned kAttrPointer = CFI_attribute_pointer1;
64static constexpr unsigned kAttrAllocatable = CFI_attribute_allocatable2;
65
66static inline mlir::Type getVoidPtrType(mlir::MLIRContext *context) {
67 return mlir::LLVM::LLVMPointerType::get(mlir::IntegerType::get(context, 8));
68}
69
70static mlir::LLVM::ConstantOp
71genConstantIndex(mlir::Location loc, mlir::Type ity,
72 mlir::ConversionPatternRewriter &rewriter,
73 std::int64_t offset) {
74 auto cattr = rewriter.getI64IntegerAttr(offset);
75 return rewriter.create<mlir::LLVM::ConstantOp>(loc, ity, cattr);
76}
77
78static mlir::Block *createBlock(mlir::ConversionPatternRewriter &rewriter,
79 mlir::Block *insertBefore) {
80 assert(insertBefore && "expected valid insertion block")(static_cast <bool> (insertBefore && "expected valid insertion block"
) ? void (0) : __assert_fail ("insertBefore && \"expected valid insertion block\""
, "flang/lib/Optimizer/CodeGen/CodeGen.cpp", 80, __extension__
__PRETTY_FUNCTION__))
;
81 return rewriter.createBlock(insertBefore->getParent(),
82 mlir::Region::iterator(insertBefore));
83}
84
85/// Extract constant from a value if it is a result of one of the
86/// ConstantOp operations, otherwise, return std::nullopt.
87static std::optional<int64_t> getIfConstantIntValue(mlir::Value val) {
88 if (!val || !val.dyn_cast<mlir::OpResult>())
89 return {};
90
91 mlir::Operation *defop = val.getDefiningOp();
92
93 if (auto constOp = mlir::dyn_cast<mlir::arith::ConstantIntOp>(defop))
94 return constOp.value();
95 if (auto llConstOp = mlir::dyn_cast<mlir::LLVM::ConstantOp>(defop))
96 if (auto attr = llConstOp.getValue().dyn_cast<mlir::IntegerAttr>())
97 return attr.getValue().getSExtValue();
98
99 return {};
100}
101
102/// Extract constant from a value that must be the result of one of the
103/// ConstantOp operations.
104static int64_t getConstantIntValue(mlir::Value val) {
105 if (auto constVal = getIfConstantIntValue(val))
106 return *constVal;
107 fir::emitFatalError(val.getLoc(), "must be a constant");
108}
109
110static unsigned getTypeDescFieldId(mlir::Type ty) {
111 auto isArray = fir::dyn_cast_ptrOrBoxEleTy(ty).isa<fir::SequenceType>();
112 return isArray ? kOptTypePtrPosInBox : kDimsPosInBox;
113}
114
115namespace {
116/// FIR conversion pattern template
117template <typename FromOp>
118class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
119public:
120 explicit FIROpConversion(fir::LLVMTypeConverter &lowering,
121 const fir::FIRToLLVMPassOptions &options)
122 : mlir::ConvertOpToLLVMPattern<FromOp>(lowering), options(options) {}
123
124protected:
125 mlir::Type convertType(mlir::Type ty) const {
126 return lowerTy().convertType(ty);
127 }
128 mlir::Type voidPtrTy() const { return getVoidPtrType(); }
129
130 mlir::Type getVoidPtrType() const {
131 return mlir::LLVM::LLVMPointerType::get(
132 mlir::IntegerType::get(&lowerTy().getContext(), 8));
133 }
134
135 mlir::LLVM::ConstantOp
136 genI32Constant(mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
137 int value) const {
138 mlir::Type i32Ty = rewriter.getI32Type();
139 mlir::IntegerAttr attr = rewriter.getI32IntegerAttr(value);
140 return rewriter.create<mlir::LLVM::ConstantOp>(loc, i32Ty, attr);
141 }
142
143 mlir::LLVM::ConstantOp
144 genConstantOffset(mlir::Location loc,
145 mlir::ConversionPatternRewriter &rewriter,
146 int offset) const {
147 mlir::Type ity = lowerTy().offsetType();
148 mlir::IntegerAttr cattr = rewriter.getI32IntegerAttr(offset);
149 return rewriter.create<mlir::LLVM::ConstantOp>(loc, ity, cattr);
150 }
151
152 /// Perform an extension or truncation as needed on an integer value. Lowering
153 /// to the specific target may involve some sign-extending or truncation of
154 /// values, particularly to fit them from abstract box types to the
155 /// appropriate reified structures.
156 mlir::Value integerCast(mlir::Location loc,
157 mlir::ConversionPatternRewriter &rewriter,
158 mlir::Type ty, mlir::Value val) const {
159 auto valTy = val.getType();
160 // If the value was not yet lowered, lower its type so that it can
161 // be used in getPrimitiveTypeSizeInBits.
162 if (!valTy.isa<mlir::IntegerType>())
163 valTy = convertType(valTy);
164 auto toSize = mlir::LLVM::getPrimitiveTypeSizeInBits(ty);
165 auto fromSize = mlir::LLVM::getPrimitiveTypeSizeInBits(valTy);
166 if (toSize < fromSize)
167 return rewriter.create<mlir::LLVM::TruncOp>(loc, ty, val);
168 if (toSize > fromSize)
169 return rewriter.create<mlir::LLVM::SExtOp>(loc, ty, val);
170 return val;
171 }
172
173 /// Construct code sequence to extract the specific value from a `fir.box`.
174 mlir::Value getValueFromBox(mlir::Location loc, mlir::Type boxTy,
175 mlir::Value box, mlir::Type resultTy,
176 mlir::ConversionPatternRewriter &rewriter,
177 int boxValue) const {
178 if (box.getType().isa<mlir::LLVM::LLVMPointerType>()) {
179 auto pty = mlir::LLVM::LLVMPointerType::get(resultTy);
180 auto p = rewriter.create<mlir::LLVM::GEPOp>(
181 loc, pty, box, llvm::ArrayRef<mlir::LLVM::GEPArg>{0, boxValue});
182 auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(loc, resultTy, p);
183 attachTBAATag(loadOp, boxTy, nullptr, p);
184 return loadOp;
185 }
186 return rewriter.create<mlir::LLVM::ExtractValueOp>(loc, box, boxValue);
187 }
188
189 /// Method to construct code sequence to get the triple for dimension `dim`
190 /// from a box.
191 llvm::SmallVector<mlir::Value, 3>
192 getDimsFromBox(mlir::Location loc, llvm::ArrayRef<mlir::Type> retTys,
193 mlir::Type boxTy, mlir::Value box, mlir::Value dim,
194 mlir::ConversionPatternRewriter &rewriter) const {
195 mlir::Value l0 =
196 loadDimFieldFromBox(loc, boxTy, box, dim, 0, retTys[0], rewriter);
197 mlir::Value l1 =
198 loadDimFieldFromBox(loc, boxTy, box, dim, 1, retTys[1], rewriter);
199 mlir::Value l2 =
200 loadDimFieldFromBox(loc, boxTy, box, dim, 2, retTys[2], rewriter);
201 return {l0, l1, l2};
202 }
203
204 llvm::SmallVector<mlir::Value, 3>
205 getDimsFromBox(mlir::Location loc, llvm::ArrayRef<mlir::Type> retTys,
206 mlir::Type boxTy, mlir::Value box, int dim,
207 mlir::ConversionPatternRewriter &rewriter) const {
208 mlir::Value l0 =
209 getDimFieldFromBox(loc, boxTy, box, dim, 0, retTys[0], rewriter);
210 mlir::Value l1 =
211 getDimFieldFromBox(loc, boxTy, box, dim, 1, retTys[1], rewriter);
212 mlir::Value l2 =
213 getDimFieldFromBox(loc, boxTy, box, dim, 2, retTys[2], rewriter);
214 return {l0, l1, l2};
215 }
216
217 mlir::Value
218 loadDimFieldFromBox(mlir::Location loc, mlir::Type boxTy, mlir::Value box,
219 mlir::Value dim, int off, mlir::Type ty,
220 mlir::ConversionPatternRewriter &rewriter) const {
221 assert(box.getType().isa<mlir::LLVM::LLVMPointerType>() &&(static_cast <bool> (box.getType().isa<mlir::LLVM::LLVMPointerType
>() && "descriptor inquiry with runtime dim can only be done on descriptor "
"in memory") ? void (0) : __assert_fail ("box.getType().isa<mlir::LLVM::LLVMPointerType>() && \"descriptor inquiry with runtime dim can only be done on descriptor \" \"in memory\""
, "flang/lib/Optimizer/CodeGen/CodeGen.cpp", 223, __extension__
__PRETTY_FUNCTION__))
222 "descriptor inquiry with runtime dim can only be done on descriptor "(static_cast <bool> (box.getType().isa<mlir::LLVM::LLVMPointerType
>() && "descriptor inquiry with runtime dim can only be done on descriptor "
"in memory") ? void (0) : __assert_fail ("box.getType().isa<mlir::LLVM::LLVMPointerType>() && \"descriptor inquiry with runtime dim can only be done on descriptor \" \"in memory\""
, "flang/lib/Optimizer/CodeGen/CodeGen.cpp", 223, __extension__
__PRETTY_FUNCTION__))
223 "in memory")(static_cast <bool> (box.getType().isa<mlir::LLVM::LLVMPointerType
>() && "descriptor inquiry with runtime dim can only be done on descriptor "
"in memory") ? void (0) : __assert_fail ("box.getType().isa<mlir::LLVM::LLVMPointerType>() && \"descriptor inquiry with runtime dim can only be done on descriptor \" \"in memory\""
, "flang/lib/Optimizer/CodeGen/CodeGen.cpp", 223, __extension__
__PRETTY_FUNCTION__))
;
224 auto pty = mlir::LLVM::LLVMPointerType::get(ty);
225 mlir::LLVM::GEPOp p = genGEP(loc, pty, rewriter, box, 0,
226 static_cast<int>(kDimsPosInBox), dim, off);
227 auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(loc, ty, p);
228 attachTBAATag(loadOp, boxTy, nullptr, p);
229 return loadOp;
230 }
231
232 mlir::Value
233 getDimFieldFromBox(mlir::Location loc, mlir::Type boxTy, mlir::Value box,
234 int dim, int off, mlir::Type ty,
235 mlir::ConversionPatternRewriter &rewriter) const {
236 if (box.getType().isa<mlir::LLVM::LLVMPointerType>()) {
237 auto pty = mlir::LLVM::LLVMPointerType::get(ty);
238 mlir::LLVM::GEPOp p = genGEP(loc, pty, rewriter, box, 0,
239 static_cast<int>(kDimsPosInBox), dim, off);
240 auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(loc, ty, p);
241 attachTBAATag(loadOp, boxTy, nullptr, p);
242 return loadOp;
243 }
244 return rewriter.create<mlir::LLVM::ExtractValueOp>(
245 loc, box, llvm::ArrayRef<std::int64_t>{kDimsPosInBox, dim, off});
246 }
247
248 mlir::Value
249 getStrideFromBox(mlir::Location loc, mlir::Type boxTy, mlir::Value box,
250 unsigned dim,
251 mlir::ConversionPatternRewriter &rewriter) const {
252 auto idxTy = lowerTy().indexType();
253 return getDimFieldFromBox(loc, boxTy, box, dim, kDimStridePos, idxTy,
254 rewriter);
255 }
256
257 /// Read base address from a fir.box. Returned address has type ty.
258 mlir::Value
259 getBaseAddrFromBox(mlir::Location loc, mlir::Type resultTy, mlir::Type boxTy,
260 mlir::Value box,
261 mlir::ConversionPatternRewriter &rewriter) const {
262 return getValueFromBox(loc, boxTy, box, resultTy, rewriter, kAddrPosInBox);
263 }
264
265 mlir::Value
266 getElementSizeFromBox(mlir::Location loc, mlir::Type resultTy,
267 mlir::Type boxTy, mlir::Value box,
268 mlir::ConversionPatternRewriter &rewriter) const {
269 return getValueFromBox(loc, boxTy, box, resultTy, rewriter,
270 kElemLenPosInBox);
271 }
272
273 // Get the element type given an LLVM type that is of the form
274 // [llvm.ptr](array|struct|vector)+ and the provided indexes.
275 static mlir::Type getBoxEleTy(mlir::Type type,
276 llvm::ArrayRef<std::int64_t> indexes) {
277 if (auto t = type.dyn_cast<mlir::LLVM::LLVMPointerType>())
278 type = t.getElementType();
279 for (unsigned i : indexes) {
280 if (auto t = type.dyn_cast<mlir::LLVM::LLVMStructType>()) {
281 assert(!t.isOpaque() && i < t.getBody().size())(static_cast <bool> (!t.isOpaque() && i < t.
getBody().size()) ? void (0) : __assert_fail ("!t.isOpaque() && i < t.getBody().size()"
, "flang/lib/Optimizer/CodeGen/CodeGen.cpp", 281, __extension__
__PRETTY_FUNCTION__))
;
282 type = t.getBody()[i];
283 } else if (auto t = type.dyn_cast<mlir::LLVM::LLVMArrayType>()) {
284 type = t.getElementType();
285 } else if (auto t = type.dyn_cast<mlir::VectorType>()) {
286 type = t.getElementType();
287 } else {
288 fir::emitFatalError(mlir::UnknownLoc::get(type.getContext()),
289 "request for invalid box element type");
290 }
291 }
292 return type;
293 }
294
295 // Return LLVM type of the base address given the LLVM type
296 // of the related descriptor (lowered fir.box type).
297 static mlir::Type getBaseAddrTypeFromBox(mlir::Type type) {
298 return getBoxEleTy(type, {kAddrPosInBox});
299 }
300
301 /// Read the address of the type descriptor from a box.
302 mlir::Value
303 loadTypeDescAddress(mlir::Location loc, mlir::Type boxTy, mlir::Value box,
304 mlir::ConversionPatternRewriter &rewriter) const {
305 unsigned typeDescFieldId = getTypeDescFieldId(boxTy);
306 mlir::Type tdescType = lowerTy().convertTypeDescType(rewriter.getContext());
307 return getValueFromBox(loc, boxTy, box, tdescType, rewriter,
308 typeDescFieldId);
309 }
310
311 // Load the attribute from the \p box and perform a check against \p maskValue
312 // The final comparison is implemented as `(attribute & maskValue) != 0`.
313 mlir::Value genBoxAttributeCheck(mlir::Location loc, mlir::Type boxTy,
314 mlir::Value box,
315 mlir::ConversionPatternRewriter &rewriter,
316 unsigned maskValue) const {
317 mlir::Type attrTy = rewriter.getI32Type();
318 mlir::Value attribute =
319 getValueFromBox(loc, boxTy, box, attrTy, rewriter, kAttributePosInBox);
320 mlir::LLVM::ConstantOp attrMask =
321 genConstantOffset(loc, rewriter, maskValue);
322 auto maskRes =
323 rewriter.create<mlir::LLVM::AndOp>(loc, attrTy, attribute, attrMask);
324 mlir::LLVM::ConstantOp c0 = genConstantOffset(loc, rewriter, 0);
325 return rewriter.create<mlir::LLVM::ICmpOp>(
326 loc, mlir::LLVM::ICmpPredicate::ne, maskRes, c0);
327 }
328
329 template <typename... ARGS>
330 mlir::LLVM::GEPOp genGEP(mlir::Location loc, mlir::Type ty,
331 mlir::ConversionPatternRewriter &rewriter,
332 mlir::Value base, ARGS... args) const {
333 llvm::SmallVector<mlir::LLVM::GEPArg> cv = {args...};
334 return rewriter.create<mlir::LLVM::GEPOp>(loc, ty, base, cv);
335 }
336
337 // Find the LLVMFuncOp in whose entry block the alloca should be inserted.
338 // The order to find the LLVMFuncOp is as follows:
339 // 1. The parent operation of the current block if it is a LLVMFuncOp.
340 // 2. The first ancestor that is a LLVMFuncOp.
341 mlir::LLVM::LLVMFuncOp
342 getFuncForAllocaInsert(mlir::ConversionPatternRewriter &rewriter) const {
343 mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp();
344 return mlir::isa<mlir::LLVM::LLVMFuncOp>(parentOp)
345 ? mlir::cast<mlir::LLVM::LLVMFuncOp>(parentOp)
346 : parentOp->getParentOfType<mlir::LLVM::LLVMFuncOp>();
347 }
348
349 // Generate an alloca of size 1 and type \p toTy.
350 mlir::LLVM::AllocaOp
351 genAllocaWithType(mlir::Location loc, mlir::Type toTy, unsigned alignment,
352 mlir::ConversionPatternRewriter &rewriter) const {
353 auto thisPt = rewriter.saveInsertionPoint();
354 mlir::LLVM::LLVMFuncOp func = getFuncForAllocaInsert(rewriter);
355 rewriter.setInsertionPointToStart(&func.front());
356 auto size = genI32Constant(loc, rewriter, 1);
357 auto al = rewriter.create<mlir::LLVM::AllocaOp>(loc, toTy, size, alignment);
358 rewriter.restoreInsertionPoint(thisPt);
359 return al;
360 }
361
362 fir::LLVMTypeConverter &lowerTy() const {
363 return *static_cast<fir::LLVMTypeConverter *>(this->getTypeConverter());
364 }
365
366 void attachTBAATag(mlir::LLVM::AliasAnalysisOpInterface op,
367 mlir::Type baseFIRType, mlir::Type accessFIRType,
368 mlir::LLVM::GEPOp gep) const {
369 lowerTy().attachTBAATag(op, baseFIRType, accessFIRType, gep);
370 }
371
372 const fir::FIRToLLVMPassOptions &options;
373};
374
375/// FIR conversion pattern template
376template <typename FromOp>
377class FIROpAndTypeConversion : public FIROpConversion<FromOp> {
378public:
379 using FIROpConversion<FromOp>::FIROpConversion;
380 using OpAdaptor = typename FromOp::Adaptor;
381
382 mlir::LogicalResult
383 matchAndRewrite(FromOp op, OpAdaptor adaptor,
384 mlir::ConversionPatternRewriter &rewriter) const final {
385 mlir::Type ty = this->convertType(op.getType());
386 return doRewrite(op, ty, adaptor, rewriter);
387 }
388
389 virtual mlir::LogicalResult
390 doRewrite(FromOp addr, mlir::Type ty, OpAdaptor adaptor,
391 mlir::ConversionPatternRewriter &rewriter) const = 0;
392};
393} // namespace
394
395namespace {
396/// Lower `fir.address_of` operation to `llvm.address_of` operation.
397struct AddrOfOpConversion : public FIROpConversion<fir::AddrOfOp> {
398 using FIROpConversion::FIROpConversion;
399
400 mlir::LogicalResult
401 matchAndRewrite(fir::AddrOfOp addr, OpAdaptor adaptor,
402 mlir::ConversionPatternRewriter &rewriter) const override {
403 auto ty = convertType(addr.getType());
404 rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(
405 addr, ty, addr.getSymbol().getRootReference().getValue());
406 return mlir::success();
407 }
408};
409} // namespace
410
411/// Lookup the function to compute the memory size of this parametric derived
412/// type. The size of the object may depend on the LEN type parameters of the
413/// derived type.
414static mlir::LLVM::LLVMFuncOp
415getDependentTypeMemSizeFn(fir::RecordType recTy, fir::AllocaOp op,
416 mlir::ConversionPatternRewriter &rewriter) {
417 auto module = op->getParentOfType<mlir::ModuleOp>();
418 std::string name = recTy.getName().str() + "P.mem.size";
419 if (auto memSizeFunc = module.lookupSymbol<mlir::LLVM::LLVMFuncOp>(name))
420 return memSizeFunc;
421 TODO(op.getLoc(), "did not find allocation function")do { fir::emitFatalError(op.getLoc(), llvm::Twine("flang/lib/Optimizer/CodeGen/CodeGen.cpp"
":" "421" ": not yet implemented: ") + llvm::Twine("did not find allocation function"
), false); } while (false)
;
422}
423
424// Compute the alloc scale size (constant factors encoded in the array type).
425// We do this for arrays without a constant interior or arrays of character with
426// dynamic length arrays, since those are the only ones that get decayed to a
427// pointer to the element type.
428template <typename OP>
429static mlir::Value
430genAllocationScaleSize(OP op, mlir::Type ity,
431 mlir::ConversionPatternRewriter &rewriter) {
432 mlir::Location loc = op.getLoc();
433 mlir::Type dataTy = op.getInType();
434 auto seqTy = dataTy.dyn_cast<fir::SequenceType>();
435 fir::SequenceType::Extent constSize = 1;
436 if (seqTy) {
437 int constRows = seqTy.getConstantRows();
438 const fir::SequenceType::ShapeRef &shape = seqTy.getShape();
439 if (constRows != static_cast<int>(shape.size())) {
440 for (auto extent : shape) {
441 if (constRows-- > 0)
442 continue;
443 if (extent != fir::SequenceType::getUnknownExtent())
444 constSize *= extent;
445 }
446 }
447 }
448
449 if (constSize != 1) {
450 mlir::Value constVal{
451 genConstantIndex(loc, ity, rewriter, constSize).getResult()};
452 return constVal;
453 }
454 return nullptr;
455}
456
457namespace {
458/// convert to LLVM IR dialect `alloca`
459struct AllocaOpConversion : public FIROpConversion<fir::AllocaOp> {
460 using FIROpConversion::FIROpConversion;
461
462 mlir::LogicalResult
463 matchAndRewrite(fir::AllocaOp alloc, OpAdaptor adaptor,
464 mlir::ConversionPatternRewriter &rewriter) const override {
465 mlir::ValueRange operands = adaptor.getOperands();
466 auto loc = alloc.getLoc();
467 mlir::Type ity = lowerTy().indexType();
468 unsigned i = 0;
469 mlir::Value size = genConstantIndex(loc, ity, rewriter, 1).getResult();
470 mlir::Type ty = convertType(alloc.getType());
471 mlir::Type resultTy = ty;
472 if (alloc.hasLenParams()) {
473 unsigned end = alloc.numLenParams();
474 llvm::SmallVector<mlir::Value> lenParams;
475 for (; i < end; ++i)
476 lenParams.push_back(operands[i]);
477 mlir::Type scalarType = fir::unwrapSequenceType(alloc.getInType());
478 if (auto chrTy = scalarType.dyn_cast<fir::CharacterType>()) {
479 fir::CharacterType rawCharTy = fir::CharacterType::getUnknownLen(
480 chrTy.getContext(), chrTy.getFKind());
481 ty = mlir::LLVM::LLVMPointerType::get(convertType(rawCharTy));
482 assert(end == 1)(static_cast <bool> (end == 1) ? void (0) : __assert_fail
("end == 1", "flang/lib/Optimizer/CodeGen/CodeGen.cpp", 482,
__extension__ __PRETTY_FUNCTION__))
;
483 size = integerCast(loc, rewriter, ity, lenParams[0]);
484 } else if (auto recTy = scalarType.dyn_cast<fir::RecordType>()) {
485 mlir::LLVM::LLVMFuncOp memSizeFn =
486 getDependentTypeMemSizeFn(recTy, alloc, rewriter);
487 if (!memSizeFn)
488 emitError(loc, "did not find allocation function");
489 mlir::NamedAttribute attr = rewriter.getNamedAttr(
490 "callee", mlir::SymbolRefAttr::get(memSizeFn));
491 auto call = rewriter.create<mlir::LLVM::CallOp>(
492 loc, ity, lenParams, llvm::ArrayRef<mlir::NamedAttribute>{attr});
493 size = call.getResult();
494 ty = ::getVoidPtrType(alloc.getContext());
495 } else {
496 return emitError(loc, "unexpected type ")
497 << scalarType << " with type parameters";
498 }
499 }
500 if (auto scaleSize = genAllocationScaleSize(alloc, ity, rewriter))
501 size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
502 if (alloc.hasShapeOperands()) {
503 unsigned end = operands.size();
504 for (; i < end; ++i)
505 size = rewriter.create<mlir::LLVM::MulOp>(
506 loc, ity, size, integerCast(loc, rewriter, ity, operands[i]));
507 }
508 if (ty == resultTy) {
509 // Do not emit the bitcast if ty and resultTy are the same.
510 rewriter.replaceOpWithNewOp<mlir::LLVM::AllocaOp>(alloc, ty, size,
511 alloc->getAttrs());
512 } else {
513 auto al = rewriter.create<mlir::LLVM::AllocaOp>(loc, ty, size,
514 alloc->getAttrs());
515 rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(alloc, resultTy, al);
516 }
517 return mlir::success();
518 }
519};
520} // namespace
521
522namespace {
523/// Lower `fir.box_addr` to the sequence of operations to extract the first
524/// element of the box.
525struct BoxAddrOpConversion : public FIROpConversion<fir::BoxAddrOp> {
526 using FIROpConversion::FIROpConversion;
527
528 mlir::LogicalResult
529 matchAndRewrite(fir::BoxAddrOp boxaddr, OpAdaptor adaptor,
530 mlir::ConversionPatternRewriter &rewriter) const override {
531 mlir::Value a = adaptor.getOperands()[0];
532 auto loc = boxaddr.getLoc();
533 mlir::Type ty = convertType(boxaddr.getType());
534 if (auto argty = boxaddr.getVal().getType().dyn_cast<fir::BaseBoxType>()) {
535 rewriter.replaceOp(boxaddr,
536 getBaseAddrFromBox(loc, ty, argty, a, rewriter));
537 } else {
538 rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>(boxaddr, a, 0);
539 }
540 return mlir::success();
541 }
542};
543
544/// Convert `!fir.boxchar_len` to `!llvm.extractvalue` for the 2nd part of the
545/// boxchar.
546struct BoxCharLenOpConversion : public FIROpConversion<fir::BoxCharLenOp> {
547 using FIROpConversion::FIROpConversion;
548
549 mlir::LogicalResult
550 matchAndRewrite(fir::BoxCharLenOp boxCharLen, OpAdaptor adaptor,
551 mlir::ConversionPatternRewriter &rewriter) const override {
552 mlir::Value boxChar = adaptor.getOperands()[0];
553 mlir::Location loc = boxChar.getLoc();
554 mlir::Type returnValTy = boxCharLen.getResult().getType();
555
556 constexpr int boxcharLenIdx = 1;
557 auto len = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, boxChar,
558 boxcharLenIdx);
559 mlir::Value lenAfterCast = integerCast(loc, rewriter, returnValTy, len);
560 rewriter.replaceOp(boxCharLen, lenAfterCast);
561
562 return mlir::success();
563 }
564};
565
566/// Lower `fir.box_dims` to a sequence of operations to extract the requested
567/// dimension information from the boxed value.
568/// Result in a triple set of GEPs and loads.
569struct BoxDimsOpConversion : public FIROpConversion<fir::BoxDimsOp> {
570 using FIROpConversion::FIROpConversion;
571
572 mlir::LogicalResult
573 matchAndRewrite(fir::BoxDimsOp boxdims, OpAdaptor adaptor,
574 mlir::ConversionPatternRewriter &rewriter) const override {
575 llvm::SmallVector<mlir::Type, 3> resultTypes = {
576 convertType(boxdims.getResult(0).getType()),
577 convertType(boxdims.getResult(1).getType()),
578 convertType(boxdims.getResult(2).getType()),
579 };
580 auto results = getDimsFromBox(
581 boxdims.getLoc(), resultTypes, boxdims.getVal().getType(),
582 adaptor.getOperands()[0], adaptor.getOperands()[1], rewriter);
583 rewriter.replaceOp(boxdims, results);
584 return mlir::success();
585 }
586};
587
588/// Lower `fir.box_elesize` to a sequence of operations ro extract the size of
589/// an element in the boxed value.
590struct BoxEleSizeOpConversion : public FIROpConversion<fir::BoxEleSizeOp> {
591 using FIROpConversion::FIROpConversion;
592
593 mlir::LogicalResult
594 matchAndRewrite(fir::BoxEleSizeOp boxelesz, OpAdaptor adaptor,
595 mlir::ConversionPatternRewriter &rewriter) const override {
596 mlir::Value box = adaptor.getOperands()[0];
597 auto loc = boxelesz.getLoc();
598 auto ty = convertType(boxelesz.getType());
599 auto elemSize = getElementSizeFromBox(loc, ty, boxelesz.getVal().getType(),
600 box, rewriter);
601 rewriter.replaceOp(boxelesz, elemSize);
602 return mlir::success();
603 }
604};
605
606/// Lower `fir.box_isalloc` to a sequence of operations to determine if the
607/// boxed value was from an ALLOCATABLE entity.
608struct BoxIsAllocOpConversion : public FIROpConversion<fir::BoxIsAllocOp> {
609 using FIROpConversion::FIROpConversion;
610
611 mlir::LogicalResult
612 matchAndRewrite(fir::BoxIsAllocOp boxisalloc, OpAdaptor adaptor,
613 mlir::ConversionPatternRewriter &rewriter) const override {
614 mlir::Value box = adaptor.getOperands()[0];
615 auto loc = boxisalloc.getLoc();
616 mlir::Value check = genBoxAttributeCheck(loc, boxisalloc.getVal().getType(),
617 box, rewriter, kAttrAllocatable);
618 rewriter.replaceOp(boxisalloc, check);
619 return mlir::success();
620 }
621};
622
623/// Lower `fir.box_isarray` to a sequence of operations to determine if the
624/// boxed is an array.
625struct BoxIsArrayOpConversion : public FIROpConversion<fir::BoxIsArrayOp> {
626 using FIROpConversion::FIROpConversion;
627
628 mlir::LogicalResult
629 matchAndRewrite(fir::BoxIsArrayOp boxisarray, OpAdaptor adaptor,
630 mlir::ConversionPatternRewriter &rewriter) const override {
631 mlir::Value a = adaptor.getOperands()[0];
632 auto loc = boxisarray.getLoc();
633 auto rank = getValueFromBox(loc, boxisarray.getVal().getType(), a,
634 rewriter.getI32Type(), rewriter, kRankPosInBox);
635 auto c0 = genConstantOffset(loc, rewriter, 0);
636 rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
637 boxisarray, mlir::LLVM::ICmpPredicate::ne, rank, c0);
638 return mlir::success();
639 }
640};
641
642/// Lower `fir.box_isptr` to a sequence of operations to determined if the
643/// boxed value was from a POINTER entity.
644struct BoxIsPtrOpConversion : public FIROpConversion<fir::BoxIsPtrOp> {
645 using FIROpConversion::FIROpConversion;
646
647 mlir::LogicalResult
648 matchAndRewrite(fir::BoxIsPtrOp boxisptr, OpAdaptor adaptor,
649 mlir::ConversionPatternRewriter &rewriter) const override {
650 mlir::Value box = adaptor.getOperands()[0];
651 auto loc = boxisptr.getLoc();
652 mlir::Value check = genBoxAttributeCheck(loc, boxisptr.getVal().getType(),
653 box, rewriter, kAttrPointer);
654 rewriter.replaceOp(boxisptr, check);
655 return mlir::success();
656 }
657};
658
659/// Lower `fir.box_rank` to the sequence of operation to extract the rank from
660/// the box.
661struct BoxRankOpConversion : public FIROpConversion<fir::BoxRankOp> {
662 using FIROpConversion::FIROpConversion;
663
664 mlir::LogicalResult
665 matchAndRewrite(fir::BoxRankOp boxrank, OpAdaptor adaptor,
666 mlir::ConversionPatternRewriter &rewriter) const override {
667 mlir::Value a = adaptor.getOperands()[0];
668 auto loc = boxrank.getLoc();
669 mlir::Type ty = convertType(boxrank.getType());
670 auto result = getValueFromBox(loc, boxrank.getVal().getType(), a, ty,
671 rewriter, kRankPosInBox);
672 rewriter.replaceOp(boxrank, result);
673 return mlir::success();
674 }
675};
676
677/// Lower `fir.boxproc_host` operation. Extracts the host pointer from the
678/// boxproc.
679/// TODO: Part of supporting Fortran 2003 procedure pointers.
680struct BoxProcHostOpConversion : public FIROpConversion<fir::BoxProcHostOp> {
681 using FIROpConversion::FIROpConversion;
682
683 mlir::LogicalResult
684 matchAndRewrite(fir::BoxProcHostOp boxprochost, OpAdaptor adaptor,
685 mlir::ConversionPatternRewriter &rewriter) const override {
686 TODO(boxprochost.getLoc(), "fir.boxproc_host codegen")do { fir::emitFatalError(boxprochost.getLoc(), llvm::Twine("flang/lib/Optimizer/CodeGen/CodeGen.cpp"
":" "686" ": not yet implemented: ") + llvm::Twine("fir.boxproc_host codegen"
), false); } while (false)
;
687 return mlir::failure();
688 }
689};
690
691/// Lower `fir.box_tdesc` to the sequence of operations to extract the type
692/// descriptor from the box.
693struct BoxTypeDescOpConversion : public FIROpConversion<fir::BoxTypeDescOp> {
694 using FIROpConversion::FIROpConversion;
695
696 mlir::LogicalResult
697 matchAndRewrite(fir::BoxTypeDescOp boxtypedesc, OpAdaptor adaptor,
698 mlir::ConversionPatternRewriter &rewriter) const override {
699 mlir::Value box = adaptor.getOperands()[0];
700 auto typeDescAddr = loadTypeDescAddress(
701 boxtypedesc.getLoc(), boxtypedesc.getBox().getType(), box, rewriter);
702 rewriter.replaceOp(boxtypedesc, typeDescAddr);
703 return mlir::success();
704 }
705};
706
707/// Lower `fir.box_typecode` to a sequence of operations to extract the type
708/// code in the boxed value.
709struct BoxTypeCodeOpConversion : public FIROpConversion<fir::BoxTypeCodeOp> {
710 using FIROpConversion::FIROpConversion;
711
712 mlir::LogicalResult
713 matchAndRewrite(fir::BoxTypeCodeOp op, OpAdaptor adaptor,
714 mlir::ConversionPatternRewriter &rewriter) const override {
715 mlir::Value box = adaptor.getOperands()[0];
716 auto loc = box.getLoc();
717 auto ty = convertType(op.getType());
718 auto typeCode = getValueFromBox(loc, op.getBox().getType(), box, ty,
719 rewriter, kTypePosInBox);
720 rewriter.replaceOp(op, typeCode);
721 return mlir::success();
722 }
723};
724
725/// Lower `fir.string_lit` to LLVM IR dialect operation.
726struct StringLitOpConversion : public FIROpConversion<fir::StringLitOp> {
727 using FIROpConversion::FIROpConversion;
728
729 mlir::LogicalResult
730 matchAndRewrite(fir::StringLitOp constop, OpAdaptor adaptor,
731 mlir::ConversionPatternRewriter &rewriter) const override {
732 auto ty = convertType(constop.getType());
733 auto attr = constop.getValue();
734 if (attr.isa<mlir::StringAttr>()) {
735 rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(constop, ty, attr);
736 return mlir::success();
737 }
738
739 auto charTy = constop.getType().cast<fir::CharacterType>();
740 unsigned bits = lowerTy().characterBitsize(charTy);
741 mlir::Type intTy = rewriter.getIntegerType(bits);
742 mlir::Location loc = constop.getLoc();
743 mlir::Value cst = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
744 if (auto arr = attr.dyn_cast<mlir::DenseElementsAttr>()) {
745 cst = rewriter.create<mlir::LLVM::ConstantOp>(loc, ty, arr);
746 } else if (auto arr = attr.dyn_cast<mlir::ArrayAttr>()) {
747 for (auto a : llvm::enumerate(arr.getValue())) {
748 // convert each character to a precise bitsize
749 auto elemAttr = mlir::IntegerAttr::get(
750 intTy,
751 a.value().cast<mlir::IntegerAttr>().getValue().zextOrTrunc(bits));
752 auto elemCst =
753 rewriter.create<mlir::LLVM::ConstantOp>(loc, intTy, elemAttr);
754 cst = rewriter.create<mlir::LLVM::InsertValueOp>(loc, cst, elemCst,
755 a.index());
756 }
757 } else {
758 return mlir::failure();
759 }
760 rewriter.replaceOp(constop, cst);
761 return mlir::success();
762 }
763};
764
765/// `fir.call` -> `llvm.call`
766struct CallOpConversion : public FIROpConversion<fir::CallOp> {
767 using FIROpConversion::FIROpConversion;
768
769 mlir::LogicalResult
770 matchAndRewrite(fir::CallOp call, OpAdaptor adaptor,
771 mlir::ConversionPatternRewriter &rewriter) const override {
772 llvm::SmallVector<mlir::Type> resultTys;
773 for (auto r : call.getResults())
774 resultTys.push_back(convertType(r.getType()));
775 // Convert arith::FastMathFlagsAttr to LLVM::FastMathFlagsAttr.
776 mlir::arith::AttrConvertFastMathToLLVM<fir::CallOp, mlir::LLVM::CallOp>
777 attrConvert(call);
778 rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
779 call, resultTys, adaptor.getOperands(), attrConvert.getAttrs());
780 return mlir::success();
781 }
782};
783} // namespace
784
785static mlir::Type getComplexEleTy(mlir::Type complex) {
786 if (auto cc = complex.dyn_cast<mlir::ComplexType>())
787 return cc.getElementType();
788 return complex.cast<fir::ComplexType>().getElementType();
789}
790
791namespace {
792/// Compare complex values
793///
794/// Per 10.1, the only comparisons available are .EQ. (oeq) and .NE. (une).
795///
796/// For completeness, all other comparison are done on the real component only.
797struct CmpcOpConversion : public FIROpConversion<fir::CmpcOp> {
798 using FIROpConversion::FIROpConversion;
799
800 mlir::LogicalResult
801 matchAndRewrite(fir::CmpcOp cmp, OpAdaptor adaptor,
802 mlir::ConversionPatternRewriter &rewriter) const override {
803 mlir::ValueRange operands = adaptor.getOperands();
804 mlir::Type resTy = convertType(cmp.getType());
805 mlir::Location loc = cmp.getLoc();
806 llvm::SmallVector<mlir::Value, 2> rp = {
807 rewriter.create<mlir::LLVM::ExtractValueOp>(loc, operands[0], 0),
808 rewriter.create<mlir::LLVM::ExtractValueOp>(loc, operands[1], 0)};
809 auto rcp =
810 rewriter.create<mlir::LLVM::FCmpOp>(loc, resTy, rp, cmp->getAttrs());
811 llvm::SmallVector<mlir::Value, 2> ip = {
812 rewriter.create<mlir::LLVM::ExtractValueOp>(loc, operands[0], 1),
813 rewriter.create<mlir::LLVM::ExtractValueOp>(loc, operands[1], 1)};
814 auto icp =
815 rewriter.create<mlir::LLVM::FCmpOp>(loc, resTy, ip, cmp->getAttrs());
816 llvm::SmallVector<mlir::Value, 2> cp = {rcp, icp};
817 switch (cmp.getPredicate()) {
818 case mlir::arith::CmpFPredicate::OEQ: // .EQ.
819 rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(cmp, resTy, cp);
820 break;
821 case mlir::arith::CmpFPredicate::UNE: // .NE.
822 rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(cmp, resTy, cp);
823 break;
824 default:
825 rewriter.replaceOp(cmp, rcp.getResult());
826 break;
827 }
828 return mlir::success();
829 }
830};
831
832/// Lower complex constants
833struct ConstcOpConversion : public FIROpConversion<fir::ConstcOp> {
834 using FIROpConversion::FIROpConversion;
835
836 mlir::LogicalResult
837 matchAndRewrite(fir::ConstcOp conc, OpAdaptor,
838 mlir::ConversionPatternRewriter &rewriter) const override {
839 mlir::Location loc = conc.getLoc();
840 mlir::Type ty = convertType(conc.getType());
841 mlir::Type ety = convertType(getComplexEleTy(conc.getType()));
842 auto realPart = rewriter.create<mlir::LLVM::ConstantOp>(
843 loc, ety, getValue(conc.getReal()));
844 auto imPart = rewriter.create<mlir::LLVM::ConstantOp>(
845 loc, ety, getValue(conc.getImaginary()));
846 auto undef = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
847 auto setReal =
848 rewriter.create<mlir::LLVM::InsertValueOp>(loc, undef, realPart, 0);
849 rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(conc, setReal,
850 imPart, 1);
851 return mlir::success();
852 }
853
854 inline llvm::APFloat getValue(mlir::Attribute attr) const {
855 return attr.cast<fir::RealAttr>().getValue();
856 }
857};
858
859/// convert value of from-type to value of to-type
860struct ConvertOpConversion : public FIROpConversion<fir::ConvertOp> {
861 using FIROpConversion::FIROpConversion;
862
863 static bool isFloatingPointTy(mlir::Type ty) {
864 return ty.isa<mlir::FloatType>();
865 }
866
867 mlir::LogicalResult
868 matchAndRewrite(fir::ConvertOp convert, OpAdaptor adaptor,
869 mlir::ConversionPatternRewriter &rewriter) const override {
870 auto fromFirTy = convert.getValue().getType();
871 auto toFirTy = convert.getRes().getType();
872 auto fromTy = convertType(fromFirTy);
873 auto toTy = convertType(toFirTy);
874 mlir::Value op0 = adaptor.getOperands()[0];
875
876 if (fromFirTy == toFirTy) {
877 rewriter.replaceOp(convert, op0);
878 return mlir::success();
879 }
880
881 auto loc = convert.getLoc();
882 auto i1Type = mlir::IntegerType::get(convert.getContext(), 1);
883
884 if (fromFirTy.isa<fir::LogicalType>() || toFirTy.isa<fir::LogicalType>()) {
885 // By specification fir::LogicalType value may be any number,
886 // where non-zero value represents .true. and zero value represents
887 // .false.
888 //
889 // integer<->logical conversion requires value normalization.
890 // Conversion from wide logical to narrow logical must set the result
891 // to non-zero iff the input is non-zero - the easiest way to implement
892 // it is to compare the input agains zero and set the result to
893 // the canonical 0/1.
894 // Conversion from narrow logical to wide logical may be implemented
895 // as a zero or sign extension of the input, but it may use value
896 // normalization as well.
897 if (!fromTy.isa<mlir::IntegerType>() || !toTy.isa<mlir::IntegerType>())
898 return mlir::emitError(loc)
899 << "unsupported types for logical conversion: " << fromTy
900 << " -> " << toTy;
901
902 // Do folding for constant inputs.
903 if (auto constVal = getIfConstantIntValue(op0)) {
904 mlir::Value normVal =
905 genConstantIndex(loc, toTy, rewriter, *constVal ? 1 : 0);
906 rewriter.replaceOp(convert, normVal);
907 return mlir::success();
908 }
909
910 // If the input is i1, then we can just zero extend it, and
911 // the result will be normalized.
912 if (fromTy == i1Type) {
913 rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(convert, toTy, op0);
914 return mlir::success();
915 }
916
917 // Compare the input with zero.
918 mlir::Value zero = genConstantIndex(loc, fromTy, rewriter, 0);
919 auto isTrue = rewriter.create<mlir::LLVM::ICmpOp>(
920 loc, mlir::LLVM::ICmpPredicate::ne, op0, zero);
921
922 // Zero extend the i1 isTrue result to the required type (unless it is i1
923 // itself).
924 if (toTy != i1Type)
925 rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(convert, toTy, isTrue);
926 else
927 rewriter.replaceOp(convert, isTrue.getResult());
928
929 return mlir::success();
930 }
931
932 if (fromTy == toTy) {
933 rewriter.replaceOp(convert, op0);
934 return mlir::success();
935 }
936 auto convertFpToFp = [&](mlir::Value val, unsigned fromBits,
937 unsigned toBits, mlir::Type toTy) -> mlir::Value {
938 if (fromBits == toBits) {
939 // TODO: Converting between two floating-point representations with the
940 // same bitwidth is not allowed for now.
941 mlir::emitError(loc,
942 "cannot implicitly convert between two floating-point "
943 "representations of the same bitwidth");
944 return {};
945 }
946 if (fromBits > toBits)
947 return rewriter.create<mlir::LLVM::FPTruncOp>(loc, toTy, val);
948 return rewriter.create<mlir::LLVM::FPExtOp>(loc, toTy, val);
949 };
950 // Complex to complex conversion.
951 if (fir::isa_complex(fromFirTy) && fir::isa_complex(toFirTy)) {
952 // Special case: handle the conversion of a complex such that both the
953 // real and imaginary parts are converted together.
954 auto ty = convertType(getComplexEleTy(convert.getValue().getType()));
955 auto rp = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, op0, 0);
956 auto ip = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, op0, 1);
957 auto nt = convertType(getComplexEleTy(convert.getRes().getType()));
958 auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(ty);
959 auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(nt);
960 auto rc = convertFpToFp(rp, fromBits, toBits, nt);
961 auto ic = convertFpToFp(ip, fromBits, toBits, nt);
962 auto un = rewriter.create<mlir::LLVM::UndefOp>(loc, toTy);
963 auto i1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, un, rc, 0);
964 rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(convert, i1, ic,
965 1);
966 return mlir::success();
967 }
968
969 // Floating point to floating point conversion.
970 if (isFloatingPointTy(fromTy)) {
971 if (isFloatingPointTy(toTy)) {
972 auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(fromTy);
973 auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(toTy);
974 auto v = convertFpToFp(op0, fromBits, toBits, toTy);
975 rewriter.replaceOp(convert, v);
976 return mlir::success();
977 }
978 if (toTy.isa<mlir::IntegerType>()) {
979 rewriter.replaceOpWithNewOp<mlir::LLVM::FPToSIOp>(convert, toTy, op0);
980 return mlir::success();
981 }
982 } else if (fromTy.isa<mlir::IntegerType>()) {
983 // Integer to integer conversion.
984 if (toTy.isa<mlir::IntegerType>()) {
985 auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(fromTy);
986 auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(toTy);
987 assert(fromBits != toBits)(static_cast <bool> (fromBits != toBits) ? void (0) : __assert_fail
("fromBits != toBits", "flang/lib/Optimizer/CodeGen/CodeGen.cpp"
, 987, __extension__ __PRETTY_FUNCTION__))
;
988 if (fromBits > toBits) {
989 rewriter.replaceOpWithNewOp<mlir::LLVM::TruncOp>(convert, toTy, op0);
990 return mlir::success();
991 }
992 if (fromFirTy == i1Type) {
993 rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(convert, toTy, op0);
994 return mlir::success();
995 }
996 rewriter.replaceOpWithNewOp<mlir::LLVM::SExtOp>(convert, toTy, op0);
997 return mlir::success();
998 }
999 // Integer to floating point conversion.
1000 if (isFloatingPointTy(toTy)) {
1001 rewriter.replaceOpWithNewOp<mlir::LLVM::SIToFPOp>(convert, toTy, op0);
1002 return mlir::success();
1003 }
1004 // Integer to pointer conversion.
1005 if (toTy.isa<mlir::LLVM::LLVMPointerType>()) {
1006 rewriter.replaceOpWithNewOp<mlir::LLVM::IntToPtrOp>(convert, toTy, op0);
1007 return mlir::success();
1008 }
1009 } else if (fromTy.isa<mlir::LLVM::LLVMPointerType>()) {
1010 // Pointer to integer conversion.
1011 if (toTy.isa<mlir::IntegerType>()) {
1012 rewriter.replaceOpWithNewOp<mlir::LLVM::PtrToIntOp>(convert, toTy, op0);
1013 return mlir::success();
1014 }
1015 // Pointer to pointer conversion.
1016 if (toTy.isa<mlir::LLVM::LLVMPointerType>()) {
1017 rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(convert, toTy, op0);
1018 return mlir::success();
1019 }
1020 }
1021 return emitError(loc) << "cannot convert " << fromTy << " to " << toTy;
1022 }
1023};
1024
1025/// `fir.disptach_table` operation has no specific CodeGen. The operation is
1026/// only used to carry information during FIR to FIR passes.
1027struct DispatchTableOpConversion
1028 : public FIROpConversion<fir::DispatchTableOp> {
1029 using FIROpConversion::FIROpConversion;
1030
1031 mlir::LogicalResult
1032 matchAndRewrite(fir::DispatchTableOp op, OpAdaptor,
1033 mlir::ConversionPatternRewriter &rewriter) const override {
1034 rewriter.eraseOp(op);
1035 return mlir::success();
1036 }
1037};
1038
1039/// `fir.dt_entry` operation has no specific CodeGen. The operation is only used
1040/// to carry information during FIR to FIR passes.
1041struct DTEntryOpConversion : public FIROpConversion<fir::DTEntryOp> {
1042 using FIROpConversion::FIROpConversion;
1043
1044 mlir::LogicalResult
1045 matchAndRewrite(fir::DTEntryOp op, OpAdaptor,
1046 mlir::ConversionPatternRewriter &rewriter) const override {
1047 rewriter.eraseOp(op);
1048 return mlir::success();
1049 }
1050};
1051
1052/// Lower `fir.global_len` operation.
1053struct GlobalLenOpConversion : public FIROpConversion<fir::GlobalLenOp> {
1054 using FIROpConversion::FIROpConversion;
1055
1056 mlir::LogicalResult
1057 matchAndRewrite(fir::GlobalLenOp globalLen, OpAdaptor adaptor,
1058 mlir::ConversionPatternRewriter &rewriter) const override {
1059 TODO(globalLen.getLoc(), "fir.global_len codegen")do { fir::emitFatalError(globalLen.getLoc(), llvm::Twine("flang/lib/Optimizer/CodeGen/CodeGen.cpp"
":" "1059" ": not yet implemented: ") + llvm::Twine("fir.global_len codegen"
), false); } while (false)
;
1060 return mlir::failure();
1061 }
1062};
1063
1064/// Lower fir.len_param_index
1065struct LenParamIndexOpConversion
1066 : public FIROpConversion<fir::LenParamIndexOp> {
1067 using FIROpConversion::FIROpConversion;
1068
1069 // FIXME: this should be specialized by the runtime target
1070 mlir::LogicalResult
1071 matchAndRewrite(fir::LenParamIndexOp lenp, OpAdaptor,
1072 mlir::ConversionPatternRewriter &rewriter) const override {
1073 TODO(lenp.getLoc(), "fir.len_param_index codegen")do { fir::emitFatalError(lenp.getLoc(), llvm::Twine("flang/lib/Optimizer/CodeGen/CodeGen.cpp"
":" "1073" ": not yet implemented: ") + llvm::Twine("fir.len_param_index codegen"
), false); } while (false)
;
1074 }
1075};
1076
1077/// Convert `!fir.emboxchar<!fir.char<KIND, ?>, #n>` into a sequence of
1078/// instructions that generate `!llvm.struct<(ptr<ik>, i64)>`. The 1st element
1079/// in this struct is a pointer. Its type is determined from `KIND`. The 2nd
1080/// element is the length of the character buffer (`#n`).
1081struct EmboxCharOpConversion : public FIROpConversion<fir::EmboxCharOp> {
1082 using FIROpConversion::FIROpConversion;
1083
1084 mlir::LogicalResult
1085 matchAndRewrite(fir::EmboxCharOp emboxChar, OpAdaptor adaptor,
1086 mlir::ConversionPatternRewriter &rewriter) const override {
1087 mlir::ValueRange operands = adaptor.getOperands();
1088
1089 mlir::Value charBuffer = operands[0];
1090 mlir::Value charBufferLen = operands[1];
1091
1092 mlir::Location loc = emboxChar.getLoc();
1093 mlir::Type llvmStructTy = convertType(emboxChar.getType());
1094 auto llvmStruct = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmStructTy);
1095
1096 mlir::Type lenTy =
1097 llvmStructTy.cast<mlir::LLVM::LLVMStructType>().getBody()[1];
1098 mlir::Value lenAfterCast = integerCast(loc, rewriter, lenTy, charBufferLen);
1099
1100 mlir::Type addrTy =
1101 llvmStructTy.cast<mlir::LLVM::LLVMStructType>().getBody()[0];
1102 if (addrTy != charBuffer.getType())
1103 charBuffer =
1104 rewriter.create<mlir::LLVM::BitcastOp>(loc, addrTy, charBuffer);
1105
1106 auto insertBufferOp = rewriter.create<mlir::LLVM::InsertValueOp>(
1107 loc, llvmStruct, charBuffer, 0);
1108 rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(
1109 emboxChar, insertBufferOp, lenAfterCast, 1);
1110
1111 return mlir::success();
1112 }
1113};
1114} // namespace
1115
1116/// Return the LLVMFuncOp corresponding to the standard malloc call.
1117static mlir::LLVM::LLVMFuncOp
1118getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
1119 auto module = op->getParentOfType<mlir::ModuleOp>();
1120 if (mlir::LLVM::LLVMFuncOp mallocFunc =
1121 module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("malloc"))
1122 return mallocFunc;
1123 mlir::OpBuilder moduleBuilder(
1124 op->getParentOfType<mlir::ModuleOp>().getBodyRegion());
1125 auto indexType = mlir::IntegerType::get(op.getContext(), 64);
1126 return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
1127 rewriter.getUnknownLoc(), "malloc",
1128 mlir::LLVM::LLVMFunctionType::get(getVoidPtrType(op.getContext()),
1129 indexType,
1130 /*isVarArg=*/false));
1131}
1132
1133/// Helper function for generating the LLVM IR that computes the distance
1134/// in bytes between adjacent elements pointed to by a pointer
1135/// of type \p ptrTy. The result is returned as a value of \p idxTy integer
1136/// type.
1137static mlir::Value
1138computeElementDistance(mlir::Location loc, mlir::Type ptrTy, mlir::Type idxTy,
1139 mlir::ConversionPatternRewriter &rewriter) {
1140 // Note that we cannot use something like
1141 // mlir::LLVM::getPrimitiveTypeSizeInBits() for the element type here. For
1142 // example, it returns 10 bytes for mlir::Float80Type for targets where it
1143 // occupies 16 bytes. Proper solution is probably to use
1144 // mlir::DataLayout::getTypeABIAlignment(), but DataLayout is not being set
1145 // yet (see llvm-project#57230). For the time being use the '(intptr_t)((type
1146 // *)0 + 1)' trick for all types. The generated instructions are optimized
1147 // into constant by the first pass of InstCombine, so it should not be a
1148 // performance issue.
1149 auto nullPtr = rewriter.create<mlir::LLVM::NullOp>(loc, ptrTy);
1150 auto gep = rewriter.create<mlir::LLVM::GEPOp>(
1151 loc, ptrTy, nullPtr, llvm::ArrayRef<mlir::LLVM::GEPArg>{1});
1152 return rewriter.create<mlir::LLVM::PtrToIntOp>(loc, idxTy, gep);
1153}
1154
1155/// Return value of the stride in bytes between adjacent elements
1156/// of LLVM type \p llTy. The result is returned as a value of
1157/// \p idxTy integer type.
1158static mlir::Value
1159genTypeStrideInBytes(mlir::Location loc, mlir::Type idxTy,
1160 mlir::ConversionPatternRewriter &rewriter,
1161 mlir::Type llTy) {
1162 // Create a pointer type and use computeElementDistance().
1163 auto ptrTy = mlir::LLVM::LLVMPointerType::get(llTy);
1164 return computeElementDistance(loc, ptrTy, idxTy, rewriter);
1165}
1166
1167namespace {
1168/// Lower a `fir.allocmem` instruction into `llvm.call @malloc`
1169struct AllocMemOpConversion : public FIROpConversion<fir::AllocMemOp> {
1170 using FIROpConversion::FIROpConversion;
1171
1172 mlir::LogicalResult
1173 matchAndRewrite(fir::AllocMemOp heap, OpAdaptor adaptor,
1174 mlir::ConversionPatternRewriter &rewriter) const override {
1175 mlir::Type heapTy = heap.getType();
1176 mlir::Type ty = convertType(heapTy);
1177 mlir::LLVM::LLVMFuncOp mallocFunc = getMalloc(heap, rewriter);
1178 mlir::Location loc = heap.getLoc();
1179 auto ity = lowerTy().indexType();
1180 mlir::Type dataTy = fir::unwrapRefType(heapTy);
1181 if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
1182 TODO(loc, "fir.allocmem codegen of derived type with length parameters")do { fir::emitFatalError(loc, llvm::Twine("flang/lib/Optimizer/CodeGen/CodeGen.cpp"
":" "1182" ": not yet implemented: ") + llvm::Twine("fir.allocmem codegen of derived type with length parameters"
), false); } while (false)
;
1183 mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, ty);
1184 if (auto scaleSize = genAllocationScaleSize(heap, ity, rewriter))
1185 size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
1186 for (mlir::Value opnd : adaptor.getOperands())
1187 size = rewriter.create<mlir::LLVM::MulOp>(
1188 loc, ity, size, integerCast(loc, rewriter, ity, opnd));
1189 heap->setAttr("callee", mlir::SymbolRefAttr::get(mallocFunc));
1190 auto malloc = rewriter.create<mlir::LLVM::CallOp>(
1191 loc, ::getVoidPtrType(heap.getContext()), size, heap->getAttrs());
1192 rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(heap, ty,
1193 malloc.getResult());
1194 return mlir::success();
1195 }
1196
1197 /// Compute the allocation size in bytes of the element type of
1198 /// \p llTy pointer type. The result is returned as a value of \p idxTy
1199 /// integer type.
1200 mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy,
1201 mlir::ConversionPatternRewriter &rewriter,
1202 mlir::Type llTy) const {
1203 auto ptrTy = llTy.dyn_cast<mlir::LLVM::LLVMPointerType>();
1204 return computeElementDistance(loc, ptrTy, idxTy, rewriter);
1205 }
1206};
1207} // namespace
1208
1209/// Return the LLVMFuncOp corresponding to the standard free call.
1210static mlir::LLVM::LLVMFuncOp
1211getFree(fir::FreeMemOp op, mlir::ConversionPatternRewriter &rewriter) {
1212 auto module = op->getParentOfType<mlir::ModuleOp>();
1213 if (mlir::LLVM::LLVMFuncOp freeFunc =
1214 module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("free"))
1215 return freeFunc;
1216 mlir::OpBuilder moduleBuilder(module.getBodyRegion());
1217 auto voidType = mlir::LLVM::LLVMVoidType::get(op.getContext());
1218 return moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
1219 rewriter.getUnknownLoc(), "free",
1220 mlir::LLVM::LLVMFunctionType::get(voidType,
1221 getVoidPtrType(op.getContext()),
1222 /*isVarArg=*/false));
1223}
1224
1225static unsigned getDimension(mlir::LLVM::LLVMArrayType ty) {
1226 unsigned result = 1;
1227 for (auto eleTy = ty.getElementType().dyn_cast<mlir::LLVM::LLVMArrayType>();
1228 eleTy;
1229 eleTy = eleTy.getElementType().dyn_cast<mlir::LLVM::LLVMArrayType>())
1230 ++result;
1231 return result;
1232}
1233
1234namespace {
1235/// Lower a `fir.freemem` instruction into `llvm.call @free`
1236struct FreeMemOpConversion : public FIROpConversion<fir::FreeMemOp> {
1237 using FIROpConversion::FIROpConversion;
1238
1239 mlir::LogicalResult
1240 matchAndRewrite(fir::FreeMemOp freemem, OpAdaptor adaptor,
1241 mlir::ConversionPatternRewriter &rewriter) const override {
1242 mlir::LLVM::LLVMFuncOp freeFunc = getFree(freemem, rewriter);
1243 mlir::Location loc = freemem.getLoc();
1244 auto bitcast = rewriter.create<mlir::LLVM::BitcastOp>(
1245 freemem.getLoc(), voidPtrTy(), adaptor.getOperands()[0]);
1246 freemem->setAttr("callee", mlir::SymbolRefAttr::get(freeFunc));
1247 rewriter.create<mlir::LLVM::CallOp>(
1248 loc, mlir::TypeRange{}, mlir::ValueRange{bitcast}, freemem->getAttrs());
1249 rewriter.eraseOp(freemem);
1250 return mlir::success();
1251 }
1252};
1253} // namespace
1254
1255/// Common base class for embox to descriptor conversion.
1256template <typename OP>
1257struct EmboxCommonConversion : public FIROpConversion<OP> {
1258 using FIROpConversion<OP>::FIROpConversion;
1259
1260 static int getCFIAttr(fir::BaseBoxType boxTy) {
1261 auto eleTy = boxTy.getEleTy();
1262 if (eleTy.isa<fir::PointerType>())
1263 return CFI_attribute_pointer1;
1264 if (eleTy.isa<fir::HeapType>())
1265 return CFI_attribute_allocatable2;
1266 return CFI_attribute_other0;
1267 }
1268
1269 // Get the element size and CFI type code of the boxed value.
1270 std::tuple<mlir::Value, mlir::Value> getSizeAndTypeCode(
1271 mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
1272 mlir::Type boxEleTy, mlir::ValueRange lenParams = {}) const {
1273 auto i64Ty = mlir::IntegerType::get(rewriter.getContext(), 64);
1274 if (auto eleTy = fir::dyn_cast_ptrEleTy(boxEleTy))
1275 boxEleTy = eleTy;
1276 if (auto seqTy = boxEleTy.dyn_cast<fir::SequenceType>())
1277 return getSizeAndTypeCode(loc, rewriter, seqTy.getEleTy(), lenParams);
1278 if (boxEleTy.isa<mlir::NoneType>()) // unlimited polymorphic or assumed type
1279 return {rewriter.create<mlir::LLVM::ConstantOp>(loc, i64Ty, 0),
1280 this->genConstantOffset(loc, rewriter, CFI_type_other(-1))};
1281 mlir::Value typeCodeVal = this->genConstantOffset(
1282 loc, rewriter,
1283 fir::getTypeCode(boxEleTy, this->lowerTy().getKindMap()));
1284 if (fir::isa_integer(boxEleTy) || boxEleTy.dyn_cast<fir::LogicalType>() ||
1285 fir::isa_real(boxEleTy) || fir::isa_complex(boxEleTy))
1286 return {genTypeStrideInBytes(loc, i64Ty, rewriter,
1287 this->convertType(boxEleTy)),
1288 typeCodeVal};
1289 if (auto charTy = boxEleTy.dyn_cast<fir::CharacterType>()) {
1290 mlir::Value size =
1291 genTypeStrideInBytes(loc, i64Ty, rewriter, this->convertType(charTy));
1292 if (charTy.getLen() == fir::CharacterType::unknownLen()) {
1293 // Multiply the single character size by the length.
1294 assert(!lenParams.empty())(static_cast <bool> (!lenParams.empty()) ? void (0) : __assert_fail
("!lenParams.empty()", "flang/lib/Optimizer/CodeGen/CodeGen.cpp"
, 1294, __extension__ __PRETTY_FUNCTION__))
;
1295 auto len64 = FIROpConversion<OP>::integerCast(loc, rewriter, i64Ty,
1296 lenParams.back());
1297 size = rewriter.create<mlir::LLVM::MulOp>(loc, i64Ty, size, len64);
1298 }
1299 return {size, typeCodeVal};
1300 };
1301 if (fir::isa_ref_type(boxEleTy)) {
1302 auto ptrTy = mlir::LLVM::LLVMPointerType::get(
1303 mlir::LLVM::LLVMVoidType::get(rewriter.getContext()));
1304 return {genTypeStrideInBytes(loc, i64Ty, rewriter, ptrTy), typeCodeVal};
1305 }
1306 if (boxEleTy.isa<fir::RecordType>())
1307 return {genTypeStrideInBytes(loc, i64Ty, rewriter,
1308 this->convertType(boxEleTy)),
1309 typeCodeVal};
1310 fir::emitFatalError(loc, "unhandled type in fir.box code generation");
1311 }
1312
1313 /// Basic pattern to write a field in the descriptor
1314 mlir::Value insertField(mlir::ConversionPatternRewriter &rewriter,
1315 mlir::Location loc, mlir::Value dest,
1316 llvm::ArrayRef<std::int64_t> fldIndexes,
1317 mlir::Value value, bool bitcast = false) const {
1318 auto boxTy = dest.getType();
1319 auto fldTy = this->getBoxEleTy(boxTy, fldIndexes);
1320 if (bitcast)
1321 value = rewriter.create<mlir::LLVM::BitcastOp>(loc, fldTy, value);
1322 else
1323 value = this->integerCast(loc, rewriter, fldTy, value);
1324 return rewriter.create<mlir::LLVM::InsertValueOp>(loc, dest, value,
1325 fldIndexes);
1326 }
1327
1328 inline mlir::Value
1329 insertBaseAddress(mlir::ConversionPatternRewriter &rewriter,
1330 mlir::Location loc, mlir::Value dest,
1331 mlir::Value base) const {
1332 return insertField(rewriter, loc, dest, {kAddrPosInBox}, base,
1333 /*bitCast=*/true);
1334 }
1335
1336 inline mlir::Value insertLowerBound(mlir::ConversionPatternRewriter &rewriter,
1337 mlir::Location loc, mlir::Value dest,
1338 unsigned dim, mlir::Value lb) const {
1339 return insertField(rewriter, loc, dest,
1340 {kDimsPosInBox, dim, kDimLowerBoundPos}, lb);
1341 }
1342
1343 inline mlir::Value insertExtent(mlir::ConversionPatternRewriter &rewriter,
1344 mlir::Location loc, mlir::Value dest,
1345 unsigned dim, mlir::Value extent) const {
1346 return insertField(rewriter, loc, dest, {kDimsPosInBox, dim, kDimExtentPos},
1347 extent);
1348 }
1349
1350 inline mlir::Value insertStride(mlir::ConversionPatternRewriter &rewriter,
1351 mlir::Location loc, mlir::Value dest,
1352 unsigned dim, mlir::Value stride) const {
1353 return insertField(rewriter, loc, dest, {kDimsPosInBox, dim, kDimStridePos},
1354 stride);
1355 }
1356
1357 /// Get the address of the type descriptor global variable that was created by
1358 /// lowering for derived type \p recType.
1359 mlir::Value getTypeDescriptor(mlir::ModuleOp mod,
1360 mlir::ConversionPatternRewriter &rewriter,
1361 mlir::Location loc,
1362 fir::RecordType recType) const {
1363 std::string name =
1364 fir::NameUniquer::getTypeDescriptorName(recType.getName());
1365 if (auto global = mod.template lookupSymbol<fir::GlobalOp>(name)) {
1366 auto ty = mlir::LLVM::LLVMPointerType::get(
1367 this->lowerTy().convertType(global.getType()));
1368 return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ty,
1369 global.getSymName());
1370 }
1371 if (auto global = mod.template lookupSymbol<mlir::LLVM::GlobalOp>(name)) {
1372 // The global may have already been translated to LLVM.
1373 auto ty = mlir::LLVM::LLVMPointerType::get(global.getType());
1374 return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ty,
1375 global.getSymName());
1376 }
1377 // Type info derived types do not have type descriptors since they are the
1378 // types defining type descriptors.
1379 if (!this->options.ignoreMissingTypeDescriptors &&
1380 !fir::NameUniquer::belongsToModule(
1381 name, Fortran::semantics::typeInfoBuiltinModule))
1382 fir::emitFatalError(
1383 loc, "runtime derived type info descriptor was not generated");
1384 return rewriter.create<mlir::LLVM::NullOp>(
1385 loc, ::getVoidPtrType(mod.getContext()));
1386 }
1387
1388 mlir::Value populateDescriptor(mlir::Location loc, mlir::ModuleOp mod,
1389 fir::BaseBoxType boxTy, mlir::Type inputType,
1390 mlir::ConversionPatternRewriter &rewriter,
1391 unsigned rank, mlir::Value eleSize,
1392 mlir::Value cfiTy,
1393 mlir::Value typeDesc) const {
1394 auto convTy = this->lowerTy().convertBoxType(boxTy, rank);
1395 auto llvmBoxPtrTy = convTy.template cast<mlir::LLVM::LLVMPointerType>();
1396 auto llvmBoxTy = llvmBoxPtrTy.getElementType();
1397 bool isUnlimitedPolymorphic = fir::isUnlimitedPolymorphicType(boxTy);
1398 bool useInputType = fir::isPolymorphicType(boxTy) || isUnlimitedPolymorphic;
1399 mlir::Value descriptor =
1400 rewriter.create<mlir::LLVM::UndefOp>(loc, llvmBoxTy);
1401 descriptor =
1402 insertField(rewriter, loc, descriptor, {kElemLenPosInBox}, eleSize);
1403 descriptor = insertField(rewriter, loc, descriptor, {kVersionPosInBox},
1404 this->genI32Constant(loc, rewriter, CFI_VERSION20180515));
1405 descriptor = insertField(rewriter, loc, descriptor, {kRankPosInBox},
1406 this->genI32Constant(loc, rewriter, rank));
1407 descriptor = insertField(rewriter, loc, descriptor, {kTypePosInBox}, cfiTy);
1408 descriptor =
1409 insertField(rewriter, loc, descriptor, {kAttributePosInBox},
1410 this->genI32Constant(loc, rewriter, getCFIAttr(boxTy)));
1411 const bool hasAddendum = fir::boxHasAddendum(boxTy);
1412 descriptor =
1413 insertField(rewriter, loc, descriptor, {kF18AddendumPosInBox},
1414 this->genI32Constant(loc, rewriter, hasAddendum ? 1 : 0));
1415
1416 if (hasAddendum) {
1417 unsigned typeDescFieldId = getTypeDescFieldId(boxTy);
1418 if (!typeDesc) {
1419 if (useInputType) {
1420 mlir::Type innerType = fir::unwrapInnerType(inputType);
1421 if (innerType && innerType.template isa<fir::RecordType>()) {
1422 auto recTy = innerType.template dyn_cast<fir::RecordType>();
1423 typeDesc = getTypeDescriptor(mod, rewriter, loc, recTy);
1424 } else {
1425 // Unlimited polymorphic type descriptor with no record type. Set
1426 // type descriptor address to a clean state.
1427 typeDesc = rewriter.create<mlir::LLVM::NullOp>(
1428 loc, ::getVoidPtrType(mod.getContext()));
1429 }
1430 } else {
1431 typeDesc = getTypeDescriptor(mod, rewriter, loc,
1432 fir::unwrapIfDerived(boxTy));
1433 }
1434 }
1435 if (typeDesc)
1436 descriptor =
1437 insertField(rewriter, loc, descriptor, {typeDescFieldId}, typeDesc,
1438 /*bitCast=*/true);
1439 }
1440 return descriptor;
1441 }
1442
1443 // Template used for fir::EmboxOp and fir::cg::XEmboxOp
1444 template <typename BOX>
1445 std::tuple<fir::BaseBoxType, mlir::Value, mlir::Value>
1446 consDescriptorPrefix(BOX box, mlir::Type inputType,
1447 mlir::ConversionPatternRewriter &rewriter, unsigned rank,
1448 [[maybe_unused]] mlir::ValueRange substrParams,
1449 mlir::ValueRange lenParams, mlir::Value sourceBox = {},
1450 mlir::Type sourceBoxType = {}) const {
1451 auto loc = box.getLoc();
1452 auto boxTy = box.getType().template dyn_cast<fir::BaseBoxType>();
1453 bool useInputType = fir::isPolymorphicType(boxTy) &&
1454 !fir::isUnlimitedPolymorphicType(inputType);
1455 llvm::SmallVector<mlir::Value> typeparams = lenParams;
1456 if constexpr (!std::is_same_v<BOX, fir::EmboxOp>) {
1457 if (!box.getSubstr().empty() && fir::hasDynamicSize(boxTy.getEleTy()))
1458 typeparams.push_back(substrParams[1]);
1459 }
1460
1461 // Write each of the fields with the appropriate values.
1462 // When emboxing an element to a polymorphic descriptor, use the
1463 // input type since the destination descriptor type has not the exact
1464 // information.
1465 auto [eleSize, cfiTy] = getSizeAndTypeCode(
1466 loc, rewriter, useInputType ? inputType : boxTy.getEleTy(), typeparams);
1467
1468 mlir::Value typeDesc;
1469 // When emboxing to a polymorphic box, get the type descriptor, type code
1470 // and element size from the source box if any.
1471 if (fir::isPolymorphicType(boxTy) && sourceBox) {
1472 typeDesc =
1473 this->loadTypeDescAddress(loc, sourceBoxType, sourceBox, rewriter);
1474 mlir::Type idxTy = this->lowerTy().indexType();
1475 eleSize = this->getElementSizeFromBox(loc, idxTy, sourceBoxType,
1476 sourceBox, rewriter);
1477 cfiTy = this->getValueFromBox(loc, sourceBoxType, sourceBox,
1478 cfiTy.getType(), rewriter, kTypePosInBox);
1479 }
1480 auto mod = box->template getParentOfType<mlir::ModuleOp>();
1481 mlir::Value descriptor = populateDescriptor(
1482 loc, mod, boxTy, inputType, rewriter, rank, eleSize, cfiTy, typeDesc);
1483
1484 return {boxTy, descriptor, eleSize};
1485 }
1486
1487 std::tuple<fir::BaseBoxType, mlir::Value, mlir::Value>
1488 consDescriptorPrefix(fir::cg::XReboxOp box, mlir::Value loweredBox,
1489 mlir::ConversionPatternRewriter &rewriter, unsigned rank,
1490 mlir::ValueRange substrParams,
1491 mlir::ValueRange lenParams,
1492 mlir::Value typeDesc = {}) const {
1493 auto loc = box.getLoc();
1494 auto boxTy = box.getType().dyn_cast<fir::BaseBoxType>();
1495 auto inputBoxTy = box.getBox().getType().dyn_cast<fir::BaseBoxType>();
1496 llvm::SmallVector<mlir::Value> typeparams = lenParams;
1497 if (!box.getSubstr().empty() && fir::hasDynamicSize(boxTy.getEleTy()))
1498 typeparams.push_back(substrParams[1]);
1499
1500 auto [eleSize, cfiTy] =
1501 getSizeAndTypeCode(loc, rewriter, boxTy.getEleTy(), typeparams);
1502
1503 // Reboxing to a polymorphic entity. eleSize and type code need to
1504 // be retrieved from the initial box and propagated to the new box.
1505 // If the initial box has an addendum, the type desc must be propagated as
1506 // well.
1507 if (fir::isPolymorphicType(boxTy)) {
1508 mlir::Type idxTy = this->lowerTy().indexType();
1509 eleSize =
1510 this->getElementSizeFromBox(loc, idxTy, boxTy, loweredBox, rewriter);
1511 cfiTy = this->getValueFromBox(loc, boxTy, loweredBox, cfiTy.getType(),
1512 rewriter, kTypePosInBox);
1513 // TODO: For initial box that are unlimited polymorphic entities, this
1514 // code must be made conditional because unlimited polymorphic entities
1515 // with intrinsic type spec does not have addendum.
1516 if (fir::boxHasAddendum(inputBoxTy))
1517 typeDesc = this->loadTypeDescAddress(loc, box.getBox().getType(),
1518 loweredBox, rewriter);
1519 }
1520
1521 auto mod = box->template getParentOfType<mlir::ModuleOp>();
1522 mlir::Value descriptor =
1523 populateDescriptor(loc, mod, boxTy, box.getBox().getType(), rewriter,
1524 rank, eleSize, cfiTy, typeDesc);
1525
1526 return {boxTy, descriptor, eleSize};
1527 }
1528
1529 // Compute the base address of a fir.box given the indices from the slice.
1530 // The indices from the "outer" dimensions (every dimension after the first
1531 // one (inlcuded) that is not a compile time constant) must have been
1532 // multiplied with the related extents and added together into \p outerOffset.
1533 mlir::Value
1534 genBoxOffsetGep(mlir::ConversionPatternRewriter &rewriter, mlir::Location loc,
1535 mlir::Value base, mlir::Value outerOffset,
1536 mlir::ValueRange cstInteriorIndices,
1537 mlir::ValueRange componentIndices,
1538 std::optional<mlir::Value> substringOffset) const {
1539 llvm::SmallVector<mlir::LLVM::GEPArg> gepArgs{outerOffset};
1540 mlir::Type resultTy =
1541 base.getType().cast<mlir::LLVM::LLVMPointerType>().getElementType();
1542 // Fortran is column major, llvm GEP is row major: reverse the indices here.
1543 for (mlir::Value interiorIndex : llvm::reverse(cstInteriorIndices)) {
1544 auto arrayTy = resultTy.dyn_cast<mlir::LLVM::LLVMArrayType>();
1545 if (!arrayTy)
1546 fir::emitFatalError(
1547 loc,
1548 "corrupted GEP generated being generated in fir.embox/fir.rebox");
1549 resultTy = arrayTy.getElementType();
1550 gepArgs.push_back(interiorIndex);
1551 }
1552 for (mlir::Value componentIndex : componentIndices) {
1553 // Component indices can be field index to select a component, or array
1554 // index, to select an element in an array component.
1555 if (auto structTy = resultTy.dyn_cast<mlir::LLVM::LLVMStructType>()) {
1556 std::int64_t cstIndex = getConstantIntValue(componentIndex);
1557 resultTy = structTy.getBody()[cstIndex];
1558 } else if (auto arrayTy =
1559 resultTy.dyn_cast<mlir::LLVM::LLVMArrayType>()) {
1560 resultTy = arrayTy.getElementType();
1561 } else {
1562 fir::emitFatalError(loc, "corrupted component GEP generated being "
1563 "generated in fir.embox/fir.rebox");
1564 }
1565 gepArgs.push_back(componentIndex);
1566 }
1567 if (substringOffset) {
1568 if (auto arrayTy = resultTy.dyn_cast<mlir::LLVM::LLVMArrayType>()) {
1569 gepArgs.push_back(*substringOffset);
1570 resultTy = arrayTy.getElementType();
1571 } else {
1572 // If the CHARACTER length is dynamic, the whole base type should have
1573 // degenerated to an llvm.ptr<i[width]>, and there should not be any
1574 // cstInteriorIndices/componentIndices. The substring offset can be
1575 // added to the outterOffset since it applies on the same LLVM type.
1576 if (gepArgs.size() != 1)
1577 fir::emitFatalError(loc,
1578 "corrupted substring GEP in fir.embox/fir.rebox");
1579 mlir::Type outterOffsetTy = gepArgs[0].get<mlir::Value>().getType();
1580 mlir::Value cast =
1581 this->integerCast(loc, rewriter, outterOffsetTy, *substringOffset);
1582
1583 gepArgs[0] = rewriter.create<mlir::LLVM::AddOp>(
1584 loc, outterOffsetTy, gepArgs[0].get<mlir::Value>(), cast);
1585 }
1586 }
1587 resultTy = mlir::LLVM::LLVMPointerType::get(resultTy);
1588 return rewriter.create<mlir::LLVM::GEPOp>(loc, resultTy, base, gepArgs);
1589 }
1590
1591 template <typename BOX>
1592 void
1593 getSubcomponentIndices(BOX xbox, mlir::Value memref,
1594 mlir::ValueRange operands,
1595 mlir::SmallVectorImpl<mlir::Value> &indices) const {
1596 // For each field in the path add the offset to base via the args list.
1597 // In the most general case, some offsets must be computed since
1598 // they are not be known until runtime.
1599 if (fir::hasDynamicSize(fir::unwrapSequenceType(
1600 fir::unwrapPassByRefType(memref.getType()))))
1601 TODO(xbox.getLoc(),do { fir::emitFatalError(xbox.getLoc(), llvm::Twine("flang/lib/Optimizer/CodeGen/CodeGen.cpp"
":" "1602" ": not yet implemented: ") + llvm::Twine("fir.embox codegen dynamic size component in derived type"
), false); } while (false)
1602 "fir.embox codegen dynamic size component in derived type")do { fir::emitFatalError(xbox.getLoc(), llvm::Twine("flang/lib/Optimizer/CodeGen/CodeGen.cpp"
":" "1602" ": not yet implemented: ") + llvm::Twine("fir.embox codegen dynamic size component in derived type"
), false); } while (false)
;
1603 indices.append(operands.begin() + xbox.subcomponentOffset(),
1604 operands.begin() + xbox.subcomponentOffset() +
1605 xbox.getSubcomponent().size());
1606 }
1607
1608 static bool isInGlobalOp(mlir::ConversionPatternRewriter &rewriter) {
1609 auto *thisBlock = rewriter.getInsertionBlock();
1610 return thisBlock &&
1611 mlir::isa<mlir::LLVM::GlobalOp>(thisBlock->getParentOp());
1612 }
1613
1614 /// If the embox is not in a globalOp body, allocate storage for the box;
1615 /// store the value inside and return the generated alloca. Return the input
1616 /// value otherwise.
1617 mlir::Value
1618 placeInMemoryIfNotGlobalInit(mlir::ConversionPatternRewriter &rewriter,
1619 mlir::Location loc, mlir::Type boxTy,
1620 mlir::Value boxValue) const {
1621 if (isInGlobalOp(rewriter))
1622 return boxValue;
1623 auto boxPtrTy = mlir::LLVM::LLVMPointerType::get(boxValue.getType());
1624 auto alloca =
1625 this->genAllocaWithType(loc, boxPtrTy, defaultAlign, rewriter);
1626 auto storeOp = rewriter.create<mlir::LLVM::StoreOp>(loc, boxValue, alloca);
1627 this->attachTBAATag(storeOp, boxTy, boxTy, nullptr);
1628 return alloca;
1629 }
1630};
1631
1632/// Compute the extent of a triplet slice (lb:ub:step).
1633static mlir::Value
1634computeTripletExtent(mlir::ConversionPatternRewriter &rewriter,
1635 mlir::Location loc, mlir::Value lb, mlir::Value ub,
1636 mlir::Value step, mlir::Value zero, mlir::Type type) {
1637 mlir::Value extent = rewriter.create<mlir::LLVM::SubOp>(loc, type, ub, lb);
1638 extent = rewriter.create<mlir::LLVM::AddOp>(loc, type, extent, step);
1639 extent = rewriter.create<mlir::LLVM::SDivOp>(loc, type, extent, step);
1640 // If the resulting extent is negative (`ub-lb` and `step` have different
1641 // signs), zero must be returned instead.
1642 auto cmp = rewriter.create<mlir::LLVM::ICmpOp>(
1643 loc, mlir::LLVM::ICmpPredicate::sgt, extent, zero);
1644 return rewriter.create<mlir::LLVM::SelectOp>(loc, cmp, extent, zero);
1645}
1646
1647/// Create a generic box on a memory reference. This conversions lowers the
1648/// abstract box to the appropriate, initialized descriptor.
1649struct EmboxOpConversion : public EmboxCommonConversion<fir::EmboxOp> {
1650 using EmboxCommonConversion::EmboxCommonConversion;
1651
1652 mlir::LogicalResult
1653 matchAndRewrite(fir::EmboxOp embox, OpAdaptor adaptor,
1654 mlir::ConversionPatternRewriter &rewriter) const override {
1655 mlir::ValueRange operands = adaptor.getOperands();
1656 mlir::Value sourceBox;
1657 mlir::Type sourceBoxType;
1658 if (embox.getSourceBox()) {
1659 sourceBox = operands[embox.getSourceBoxOffset()];
1660 sourceBoxType = embox.getSourceBox().getType();
1661 }
1662 assert(!embox.getShape() && "There should be no dims on this embox op")(static_cast <bool> (!embox.getShape() && "There should be no dims on this embox op"
) ? void (0) : __assert_fail ("!embox.getShape() && \"There should be no dims on this embox op\""
, "flang/lib/Optimizer/CodeGen/CodeGen.cpp", 1662, __extension__
__PRETTY_FUNCTION__))
;
1663 auto [boxTy, dest, eleSize] = consDescriptorPrefix(
1664 embox, fir::unwrapRefType(embox.getMemref().getType()), rewriter,
1665 /*rank=*/0, /*substrParams=*/mlir::ValueRange{},
1666 adaptor.getTypeparams(), sourceBox, sourceBoxType);
1667 dest = insertBaseAddress(rewriter, embox.getLoc(), dest, operands[0]);
1668 if (fir::isDerivedTypeWithLenParams(boxTy)) {
1669 TODO(embox.getLoc(),do { fir::emitFatalError(embox.getLoc(), llvm::Twine("flang/lib/Optimizer/CodeGen/CodeGen.cpp"
":" "1670" ": not yet implemented: ") + llvm::Twine("fir.embox codegen of derived with length parameters"
), false); } while (false)
1670 "fir.embox codegen of derived with length parameters")do { fir::emitFatalError(embox.getLoc(), llvm::Twine("flang/lib/Optimizer/CodeGen/CodeGen.cpp"
":" "1670" ": not yet implemented: ") + llvm::Twine("fir.embox codegen of derived with length parameters"
), false); } while (false)
;
1671 return mlir::failure();
1672 }
1673 auto result =
1674 placeInMemoryIfNotGlobalInit(rewriter, embox.getLoc(), boxTy, dest);
1675 rewriter.replaceOp(embox, result);
1676 return mlir::success();
1677 }
1678};
1679
1680/// Create a generic box on a memory reference.
1681struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
1682 using EmboxCommonConversion::EmboxCommonConversion;
1683
1684 mlir::LogicalResult
1685 matchAndRewrite(fir::cg::XEmboxOp xbox, OpAdaptor adaptor,
1686 mlir::ConversionPatternRewriter &rewriter) const override {
1687 mlir::ValueRange operands = adaptor.getOperands();
1688 mlir::Value sourceBox;
1689 mlir::Type sourceBoxType;
1690 if (xbox.getSourceBox()) {
1691 sourceBox = operands[xbox.getSourceBoxOffset()];
1692 sourceBoxType = xbox.getSourceBox().getType();
1693 }
1694 auto [boxTy, dest, eleSize] = consDescriptorPrefix(
1695 xbox, fir::unwrapRefType(xbox.getMemref().getType()), rewriter,
1696 xbox.getOutRank(), adaptor.getSubstr(), adaptor.getLenParams(),
1697 sourceBox, sourceBoxType);
1698 // Generate the triples in the dims field of the descriptor
1699 auto i64Ty = mlir::IntegerType::get(xbox.getContext(), 64);
1700 mlir::Value base = operands[0];
1701 assert(!xbox.getShape().empty() && "must have a shape")(static_cast <bool> (!xbox.getShape().empty() &&
"must have a shape") ? void (0) : __assert_fail ("!xbox.getShape().empty() && \"must have a shape\""
, "flang/lib/Optimizer/CodeGen/CodeGen.cpp", 1701, __extension__
__PRETTY_FUNCTION__))
;
1702 unsigned shapeOffset = xbox.shapeOffset();
1703 bool hasShift = !xbox.getShift().empty();
1704 unsigned shiftOffset = xbox.shiftOffset();
1705 bool hasSlice = !xbox.getSlice().empty();
1706 unsigned sliceOffset = xbox.sliceOffset();
1707 mlir::Location loc = xbox.getLoc();
1708 mlir::Value zero = genConstantIndex(loc, i64Ty, rewriter, 0);
1709 mlir::Value one = genConstantIndex(loc, i64Ty, rewriter, 1);
1710 mlir::Value prevPtrOff = one;
1711 mlir::Type eleTy = boxTy.getEleTy();
1712 const unsigned rank = xbox.getRank();
1713 llvm::SmallVector<mlir::Value> cstInteriorIndices;
1714 unsigned constRows = 0;
1715 mlir::Value ptrOffset = zero;
1716 mlir::Type memEleTy = fir::dyn_cast_ptrEleTy(xbox.getMemref().getType());
1717 assert(memEleTy.isa<fir::SequenceType>())(static_cast <bool> (memEleTy.isa<fir::SequenceType>
()) ? void (0) : __assert_fail ("memEleTy.isa<fir::SequenceType>()"
, "flang/lib/Optimizer/CodeGen/CodeGen.cpp", 1717, __extension__
__PRETTY_FUNCTION__))
;
1718 auto seqTy = memEleTy.cast<fir::SequenceType>();
1719 mlir::Type seqEleTy = seqTy.getEleTy();
1720 // Adjust the element scaling factor if the element is a dependent type.
1721 if (fir::hasDynamicSize(seqEleTy)) {
1722 if (auto charTy = seqEleTy.dyn_cast<fir::CharacterType>()) {
1723 prevPtrOff = eleSize;
1724 } else if (seqEleTy.isa<fir::RecordType>()) {
1725 // prevPtrOff = ;
1726 TODO(loc, "generate call to calculate size of PDT")do { fir::emitFatalError(loc, llvm::Twine("flang/lib/Optimizer/CodeGen/CodeGen.cpp"
":" "1726" ": not yet implemented: ") + llvm::Twine("generate call to calculate size of PDT"
), false); } while (false)
;
1727 } else {
1728 fir::emitFatalError(loc, "unexpected dynamic type");
1729 }
1730 } else {
1731 constRows = seqTy.getConstantRows();
1732 }
1733
1734 const auto hasSubcomp = !xbox.getSubcomponent().empty();
1735 const bool hasSubstr = !xbox.getSubstr().empty();
1736 // Initial element stride that will be use to compute the step in
1737 // each dimension.
1738 mlir::Value prevDimByteStride = eleSize;
1739 if (hasSubcomp) {
1740 // We have a subcomponent. The step value needs to be the number of
1741 // bytes per element (which is a derived type).
1742 prevDimByteStride =
1743 genTypeStrideInBytes(loc, i64Ty, rewriter, convertType(seqEleTy));
1744 } else if (hasSubstr) {
1745 // We have a substring. The step value needs to be the number of bytes
1746 // per CHARACTER element.
1747 auto charTy = seqEleTy.cast<fir::CharacterType>();
1748 if (fir::hasDynamicSize(charTy)) {
1749 prevDimByteStride = prevPtrOff;
1750 } else {
1751 prevDimByteStride = genConstantIndex(
1752 loc, i64Ty, rewriter,
1753 charTy.getLen() * lowerTy().characterBitsize(charTy) / 8);
1754 }
1755 }
1756
1757 // Process the array subspace arguments (shape, shift, etc.), if any,
1758 // translating everything to values in the descriptor wherever the entity
1759 // has a dynamic array dimension.
1760 for (unsigned di = 0, descIdx = 0; di < rank; ++di) {
1761 mlir::Value extent = operands[shapeOffset];
1762 mlir::Value outerExtent = extent;
1763 bool skipNext = false;
1764 if (hasSlice) {
1765 mlir::Value off = operands[sliceOffset];
1766 mlir::Value adj = one;
1767 if (hasShift)
1768 adj = operands[shiftOffset];
1769 auto ao = rewriter.create<mlir::LLVM::SubOp>(loc, i64Ty, off, adj);
1770 if (constRows > 0) {
1771 cstInteriorIndices.push_back(ao);
1772 } else {
1773 auto dimOff =
1774 rewriter.create<mlir::LLVM::MulOp>(loc, i64Ty, ao, prevPtrOff);
1775 ptrOffset =
1776 rewriter.create<mlir::LLVM::AddOp>(loc, i64Ty, dimOff, ptrOffset);
1777 }
1778 if (mlir::isa_and_nonnull<fir::UndefOp>(
1779 xbox.getSlice()[3 * di + 1].getDefiningOp())) {
1780 // This dimension contains a scalar expression in the array slice op.
1781 // The dimension is loop invariant, will be dropped, and will not
1782 // appear in the descriptor.
1783 skipNext = true;
1784 }
1785 }
1786 if (!skipNext) {
1787 // store extent
1788 if (hasSlice)
1789 extent = computeTripletExtent(rewriter, loc, operands[sliceOffset],
1790 operands[sliceOffset + 1],
1791 operands[sliceOffset + 2], zero, i64Ty);
1792 // Lower bound is normalized to 0 for BIND(C) interoperability.
1793 mlir::Value lb = zero;
1794 const bool isaPointerOrAllocatable =
1795 eleTy.isa<fir::PointerType>() || eleTy.isa<fir::HeapType>();
1796 // Lower bound is defaults to 1 for POINTER, ALLOCATABLE, and
1797 // denormalized descriptors.
1798 if (isaPointerOrAllocatable || !normalizedLowerBound(xbox))
1799 lb = one;
1800 // If there is a shifted origin, and no fir.slice, and this is not
1801 // a normalized descriptor then use the value from the shift op as
1802 // the lower bound.
1803 if (hasShift && !(hasSlice || hasSubcomp || hasSubstr) &&
1804 (isaPointerOrAllocatable || !normalizedLowerBound(xbox))) {
1805 lb = operands[shiftOffset];
1806 auto extentIsEmpty = rewriter.create<mlir::LLVM::ICmpOp>(
1807 loc, mlir::LLVM::ICmpPredicate::eq, extent, zero);
1808 lb = rewriter.create<mlir::LLVM::SelectOp>(loc, extentIsEmpty, one,
1809 lb);
1810 }
1811 dest = insertLowerBound(rewriter, loc, dest, descIdx, lb);
1812
1813 dest = insertExtent(rewriter, loc, dest, descIdx, extent);
1814
1815 // store step (scaled by shaped extent)
1816 mlir::Value step = prevDimByteStride;
1817 if (hasSlice)
1818 step = rewriter.create<mlir::LLVM::MulOp>(loc, i64Ty, step,
1819 operands[sliceOffset + 2]);
1820 dest = insertStride(rewriter, loc, dest, descIdx, step);
1821 ++descIdx;
1822 }
1823
1824 // compute the stride and offset for the next natural dimension
1825 prevDimByteStride = rewriter.create<mlir::LLVM::MulOp>(
1826 loc, i64Ty, prevDimByteStride, outerExtent);
1827 if (constRows == 0)
1828 prevPtrOff = rewriter.create<mlir::LLVM::MulOp>(loc, i64Ty, prevPtrOff,
1829 outerExtent);
1830 else
1831 --constRows;
1832
1833 // increment iterators
1834 ++shapeOffset;
1835 if (hasShift)
1836 ++shiftOffset;
1837 if (hasSlice)
1838 sliceOffset += 3;
1839 }
1840 if (hasSlice || hasSubcomp || hasSubstr) {
1841 // Shift the base address.
1842 llvm::SmallVector<mlir::Value> fieldIndices;
1843 std::optional<mlir::Value> substringOffset;
1844 if (hasSubcomp)
1845 getSubcomponentIndices(xbox, xbox.getMemref(), operands, fieldIndices);
1846 if (hasSubstr)
1847 substringOffset = operands[xbox.substrOffset()];
1848 base = genBoxOffsetGep(rewriter, loc, base, ptrOffset, cstInteriorIndices,
1849 fieldIndices, substringOffset);
1850 }
1851 dest = insertBaseAddress(rewriter, loc, dest, base);
1852 if (fir::isDerivedTypeWithLenParams(boxTy))
1853 TODO(loc, "fir.embox codegen of derived with length parameters")do { fir::emitFatalError(loc, llvm::Twine("flang/lib/Optimizer/CodeGen/CodeGen.cpp"
":" "1853" ": not yet implemented: ") + llvm::Twine("fir.embox codegen of derived with length parameters"
), false); } while (false)
;
1854
1855 mlir::Value result =
1856 placeInMemoryIfNotGlobalInit(rewriter, loc, boxTy, dest);
1857 rewriter.replaceOp(xbox, result);
1858 return mlir::success();
1859 }
1860
1861 /// Return true if `xbox` has a normalized lower bounds attribute. A box value
1862 /// that is neither a POINTER nor an ALLOCATABLE should be normalized to a
1863 /// zero origin lower bound for interoperability with BIND(C).
1864 inline static bool normalizedLowerBound(fir::cg::XEmboxOp xbox) {
1865 return xbox->hasAttr(fir::getNormalizedLowerBoundAttrName());
1866 }
1867};
1868
1869/// Create a new box given a box reference.
1870struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
1871 using EmboxCommonConversion::EmboxCommonConversion;
1872
1873 mlir::LogicalResult
1874 matchAndRewrite(fir::cg::XReboxOp rebox, OpAdaptor adaptor,
1875 mlir::ConversionPatternRewriter &rewriter) const override {
1876 mlir::Location loc = rebox.getLoc();
1877 mlir::Type idxTy = lowerTy().indexType();
1878 mlir::Value loweredBox = adaptor.getOperands()[0];
1879 mlir::ValueRange operands = adaptor.getOperands();
1880
1881 // Inside a fir.global, the input box was produced as an llvm.struct<>
1882 // because objects cannot be handled in memory inside a fir.global body that
1883 // must be constant foldable. However, the type translation are not
1884 // contextual, so the fir.box<T> type of the operation that produced the
1885 // fir.box was translated to an llvm.ptr<llvm.struct<>> and the MLIR pass
1886 // manager inserted a builtin.unrealized_conversion_cast that was inserted
1887 // and needs to be removed here.
1888 if (isInGlobalOp(rewriter))
1889 if (auto unrealizedCast =
1890 loweredBox.getDefiningOp<mlir::UnrealizedConversionCastOp>())
1891 loweredBox = unrealizedCast.getInputs()[0];
1892
1893 // Create new descriptor and fill its non-shape related data.
1894 llvm::SmallVector<mlir::Value, 2> lenParams;
1895 mlir::Type inputEleTy = getInputEleTy(rebox);
1896 if (auto charTy = inputEleTy.dyn_cast<fir::CharacterType>()) {
1897 mlir::Value len = getElementSizeFromBox(
1898 loc, idxTy, rebox.getBox().getType(), loweredBox, rewriter);
1899 if (charTy.getFKind() != 1) {
1900 mlir::Value width =
1901 genConstantIndex(loc, idxTy, rewriter, charTy.getFKind());
1902 len = rewriter.create<mlir::LLVM::SDivOp>(loc, idxTy, len, width);
1903 }
1904 lenParams.emplace_back(len);
1905 } else if (auto recTy = inputEleTy.dyn_cast<fir::RecordType>()) {
1906 if (recTy.getNumLenParams() != 0)
1907 TODO(loc, "reboxing descriptor of derived type with length parameters")do { fir::emitFatalError(loc, llvm::Twine("flang/lib/Optimizer/CodeGen/CodeGen.cpp"
":" "1907" ": not yet implemented: ") + llvm::Twine("reboxing descriptor of derived type with length parameters"
), false); } while (false)
;
1908 }
1909
1910 // Rebox on polymorphic entities needs to carry over the dynamic type.
1911 mlir::Value typeDescAddr;
1912 if (rebox.getBox().getType().isa<fir::ClassType>() &&
1913 rebox.getType().isa<fir::ClassType>())
1914 typeDescAddr = loadTypeDescAddress(loc, rebox.getBox().getType(),
1915 loweredBox, rewriter);
1916
1917 auto [boxTy, dest, eleSize] =
1918 consDescriptorPrefix(rebox, loweredBox, rewriter, rebox.getOutRank(),
1919 adaptor.getSubstr(), lenParams, typeDescAddr);
1920
1921 // Read input extents, strides, and base address
1922 llvm::SmallVector<mlir::Value> inputExtents;
1923 llvm::SmallVector<mlir::Value> inputStrides;
1924 const unsigned inputRank = rebox.getRank();
1925 for (unsigned dim = 0; dim < inputRank; ++dim) {
1926 llvm::SmallVector<mlir::Value, 3> dimInfo =
1927 getDimsFromBox(loc, {idxTy, idxTy, idxTy}, rebox.getBox().getType(),
1928 loweredBox, dim, rewriter);
1929 inputExtents.emplace_back(dimInfo[1]);
1930 inputStrides.emplace_back(dimInfo[2]);
1931 }
1932
1933 mlir::Type baseTy = getBaseAddrTypeFromBox(loweredBox.getType());
1934 mlir::Value baseAddr = getBaseAddrFromBox(
1935 loc, baseTy, rebox.getBox().getType(), loweredBox, rewriter);
1936
1937 if (!rebox.getSlice().empty() || !rebox.getSubcomponent().empty())
1938 return sliceBox(rebox, boxTy, dest, baseAddr, inputExtents, inputStrides,
1939 operands, rewriter);
1940 return reshapeBox(rebox, boxTy, dest, baseAddr, inputExtents, inputStrides,
1941 operands, rewriter);
1942 }
1943
1944private:
1945 /// Write resulting shape and base address in descriptor, and replace rebox
1946 /// op.
1947 mlir::LogicalResult
1948 finalizeRebox(fir::cg::XReboxOp rebox, mlir::Type destBoxTy, mlir::Value dest,
1949 mlir::Value base, mlir::ValueRange lbounds,
1950 mlir::ValueRange extents, mlir::ValueRange strides,
1951 mlir::ConversionPatternRewriter &rewriter) const {
1952 mlir::Location loc = rebox.getLoc();
1953 mlir::Value zero =
1954 genConstantIndex(loc, lowerTy().indexType(), rewriter, 0);
1955 mlir::Value one = genConstantIndex(loc, lowerTy().indexType(), rewriter, 1);
1956 for (auto iter : llvm::enumerate(llvm::zip(extents, strides))) {
1957 mlir::Value extent = std::get<0>(iter.value());
1958 unsigned dim = iter.index();
1959 mlir::Value lb = one;
1960 if (!lbounds.empty()) {
1961 lb = lbounds[dim];
1962 auto extentIsEmpty = rewriter.create<mlir::LLVM::ICmpOp>(
1963 loc, mlir::LLVM::ICmpPredicate::eq, extent, zero);
1964 lb = rewriter.create<mlir::LLVM::SelectOp>(loc, extentIsEmpty, one, lb);
1965 };
1966 dest = insertLowerBound(rewriter, loc, dest, dim, lb);
1967 dest = insertExtent(rewriter, loc, dest, dim, extent);
1968 dest = insertStride(rewriter, loc, dest, dim, std::get<1>(iter.value()));
1969 }
1970 dest = insertBaseAddress(rewriter, loc, dest, base);
1971 mlir::Value result =
1972 placeInMemoryIfNotGlobalInit(rewriter, rebox.getLoc(), destBoxTy, dest);
1973 rewriter.replaceOp(rebox, result);
1974 return mlir::success();
1975 }
1976
1977 // Apply slice given the base address, extents and strides of the input box.
1978 mlir::LogicalResult
1979 sliceBox(fir::cg::XReboxOp rebox, mlir::Type destBoxTy, mlir::Value dest,
1980 mlir::Value base, mlir::ValueRange inputExtents,
1981 mlir::ValueRange inputStrides, mlir::ValueRange operands,
1982 mlir::ConversionPatternRewriter &rewriter) const {
1983 mlir::Location loc = rebox.getLoc();
1984 mlir::Type voidPtrTy = ::getVoidPtrType(rebox.getContext());
1985 mlir::Type idxTy = lowerTy().indexType();
1986 mlir::Value zero = genConstantIndex(loc, idxTy, rewriter, 0);
1987 // Apply subcomponent and substring shift on base address.
1988 if (!rebox.getSubcomponent().empty() || !rebox.getSubstr().empty()) {
1989 // Cast to inputEleTy* so that a GEP can be used.
1990 mlir::Type inputEleTy = getInputEleTy(rebox);
1991 auto llvmElePtrTy =
1992 mlir::LLVM::LLVMPointerType::get(convertType(inputEleTy));
1993 base = rewriter.create<mlir::LLVM::BitcastOp>(loc, llvmElePtrTy, base);
1994
1995 llvm::SmallVector<mlir::Value> fieldIndices;
1996 std::optional<mlir::Value> substringOffset;
1997 if (!rebox.getSubcomponent().empty())
1998 getSubcomponentIndices(rebox, rebox.getBox(), operands, fieldIndices);
1999 if (!rebox.getSubstr().empty())
2000 substringOffset = operands[rebox.substrOffset()];
2001 base = genBoxOffsetGep(rewriter, loc, base, zero,
2002 /*cstInteriorIndices=*/std::nullopt, fieldIndices,
2003 substringOffset);
2004 }
2005
2006 if (rebox.getSlice().empty())
2007 // The array section is of the form array[%component][substring], keep
2008 // the input array extents and strides.
2009 return finalizeRebox(rebox, destBoxTy, dest, base,
2010 /*lbounds*/ std::nullopt, inputExtents, inputStrides,
2011 rewriter);
2012
2013 // Strides from the fir.box are in bytes.
2014 base = rewriter.create<mlir::LLVM::BitcastOp>(loc, voidPtrTy, base);
2015
2016 // The slice is of the form array(i:j:k)[%component]. Compute new extents
2017 // and strides.
2018 llvm::SmallVector<mlir::Value> slicedExtents;
2019 llvm::SmallVector<mlir::Value> slicedStrides;
2020 mlir::Value one = genConstantIndex(loc, idxTy, rewriter, 1);
2021 const bool sliceHasOrigins = !rebox.getShift().empty();
2022 unsigned sliceOps = rebox.sliceOffset();
2023 unsigned shiftOps = rebox.shiftOffset();
2024 auto strideOps = inputStrides.begin();
2025 const unsigned inputRank = inputStrides.size();
2026 for (unsigned i = 0; i < inputRank;
2027 ++i, ++strideOps, ++shiftOps, sliceOps += 3) {
2028 mlir::Value sliceLb =
2029 integerCast(loc, rewriter, idxTy, operands[sliceOps]);
2030 mlir::Value inputStride = *strideOps; // already idxTy
2031 // Apply origin shift: base += (lb-shift)*input_stride
2032 mlir::Value sliceOrigin =
2033 sliceHasOrigins
2034 ? integerCast(loc, rewriter, idxTy, operands[shiftOps])
2035 : one;
2036 mlir::Value diff =
2037 rewriter.create<mlir::LLVM::SubOp>(loc, idxTy, sliceLb, sliceOrigin);
2038 mlir::Value offset =
2039 rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, diff, inputStride);
2040 base = genGEP(loc, voidPtrTy, rewriter, base, offset);
2041 // Apply upper bound and step if this is a triplet. Otherwise, the
2042 // dimension is dropped and no extents/strides are computed.
2043 mlir::Value upper = operands[sliceOps + 1];
2044 const bool isTripletSlice =
2045 !mlir::isa_and_nonnull<mlir::LLVM::UndefOp>(upper.getDefiningOp());
2046 if (isTripletSlice) {
2047 mlir::Value step =
2048 integerCast(loc, rewriter, idxTy, operands[sliceOps + 2]);
2049 // extent = ub-lb+step/step
2050 mlir::Value sliceUb = integerCast(loc, rewriter, idxTy, upper);
2051 mlir::Value extent = computeTripletExtent(rewriter, loc, sliceLb,
2052 sliceUb, step, zero, idxTy);
2053 slicedExtents.emplace_back(extent);
2054 // stride = step*input_stride
2055 mlir::Value stride =
2056 rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, step, inputStride);
2057 slicedStrides.emplace_back(stride);
2058 }
2059 }
2060 return finalizeRebox(rebox, destBoxTy, dest, base, /*lbounds*/ std::nullopt,
2061 slicedExtents, slicedStrides, rewriter);
2062 }
2063
2064 /// Apply a new shape to the data described by a box given the base address,
2065 /// extents and strides of the box.
2066 mlir::LogicalResult
2067 reshapeBox(fir::cg::XReboxOp rebox, mlir::Type destBoxTy, mlir::Value dest,
2068 mlir::Value base, mlir::ValueRange inputExtents,
2069 mlir::ValueRange inputStrides, mlir::ValueRange operands,
2070 mlir::ConversionPatternRewriter &rewriter) const {
2071 mlir::ValueRange reboxShifts{operands.begin() + rebox.shiftOffset(),
2072 operands.begin() + rebox.shiftOffset() +
2073 rebox.getShift().size()};
2074 if (rebox.getShape().empty()) {
2075 // Only setting new lower bounds.
2076 return finalizeRebox(rebox, destBoxTy, dest, base, reboxShifts,
2077 inputExtents, inputStrides, rewriter);
2078 }
2079
2080 mlir::Location loc = rebox.getLoc();
2081 // Strides from the fir.box are in bytes.
2082 mlir::Type voidPtrTy = ::getVoidPtrType(rebox.getContext());
2083 base = rewriter.create<mlir::LLVM::BitcastOp>(loc, voidPtrTy, base);
2084
2085 llvm::SmallVector<mlir::Value> newStrides;
2086 llvm::SmallVector<mlir::Value> newExtents;
2087 mlir::Type idxTy = lowerTy().indexType();
2088 // First stride from input box is kept. The rest is assumed contiguous
2089 // (it is not possible to reshape otherwise). If the input is scalar,
2090 // which may be OK if all new extents are ones, the stride does not
2091 // matter, use one.
2092 mlir::Value stride = inputStrides.empty()
2093 ? genConstantIndex(loc, idxTy, rewriter, 1)
2094 : inputStrides[0];
2095 for (unsigned i = 0; i < rebox.getShape().size(); ++i) {
2096 mlir::Value rawExtent = operands[rebox.shapeOffset() + i];
2097 mlir::Value extent = integerCast(loc, rewriter, idxTy, rawExtent);
2098 newExtents.emplace_back(extent);
2099 newStrides.emplace_back(stride);
2100 // nextStride = extent * stride;
2101 stride = rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, extent, stride);
2102 }
2103 return finalizeRebox(rebox, destBoxTy, dest, base, reboxShifts, newExtents,
2104 newStrides, rewriter);
2105 }
2106
2107 /// Return scalar element type of the input box.
2108 static mlir::Type getInputEleTy(fir::cg::XReboxOp rebox) {
2109 auto ty = fir::dyn_cast_ptrOrBoxEleTy(rebox.getBox().getType());
2110 if (auto seqTy = ty.dyn_cast<fir::SequenceType>())
2111 return seqTy.getEleTy();
2112 return ty;
2113 }
2114};
2115
2116/// Lower `fir.emboxproc` operation. Creates a procedure box.
2117/// TODO: Part of supporting Fortran 2003 procedure pointers.
2118struct EmboxProcOpConversion : public FIROpConversion<fir::EmboxProcOp> {
2119 using FIROpConversion::FIROpConversion;
2120
2121 mlir::LogicalResult
2122 matchAndRewrite(fir::EmboxProcOp emboxproc, OpAdaptor adaptor,
2123 mlir::ConversionPatternRewriter &rewriter) const override {
2124 TODO(emboxproc.getLoc(), "fir.emboxproc codegen")do { fir::emitFatalError(emboxproc.getLoc(), llvm::Twine("flang/lib/Optimizer/CodeGen/CodeGen.cpp"
":" "2124" ": not yet implemented: ") + llvm::Twine("fir.emboxproc codegen"
), false); } while (false)
;
2125 return mlir::failure();
2126 }
2127};
2128
2129// Code shared between insert_value and extract_value Ops.
2130struct ValueOpCommon {
2131 // Translate the arguments pertaining to any multidimensional array to
2132 // row-major order for LLVM-IR.
2133 static void toRowMajor(llvm::SmallVectorImpl<int64_t> &indices,
2134 mlir::Type ty) {
2135 assert(ty && "type is null")(static_cast <bool> (ty && "type is null") ? void
(0) : __assert_fail ("ty && \"type is null\"", "flang/lib/Optimizer/CodeGen/CodeGen.cpp"
, 2135, __extension__ __PRETTY_FUNCTION__))
;
2136 const auto end = indices.size();
2137 for (std::remove_const_t<decltype(end)> i = 0; i < end; ++i) {
2138 if (auto seq = ty.dyn_cast<mlir::LLVM::LLVMArrayType>()) {
2139 const auto dim = getDimension(seq);
2140 if (dim > 1) {
2141 auto ub = std::min(i + dim, end);
2142 std::reverse(indices.begin() + i, indices.begin() + ub);
2143 i += dim - 1;
2144 }
2145 ty = getArrayElementType(seq);
2146 } else if (auto st = ty.dyn_cast<mlir::LLVM::LLVMStructType>()) {
2147 ty = st.getBody()[indices[i]];
2148 } else {
2149 llvm_unreachable("index into invalid type")::llvm::llvm_unreachable_internal("index into invalid type", "flang/lib/Optimizer/CodeGen/CodeGen.cpp"
, 2149)
;
2150 }
2151 }
2152 }
2153
2154 static llvm::SmallVector<int64_t>
2155 collectIndices(mlir::ConversionPatternRewriter &rewriter,
2156 mlir::ArrayAttr arrAttr) {
2157 llvm::SmallVector<int64_t> indices;
2158 for (auto i = arrAttr.begin(), e = arrAttr.end(); i != e; ++i) {
2159 if (auto intAttr = i->dyn_cast<mlir::IntegerAttr>()) {
2160 indices.push_back(intAttr.getInt());
2161 } else {
2162 auto fieldName = i->cast<mlir::StringAttr>().getValue();
2163 ++i;
2164 auto ty = i->cast<mlir::TypeAttr>().getValue();
2165 auto index = ty.cast<fir::RecordType>().getFieldIndex(fieldName);
2166 indices.push_back(index);
2167 }
2168 }
2169 return indices;
2170 }
2171
2172private:
2173 static mlir::Type getArrayElementType(mlir::LLVM::LLVMArrayType ty) {
2174 auto eleTy = ty.getElementType();
2175 while (auto arrTy = eleTy.dyn_cast<mlir::LLVM::LLVMArrayType>())
2176 eleTy = arrTy.getElementType();
2177 return eleTy;
2178 }
2179};
2180
2181namespace {
2182/// Extract a subobject value from an ssa-value of aggregate type
2183struct ExtractValueOpConversion
2184 : public FIROpAndTypeConversion<fir::ExtractValueOp>,
2185 public ValueOpCommon {
2186 using FIROpAndTypeConversion::FIROpAndTypeConversion;
2187
2188 mlir::LogicalResult
2189 doRewrite(fir::ExtractValueOp extractVal, mlir::Type ty, OpAdaptor adaptor,
2190 mlir::ConversionPatternRewriter &rewriter) const override {
2191 mlir::ValueRange operands = adaptor.getOperands();
2192 auto indices = collectIndices(rewriter, extractVal.getCoor());
2193 toRowMajor(indices, operands[0].getType());
2194 rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>(
2195 extractVal, operands[0], indices);
2196 return mlir::success();
2197 }
2198};
2199
2200/// InsertValue is the generalized instruction for the composition of new
2201/// aggregate type values.
2202struct InsertValueOpConversion
2203 : public FIROpAndTypeConversion<fir::InsertValueOp>,
2204 public ValueOpCommon {
2205 using FIROpAndTypeConversion::FIROpAndTypeConversion;
2206
2207 mlir::LogicalResult
2208 doRewrite(fir::InsertValueOp insertVal, mlir::Type ty, OpAdaptor adaptor,
2209 mlir::ConversionPatternRewriter &rewriter) const override {
2210 mlir::ValueRange operands = adaptor.getOperands();
2211 auto indices = collectIndices(rewriter, insertVal.getCoor());
2212 toRowMajor(indices, operands[0].getType());
2213 rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(
2214 insertVal, operands[0], operands[1], indices);
2215 return mlir::success();
2216 }
2217};
2218
2219/// InsertOnRange inserts a value into a sequence over a range of offsets.
2220struct InsertOnRangeOpConversion
2221 : public FIROpAndTypeConversion<fir::InsertOnRangeOp> {
2222 using FIROpAndTypeConversion::FIROpAndTypeConversion;
2223
2224 // Increments an array of subscripts in a row major fasion.
2225 void incrementSubscripts(llvm::ArrayRef<int64_t> dims,
2226 llvm::SmallVectorImpl<int64_t> &subscripts) const {
2227 for (size_t i = dims.size(); i > 0; --i) {
2228 if (++subscripts[i - 1] < dims[i - 1]) {
2229 return;
2230 }
2231 subscripts[i - 1] = 0;
2232 }
2233 }
2234
2235 mlir::LogicalResult
2236 doRewrite(fir::InsertOnRangeOp range, mlir::Type ty, OpAdaptor adaptor,
2237 mlir::ConversionPatternRewriter &rewriter) const override {
2238
2239 llvm::SmallVector<std::int64_t> dims;
2240 auto type = adaptor.getOperands()[0].getType();
2241
2242 // Iteratively extract the array dimensions from the type.
2243 while (auto t = type.dyn_cast<mlir::LLVM::LLVMArrayType>()) {
2244 dims.push_back(t.getNumElements());
2245 type = t.getElementType();
2246 }
2247
2248 llvm::SmallVector<std::int64_t> lBounds;
2249 llvm::SmallVector<std::int64_t> uBounds;
2250
2251 // Unzip the upper and lower bound and convert to a row major format.
2252 mlir::DenseIntElementsAttr coor = range.getCoor();
2253 auto reversedCoor = llvm::reverse(coor.getValues<int64_t>());
2254 for (auto i = reversedCoor.begin(), e = reversedCoor.end(); i != e; ++i) {
2255 uBounds.push_back(*i++);
2256 lBounds.push_back(*i);
2257 }
2258
2259 auto &subscripts = lBounds;
2260 auto loc = range.getLoc();
2261 mlir::Value lastOp = adaptor.getOperands()[0];
2262 mlir::Value insertVal = adaptor.getOperands()[1];
2263
2264 while (subscripts != uBounds) {
2265 lastOp = rewriter.create<mlir::LLVM::InsertValueOp>(
2266 loc, lastOp, insertVal, subscripts);
2267
2268 incrementSubscripts(dims, subscripts);
2269 }
2270
2271 rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(
2272 range, lastOp, insertVal, subscripts);
2273
2274 return mlir::success();
2275 }
2276};
2277} // namespace
2278
2279namespace {
2280/// XArrayCoor is the address arithmetic on a dynamically shaped, sliced,
2281/// shifted etc. array.
2282/// (See the static restriction on coordinate_of.) array_coor determines the
2283/// coordinate (location) of a specific element.
2284struct XArrayCoorOpConversion
2285 : public FIROpAndTypeConversion<fir::cg::XArrayCoorOp> {
2286 using FIROpAndTypeConversion::FIROpAndTypeConversion;
2287
2288 mlir::LogicalResult
2289 doRewrite(fir::cg::XArrayCoorOp coor, mlir::Type ty, OpAdaptor adaptor,
2290 mlir::ConversionPatternRewriter &rewriter) const override {
2291 auto loc = coor.getLoc();
2292 mlir::ValueRange operands = adaptor.getOperands();
2293 unsigned rank = coor.getRank();
2294 assert(coor.getIndices().size() == rank)(static_cast <bool> (coor.getIndices().size() == rank) ?
void (0) : __assert_fail ("coor.getIndices().size() == rank"
, "flang/lib/Optimizer/CodeGen/CodeGen.cpp", 2294, __extension__
__PRETTY_FUNCTION__))
;
2295 assert(coor.getShape().empty() || coor.getShape().size() == rank)(static_cast <bool> (coor.getShape().empty() || coor.getShape
().size() == rank) ? void (0) : __assert_fail ("coor.getShape().empty() || coor.getShape().size() == rank"
, "flang/lib/Optimizer/CodeGen/CodeGen.cpp", 2295, __extension__
__PRETTY_FUNCTION__))
;
2296 assert(coor.getShift().empty() || coor.getShift().size() == rank)(static_cast <bool> (coor.getShift().empty() || coor.getShift
().size() == rank) ? void (0) : __assert_fail ("coor.getShift().empty() || coor.getShift().size() == rank"
, "flang/lib/Optimizer/CodeGen/CodeGen.cpp", 2296, __extension__
__PRETTY_FUNCTION__))
;
2297 assert(coor.getSlice().empty() || coor.getSlice().size() == 3 * rank)(static_cast <bool> (coor.getSlice().empty() || coor.getSlice
().size() == 3 * rank) ? void (0) : __assert_fail ("coor.getSlice().empty() || coor.getSlice().size() == 3 * rank"
, "flang/lib/Optimizer/CodeGen/CodeGen.cpp", 2297, __extension__
__PRETTY_FUNCTION__))
;
2298 mlir::Type idxTy = lowerTy().indexType();
2299 unsigned indexOffset = coor.indicesOffset();
2300 unsigned shapeOffset = coor.shapeOffset();
2301 unsigned shiftOffset = coor.shiftOffset();
2302 unsigned sliceOffset = coor.sliceOffset();
2303 auto sliceOps = coor.getSlice().begin();
2304 mlir::Value one = genConstantIndex(loc, idxTy, rewriter, 1);
2305 mlir::Value prevExt = one;
2306 mlir::Value offset = genConstantIndex(loc, idxTy, rewriter, 0);
2307 const bool isShifted = !coor.getShift().empty();
2308 const bool isSliced = !coor.getSlice().empty();
2309 const bool baseIsBoxed = coor.getMemref().getType().isa<fir::BaseBoxType>();
2310
2311 // For each dimension of the array, generate the offset calculation.
2312 for (unsigned i = 0; i < rank; ++i, ++indexOffset, ++shapeOffset,
2313 ++shiftOffset, sliceOffset += 3, sliceOps += 3) {
2314 mlir::Value index =
2315 integerCast(loc, rewriter, idxTy, operands[indexOffset]);
2316 mlir::Value lb =
2317 isShifted ? integerCast(loc, rewriter, idxTy, operands[shiftOffset])
2318 : one;
2319 mlir::Value step = one;
2320 bool normalSlice = isSliced;
2321 // Compute zero based index in dimension i of the element, applying
2322 // potential triplets and lower bounds.
2323 if (isSliced) {
2324 mlir::Value originalUb = *(sliceOps + 1);
2325 normalSlice =
2326 !mlir::isa_and_nonnull<fir::UndefOp>(originalUb.getDefiningOp());
2327 if (normalSlice)
2328 step = integerCast(loc, rewriter, idxTy, operands[sliceOffset + 2]);
2329 }
2330 auto idx = rewriter.create<mlir::LLVM::SubOp>(loc, idxTy, index, lb);
2331 mlir::Value diff =
2332 rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, idx, step);
2333 if (normalSlice) {
2334 mlir::Value sliceLb =
2335 integerCast(loc, rewriter, idxTy, operands[sliceOffset]);
2336 auto adj = rewriter.create<mlir::LLVM::SubOp>(loc, idxTy, sliceLb, lb);
2337 diff = rewriter.create<mlir::LLVM::AddOp>(loc, idxTy, diff, adj);
2338 }
2339 // Update the offset given the stride and the zero based index `diff`
2340 // that was just computed.
2341 if (baseIsBoxed) {
2342 // Use stride in bytes from the descriptor.
2343 mlir::Value stride = getStrideFromBox(loc, coor.getMemref().getType(),
2344 operands[0], i, rewriter);
2345 auto sc = rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, diff, stride);
2346 offset = rewriter.create<mlir::LLVM::AddOp>(loc, idxTy, sc, offset);
2347 } else {
2348 // Use stride computed at last iteration.
2349 auto sc = rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, diff, prevExt);
2350 offset = rewriter.create<mlir::LLVM::AddOp>(loc, idxTy, sc, offset);
2351 // Compute next stride assuming contiguity of the base array
2352 // (in element number).
2353 auto nextExt = integerCast(loc, rewriter, idxTy, operands[shapeOffset]);
2354 prevExt =
2355 rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, prevExt, nextExt);
2356 }
2357 }
2358
2359 // Add computed offset to the base address.
2360 if (baseIsBoxed) {
2361 // Working with byte offsets. The base address is read from the fir.box.
2362 // and need to be casted to i8* to do the pointer arithmetic.
2363 mlir::Type baseTy = getBaseAddrTypeFromBox(operands[0].getType());
2364 mlir::Value base = getBaseAddrFromBox(
2365 loc, baseTy, coor.getMemref().getType(), operands[0], rewriter);
2366 mlir::Type voidPtrTy = getVoidPtrType();
2367 base = rewriter.create<mlir::LLVM::BitcastOp>(loc, voidPtrTy, base);
2368 llvm::SmallVector<mlir::LLVM::GEPArg> args{offset};
2369 auto addr =
2370 rewriter.create<mlir::LLVM::GEPOp>(loc, voidPtrTy, base, args);
2371 if (coor.getSubcomponent().empty()) {
2372 rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(coor, ty, addr);
2373 return mlir::success();
2374 }
2375 // Cast the element address from void* to the derived type so that the
2376 // derived type members can be addresses via a GEP using the index of
2377 // components.
2378 mlir::Type elementType =
2379 baseTy.cast<mlir::LLVM::LLVMPointerType>().getElementType();
2380 while (auto arrayTy = elementType.dyn_cast<mlir::LLVM::LLVMArrayType>())
2381 elementType = arrayTy.getElementType();
2382 mlir::Type elementPtrType = mlir::LLVM::LLVMPointerType::get(elementType);
2383 auto casted =
2384 rewriter.create<mlir::LLVM::BitcastOp>(loc, elementPtrType, addr);
2385 args.clear();
2386 args.push_back(0);
2387 if (!coor.getLenParams().empty()) {
2388 // If type parameters are present, then we don't want to use a GEPOp
2389 // as below, as the LLVM struct type cannot be statically defined.
2390 TODO(loc, "derived type with type parameters")do { fir::emitFatalError(loc, llvm::Twine("flang/lib/Optimizer/CodeGen/CodeGen.cpp"
":" "2390" ": not yet implemented: ") + llvm::Twine("derived type with type parameters"
), false); } while (false)
;
2391 }
2392 // TODO: array offset subcomponents must be converted to LLVM's
2393 // row-major layout here.
2394 for (auto i = coor.subcomponentOffset(); i != coor.indicesOffset(); ++i)
2395 args.push_back(operands[i]);
2396 rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(coor, ty, casted, args);
2397 return mlir::success();
2398 }
2399
2400 // The array was not boxed, so it must be contiguous. offset is therefore an
2401 // element offset and the base type is kept in the GEP unless the element
2402 // type size is itself dynamic.
2403 mlir::Value base;
2404 if (coor.getSubcomponent().empty()) {
2405 // No subcomponent.
2406 if (!coor.getLenParams().empty()) {
2407 // Type parameters. Adjust element size explicitly.
2408 auto eleTy = fir::dyn_cast_ptrEleTy(coor.getType());
2409 assert(eleTy && "result must be a reference-like type")(static_cast <bool> (eleTy && "result must be a reference-like type"
) ? void (0) : __assert_fail ("eleTy && \"result must be a reference-like type\""
, "flang/lib/Optimizer/CodeGen/CodeGen.cpp", 2409, __extension__
__PRETTY_FUNCTION__))
;
2410 if (fir::characterWithDynamicLen(eleTy)) {
2411 assert(coor.getLenParams().size() == 1)(static_cast <bool> (coor.getLenParams().size() == 1) ?
void (0) : __assert_fail ("coor.getLenParams().size() == 1",
"flang/lib/Optimizer/CodeGen/CodeGen.cpp", 2411, __extension__
__PRETTY_FUNCTION__))
;
2412 auto length = integerCast(loc, rewriter, idxTy,
2413 operands[coor.lenParamsOffset()]);
2414 offset =
2415 rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, offset, length);
2416 } else {
2417 TODO(loc, "compute size of derived type with type parameters")do { fir::emitFatalError(loc, llvm::Twine("flang/lib/Optimizer/CodeGen/CodeGen.cpp"
":" "2417" ": not yet implemented: ") + llvm::Twine("compute size of derived type with type parameters"
), false); } while (false)
;
2418 }
2419 }
2420 // Cast the base address to a pointer to T.
2421 base = rewriter.create<mlir::LLVM::BitcastOp>(loc, ty, operands[0]);
2422 } else {
2423 // Operand #0 must have a pointer type. For subcomponent slicing, we
2424 // want to cast away the array type and have a plain struct type.
2425 mlir::Type ty0 = operands[0].getType();
2426 auto ptrTy = ty0.dyn_cast<mlir::LLVM::LLVMPointerType>();
2427 assert(ptrTy && "expected pointer type")(static_cast <bool> (ptrTy && "expected pointer type"
) ? void (0) : __assert_fail ("ptrTy && \"expected pointer type\""
, "flang/lib/Optimizer/CodeGen/CodeGen.cpp", 2427, __extension__
__PRETTY_FUNCTION__))
;
2428 mlir::Type eleTy = ptrTy.getElementType();
2429 while (auto arrTy = eleTy.dyn_cast<mlir::LLVM::LLVMArrayType>())
2430 eleTy = arrTy.getElementType();
2431 auto newTy = mlir::LLVM::LLVMPointerType::get(eleTy);
2432 base = rewriter.create<mlir::LLVM::BitcastOp>(loc, newTy, operands[0]);
2433 }
2434 llvm::SmallVector<mlir::LLVM::GEPArg> args = {offset};
2435 for (auto i = coor.subcomponentOffset(); i != coor.indicesOffset(); ++i)
2436 args.push_back(operands[i]);
2437 rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(coor, ty, base, args);
2438 return mlir::success();
2439 }
2440};
2441} // namespace
2442
2443/// Convert to (memory) reference to a reference to a subobject.
2444/// The coordinate_of op is a Swiss army knife operation that can be used on
2445/// (memory) references to records, arrays, complex, etc. as well as boxes.
2446/// With unboxed arrays, there is the restriction that the array have a static
2447/// shape in all but the last column.
2448struct CoordinateOpConversion
2449 : public FIROpAndTypeConversion<fir::CoordinateOp> {
2450 using FIROpAndTypeConversion::FIROpAndTypeConversion;
2451
2452 mlir::LogicalResult
2453 doRewrite(fir::CoordinateOp coor, mlir::Type ty, OpAdaptor adaptor,
2454 mlir::ConversionPatternRewriter &rewriter) const override {
2455 mlir::ValueRange operands = adaptor.getOperands();
2456
2457 mlir::Location loc = coor.getLoc();
2458 mlir::Value base = operands[0];
2459 mlir::Type baseObjectTy = coor.getBaseType();
2460 mlir::Type objectTy = fir::dyn_cast_ptrOrBoxEleTy(baseObjectTy);
2461 assert(objectTy && "fir.coordinate_of expects a reference type")(static_cast <bool> (objectTy && "fir.coordinate_of expects a reference type"
) ? void (0) : __assert_fail ("objectTy && \"fir.coordinate_of expects a reference type\""
, "flang/lib/Optimizer/CodeGen/CodeGen.cpp", 2461, __extension__
__PRETTY_FUNCTION__))
;
2462
2463 // Complex type - basically, extract the real or imaginary part
2464 if (fir::isa_complex(objectTy)) {
2465 mlir::Value gep = genGEP(loc, ty, rewriter, base, 0, operands[1]);
2466 rewriter.replaceOp(coor, gep);
2467 return mlir::success();
2468 }
2469
2470 // Boxed type - get the base pointer from the box
2471 if (baseObjectTy.dyn_cast<fir::BaseBoxType>())
2472 return doRewriteBox(coor, ty, operands, loc, rewriter);
2473
2474 // Reference, pointer or a heap type
2475 if (baseObjectTy.isa<fir::ReferenceType, fir::PointerType, fir::HeapType>())
2476 return doRewriteRefOrPtr(coor, ty, operands, loc, rewriter);
2477
2478 return rewriter.notifyMatchFailure(
2479 coor, "fir.coordinate_of base operand has unsupported type");
2480 }
2481
2482 static unsigned getFieldNumber(fir::RecordType ty, mlir::Value op) {
2483 return fir::hasDynamicSize(ty)
2484 ? op.getDefiningOp()
2485 ->getAttrOfType<mlir::IntegerAttr>("field")
2486 .getInt()
2487 : getConstantIntValue(op);
2488 }
2489
2490 static bool hasSubDimensions(mlir::Type type) {
2491 return type.isa<fir::SequenceType, fir::RecordType, mlir::TupleType>();
2492 }
2493
2494 /// Check whether this form of `!fir.coordinate_of` is supported. These
2495 /// additional checks are required, because we are not yet able to convert
2496 /// all valid forms of `!fir.coordinate_of`.
2497 /// TODO: Either implement the unsupported cases or extend the verifier
2498 /// in FIROps.cpp instead.
2499 static bool supportedCoordinate(mlir::Type type, mlir::ValueRange coors) {
2500 const std::size_t numOfCoors = coors.size();
2501 std::size_t i = 0;
2502 bool subEle = false;
2503 bool ptrEle = false;
2504 for (; i < numOfCoors; ++i) {
2505 mlir::Value nxtOpnd = coors[i];
2506 if (auto arrTy = type.dyn_cast<fir::SequenceType>()) {
2507 subEle = true;
2508 i += arrTy.getDimension() - 1;
2509 type = arrTy.getEleTy();
2510 } else if (auto recTy = type.dyn_cast<fir::RecordType>()) {
2511 subEle = true;
2512 type = recTy.getType(getFieldNumber(recTy, nxtOpnd));
2513 } else if (auto tupTy = type.dyn_cast<mlir::TupleType>()) {
2514 subEle = true;
2515 type = tupTy.getType(getConstantIntValue(nxtOpnd));
2516 } else {
2517 ptrEle = true;
2518 }
2519 }
2520 if (ptrEle)
2521 return (!subEle) && (numOfCoors == 1);
2522 return subEle && (i >= numOfCoors);
2523 }
2524
2525 /// Walk the abstract memory layout and determine if the path traverses any
2526 /// array types with unknown shape. Return true iff all the array types have a
2527 /// constant shape along the path.
2528 static bool arraysHaveKnownShape(mlir::Type type, mlir::ValueRange coors) {
2529 for (std::size_t i = 0, sz = coors.size(); i < sz; ++i) {
2530 mlir::Value nxtOpnd = coors[i];
2531 if (auto arrTy = type.dyn_cast<fir::SequenceType>()) {
2532 if (fir::sequenceWithNonConstantShape(arrTy))
2533 return false;
2534 i += arrTy.getDimension() - 1;
2535 type = arrTy.getEleTy();
2536 } else if (auto strTy = type.dyn_cast<fir::RecordType>()) {
2537 type = strTy.getType(getFieldNumber(strTy, nxtOpnd));
2538 } else if (auto strTy = type.dyn_cast<mlir::TupleType>()) {
2539 type = strTy.getType(getConstantIntValue(nxtOpnd));
2540 } else {
2541 return true;
2542 }
2543 }
2544 return true;
2545 }
2546
2547private:
2548 mlir::LogicalResult
2549 doRewriteBox(fir::CoordinateOp coor, mlir::Type ty, mlir::ValueRange operands,
2550 mlir::Location loc,
2551 mlir::ConversionPatternRewriter &rewriter) const {
2552 mlir::Type boxObjTy = coor.getBaseType();
2553 assert(boxObjTy.dyn_cast<fir::BaseBoxType>() && "This is not a `fir.box`")(static_cast <bool> (boxObjTy.dyn_cast<fir::BaseBoxType
>() && "This is not a `fir.box`") ? void (0) : __assert_fail
("boxObjTy.dyn_cast<fir::BaseBoxType>() && \"This is not a `fir.box`\""
, "flang/lib/Optimizer/CodeGen/CodeGen.cpp", 2553, __extension__
__PRETTY_FUNCTION__))
;
2554
2555 mlir::Value boxBaseAddr = operands[0];
2556
2557 // 1. SPECIAL CASE (uses `fir.len_param_index`):
2558 // %box = ... : !fir.box<!fir.type<derived{len1:i32}>>
2559 // %lenp = fir.len_param_index len1, !fir.type<derived{len1:i32}>
2560 // %addr = coordinate_of %box, %lenp
2561 if (coor.getNumOperands() == 2) {
2562 mlir::Operation *coordinateDef =
2563 (*coor.getCoor().begin()).getDefiningOp();
2564 if (mlir::isa_and_nonnull<fir::LenParamIndexOp>(coordinateDef))
2565 TODO(loc,do { fir::emitFatalError(loc, llvm::Twine("flang/lib/Optimizer/CodeGen/CodeGen.cpp"
":" "2566" ": not yet implemented: ") + llvm::Twine("fir.coordinate_of - fir.len_param_index is not supported yet"
), false); } while (false)
2566 "fir.coordinate_of - fir.len_param_index is not supported yet")do { fir::emitFatalError(loc, llvm::Twine("flang/lib/Optimizer/CodeGen/CodeGen.cpp"
":" "2566" ": not yet implemented: ") + llvm::Twine("fir.coordinate_of - fir.len_param_index is not supported yet"
), false); } while (false)
;
2567 }
2568
2569 // 2. GENERAL CASE:
2570 // 2.1. (`fir.array`)
2571 // %box = ... : !fix.box<!fir.array<?xU>>
2572 // %idx = ... : index
2573 // %resultAddr = coordinate_of %box, %idx : !fir.ref<U>
2574 // 2.2 (`fir.derived`)
2575 // %box = ... : !fix.box<!fir.type<derived_type{field_1:i32}>>
2576 // %idx = ... : i32
2577 // %resultAddr = coordinate_of %box, %idx : !fir.ref<i32>
2578 // 2.3 (`fir.derived` inside `fir.array`)
2579 // %box = ... : !fir.box<!fir.array<10 x !fir.type<derived_1{field_1:f32,
2580 // field_2:f32}>>> %idx1 = ... : index %idx2 = ... : i32 %resultAddr =
2581 // coordinate_of %box, %idx1, %idx2 : !fir.ref<f32>
2582 // 2.4. TODO: Either document or disable any other case that the following
2583 // implementation might convert.
2584 mlir::Value resultAddr =
2585 getBaseAddrFromBox(loc, getBaseAddrTypeFromBox(boxBaseAddr.getType()),
2586 boxObjTy, boxBaseAddr, rewriter);
2587 // Component Type
2588 auto cpnTy = fir::dyn_cast_ptrOrBoxEleTy(boxObjTy);
2589 mlir::Type voidPtrTy = ::getVoidPtrType(coor.getContext());
2590
2591 for (unsigned i = 1, last = operands.size(); i < last; ++i) {
2592 if (auto arrTy = cpnTy.dyn_cast<fir::SequenceType>()) {
2593 if (i != 1)
2594 TODO(loc, "fir.array nested inside other array and/or derived type")do { fir::emitFatalError(loc, llvm::Twine("flang/lib/Optimizer/CodeGen/CodeGen.cpp"
":" "2594" ": not yet implemented: ") + llvm::Twine("fir.array nested inside other array and/or derived type"
), false); } while (false)
;
2595 // Applies byte strides from the box. Ignore lower bound from box
2596 // since fir.coordinate_of indexes are zero based. Lowering takes care
2597 // of lower bound aspects. This both accounts for dynamically sized
2598 // types and non contiguous arrays.
2599 auto idxTy = lowerTy().indexType();
2600 mlir::Value off = genConstantIndex(loc, idxTy, rewriter, 0);
2601 for (unsigned index = i, lastIndex = i + arrTy.getDimension();
2602 index < lastIndex; ++index) {
2603 mlir::Value stride =
2604 getStrideFromBox(loc, boxObjTy, operands[0], index - i, rewriter);
2605 auto sc = rewriter.create<mlir::LLVM::MulOp>(loc, idxTy,
2606 operands[index], stride);
2607 off = rewriter.create<mlir::LLVM::AddOp>(loc, idxTy, sc, off);
2608 }
2609 auto voidPtrBase =
2610 rewriter.create<mlir::LLVM::BitcastOp>(loc, voidPtrTy, resultAddr);
2611 resultAddr = rewriter.create<mlir::LLVM::GEPOp>(
2612 loc, voidPtrTy, voidPtrBase,
2613 llvm::ArrayRef<mlir::LLVM::GEPArg>{off});
2614 i += arrTy.getDimension() - 1;
2615 cpnTy = arrTy.getEleTy();
2616 } else if (auto recTy = cpnTy.dyn_cast<fir::RecordType>()) {
2617 auto recRefTy =
2618 mlir::LLVM::LLVMPointerType::get(lowerTy().convertType(recTy));
2619 mlir::Value nxtOpnd = operands[i];
2620 auto memObj =
2621 rewriter.create<mlir::LLVM::BitcastOp>(loc, recRefTy, resultAddr);
2622 cpnTy = recTy.getType(getFieldNumber(recTy, nxtOpnd));
2623 auto llvmCurrentObjTy = lowerTy().convertType(cpnTy);
2624 auto gep = rewriter.create<mlir::LLVM::GEPOp>(
2625 loc, mlir::LLVM::LLVMPointerType::get(llvmCurrentObjTy), memObj,
2626 llvm::ArrayRef<mlir::LLVM::GEPArg>{0, nxtOpnd});
2627 resultAddr =
2628 rewriter.create<mlir::LLVM::BitcastOp>(loc, voidPtrTy, gep);
2629 } else {
2630 fir::emitFatalError(loc, "unexpected type in coordinate_of");
2631 }
2632 }
2633
2634 rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(coor, ty, resultAddr);
2635 return mlir::success();
2636 }
2637
2638 mlir::LogicalResult
2639 doRewriteRefOrPtr(fir::CoordinateOp coor, mlir::Type ty,
2640 mlir::ValueRange operands, mlir::Location loc,
2641 mlir::ConversionPatternRewriter &rewriter) const {
2642 mlir::Type baseObjectTy = coor.getBaseType();
2643
2644 // Component Type
2645 mlir::Type cpnTy = fir::dyn_cast_ptrOrBoxEleTy(baseObjectTy);
2646 bool hasSubdimension = hasSubDimensions(cpnTy);
2647 bool columnIsDeferred = !hasSubdimension;
2648
2649 if (!supportedCoordinate(cpnTy, operands.drop_front(1)))
2650 TODO(loc, "unsupported combination of coordinate operands")do { fir::emitFatalError(loc, llvm::Twine("flang/lib/Optimizer/CodeGen/CodeGen.cpp"
":" "2650" ": not yet implemented: ") + llvm::Twine("unsupported combination of coordinate operands"
), false); } while (false)
;
2651
2652 const bool hasKnownShape =
2653 arraysHaveKnownShape(cpnTy, operands.drop_front(1));
2654
2655 // If only the column is `?`, then we can simply place the column value in
2656 // the 0-th GEP position.
2657 if (auto arrTy = cpnTy.dyn_cast<fir::SequenceType>()) {
2658 if (!hasKnownShape) {
2659 const unsigned sz = arrTy.getDimension();
2660 if (arraysHaveKnownShape(arrTy.getEleTy(),
2661 operands.drop_front(1 + sz))) {
2662 fir::SequenceType::ShapeRef shape = arrTy.getShape();
2663 bool allConst = true;
2664 for (unsigned i = 0; i < sz - 1; ++i) {
2665 if (shape[i] < 0) {
2666 allConst = false;
2667 break;
2668 }
2669 }
2670 if (allConst)
2671 columnIsDeferred = true;
2672 }
2673 }
2674 }
2675
2676 if (fir::hasDynamicSize(fir::unwrapSequenceType(cpnTy)))
2677 return mlir::emitError(
2678 loc, "fir.coordinate_of with a dynamic element size is unsupported");
2679
2680 if (hasKnownShape || columnIsDeferred) {
2681 llvm::SmallVector<mlir::LLVM::GEPArg> offs;
2682 if (hasKnownShape && hasSubdimension) {
2683 offs.push_back(0);
2684 }
2685 std::optional<int> dims;
2686 llvm::SmallVector<mlir::Value> arrIdx;
2687 for (std::size_t i = 1, sz = operands.size(); i < sz; ++i) {
2688 mlir::Value nxtOpnd = operands[i];
2689
2690 if (!cpnTy)
2691 return mlir::emitError(loc, "invalid coordinate/check failed");
2692
2693 // check if the i-th coordinate relates to an array
2694 if (dims) {
2695 arrIdx.push_back(nxtOpnd);
2696 int dimsLeft = *dims;
2697 if (dimsLeft > 1) {
2698 dims = dimsLeft - 1;
2699 continue;
2700 }
2701 cpnTy = cpnTy.cast<fir::SequenceType>().getEleTy();
2702 // append array range in reverse (FIR arrays are column-major)
2703 offs.append(arrIdx.rbegin(), arrIdx.rend());
2704 arrIdx.clear();
2705 dims.reset();
2706 continue;
2707 }
2708 if (auto arrTy = cpnTy.dyn_cast<fir::SequenceType>()) {
2709 int d = arrTy.getDimension() - 1;
2710 if (d > 0) {
2711 dims = d;
2712 arrIdx.push_back(nxtOpnd);
2713 continue;
2714 }
2715 cpnTy = cpnTy.cast<fir::SequenceType>().getEleTy();
2716 offs.push_back(nxtOpnd);
2717 continue;
2718 }
2719
2720 // check if the i-th coordinate relates to a field
2721 if (auto recTy = cpnTy.dyn_cast<fir::RecordType>())
2722 cpnTy = recTy.getType(getFieldNumber(recTy, nxtOpnd));
2723 else if (auto tupTy = cpnTy.dyn_cast<mlir::TupleType>())
2724 cpnTy = tupTy.getType(getConstantIntValue(nxtOpnd));
2725 else
2726 cpnTy = nullptr;
2727
2728 offs.push_back(nxtOpnd);
2729 }
2730 if (dims)
2731 offs.append(arrIdx.rbegin(), arrIdx.rend());
2732 mlir::Value base = operands[0];
2733 mlir::Value retval = genGEP(loc, ty, rewriter, base, offs);
2734 rewriter.replaceOp(coor, retval);
2735 return mlir::success();
2736 }
2737
2738 return mlir::emitError(
2739 loc, "fir.coordinate_of base operand has unsupported type");
2740 }
2741};
2742
2743/// Convert `fir.field_index`. The conversion depends on whether the size of
2744/// the record is static or dynamic.
2745struct FieldIndexOpConversion : public FIROpConversion<fir::FieldIndexOp> {
2746 using FIROpConversion::FIROpConversion;
2747
2748 // NB: most field references should be resolved by this point
2749 mlir::LogicalResult
2750 matchAndRewrite(fir::FieldIndexOp field, OpAdaptor adaptor,
2751 mlir::ConversionPatternRewriter &rewriter) const override {
2752 auto recTy = field.getOnType().cast<fir::RecordType>();
2753 unsigned index = recTy.getFieldIndex(field.getFieldId());
2754
2755 if (!fir::hasDynamicSize(recTy)) {
2756 // Derived type has compile-time constant layout. Return index of the
2757 // component type in the parent type (to be used in GEP).
2758 rewriter.replaceOp(field, mlir::ValueRange{genConstantOffset(
2759 field.getLoc(), rewriter, index)});
2760 return mlir::success();
2761 }
2762
2763 // Derived type has compile-time constant layout. Call the compiler
2764 // generated function to determine the byte offset of the field at runtime.
2765 // This returns a non-constant.
2766 mlir::FlatSymbolRefAttr symAttr = mlir::SymbolRefAttr::get(
2767 field.getContext(), getOffsetMethodName(recTy, field.getFieldId()));
2768 mlir::NamedAttribute callAttr = rewriter.getNamedAttr("callee", symAttr);
2769 mlir::NamedAttribute fieldAttr = rewriter.getNamedAttr(
2770 "field", mlir::IntegerAttr::get(lowerTy().indexType(), index));
2771 rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
2772 field, lowerTy().offsetType(), adaptor.getOperands(),
2773 llvm::ArrayRef<mlir::NamedAttribute>{callAttr, fieldAttr});
2774 return mlir::success();
2775 }
2776
2777 // Re-Construct the name of the compiler generated method that calculates the
2778 // offset
2779 inline static std::string getOffsetMethodName(fir::RecordType recTy,
2780 llvm::StringRef field) {
2781 return recTy.getName().str() + "P." + field.str() + ".offset";
2782 }
2783};
2784
2785/// Convert `fir.end`
2786struct FirEndOpConversion : public FIROpConversion<fir::FirEndOp> {
2787 using FIROpConversion::FIROpConversion;
2788
2789 mlir::LogicalResult
2790 matchAndRewrite(fir::FirEndOp firEnd, OpAdaptor,
2791 mlir::ConversionPatternRewriter &rewriter) const override {
2792 TODO(firEnd.getLoc(), "fir.end codegen")do { fir::emitFatalError(firEnd.getLoc(), llvm::Twine("flang/lib/Optimizer/CodeGen/CodeGen.cpp"
":" "2792" ": not yet implemented: ") + llvm::Twine("fir.end codegen"
), false); } while (false)
;
2793 return mlir::failure();
2794 }
2795};
2796
2797/// Lower `fir.type_desc` to a global addr.
2798struct TypeDescOpConversion : public FIROpConversion<fir::TypeDescOp> {
2799 using FIROpConversion::FIROpConversion;
2800
2801 mlir::LogicalResult
2802 matchAndRewrite(fir::TypeDescOp typeDescOp, OpAdaptor adaptor,
2803 mlir::ConversionPatternRewriter &rewriter) const override {
2804 mlir::Type inTy = typeDescOp.getInType();
2805 assert(inTy.isa<fir::RecordType>() && "expecting fir.type")(static_cast <bool> (inTy.isa<fir::RecordType>() &&
"expecting fir.type") ? void (0) : __assert_fail ("inTy.isa<fir::RecordType>() && \"expecting fir.type\""
, "flang/lib/Optimizer/CodeGen/CodeGen.cpp", 2805, __extension__
__PRETTY_FUNCTION__))
;
2806 auto recordType = inTy.dyn_cast<fir::RecordType>();
2807 auto module = typeDescOp.getOperation()->getParentOfType<mlir::ModuleOp>();
2808 std::string typeDescName =
2809 fir::NameUniquer::getTypeDescriptorName(recordType.getName());
2810 if (auto global = module.lookupSymbol<mlir::LLVM::GlobalOp>(typeDescName)) {
2811 auto ty = mlir::LLVM::LLVMPointerType::get(
2812 this->lowerTy().convertType(global.getType()));
2813 rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(typeDescOp, ty,
2814 global.getSymName());
2815 return mlir::success();
2816 } else if (auto global = module.lookupSymbol<fir::GlobalOp>(typeDescName)) {
2817 auto ty = mlir::LLVM::LLVMPointerType::get(
2818 this->lowerTy().convertType(global.getType()));
2819 rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(typeDescOp, ty,
2820 global.getSymName());
2821 return mlir::success();
2822 }
2823 return mlir::failure();
2824 }
2825};
2826
2827/// Lower `fir.has_value` operation to `llvm.return` operation.
2828struct HasValueOpConversion : public FIROpConversion<fir::HasValueOp> {
2829 using FIROpConversion::FIROpConversion;
2830
2831 mlir::LogicalResult
2832 matchAndRewrite(fir::HasValueOp op, OpAdaptor adaptor,
2833 mlir::ConversionPatternRewriter &rewriter) const override {
2834 rewriter.replaceOpWithNewOp<mlir::LLVM::ReturnOp>(op,
2835 adaptor.getOperands());
2836 return mlir::success();
2837 }
2838};
2839
2840/// Lower `fir.global` operation to `llvm.global` operation.
2841/// `fir.insert_on_range` operations are replaced with constant dense attribute
2842/// if they are applied on the full range.
2843struct GlobalOpConversion : public FIROpConversion<fir::GlobalOp> {
2844 using FIROpConversion::FIROpConversion;
2845
2846 mlir::LogicalResult
2847 matchAndRewrite(fir::GlobalOp global, OpAdaptor adaptor,
2848 mlir::ConversionPatternRewriter &rewriter) const override {
2849 auto tyAttr = convertType(global.getType());
2850 if (global.getType().isa<fir::BaseBoxType>())
2851 tyAttr = tyAttr.cast<mlir::LLVM::LLVMPointerType>().getElementType();
2852 auto loc = global.getLoc();
2853 mlir::Attribute initAttr = global.getInitVal().value_or(mlir::Attribute());
2854 auto linkage = convertLinkage(global.getLinkName());
2855 auto isConst = global.getConstant().has_value();
2856 auto g = rewriter.create<mlir::LLVM::GlobalOp>(
2857 loc, tyAttr, isConst, linkage, global.getSymName(), initAttr);
2858
2859 // Apply all non-Fir::GlobalOp attributes to the LLVM::GlobalOp, preserving
2860 // them; whilst taking care not to apply attributes that are lowered in
2861 // other ways.
2862 llvm::SmallDenseSet<llvm::StringRef> elidedAttrsSet(
2863 global.getAttributeNames().begin(), global.getAttributeNames().end());
2864 for (auto &attr : global->getAttrs())
2865 if (!elidedAttrsSet.contains(attr.getName().strref()))
2866 g->setAttr(attr.getName(), attr.getValue());
2867
2868 auto &gr = g.getInitializerRegion();
2869 rewriter.inlineRegionBefore(global.getRegion(), gr, gr.end());
2870 if (!gr.empty()) {
2871 // Replace insert_on_range with a constant dense attribute if the
2872 // initialization is on the full range.
2873 auto insertOnRangeOps = gr.front().getOps<fir::InsertOnRangeOp>();
2874 for (auto insertOp : insertOnRangeOps) {
2875 if (isFullRange(insertOp.getCoor(), insertOp.getType())) {
2876 auto seqTyAttr = convertType(insertOp.getType());
2877 auto *op = insertOp.getVal().getDefiningOp();
2878 auto constant = mlir::dyn_cast<mlir::arith::ConstantOp>(op);
2879 if (!constant) {
2880 auto convertOp = mlir::dyn_cast<fir::ConvertOp>(op);
2881 if (!convertOp)
2882 continue;
2883 constant = mlir::cast<mlir::arith::ConstantOp>(
2884 convertOp.getValue().getDefiningOp());
2885 }
2886 mlir::Type vecType = mlir::VectorType::get(
2887 insertOp.getType().getShape(), constant.getType());
2888 auto denseAttr = mlir::DenseElementsAttr::get(
2889 vecType.cast<mlir::ShapedType>(), constant.getValue());
2890 rewriter.setInsertionPointAfter(insertOp);
2891 rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>(
2892 insertOp, seqTyAttr, denseAttr);
2893 }
2894 }
2895 }
2896 rewriter.eraseOp(global);
2897 return mlir::success();
2898 }
2899
2900 bool isFullRange(mlir::DenseIntElementsAttr indexes,
2901 fir::SequenceType seqTy) const {
2902 auto extents = seqTy.getShape();
2903 if (indexes.size() / 2 != static_cast<int64_t>(extents.size()))
2904 return false;
2905 auto cur_index = indexes.value_begin<int64_t>();
2906 for (unsigned i = 0; i < indexes.size(); i += 2) {
2907 if (*(cur_index++) != 0)
2908 return false;
2909 if (*(cur_index++) != extents[i / 2] - 1)
2910 return false;
2911 }
2912 return true;
2913 }
2914
2915 // TODO: String comparaison should be avoided. Replace linkName with an
2916 // enumeration.
2917 mlir::LLVM::Linkage
2918 convertLinkage(std::optional<llvm::StringRef> optLinkage) const {
2919 if (optLinkage) {
2920 auto name = *optLinkage;
2921 if (name == "internal")
2922 return mlir::LLVM::Linkage::Internal;
2923 if (name == "linkonce")
2924 return mlir::LLVM::Linkage::Linkonce;
2925 if (name == "linkonce_odr")
2926 return mlir::LLVM::Linkage::LinkonceODR;
2927 if (name == "common")
2928 return mlir::LLVM::Linkage::Common;
2929 if (name == "weak")
2930 return mlir::LLVM::Linkage::Weak;
2931 }
2932 return mlir::LLVM::Linkage::External;
2933 }
2934};
2935
2936/// `fir.load` --> `llvm.load`
2937struct LoadOpConversion : public FIROpConversion<fir::LoadOp> {
2938 using FIROpConversion::FIROpConversion;
2939
2940 mlir::LogicalResult
2941 matchAndRewrite(fir::LoadOp load, OpAdaptor adaptor,
2942 mlir::ConversionPatternRewriter &rewriter) const override {
2943 if (auto boxTy = load.getType().dyn_cast<fir::BaseBoxType>()) {
2944 // fir.box is a special case because it is considered as an ssa values in
2945 // fir, but it is lowered as a pointer to a descriptor. So
2946 // fir.ref<fir.box> and fir.box end up being the same llvm types and
2947 // loading a fir.ref<fir.box> is implemented as taking a snapshot of the
2948 // descriptor value into a new descriptor temp.
2949 auto inputBoxStorage = adaptor.getOperands()[0];
2950 mlir::Location loc = load.getLoc();
2951 fir::SequenceType seqTy = fir::unwrapUntilSeqType(boxTy);
2952 // fir.box of assumed rank do not have a storage
2953 // size that is know at compile time. The copy needs to be runtime driven
2954 // depending on the actual dynamic rank or type.
2955 if (seqTy && seqTy.hasUnknownShape())
2956 TODO(loc, "loading or assumed rank fir.box")do { fir::emitFatalError(loc, llvm::Twine("flang/lib/Optimizer/CodeGen/CodeGen.cpp"
":" "2956" ": not yet implemented: ") + llvm::Twine("loading or assumed rank fir.box"
), false); } while (false)
;
2957 mlir::Type boxPtrTy = inputBoxStorage.getType();
2958 auto boxValue = rewriter.create<mlir::LLVM::LoadOp>(
2959 loc, boxPtrTy.cast<mlir::LLVM::LLVMPointerType>().getElementType(),
2960 inputBoxStorage);
2961 attachTBAATag(boxValue, boxTy, boxTy, nullptr);
2962 auto newBoxStorage =
2963 genAllocaWithType(loc, boxPtrTy, defaultAlign, rewriter);
2964 auto storeOp =
2965 rewriter.create<mlir::LLVM::StoreOp>(loc, boxValue, newBoxStorage);
2966 attachTBAATag(storeOp, boxTy, boxTy, nullptr);
2967 rewriter.replaceOp(load, newBoxStorage.getResult());
2968 } else {
2969 mlir::Type loadTy = convertType(load.getType());
2970 auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(
2971 load.getLoc(), loadTy, adaptor.getOperands(), load->getAttrs());
2972 attachTBAATag(loadOp, load.getType(), load.getType(), nullptr);
2973 rewriter.replaceOp(load, loadOp.getResult());
2974 }
2975 return mlir::success();
2976 }
2977};
2978
2979/// Lower `fir.no_reassoc` to LLVM IR dialect.
2980/// TODO: how do we want to enforce this in LLVM-IR? Can we manipulate the fast
2981/// math flags?
2982struct NoReassocOpConversion : public FIROpConversion<fir::NoReassocOp> {
2983 using FIROpConversion::FIROpConversion;
2984
2985 mlir::LogicalResult
2986 matchAndRewrite(fir::NoReassocOp noreassoc, OpAdaptor adaptor,
2987 mlir::ConversionPatternRewriter &rewriter) const override {
2988 rewriter.replaceOp(noreassoc, adaptor.getOperands()[0]);
2989 return mlir::success();
2990 }
2991};
2992
2993static void genCondBrOp(mlir::Location loc, mlir::Value cmp, mlir::Block *dest,
2994 std::optional<mlir::ValueRange> destOps,
2995 mlir::ConversionPatternRewriter &rewriter,
2996 mlir::Block *newBlock) {
2997 if (destOps)
2998 rewriter.create<mlir::LLVM::CondBrOp>(loc, cmp, dest, *destOps, newBlock,
2999 mlir::ValueRange());
3000 else
3001 rewriter.create<mlir::LLVM::CondBrOp>(loc, cmp, dest, newBlock);
3002}
3003
3004template <typename A, typename B>
3005static void genBrOp(A caseOp, mlir::Block *dest, std::optional<B> destOps,
3006 mlir::ConversionPatternRewriter &rewriter) {
3007 if (destOps)
3008 rewriter.replaceOpWithNewOp<mlir::LLVM::BrOp>(caseOp, *destOps, dest);
3009 else
3010 rewriter.replaceOpWithNewOp<mlir::LLVM::BrOp>(caseOp, std::nullopt, dest);
3011}
3012
3013static void genCaseLadderStep(mlir::Location loc, mlir::Value cmp,
3014 mlir::Block *dest,
3015 std::optional<mlir::ValueRange> destOps,
3016 mlir::ConversionPatternRewriter &rewriter) {
3017 auto *thisBlock = rewriter.getInsertionBlock();
3018 auto *newBlock = createBlock(rewriter, dest);
3019 rewriter.setInsertionPointToEnd(thisBlock);
3020 genCondBrOp(loc, cmp, dest, destOps, rewriter, newBlock);
3021 rewriter.setInsertionPointToEnd(newBlock);
3022}
3023
3024/// Conversion of `fir.select_case`
3025///
3026/// The `fir.select_case` operation is converted to a if-then-else ladder.
3027/// Depending on the case condition type, one or several comparison and
3028/// conditional branching can be generated.
3029///
3030/// A a point value case such as `case(4)`, a lower bound case such as
3031/// `case(5:)` or an upper bound case such as `case(:3)` are converted to a
3032/// simple comparison between the selector value and the constant value in the
3033/// case. The block associated with the case condition is then executed if
3034/// the comparison succeed otherwise it branch to the next block with the
3035/// comparison for the the next case conditon.
3036///
3037/// A closed interval case condition such as `case(7:10)` is converted with a
3038/// first comparison and conditional branching for the lower bound. If
3039/// successful, it branch to a second block with the comparison for the
3040/// upper bound in the same case condition.
3041///
3042/// TODO: lowering of CHARACTER type cases is not handled yet.
3043struct SelectCaseOpConversion : public FIROpConversion<fir::SelectCaseOp> {
3044 using FIROpConversion::FIROpConversion;
3045
3046 mlir::LogicalResult
3047 matchAndRewrite(fir::SelectCaseOp caseOp, OpAdaptor adaptor,
3048 mlir::ConversionPatternRewriter &rewriter) const override {
3049 unsigned conds = caseOp.getNumConditions();
3050 llvm::ArrayRef<mlir::Attribute> cases = caseOp.getCases().getValue();
3051 // Type can be CHARACTER, INTEGER, or LOGICAL (C1145)
3052 auto ty = caseOp.getSelector().getType();
3053 if (ty.isa<fir::CharacterType>()) {
3054 TODO(caseOp.getLoc(), "fir.select_case codegen with character type")do { fir::emitFatalError(caseOp.getLoc(), llvm::Twine("flang/lib/Optimizer/CodeGen/CodeGen.cpp"
":" "3054" ": not yet implemented: ") + llvm::Twine("fir.select_case codegen with character type"
), false); } while (false)
;
3055 return mlir::failure();
3056 }
3057 mlir::Value selector = caseOp.getSelector(adaptor.getOperands());
3058 auto loc = caseOp.getLoc();
3059 for (unsigned t = 0; t != conds; ++t) {
3060 mlir::Block *dest = caseOp.getSuccessor(t);
3061 std::optional<mlir::ValueRange> destOps =
3062 caseOp.getSuccessorOperands(adaptor.getOperands(), t);
3063 std::optional<mlir::ValueRange> cmpOps =
3064 *caseOp.getCompareOperands(adaptor.getOperands(), t);
3065 mlir::Value caseArg = *(cmpOps.value().begin());
3066 mlir::Attribute attr = cases[t];
3067 if (attr.isa<fir::PointIntervalAttr>()) {
3068 auto cmp = rewriter.create<mlir::LLVM::ICmpOp>(
3069 loc, mlir::LLVM::ICmpPredicate::eq, selector, caseArg);
3070 genCaseLadderStep(loc, cmp, dest, destOps, rewriter);
3071 continue;
3072 }
3073 if (attr.isa<fir::LowerBoundAttr>()) {
3074 auto cmp = rewriter.create<mlir::LLVM::ICmpOp>(
3075 loc, mlir::LLVM::ICmpPredicate::sle, caseArg, selector);
3076 genCaseLadderStep(loc, cmp, dest, destOps, rewriter);
3077 continue;
3078 }
3079 if (attr.isa<fir::UpperBoundAttr>()) {
3080 auto cmp = rewriter.create<mlir::LLVM::ICmpOp>(
3081 loc, mlir::LLVM::ICmpPredicate::sle, selector, caseArg);
3082 genCaseLadderStep(loc, cmp, dest, destOps, rewriter);
3083 continue;
3084 }
3085 if (attr.isa<fir::ClosedIntervalAttr>()) {
3086 auto cmp = rewriter.create<mlir::LLVM::ICmpOp>(
3087 loc, mlir::LLVM::ICmpPredicate::sle, caseArg, selector);
3088 auto *thisBlock = rewriter.getInsertionBlock();
3089 auto *newBlock1 = createBlock(rewriter, dest);
3090 auto *newBlock2 = createBlock(rewriter, dest);
3091 rewriter.setInsertionPointToEnd(thisBlock);
3092 rewriter.create<mlir::LLVM::CondBrOp>(loc, cmp, newBlock1, newBlock2);
3093 rewriter.setInsertionPointToEnd(newBlock1);
3094 mlir::Value caseArg0 = *(cmpOps.value().begin() + 1);
3095 auto cmp0 = rewriter.create<mlir::LLVM::ICmpOp>(
3096 loc, mlir::LLVM::ICmpPredicate::sle, selector, caseArg0);
3097 genCondBrOp(loc, cmp0, dest, destOps, rewriter, newBlock2);
3098 rewriter.setInsertionPointToEnd(newBlock2);
3099 continue;
3100 }
3101 assert(attr.isa<mlir::UnitAttr>())(static_cast <bool> (attr.isa<mlir::UnitAttr>()) ?
void (0) : __assert_fail ("attr.isa<mlir::UnitAttr>()"
, "flang/lib/Optimizer/CodeGen/CodeGen.cpp", 3101, __extension__
__PRETTY_FUNCTION__))
;
3102 assert((t + 1 == conds) && "unit must be last")(static_cast <bool> ((t + 1 == conds) && "unit must be last"
) ? void (0) : __assert_fail ("(t + 1 == conds) && \"unit must be last\""
, "flang/lib/Optimizer/CodeGen/CodeGen.cpp", 3102, __extension__
__PRETTY_FUNCTION__))
;
3103 genBrOp(caseOp, dest, destOps, rewriter);
3104 }
3105 return mlir::success();
3106 }
3107};
3108
3109template <typename OP>
3110static void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select,
3111 typename OP::Adaptor adaptor,
3112 mlir::ConversionPatternRewriter &rewriter) {
3113 unsigned conds = select.getNumConditions();
3114 auto cases = select.getCases().getValue();
3115 mlir::Value selector = adaptor.getSelector();
3116 auto loc = select.getLoc();
3117 assert(conds > 0 && "select must have cases")(static_cast <bool> (conds > 0 && "select must have cases"
) ? void (0) : __assert_fail ("conds > 0 && \"select must have cases\""
, "flang/lib/Optimizer/CodeGen/CodeGen.cpp", 3117, __extension__
__PRETTY_FUNCTION__))
;
2
Assuming 'conds' is > 0
3
'?' condition is true
3118
3119 llvm::SmallVector<mlir::Block *> destinations;
3120 llvm::SmallVector<mlir::ValueRange> destinationsOperands;
3121 mlir::Block *defaultDestination;
4
'defaultDestination' declared without an initial value
3122 mlir::ValueRange defaultOperands;
3123 llvm::SmallVector<int32_t> caseValues;
3124
3125 for (unsigned t = 0; t
4.1
't' is not equal to 'conds'
4.1
't' is not equal to 'conds'
4.1
't' is not equal to 'conds'
!= conds
; ++t) {
5
Loop condition is true. Entering loop body
10
Assuming 't' is equal to 'conds'
11
Loop condition is false. Execution continues on line 3142
3126 mlir::Block *dest = select.getSuccessor(t);
3127 auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
3128 const mlir::Attribute &attr = cases[t];
3129 if (auto intAttr = attr.template dyn_cast<mlir::IntegerAttr>()) {
6
Taking true branch
3130 destinations.push_back(dest);
3131 destinationsOperands.push_back(destOps ? *destOps : mlir::ValueRange{});
7
Assuming the condition is false
8
'?' condition is false
3132 caseValues.push_back(intAttr.getInt());
3133 continue;
9
Execution continues on line 3125
3134 }
3135 assert(attr.template dyn_cast_or_null<mlir::UnitAttr>())(static_cast <bool> (attr.template dyn_cast_or_null<
mlir::UnitAttr>()) ? void (0) : __assert_fail ("attr.template dyn_cast_or_null<mlir::UnitAttr>()"
, "flang/lib/Optimizer/CodeGen/CodeGen.cpp", 3135, __extension__
__PRETTY_FUNCTION__))
;
3136 assert((t + 1 == conds) && "unit must be last")(static_cast <bool> ((t + 1 == conds) && "unit must be last"
) ? void (0) : __assert_fail ("(t + 1 == conds) && \"unit must be last\""
, "flang/lib/Optimizer/CodeGen/CodeGen.cpp", 3136, __extension__
__PRETTY_FUNCTION__))
;
3137 defaultDestination = dest;
3138 defaultOperands = destOps ? *destOps : mlir::ValueRange{};
3139 }
3140
3141 // LLVM::SwitchOp takes a i32 type for the selector.
3142 if (select.getSelector().getType() != rewriter.getI32Type())
12
Taking false branch
3143 selector = rewriter.create<mlir::LLVM::TruncOp>(loc, rewriter.getI32Type(),
3144 selector);
3145
3146 rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
13
Calling 'RewriterBase::replaceOpWithNewOp'
3147 select, selector,
3148 /*defaultDestination=*/defaultDestination,
3149 /*defaultOperands=*/defaultOperands,
3150 /*caseValues=*/caseValues,
3151 /*caseDestinations=*/destinations,
3152 /*caseOperands=*/destinationsOperands,
3153 /*branchWeights=*/llvm::ArrayRef<std::int32_t>());
3154}
3155
3156/// conversion of fir::SelectOp to an if-then-else ladder
3157struct SelectOpConversion : public FIROpConversion<fir::SelectOp> {
3158 using FIROpConversion::FIROpConversion;
3159
3160 mlir::LogicalResult
3161 matchAndRewrite(fir::SelectOp op, OpAdaptor adaptor,
3162 mlir::ConversionPatternRewriter &rewriter) const override {
3163 selectMatchAndRewrite<fir::SelectOp>(lowerTy(), op, adaptor, rewriter);
3164 return mlir::success();
3165 }
3166};
3167
3168/// conversion of fir::SelectRankOp to an if-then-else ladder
3169struct SelectRankOpConversion : public FIROpConversion<fir::SelectRankOp> {
3170 using FIROpConversion::FIROpConversion;
3171
3172 mlir::LogicalResult
3173 matchAndRewrite(fir::SelectRankOp op, OpAdaptor adaptor,
3174 mlir::ConversionPatternRewriter &rewriter) const override {
3175 selectMatchAndRewrite<fir::SelectRankOp>(lowerTy(), op, adaptor, rewriter);
1
Calling 'selectMatchAndRewrite<fir::SelectRankOp>'
3176 return mlir::success();
3177 }
3178};
3179
3180/// Lower `fir.select_type` to LLVM IR dialect.
3181struct SelectTypeOpConversion : public FIROpConversion<fir::SelectTypeOp> {
3182 using FIROpConversion::FIROpConversion;
3183
3184 mlir::LogicalResult
3185 matchAndRewrite(fir::SelectTypeOp select, OpAdaptor adaptor,
3186 mlir::ConversionPatternRewriter &rewriter) const override {
3187 mlir::emitError(select.getLoc(),
3188 "fir.select_type should have already been converted");
3189 return mlir::failure();
3190 }
3191};
3192
3193/// `fir.store` --> `llvm.store`
3194struct StoreOpConversion : public FIROpConversion<fir::StoreOp> {
3195 using FIROpConversion::FIROpConversion;
3196
3197 mlir::LogicalResult
3198 matchAndRewrite(fir::StoreOp store, OpAdaptor adaptor,
3199 mlir::ConversionPatternRewriter &rewriter) const override {
3200 mlir::Location loc = store.getLoc();
3201 mlir::Type storeTy = store.getValue().getType();
3202 mlir::LLVM::StoreOp newStoreOp;
3203 if (auto boxTy = storeTy.dyn_cast<fir::BaseBoxType>()) {
3204 // fir.box value is actually in memory, load it first before storing it.
3205 mlir::Type boxPtrTy = adaptor.getOperands()[0].getType();
3206 auto val = rewriter.create<mlir::LLVM::LoadOp>(
3207 loc, boxPtrTy.cast<mlir::LLVM::LLVMPointerType>().getElementType(),
3208 adaptor.getOperands()[0]);
3209 attachTBAATag(val, boxTy, boxTy, nullptr);
3210 newStoreOp = rewriter.create<mlir::LLVM::StoreOp>(
3211 loc, val, adaptor.getOperands()[1]);
3212 } else {
3213 newStoreOp = rewriter.create<mlir::LLVM::StoreOp>(
3214 loc, adaptor.getOperands()[0], adaptor.getOperands()[1]);
3215 }
3216 attachTBAATag(newStoreOp, storeTy, storeTy, nullptr);
3217 rewriter.eraseOp(store);
3218 return mlir::success();
3219 }
3220};
3221
3222namespace {
3223
3224/// Convert `fir.unboxchar` into two `llvm.extractvalue` instructions. One for
3225/// the character buffer and one for the buffer length.
3226struct UnboxCharOpConversion : public FIROpConversion<fir::UnboxCharOp> {
3227 using FIROpConversion::FIROpConversion;
3228
3229 mlir::LogicalResult
3230 matchAndRewrite(fir::UnboxCharOp unboxchar, OpAdaptor adaptor,
3231 mlir::ConversionPatternRewriter &rewriter) const override {
3232 mlir::Type lenTy = convertType(unboxchar.getType(1));
3233 mlir::Value tuple = adaptor.getOperands()[0];
3234
3235 mlir::Location loc = unboxchar.getLoc();
3236 mlir::Value ptrToBuffer =
3237 rewriter.create<mlir::LLVM::ExtractValueOp>(loc, tuple, 0);
3238
3239 auto len = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, tuple, 1);
3240 mlir::Value lenAfterCast = integerCast(loc, rewriter, lenTy, len);
3241
3242 rewriter.replaceOp(unboxchar,
3243 llvm::ArrayRef<mlir::Value>{ptrToBuffer, lenAfterCast});
3244 return mlir::success();
3245 }
3246};
3247
3248/// Lower `fir.unboxproc` operation. Unbox a procedure box value, yielding its
3249/// components.
3250/// TODO: Part of supporting Fortran 2003 procedure pointers.
3251struct UnboxProcOpConversion : public FIROpConversion<fir::UnboxProcOp> {
3252 using FIROpConversion::FIROpConversion;
3253
3254 mlir::LogicalResult
3255 matchAndRewrite(fir::UnboxProcOp unboxproc, OpAdaptor adaptor,
3256 mlir::ConversionPatternRewriter &rewriter) const override {
3257 TODO(unboxproc.getLoc(), "fir.unboxproc codegen")do { fir::emitFatalError(unboxproc.getLoc(), llvm::Twine("flang/lib/Optimizer/CodeGen/CodeGen.cpp"
":" "3257" ": not yet implemented: ") + llvm::Twine("fir.unboxproc codegen"
), false); } while (false)
;
3258 return mlir::failure();
3259 }
3260};
3261
3262/// convert to LLVM IR dialect `undef`
3263struct UndefOpConversion : public FIROpConversion<fir::UndefOp> {
3264 using FIROpConversion::FIROpConversion;
3265
3266 mlir::LogicalResult
3267 matchAndRewrite(fir::UndefOp undef, OpAdaptor,
3268 mlir::ConversionPatternRewriter &rewriter) const override {
3269 rewriter.replaceOpWithNewOp<mlir::LLVM::UndefOp>(
3270 undef, convertType(undef.getType()));
3271 return mlir::success();
3272 }
3273};
3274
3275struct ZeroOpConversion : public FIROpConversion<fir::ZeroOp> {
3276 using FIROpConversion::FIROpConversion;
3277
3278 mlir::LogicalResult
3279 matchAndRewrite(fir::ZeroOp zero, OpAdaptor,
3280 mlir::ConversionPatternRewriter &rewriter) const override {
3281 mlir::Type ty = convertType(zero.getType());
3282 if (ty.isa<mlir::LLVM::LLVMPointerType>()) {
3283 rewriter.replaceOpWithNewOp<mlir::LLVM::NullOp>(zero, ty);
3284 } else if (ty.isa<mlir::IntegerType>()) {
3285 rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(
3286 zero, ty, mlir::IntegerAttr::get(ty, 0));
3287 } else if (mlir::LLVM::isCompatibleFloatingPointType(ty)) {
3288 rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(
3289 zero, ty, mlir::FloatAttr::get(zero.getType(), 0.0));
3290 } else {
3291 // TODO: create ConstantAggregateZero for FIR aggregate/array types.
3292 return rewriter.notifyMatchFailure(
3293 zero,
3294 "conversion of fir.zero with aggregate type not implemented yet");
3295 }
3296 return mlir::success();
3297 }
3298};
3299
3300/// `fir.unreachable` --> `llvm.unreachable`
3301struct UnreachableOpConversion : public FIROpConversion<fir::UnreachableOp> {
3302 using FIROpConversion::FIROpConversion;
3303
3304 mlir::LogicalResult
3305 matchAndRewrite(fir::UnreachableOp unreach, OpAdaptor adaptor,
3306 mlir::ConversionPatternRewriter &rewriter) const override {
3307 rewriter.replaceOpWithNewOp<mlir::LLVM::UnreachableOp>(unreach);
3308 return mlir::success();
3309 }
3310};
3311
3312/// `fir.is_present` -->
3313/// ```
3314/// %0 = llvm.mlir.constant(0 : i64)
3315/// %1 = llvm.ptrtoint %0
3316/// %2 = llvm.icmp "ne" %1, %0 : i64
3317/// ```
3318struct IsPresentOpConversion : public FIROpConversion<fir::IsPresentOp> {
3319 using FIROpConversion::FIROpConversion;
3320
3321 mlir::LogicalResult
3322 matchAndRewrite(fir::IsPresentOp isPresent, OpAdaptor adaptor,
3323 mlir::ConversionPatternRewriter &rewriter) const override {
3324 mlir::Type idxTy = lowerTy().indexType();
3325 mlir::Location loc = isPresent.getLoc();
3326 auto ptr = adaptor.getOperands()[0];
3327
3328 if (isPresent.getVal().getType().isa<fir::BoxCharType>()) {
3329 [[maybe_unused]] auto structTy =
3330 ptr.getType().cast<mlir::LLVM::LLVMStructType>();
3331 assert(!structTy.isOpaque() && !structTy.getBody().empty())(static_cast <bool> (!structTy.isOpaque() && !structTy
.getBody().empty()) ? void (0) : __assert_fail ("!structTy.isOpaque() && !structTy.getBody().empty()"
, "flang/lib/Optimizer/CodeGen/CodeGen.cpp", 3331, __extension__
__PRETTY_FUNCTION__))
;
3332
3333 ptr = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, ptr, 0);
3334 }
3335 mlir::LLVM::ConstantOp c0 =
3336 genConstantIndex(isPresent.getLoc(), idxTy, rewriter, 0);
3337 auto addr = rewriter.create<mlir::LLVM::PtrToIntOp>(loc, idxTy, ptr);
3338 rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
3339 isPresent, mlir::LLVM::ICmpPredicate::ne, addr, c0);
3340
3341 return mlir::success();
3342 }
3343};
3344
3345/// Create value signaling an absent optional argument in a call, e.g.
3346/// `fir.absent !fir.ref<i64>` --> `llvm.mlir.null : !llvm.ptr<i64>`
3347struct AbsentOpConversion : public FIROpConversion<fir::AbsentOp> {
3348 using FIROpConversion::FIROpConversion;
3349
3350 mlir::LogicalResult
3351 matchAndRewrite(fir::AbsentOp absent, OpAdaptor,
3352 mlir::ConversionPatternRewriter &rewriter) const override {
3353 mlir::Type ty = convertType(absent.getType());
3354 mlir::Location loc = absent.getLoc();
3355
3356 if (absent.getType().isa<fir::BoxCharType>()) {
3357 auto structTy = ty.cast<mlir::LLVM::LLVMStructType>();
3358 assert(!structTy.isOpaque() && !structTy.getBody().empty())(static_cast <bool> (!structTy.isOpaque() && !structTy
.getBody().empty()) ? void (0) : __assert_fail ("!structTy.isOpaque() && !structTy.getBody().empty()"
, "flang/lib/Optimizer/CodeGen/CodeGen.cpp", 3358, __extension__
__PRETTY_FUNCTION__))
;
3359 auto undefStruct = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
3360 auto nullField =
3361 rewriter.create<mlir::LLVM::NullOp>(loc, structTy.getBody()[0]);
3362 rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(
3363 absent, undefStruct, nullField, 0);
3364 } else {
3365 rewriter.replaceOpWithNewOp<mlir::LLVM::NullOp>(absent, ty);
3366 }
3367 return mlir::success();
3368 }
3369};
3370
3371//
3372// Primitive operations on Complex types
3373//
3374
3375/// Generate inline code for complex addition/subtraction
3376template <typename LLVMOP, typename OPTY>
3377static mlir::LLVM::InsertValueOp
3378complexSum(OPTY sumop, mlir::ValueRange opnds,
3379 mlir::ConversionPatternRewriter &rewriter,
3380 fir::LLVMTypeConverter &lowering) {
3381 mlir::Value a = opnds[0];
3382 mlir::Value b = opnds[1];
3383 auto loc = sumop.getLoc();
3384 mlir::Type eleTy = lowering.convertType(getComplexEleTy(sumop.getType()));
3385 mlir::Type ty = lowering.convertType(sumop.getType());
3386 auto x0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, a, 0);
3387 auto y0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, a, 1);
3388 auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 0);
3389 auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 1);
3390 auto rx = rewriter.create<LLVMOP>(loc, eleTy, x0, x1);
3391 auto ry = rewriter.create<LLVMOP>(loc, eleTy, y0, y1);
3392 auto r0 = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
3393 auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r0, rx, 0);
3394 return rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ry, 1);
3395}
3396} // namespace
3397
3398namespace {
3399struct AddcOpConversion : public FIROpConversion<fir::AddcOp> {
3400 using FIROpConversion::FIROpConversion;
3401
3402 mlir::LogicalResult
3403 matchAndRewrite(fir::AddcOp addc, OpAdaptor adaptor,
3404 mlir::ConversionPatternRewriter &rewriter) const override {
3405 // given: (x + iy) + (x' + iy')
3406 // result: (x + x') + i(y + y')
3407 auto r = complexSum<mlir::LLVM::FAddOp>(addc, adaptor.getOperands(),
3408 rewriter, lowerTy());
3409 rewriter.replaceOp(addc, r.getResult());
3410 return mlir::success();
3411 }
3412};
3413
3414struct SubcOpConversion : public FIROpConversion<fir::SubcOp> {
3415 using FIROpConversion::FIROpConversion;
3416
3417 mlir::LogicalResult
3418 matchAndRewrite(fir::SubcOp subc, OpAdaptor adaptor,
3419 mlir::ConversionPatternRewriter &rewriter) const override {
3420 // given: (x + iy) - (x' + iy')
3421 // result: (x - x') + i(y - y')
3422 auto r = complexSum<mlir::LLVM::FSubOp>(subc, adaptor.getOperands(),
3423 rewriter, lowerTy());
3424 rewriter.replaceOp(subc, r.getResult());
3425 return mlir::success();
3426 }
3427};
3428
3429/// Inlined complex multiply
3430struct MulcOpConversion : public FIROpConversion<fir::MulcOp> {
3431 using FIROpConversion::FIROpConversion;
3432
3433 mlir::LogicalResult
3434 matchAndRewrite(fir::MulcOp mulc, OpAdaptor adaptor,
3435 mlir::ConversionPatternRewriter &rewriter) const override {
3436 // TODO: Can we use a call to __muldc3 ?
3437 // given: (x + iy) * (x' + iy')
3438 // result: (xx'-yy')+i(xy'+yx')
3439 mlir::Value a = adaptor.getOperands()[0];
3440 mlir::Value b = adaptor.getOperands()[1];
3441 auto loc = mulc.getLoc();
3442 mlir::Type eleTy = convertType(getComplexEleTy(mulc.getType()));
3443 mlir::Type ty = convertType(mulc.getType());
3444 auto x0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, a, 0);
3445 auto y0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, a, 1);
3446 auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 0);
3447 auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 1);
3448 auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1);
3449 auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1);
3450 auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1);
3451 auto ri = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xy, yx);
3452 auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1);
3453 auto rr = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, xx, yy);
3454 auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
3455 auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ra, rr, 0);
3456 auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ri, 1);
3457 rewriter.replaceOp(mulc, r0.getResult());
3458 return mlir::success();
3459 }
3460};
3461
3462/// Inlined complex division
3463struct DivcOpConversion : public FIROpConversion<fir::DivcOp> {
3464 using FIROpConversion::FIROpConversion;
3465
3466 mlir::LogicalResult
3467 matchAndRewrite(fir::DivcOp divc, OpAdaptor adaptor,
3468 mlir::ConversionPatternRewriter &rewriter) const override {
3469 // TODO: Can we use a call to __divdc3 instead?
3470 // Just generate inline code for now.
3471 // given: (x + iy) / (x' + iy')
3472 // result: ((xx'+yy')/d) + i((yx'-xy')/d) where d = x'x' + y'y'
3473 mlir::Value a = adaptor.getOperands()[0];
3474 mlir::Value b = adaptor.getOperands()[1];
3475 auto loc = divc.getLoc();
3476 mlir::Type eleTy = convertType(getComplexEleTy(divc.getType()));
3477 mlir::Type ty = convertType(divc.getType());
3478 auto x0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, a, 0);
3479 auto y0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, a, 1);
3480 auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 0);
3481 auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 1);
3482 auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1);
3483 auto x1x1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x1, x1);
3484 auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1);
3485 auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1);
3486 auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1);
3487 auto y1y1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y1, y1);
3488 auto d = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, x1x1, y1y1);
3489 auto rrn = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xx, yy);
3490 auto rin = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, yx, xy);
3491 auto rr = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rrn, d);
3492 auto ri = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rin, d);
3493 auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
3494 auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ra, rr, 0);
3495 auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ri, 1);
3496 rewriter.replaceOp(divc, r0.getResult());
3497 return mlir::success();
3498 }
3499};
3500
3501/// Inlined complex negation
3502struct NegcOpConversion : public FIROpConversion<fir::NegcOp> {
3503 using FIROpConversion::FIROpConversion;
3504
3505 mlir::LogicalResult
3506 matchAndRewrite(fir::NegcOp neg, OpAdaptor adaptor,
3507 mlir::ConversionPatternRewriter &rewriter) const override {
3508 // given: -(x + iy)
3509 // result: -x - iy
3510 auto eleTy = convertType(getComplexEleTy(neg.getType()));
3511 auto loc = neg.getLoc();
3512 mlir::Value o0 = adaptor.getOperands()[0];
3513 auto rp = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, o0, 0);
3514 auto ip = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, o0, 1);
3515 auto nrp = rewriter.create<mlir::LLVM::FNegOp>(loc, eleTy, rp);
3516 auto nip = rewriter.create<mlir::LLVM::FNegOp>(loc, eleTy, ip);
3517 auto r = rewriter.create<mlir::LLVM::InsertValueOp>(loc, o0, nrp, 0);
3518 rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(neg, r, nip, 1);
3519 return mlir::success();
3520 }
3521};
3522
3523/// Conversion pattern for operation that must be dead. The information in these
3524/// operations is used by other operation. At this point they should not have
3525/// anymore uses.
3526/// These operations are normally dead after the pre-codegen pass.
3527template <typename FromOp>
3528struct MustBeDeadConversion : public FIROpConversion<FromOp> {
3529 explicit MustBeDeadConversion(fir::LLVMTypeConverter &lowering,
3530 const fir::FIRToLLVMPassOptions &options)
3531 : FIROpConversion<FromOp>(lowering, options) {}
3532 using OpAdaptor = typename FromOp::Adaptor;
3533
3534 mlir::LogicalResult
3535 matchAndRewrite(FromOp op, OpAdaptor adaptor,
3536 mlir::ConversionPatternRewriter &rewriter) const final {
3537 if (!op->getUses().empty())
3538 return rewriter.notifyMatchFailure(op, "op must be dead");
3539 rewriter.eraseOp(op);
3540 return mlir::success();
3541 }
3542};
3543
3544struct UnrealizedConversionCastOpConversion
3545 : public FIROpConversion<mlir::UnrealizedConversionCastOp> {
3546 using FIROpConversion::FIROpConversion;
3547
3548 mlir::LogicalResult
3549 matchAndRewrite(mlir::UnrealizedConversionCastOp op, OpAdaptor adaptor,
3550 mlir::ConversionPatternRewriter &rewriter) const override {
3551 assert(op.getOutputs().getTypes().size() == 1 && "expect a single type")(static_cast <bool> (op.getOutputs().getTypes().size() ==
1 && "expect a single type") ? void (0) : __assert_fail
("op.getOutputs().getTypes().size() == 1 && \"expect a single type\""
, "flang/lib/Optimizer/CodeGen/CodeGen.cpp", 3551, __extension__
__PRETTY_FUNCTION__))
;
3552 mlir::Type convertedType = convertType(op.getOutputs().getTypes()[0]);
3553 if (convertedType == adaptor.getInputs().getTypes()[0]) {
3554 rewriter.replaceOp(op, adaptor.getInputs());
3555 return mlir::success();
3556 }
3557
3558 convertedType = adaptor.getInputs().getTypes()[0];
3559 if (convertedType == op.getOutputs().getType()[0]) {
3560 rewriter.replaceOp(op, adaptor.getInputs());
3561 return mlir::success();
3562 }
3563 return mlir::failure();
3564 }
3565};
3566
3567struct ShapeOpConversion : public MustBeDeadConversion<fir::ShapeOp> {
3568 using MustBeDeadConversion::MustBeDeadConversion;
3569};
3570
3571struct ShapeShiftOpConversion : public MustBeDeadConversion<fir::ShapeShiftOp> {
3572 using MustBeDeadConversion::MustBeDeadConversion;
3573};
3574
3575struct ShiftOpConversion : public MustBeDeadConversion<fir::ShiftOp> {
3576 using MustBeDeadConversion::MustBeDeadConversion;
3577};
3578
3579struct SliceOpConversion : public MustBeDeadConversion<fir::SliceOp> {
3580 using MustBeDeadConversion::MustBeDeadConversion;
3581};
3582
3583} // namespace
3584
3585namespace {
3586class RenameMSVCLibmCallees
3587 : public mlir::OpRewritePattern<mlir::LLVM::CallOp> {
3588public:
3589 using OpRewritePattern::OpRewritePattern;
3590
3591 mlir::LogicalResult
3592 matchAndRewrite(mlir::LLVM::CallOp op,
3593 mlir::PatternRewriter &rewriter) const override {
3594 rewriter.startRootUpdate(op);
3595 auto callee = op.getCallee();
3596 if (callee)
3597 if (callee->equals("hypotf"))
3598 op.setCalleeAttr(mlir::SymbolRefAttr::get(op.getContext(), "_hypotf"));
3599
3600 rewriter.finalizeRootUpdate(op);
3601 return mlir::success();
3602 }
3603};
3604
3605class RenameMSVCLibmFuncs
3606 : public mlir::OpRewritePattern<mlir::LLVM::LLVMFuncOp> {
3607public:
3608 using OpRewritePattern::OpRewritePattern;
3609
3610 mlir::LogicalResult
3611 matchAndRewrite(mlir::LLVM::LLVMFuncOp op,
3612 mlir::PatternRewriter &rewriter) const override {
3613 rewriter.startRootUpdate(op);
3614 if (op.getSymName().equals("hypotf"))
3615 op.setSymNameAttr(rewriter.getStringAttr("_hypotf"));
3616 rewriter.finalizeRootUpdate(op);
3617 return mlir::success();
3618 }
3619};
3620} // namespace
3621
3622namespace {
3623/// Convert FIR dialect to LLVM dialect
3624///
3625/// This pass lowers all FIR dialect operations to LLVM IR dialect. An
3626/// MLIR pass is used to lower residual Std dialect to LLVM IR dialect.
3627class FIRToLLVMLowering
3628 : public fir::impl::FIRToLLVMLoweringBase<FIRToLLVMLowering> {
3629public:
3630 FIRToLLVMLowering() = default;
3631 FIRToLLVMLowering(fir::FIRToLLVMPassOptions options) : options{options} {}
3632 mlir::ModuleOp getModule() { return getOperation(); }
3633
3634 void runOnOperation() override final {
3635 auto mod = getModule();
3636 if (!forcedTargetTriple.empty())
3637 fir::setTargetTriple(mod, forcedTargetTriple);
3638
3639 // Run dynamic pass pipeline for converting Math dialect
3640 // operations into other dialects (llvm, func, etc.).
3641 // Some conversions of Math operations cannot be done
3642 // by just using conversion patterns. This is true for
3643 // conversions that affect the ModuleOp, e.g. create new
3644 // function operations in it. We have to run such conversions
3645 // as passes here.
3646 mlir::OpPassManager mathConvertionPM("builtin.module");
3647
3648 // Convert math::FPowI operations to inline implementation
3649 // only if the exponent's width is greater than 32, otherwise,
3650 // it will be lowered to LLVM intrinsic operation by a later conversion.
3651 mlir::ConvertMathToFuncsOptions mathToFuncsOptions{};
3652 mathToFuncsOptions.minWidthOfFPowIExponent = 33;
3653 mathConvertionPM.addPass(
3654 mlir::createConvertMathToFuncs(mathToFuncsOptions));
3655 mathConvertionPM.addPass(mlir::createConvertComplexToStandardPass());
3656 // Convert Math dialect operations into LLVM dialect operations.
3657 // There is no way to prefer MathToLLVM patterns over MathToLibm
3658 // patterns (applied below), so we have to run MathToLLVM conversion here.
3659 mathConvertionPM.addNestedPass<mlir::func::FuncOp>(
3660 mlir::createConvertMathToLLVMPass());
3661 if (mlir::failed(runPipeline(mathConvertionPM, mod)))
3662 return signalPassFailure();
3663
3664 auto *context = getModule().getContext();
3665 fir::LLVMTypeConverter typeConverter{getModule(),
3666 options.applyTBAA || applyTBAA};
3667 mlir::RewritePatternSet pattern(context);
3668 pattern.insert<
3669 AbsentOpConversion, AddcOpConversion, AddrOfOpConversion,
3670 AllocaOpConversion, AllocMemOpConversion, BoxAddrOpConversion,
3671 BoxCharLenOpConversion, BoxDimsOpConversion, BoxEleSizeOpConversion,
3672 BoxIsAllocOpConversion, BoxIsArrayOpConversion, BoxIsPtrOpConversion,
3673 BoxProcHostOpConversion, BoxRankOpConversion, BoxTypeCodeOpConversion,
3674 BoxTypeDescOpConversion, CallOpConversion, CmpcOpConversion,
3675 ConstcOpConversion, ConvertOpConversion, CoordinateOpConversion,
3676 DispatchTableOpConversion, DTEntryOpConversion, DivcOpConversion,
3677 EmboxOpConversion, EmboxCharOpConversion, EmboxProcOpConversion,
3678 ExtractValueOpConversion, FieldIndexOpConversion, FirEndOpConversion,
3679 FreeMemOpConversion, GlobalLenOpConversion, GlobalOpConversion,
3680 HasValueOpConversion, InsertOnRangeOpConversion,
3681 InsertValueOpConversion, IsPresentOpConversion,
3682 LenParamIndexOpConversion, LoadOpConversion, MulcOpConversion,
3683 NegcOpConversion, NoReassocOpConversion, SelectCaseOpConversion,
3684 SelectOpConversion, SelectRankOpConversion, SelectTypeOpConversion,
3685 ShapeOpConversion, ShapeShiftOpConversion, ShiftOpConversion,
3686 SliceOpConversion, StoreOpConversion, StringLitOpConversion,
3687 SubcOpConversion, TypeDescOpConversion, UnboxCharOpConversion,
3688 UnboxProcOpConversion, UndefOpConversion, UnreachableOpConversion,
3689 UnrealizedConversionCastOpConversion, XArrayCoorOpConversion,
3690 XEmboxOpConversion, XReboxOpConversion, ZeroOpConversion>(typeConverter,
3691 options);
3692 mlir::populateFuncToLLVMConversionPatterns(typeConverter, pattern);
3693 mlir::populateOpenACCToLLVMConversionPatterns(typeConverter, pattern);
3694 mlir::populateOpenMPToLLVMConversionPatterns(typeConverter, pattern);
3695 mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, pattern);
3696 mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
3697 pattern);
3698 // Math operations that have not been converted yet must be converted
3699 // to Libm.
3700 mlir::populateMathToLibmConversionPatterns(pattern);
3701 mlir::populateComplexToLLVMConversionPatterns(typeConverter, pattern);
3702 mlir::ConversionTarget target{*context};
3703 target.addLegalDialect<mlir::LLVM::LLVMDialect>();
3704 // The OpenMP dialect is legal for Operations without regions, for those
3705 // which contains regions it is legal if the region contains only the
3706 // LLVM dialect. Add OpenMP dialect as a legal dialect for conversion and
3707 // legalize conversion of OpenMP operations without regions.
3708 mlir::configureOpenMPToLLVMConversionLegality(target, typeConverter);
3709 target.addLegalDialect<mlir::omp::OpenMPDialect>();
3710 target.addLegalDialect<mlir::acc::OpenACCDialect>();
3711
3712 // required NOPs for applying a full conversion
3713 target.addLegalOp<mlir::ModuleOp>();
3714
3715 // If we're on Windows, we might need to rename some libm calls.
3716 bool isMSVC = fir::getTargetTriple(mod).isOSMSVCRT();
3717 if (isMSVC) {
3718 pattern.insert<RenameMSVCLibmCallees, RenameMSVCLibmFuncs>(context);
3719
3720 target.addDynamicallyLegalOp<mlir::LLVM::CallOp>(
3721 [](mlir::LLVM::CallOp op) {
3722 auto callee = op.getCallee();
3723 if (!callee)
3724 return true;
3725 return !callee->equals("hypotf");
3726 });
3727 target.addDynamicallyLegalOp<mlir::LLVM::LLVMFuncOp>(
3728 [](mlir::LLVM::LLVMFuncOp op) {
3729 return !op.getSymName().equals("hypotf");
3730 });
3731 }
3732
3733 // apply the patterns
3734 if (mlir::failed(mlir::applyFullConversion(getModule(), target,
3735 std::move(pattern)))) {
3736 signalPassFailure();
3737 }
3738 }
3739
3740private:
3741 fir::FIRToLLVMPassOptions options;
3742};
3743
3744/// Lower from LLVM IR dialect to proper LLVM-IR and dump the module
3745struct LLVMIRLoweringPass
3746 : public mlir::PassWrapper<LLVMIRLoweringPass,
3747 mlir::OperationPass<mlir::ModuleOp>> {
3748 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LLVMIRLoweringPass)static ::mlir::TypeID resolveTypeID() { static ::mlir::SelfOwningTypeID
id; return id; } static_assert( ::mlir::detail::InlineTypeIDResolver
::has_resolve_typeid< LLVMIRLoweringPass>::value, "`MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID` must be placed in a "
"public section of `" "LLVMIRLoweringPass" "`");
3749
3750 LLVMIRLoweringPass(llvm::raw_ostream &output, fir::LLVMIRLoweringPrinter p)
3751 : output{output}, printer{p} {}
3752
3753 mlir::ModuleOp getModule() { return getOperation(); }
3754
3755 void runOnOperation() override final {
3756 auto *ctx = getModule().getContext();
3757 auto optName = getModule().getName();
3758 llvm::LLVMContext llvmCtx;
3759 if (auto llvmModule = mlir::translateModuleToLLVMIR(
3760 getModule(), llvmCtx, optName ? *optName : "FIRModule")) {
3761 printer(*llvmModule, output);
3762 return;
3763 }
3764
3765 mlir::emitError(mlir::UnknownLoc::get(ctx), "could not emit LLVM-IR\n");
3766 signalPassFailure();
3767 }
3768
3769private:
3770 llvm::raw_ostream &output;
3771 fir::LLVMIRLoweringPrinter printer;
3772};
3773
3774} // namespace
3775
3776std::unique_ptr<mlir::Pass> fir::createFIRToLLVMPass() {
3777 return std::make_unique<FIRToLLVMLowering>();
3778}
3779
3780std::unique_ptr<mlir::Pass>
3781fir::createFIRToLLVMPass(fir::FIRToLLVMPassOptions options) {
3782 return std::make_unique<FIRToLLVMLowering>(options);
3783}
3784
3785std::unique_ptr<mlir::Pass>
3786fir::createLLVMDialectToLLVMPass(llvm::raw_ostream &output,
3787 fir::LLVMIRLoweringPrinter printer) {
3788 return std::make_unique<LLVMIRLoweringPass>(output, printer);
3789}

/build/source/llvm/../mlir/include/mlir/IR/PatternMatch.h

1//===- PatternMatch.h - PatternMatcher classes -------==---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_IR_PATTERNMATCH_H
10#define MLIR_IR_PATTERNMATCH_H
11
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/BuiltinOps.h"
14#include "llvm/ADT/FunctionExtras.h"
15#include "llvm/Support/TypeName.h"
16#include <optional>
17
18namespace mlir {
19
20class PatternRewriter;
21
22//===----------------------------------------------------------------------===//
23// PatternBenefit class
24//===----------------------------------------------------------------------===//
25
26/// This class represents the benefit of a pattern match in a unitless scheme
27/// that ranges from 0 (very little benefit) to 65K. The most common unit to
28/// use here is the "number of operations matched" by the pattern.
29///
30/// This also has a sentinel representation that can be used for patterns that
31/// fail to match.
32///
33class PatternBenefit {
34 enum { ImpossibleToMatchSentinel = 65535 };
35
36public:
37 PatternBenefit() = default;
38 PatternBenefit(unsigned benefit);
39 PatternBenefit(const PatternBenefit &) = default;
40 PatternBenefit &operator=(const PatternBenefit &) = default;
41
42 static PatternBenefit impossibleToMatch() { return PatternBenefit(); }
43 bool isImpossibleToMatch() const { return *this == impossibleToMatch(); }
44
45 /// If the corresponding pattern can match, return its benefit. If the
46 // corresponding pattern isImpossibleToMatch() then this aborts.
47 unsigned short getBenefit() const;
48
49 bool operator==(const PatternBenefit &rhs) const {
50 return representation == rhs.representation;
51 }
52 bool operator!=(const PatternBenefit &rhs) const { return !(*this == rhs); }
53 bool operator<(const PatternBenefit &rhs) const {
54 return representation < rhs.representation;
55 }
56 bool operator>(const PatternBenefit &rhs) const { return rhs < *this; }
57 bool operator<=(const PatternBenefit &rhs) const { return !(*this > rhs); }
58 bool operator>=(const PatternBenefit &rhs) const { return !(*this < rhs); }
59
60private:
61 unsigned short representation{ImpossibleToMatchSentinel};
62};
63
64//===----------------------------------------------------------------------===//
65// Pattern
66//===----------------------------------------------------------------------===//
67
68/// This class contains all of the data related to a pattern, but does not
69/// contain any methods or logic for the actual matching. This class is solely
70/// used to interface with the metadata of a pattern, such as the benefit or
71/// root operation.
72class Pattern {
73 /// This enum represents the kind of value used to select the root operations
74 /// that match this pattern.
75 enum class RootKind {
76 /// The pattern root matches "any" operation.
77 Any,
78 /// The pattern root is matched using a concrete operation name.
79 OperationName,
80 /// The pattern root is matched using an interface ID.
81 InterfaceID,
82 /// The patter root is matched using a trait ID.
83 TraitID
84 };
85
86public:
87 /// Return a list of operations that may be generated when rewriting an
88 /// operation instance with this pattern.
89 ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
90
91 /// Return the root node that this pattern matches. Patterns that can match
92 /// multiple root types return std::nullopt.
93 std::optional<OperationName> getRootKind() const {
94 if (rootKind == RootKind::OperationName)
95 return OperationName::getFromOpaquePointer(rootValue);
96 return std::nullopt;
97 }
98
99 /// Return the interface ID used to match the root operation of this pattern.
100 /// If the pattern does not use an interface ID for deciding the root match,
101 /// this returns std::nullopt.
102 std::optional<TypeID> getRootInterfaceID() const {
103 if (rootKind == RootKind::InterfaceID)
104 return TypeID::getFromOpaquePointer(rootValue);
105 return std::nullopt;
106 }
107
108 /// Return the trait ID used to match the root operation of this pattern.
109 /// If the pattern does not use a trait ID for deciding the root match, this
110 /// returns std::nullopt.
111 std::optional<TypeID> getRootTraitID() const {
112 if (rootKind == RootKind::TraitID)
113 return TypeID::getFromOpaquePointer(rootValue);
114 return std::nullopt;
115 }
116
117 /// Return the benefit (the inverse of "cost") of matching this pattern. The
118 /// benefit of a Pattern is always static - rewrites that may have dynamic
119 /// benefit can be instantiated multiple times (different Pattern instances)
120 /// for each benefit that they may return, and be guarded by different match
121 /// condition predicates.
122 PatternBenefit getBenefit() const { return benefit; }
123
124 /// Returns true if this pattern is known to result in recursive application,
125 /// i.e. this pattern may generate IR that also matches this pattern, but is
126 /// known to bound the recursion. This signals to a rewrite driver that it is
127 /// safe to apply this pattern recursively to generated IR.
128 bool hasBoundedRewriteRecursion() const {
129 return contextAndHasBoundedRecursion.getInt();
130 }
131
132 /// Return the MLIRContext used to create this pattern.
133 MLIRContext *getContext() const {
134 return contextAndHasBoundedRecursion.getPointer();
135 }
136
137 /// Return a readable name for this pattern. This name should only be used for
138 /// debugging purposes, and may be empty.
139 StringRef getDebugName() const { return debugName; }
140
141 /// Set the human readable debug name used for this pattern. This name will
142 /// only be used for debugging purposes.
143 void setDebugName(StringRef name) { debugName = name; }
144
145 /// Return the set of debug labels attached to this pattern.
146 ArrayRef<StringRef> getDebugLabels() const { return debugLabels; }
147
148 /// Add the provided debug labels to this pattern.
149 void addDebugLabels(ArrayRef<StringRef> labels) {
150 debugLabels.append(labels.begin(), labels.end());
151 }
152 void addDebugLabels(StringRef label) { debugLabels.push_back(label); }
153
154protected:
155 /// This class acts as a special tag that makes the desire to match "any"
156 /// operation type explicit. This helps to avoid unnecessary usages of this
157 /// feature, and ensures that the user is making a conscious decision.
158 struct MatchAnyOpTypeTag {};
159 /// This class acts as a special tag that makes the desire to match any
160 /// operation that implements a given interface explicit. This helps to avoid
161 /// unnecessary usages of this feature, and ensures that the user is making a
162 /// conscious decision.
163 struct MatchInterfaceOpTypeTag {};
164 /// This class acts as a special tag that makes the desire to match any
165 /// operation that implements a given trait explicit. This helps to avoid
166 /// unnecessary usages of this feature, and ensures that the user is making a
167 /// conscious decision.
168 struct MatchTraitOpTypeTag {};
169
170 /// Construct a pattern with a certain benefit that matches the operation
171 /// with the given root name.
172 Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context,
173 ArrayRef<StringRef> generatedNames = {});
174 /// Construct a pattern that may match any operation type. `generatedNames`
175 /// contains the names of operations that may be generated during a successful
176 /// rewrite. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
177 /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
178 /// always be supplied here.
179 Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit, MLIRContext *context,
180 ArrayRef<StringRef> generatedNames = {});
181 /// Construct a pattern that may match any operation that implements the
182 /// interface defined by the provided `interfaceID`. `generatedNames` contains
183 /// the names of operations that may be generated during a successful rewrite.
184 /// `MatchInterfaceOpTypeTag` is just a tag to ensure that the "match
185 /// interface" behavior is what the user actually desired,
186 /// `MatchInterfaceOpTypeTag()` should always be supplied here.
187 Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID,
188 PatternBenefit benefit, MLIRContext *context,
189 ArrayRef<StringRef> generatedNames = {});
190 /// Construct a pattern that may match any operation that implements the
191 /// trait defined by the provided `traitID`. `generatedNames` contains the
192 /// names of operations that may be generated during a successful rewrite.
193 /// `MatchTraitOpTypeTag` is just a tag to ensure that the "match trait"
194 /// behavior is what the user actually desired, `MatchTraitOpTypeTag()` should
195 /// always be supplied here.
196 Pattern(MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit,
197 MLIRContext *context, ArrayRef<StringRef> generatedNames = {});
198
199 /// Set the flag detailing if this pattern has bounded rewrite recursion or
200 /// not.
201 void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg = true) {
202 contextAndHasBoundedRecursion.setInt(hasBoundedRecursionArg);
203 }
204
205private:
206 Pattern(const void *rootValue, RootKind rootKind,
207 ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
208 MLIRContext *context);
209
210 /// The value used to match the root operation of the pattern.
211 const void *rootValue;
212 RootKind rootKind;
213
214 /// The expected benefit of matching this pattern.
215 const PatternBenefit benefit;
216
217 /// The context this pattern was created from, and a boolean flag indicating
218 /// whether this pattern has bounded recursion or not.
219 llvm::PointerIntPair<MLIRContext *, 1, bool> contextAndHasBoundedRecursion;
220
221 /// A list of the potential operations that may be generated when rewriting
222 /// an op with this pattern.
223 SmallVector<OperationName, 2> generatedOps;
224
225 /// A readable name for this pattern. May be empty.
226 StringRef debugName;
227
228 /// The set of debug labels attached to this pattern.
229 SmallVector<StringRef, 0> debugLabels;
230};
231
232//===----------------------------------------------------------------------===//
233// RewritePattern
234//===----------------------------------------------------------------------===//
235
236/// RewritePattern is the common base class for all DAG to DAG replacements.
237/// There are two possible usages of this class:
238/// * Multi-step RewritePattern with "match" and "rewrite"
239/// - By overloading the "match" and "rewrite" functions, the user can
240/// separate the concerns of matching and rewriting.
241/// * Single-step RewritePattern with "matchAndRewrite"
242/// - By overloading the "matchAndRewrite" function, the user can perform
243/// the rewrite in the same call as the match.
244///
245class RewritePattern : public Pattern {
246public:
247 virtual ~RewritePattern() = default;
248
249 /// Rewrite the IR rooted at the specified operation with the result of
250 /// this pattern, generating any new operations with the specified
251 /// builder. If an unexpected error is encountered (an internal
252 /// compiler error), it is emitted through the normal MLIR diagnostic
253 /// hooks and the IR is left in a valid state.
254 virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
255
256 /// Attempt to match against code rooted at the specified operation,
257 /// which is the same operation code as getRootKind().
258 virtual LogicalResult match(Operation *op) const;
259
260 /// Attempt to match against code rooted at the specified operation,
261 /// which is the same operation code as getRootKind(). If successful, this
262 /// function will automatically perform the rewrite.
263 virtual LogicalResult matchAndRewrite(Operation *op,
264 PatternRewriter &rewriter) const {
265 if (succeeded(match(op))) {
266 rewrite(op, rewriter);
267 return success();
268 }
269 return failure();
270 }
271
272 /// This method provides a convenient interface for creating and initializing
273 /// derived rewrite patterns of the given type `T`.
274 template <typename T, typename... Args>
275 static std::unique_ptr<T> create(Args &&...args) {
276 std::unique_ptr<T> pattern =
277 std::make_unique<T>(std::forward<Args>(args)...);
278 initializePattern<T>(*pattern);
279
280 // Set a default debug name if one wasn't provided.
281 if (pattern->getDebugName().empty())
282 pattern->setDebugName(llvm::getTypeName<T>());
283 return pattern;
284 }
285
286protected:
287 /// Inherit the base constructors from `Pattern`.
288 using Pattern::Pattern;
289
290private:
291 /// Trait to check if T provides a `getOperationName` method.
292 template <typename T, typename... Args>
293 using has_initialize = decltype(std::declval<T>().initialize());
294 template <typename T>
295 using detect_has_initialize = llvm::is_detected<has_initialize, T>;
296
297 /// Initialize the derived pattern by calling its `initialize` method.
298 template <typename T>
299 static std::enable_if_t<detect_has_initialize<T>::value>
300 initializePattern(T &pattern) {
301 pattern.initialize();
302 }
303 /// Empty derived pattern initializer for patterns that do not have an
304 /// initialize method.
305 template <typename T>
306 static std::enable_if_t<!detect_has_initialize<T>::value>
307 initializePattern(T &) {}
308
309 /// An anchor for the virtual table.
310 virtual void anchor();
311};
312
313namespace detail {
314/// OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that
315/// allows for matching and rewriting against an instance of a derived operation
316/// class or Interface.
317template <typename SourceOp>
318struct OpOrInterfaceRewritePatternBase : public RewritePattern {
319 using RewritePattern::RewritePattern;
320
321 /// Wrappers around the RewritePattern methods that pass the derived op type.
322 void rewrite(Operation *op, PatternRewriter &rewriter) const final {
323 rewrite(cast<SourceOp>(op), rewriter);
324 }
325 LogicalResult match(Operation *op) const final {
326 return match(cast<SourceOp>(op));
327 }
328 LogicalResult matchAndRewrite(Operation *op,
329 PatternRewriter &rewriter) const final {
330 return matchAndRewrite(cast<SourceOp>(op), rewriter);
331 }
332
333 /// Rewrite and Match methods that operate on the SourceOp type. These must be
334 /// overridden by the derived pattern class.
335 virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
336 llvm_unreachable("must override rewrite or matchAndRewrite")::llvm::llvm_unreachable_internal("must override rewrite or matchAndRewrite"
, "llvm/../mlir/include/mlir/IR/PatternMatch.h", 336)
;
337 }
338 virtual LogicalResult match(SourceOp op) const {
339 llvm_unreachable("must override match or matchAndRewrite")::llvm::llvm_unreachable_internal("must override match or matchAndRewrite"
, "llvm/../mlir/include/mlir/IR/PatternMatch.h", 339)
;
340 }
341 virtual LogicalResult matchAndRewrite(SourceOp op,
342 PatternRewriter &rewriter) const {
343 if (succeeded(match(op))) {
344 rewrite(op, rewriter);
345 return success();
346 }
347 return failure();
348 }
349};
350} // namespace detail
351
352/// OpRewritePattern is a wrapper around RewritePattern that allows for
353/// matching and rewriting against an instance of a derived operation class as
354/// opposed to a raw Operation.
355template <typename SourceOp>
356struct OpRewritePattern
357 : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
358 /// Patterns must specify the root operation name they match against, and can
359 /// also specify the benefit of the pattern matching and a list of generated
360 /// ops.
361 OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1,
362 ArrayRef<StringRef> generatedNames = {})
363 : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
364 SourceOp::getOperationName(), benefit, context, generatedNames) {}
365};
366
367/// OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for
368/// matching and rewriting against an instance of an operation interface instead
369/// of a raw Operation.
370template <typename SourceOp>
371struct OpInterfaceRewritePattern
372 : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
373 OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
374 : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
375 Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
376 benefit, context) {}
377};
378
379/// OpTraitRewritePattern is a wrapper around RewritePattern that allows for
380/// matching and rewriting against instances of an operation that possess a
381/// given trait.
382template <template <typename> class TraitType>
383class OpTraitRewritePattern : public RewritePattern {
384public:
385 OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
386 : RewritePattern(Pattern::MatchTraitOpTypeTag(), TypeID::get<TraitType>(),
387 benefit, context) {}
388};
389
390//===----------------------------------------------------------------------===//
391// RewriterBase
392//===----------------------------------------------------------------------===//
393
394/// This class coordinates the application of a rewrite on a set of IR,
395/// providing a way for clients to track mutations and create new operations.
396/// This class serves as a common API for IR mutation between pattern rewrites
397/// and non-pattern rewrites, and facilitates the development of shared
398/// IR transformation utilities.
399class RewriterBase : public OpBuilder {
400public:
401 struct Listener : public OpBuilder::Listener {
402 Listener()
403 : OpBuilder::Listener(ListenerBase::Kind::RewriterBaseListener) {}
404
405 /// Notify the listener that the specified operation was modified in-place.
406 virtual void notifyOperationModified(Operation *op) {}
407
408 /// Notify the listener that the specified operation is about to be replaced
409 /// with the set of values potentially produced by new operations. This is
410 /// called before the uses of the operation have been changed.
411 virtual void notifyOperationReplaced(Operation *op,
412 ValueRange replacement) {}
413
414 /// This is called on an operation that a rewrite is removing, right before
415 /// the operation is deleted. At this point, the operation has zero uses.
416 virtual void notifyOperationRemoved(Operation *op) {}
417
418 /// Notify the listener that the pattern failed to match the given
419 /// operation, and provide a callback to populate a diagnostic with the
420 /// reason why the failure occurred. This method allows for derived
421 /// listeners to optionally hook into the reason why a rewrite failed, and
422 /// display it to users.
423 virtual LogicalResult
424 notifyMatchFailure(Location loc,
425 function_ref<void(Diagnostic &)> reasonCallback) {
426 return failure();
427 }
428
429 static bool classof(const OpBuilder::Listener *base);
430 };
431
432 /// Move the blocks that belong to "region" before the given position in
433 /// another region "parent". The two regions must be different. The caller
434 /// is responsible for creating or updating the operation transferring flow
435 /// of control to the region and passing it the correct block arguments.
436 virtual void inlineRegionBefore(Region &region, Region &parent,
437 Region::iterator before);
438 void inlineRegionBefore(Region &region, Block *before);
439
440 /// Clone the blocks that belong to "region" before the given position in
441 /// another region "parent". The two regions must be different. The caller is
442 /// responsible for creating or updating the operation transferring flow of
443 /// control to the region and passing it the correct block arguments.
444 virtual void cloneRegionBefore(Region &region, Region &parent,
445 Region::iterator before, IRMapping &mapping);
446 void cloneRegionBefore(Region &region, Region &parent,
447 Region::iterator before);
448 void cloneRegionBefore(Region &region, Block *before);
449
450 /// This method replaces the uses of the results of `op` with the values in
451 /// `newValues` when the provided `functor` returns true for a specific use.
452 /// The number of values in `newValues` is required to match the number of
453 /// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of
454 /// the uses of `op` were replaced. Note that in some rewriters, the given
455 /// 'functor' may be stored beyond the lifetime of the rewrite being applied.
456 /// As such, the function should not capture by reference and instead use
457 /// value capture as necessary.
458 virtual void
459 replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced,
460 llvm::unique_function<bool(OpOperand &) const> functor);
461 void replaceOpWithIf(Operation *op, ValueRange newValues,
462 llvm::unique_function<bool(OpOperand &) const> functor) {
463 replaceOpWithIf(op, newValues, /*allUsesReplaced=*/nullptr,
464 std::move(functor));
465 }
466
467 /// This method replaces the uses of the results of `op` with the values in
468 /// `newValues` when a use is nested within the given `block`. The number of
469 /// values in `newValues` is required to match the number of results of `op`.
470 /// If all uses of this operation are replaced, the operation is erased.
471 void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block,
472 bool *allUsesReplaced = nullptr);
473
474 /// This method replaces the results of the operation with the specified list
475 /// of values. The number of provided values must match the number of results
476 /// of the operation.
477 virtual void replaceOp(Operation *op, ValueRange newValues);
478
479 /// Replaces the result op with a new op that is created without verification.
480 /// The result values of the two ops must be the same types.
481 template <typename OpTy, typename... Args>
482 OpTy replaceOpWithNewOp(Operation *op, Args &&...args) {
483 auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
14
Calling 'forward<mlir::Block *&>'
15
Returning from 'forward<mlir::Block *&>'
16
Calling 'OpBuilder::create'
484 replaceOpWithResultsOfAnotherOp(op, newOp.getOperation());
485 return newOp;
486 }
487
488 /// This method erases an operation that is known to have no uses.
489 virtual void eraseOp(Operation *op);
490
491 /// This method erases all operations in a block.
492 virtual void eraseBlock(Block *block);
493
494 /// Inline the operations of block 'source' into block 'dest' before the given
495 /// position. The source block will be deleted and must have no uses.
496 /// 'argValues' is used to replace the block arguments of 'source'.
497 ///
498 /// If the source block is inserted at the end of the dest block, the dest
499 /// block must have no successors. Similarly, if the source block is inserted
500 /// somewhere in the middle (or beginning) of the dest block, the source block
501 /// must have no successors. Otherwise, the resulting IR would have
502 /// unreachable operations.
503 virtual void inlineBlockBefore(Block *source, Block *dest,
504 Block::iterator before,
505 ValueRange argValues = std::nullopt);
506
507 /// Inline the operations of block 'source' before the operation 'op'. The
508 /// source block will be deleted and must have no uses. 'argValues' is used to
509 /// replace the block arguments of 'source'
510 ///
511 /// The source block must have no successors. Otherwise, the resulting IR
512 /// would have unreachable operations.
513 void inlineBlockBefore(Block *source, Operation *op,
514 ValueRange argValues = std::nullopt);
515
516 /// Inline the operations of block 'source' into the end of block 'dest'. The
517 /// source block will be deleted and must have no uses. 'argValues' is used to
518 /// replace the block arguments of 'source'
519 ///
520 /// The dest block must have no successors. Otherwise, the resulting IR would
521 /// have unreachable operation.
522 void mergeBlocks(Block *source, Block *dest,
523 ValueRange argValues = std::nullopt);
524
525 /// Split the operations starting at "before" (inclusive) out of the given
526 /// block into a new block, and return it.
527 virtual Block *splitBlock(Block *block, Block::iterator before);
528
529 /// This method is used to notify the rewriter that an in-place operation
530 /// modification is about to happen. A call to this function *must* be
531 /// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`.
532 /// This is a minor efficiency win (it avoids creating a new operation and
533 /// removing the old one) but also often allows simpler code in the client.
534 virtual void startRootUpdate(Operation *op) {}
535
536 /// This method is used to signal the end of a root update on the given
537 /// operation. This can only be called on operations that were provided to a
538 /// call to `startRootUpdate`.
539 virtual void finalizeRootUpdate(Operation *op);
540
541 /// This method cancels a pending root update. This can only be called on
542 /// operations that were provided to a call to `startRootUpdate`.
543 virtual void cancelRootUpdate(Operation *op) {}
544
545 /// This method is a utility wrapper around a root update of an operation. It
546 /// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given
547 /// callable.
548 template <typename CallableT>
549 void updateRootInPlace(Operation *root, CallableT &&callable) {
550 startRootUpdate(root);
551 callable();
552 finalizeRootUpdate(root);
553 }
554
555 /// Find uses of `from` and replace them with `to`. It also marks every
556 /// modified uses and notifies the rewriter that an in-place operation
557 /// modification is about to happen.
558 void replaceAllUsesWith(Value from, Value to) {
559 return replaceAllUsesWith(from.getImpl(), to);
560 }
561 template <typename OperandType, typename ValueT>
562 void replaceAllUsesWith(IRObjectWithUseList<OperandType> *from, ValueT &&to) {
563 for (OperandType &operand : llvm::make_early_inc_range(from->getUses())) {
564 Operation *op = operand.getOwner();
565 updateRootInPlace(op, [&]() { operand.set(to); });
566 }
567 }
568 void replaceAllUsesWith(ValueRange from, ValueRange to) {
569 assert(from.size() == to.size() && "incorrect number of replacements")(static_cast <bool> (from.size() == to.size() &&
"incorrect number of replacements") ? void (0) : __assert_fail
("from.size() == to.size() && \"incorrect number of replacements\""
, "llvm/../mlir/include/mlir/IR/PatternMatch.h", 569, __extension__
__PRETTY_FUNCTION__))
;
570 for (auto it : llvm::zip(from, to))
571 replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
572 }
573
574 /// Find uses of `from` and replace them with `to` if the `functor` returns
575 /// true. It also marks every modified uses and notifies the rewriter that an
576 /// in-place operation modification is about to happen.
577 void replaceUsesWithIf(Value from, Value to,
578 function_ref<bool(OpOperand &)> functor);
579
580 /// Find uses of `from` and replace them with `to` except if the user is
581 /// `exceptedUser`. It also marks every modified uses and notifies the
582 /// rewriter that an in-place operation modification is about to happen.
583 void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser) {
584 return replaceUsesWithIf(from, to, [&](OpOperand &use) {
585 Operation *user = use.getOwner();
586 return user != exceptedUser;
587 });
588 }
589
590 /// Used to notify the rewriter that the IR failed to be rewritten because of
591 /// a match failure, and provide a callback to populate a diagnostic with the
592 /// reason why the failure occurred. This method allows for derived rewriters
593 /// to optionally hook into the reason why a rewrite failed, and display it to
594 /// users.
595 template <typename CallbackT>
596 std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
597 notifyMatchFailure(Location loc, CallbackT &&reasonCallback) {
598#ifndef NDEBUG
599 if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
600 return rewriteListener->notifyMatchFailure(
601 loc, function_ref<void(Diagnostic &)>(reasonCallback));
602 return failure();
603#else
604 return failure();
605#endif
606 }
607 template <typename CallbackT>
608 std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
609 notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) {
610 if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
611 return rewriteListener->notifyMatchFailure(
612 op->getLoc(), function_ref<void(Diagnostic &)>(reasonCallback));
613 return failure();
614 }
615 template <typename ArgT>
616 LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg) {
617 return notifyMatchFailure(std::forward<ArgT>(arg),
618 [&](Diagnostic &diag) { diag << msg; });
619 }
620 template <typename ArgT>
621 LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg) {
622 return notifyMatchFailure(std::forward<ArgT>(arg), Twine(msg));
623 }
624
625protected:
626 /// Initialize the builder.
627 explicit RewriterBase(MLIRContext *ctx,
628 OpBuilder::Listener *listener = nullptr)
629 : OpBuilder(ctx, listener) {}
630 explicit RewriterBase(const OpBuilder &otherBuilder)
631 : OpBuilder(otherBuilder) {}
632 virtual ~RewriterBase();
633
634private:
635 void operator=(const RewriterBase &) = delete;
636 RewriterBase(const RewriterBase &) = delete;
637
638 /// 'op' and 'newOp' are known to have the same number of results, replace the
639 /// uses of op with uses of newOp.
640 void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp);
641};
642
643//===----------------------------------------------------------------------===//
644// IRRewriter
645//===----------------------------------------------------------------------===//
646
647/// This class coordinates rewriting a piece of IR outside of a pattern rewrite,
648/// providing a way to keep track of the mutations made to the IR. This class
649/// should only be used in situations where another `RewriterBase` instance,
650/// such as a `PatternRewriter`, is not available.
651class IRRewriter : public RewriterBase {
652public:
653 explicit IRRewriter(MLIRContext *ctx, OpBuilder::Listener *listener = nullptr)
654 : RewriterBase(ctx, listener) {}
655 explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {}
656};
657
658//===----------------------------------------------------------------------===//
659// PatternRewriter
660//===----------------------------------------------------------------------===//
661
662/// A special type of `RewriterBase` that coordinates the application of a
663/// rewrite pattern on the current IR being matched, providing a way to keep
664/// track of any mutations made. This class should be used to perform all
665/// necessary IR mutations within a rewrite pattern, as the pattern driver may
666/// be tracking various state that would be invalidated when a mutation takes
667/// place.
668class PatternRewriter : public RewriterBase {
669public:
670 using RewriterBase::RewriterBase;
671
672 /// A hook used to indicate if the pattern rewriter can recover from failure
673 /// during the rewrite stage of a pattern. For example, if the pattern
674 /// rewriter supports rollback, it may progress smoothly even if IR was
675 /// changed during the rewrite.
676 virtual bool canRecoverFromRewriteFailure() const { return false; }
677};
678
679//===----------------------------------------------------------------------===//
680// PDL Patterns
681//===----------------------------------------------------------------------===//
682
683//===----------------------------------------------------------------------===//
684// PDLValue
685
686/// Storage type of byte-code interpreter values. These are passed to constraint
687/// functions as arguments.
688class PDLValue {
689public:
690 /// The underlying kind of a PDL value.
691 enum class Kind { Attribute, Operation, Type, TypeRange, Value, ValueRange };
692
693 /// Construct a new PDL value.
694 PDLValue(const PDLValue &other) = default;
695 PDLValue(std::nullptr_t = nullptr) {}
696 PDLValue(Attribute value)
697 : value(value.getAsOpaquePointer()), kind(Kind::Attribute) {}
698 PDLValue(Operation *value) : value(value), kind(Kind::Operation) {}
699 PDLValue(Type value) : value(value.getAsOpaquePointer()), kind(Kind::Type) {}
700 PDLValue(TypeRange *value) : value(value), kind(Kind::TypeRange) {}
701 PDLValue(Value value)
702 : value(value.getAsOpaquePointer()), kind(Kind::Value) {}
703 PDLValue(ValueRange *value) : value(value), kind(Kind::ValueRange) {}
704
705 /// Returns true if the type of the held value is `T`.
706 template <typename T>
707 bool isa() const {
708 assert(value && "isa<> used on a null value")(static_cast <bool> (value && "isa<> used on a null value"
) ? void (0) : __assert_fail ("value && \"isa<> used on a null value\""
, "llvm/../mlir/include/mlir/IR/PatternMatch.h", 708, __extension__
__PRETTY_FUNCTION__))
;
709 return kind == getKindOf<T>();
710 }
711
712 /// Attempt to dynamically cast this value to type `T`, returns null if this
713 /// value is not an instance of `T`.
714 template <typename T,
715 typename ResultT = std::conditional_t<
716 std::is_convertible<T, bool>::value, T, std::optional<T>>>
717 ResultT dyn_cast() const {
718 return isa<T>() ? castImpl<T>() : ResultT();
719 }
720
721 /// Cast this value to type `T`, asserts if this value is not an instance of
722 /// `T`.
723 template <typename T>
724 T cast() const {
725 assert(isa<T>() && "expected value to be of type `T`")(static_cast <bool> (isa<T>() && "expected value to be of type `T`"
) ? void (0) : __assert_fail ("isa<T>() && \"expected value to be of type `T`\""
, "llvm/../mlir/include/mlir/IR/PatternMatch.h", 725, __extension__
__PRETTY_FUNCTION__))
;
726 return castImpl<T>();
727 }
728
729 /// Get an opaque pointer to the value.
730 const void *getAsOpaquePointer() const { return value; }
731
732 /// Return if this value is null or not.
733 explicit operator bool() const { return value; }
734
735 /// Return the kind of this value.
736 Kind getKind() const { return kind; }
737
738 /// Print this value to the provided output stream.
739 void print(raw_ostream &os) const;
740
741 /// Print the specified value kind to an output stream.
742 static void print(raw_ostream &os, Kind kind);
743
744private:
745 /// Find the index of a given type in a range of other types.
746 template <typename...>
747 struct index_of_t;
748 template <typename T, typename... R>
749 struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {};
750 template <typename T, typename F, typename... R>
751 struct index_of_t<T, F, R...>
752 : std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {};
753
754 /// Return the kind used for the given T.
755 template <typename T>
756 static Kind getKindOf() {
757 return static_cast<Kind>(index_of_t<T, Attribute, Operation *, Type,
758 TypeRange, Value, ValueRange>::value);
759 }
760
761 /// The internal implementation of `cast`, that returns the underlying value
762 /// as the given type `T`.
763 template <typename T>
764 std::enable_if_t<llvm::is_one_of<T, Attribute, Type, Value>::value, T>
765 castImpl() const {
766 return T::getFromOpaquePointer(value);
767 }
768 template <typename T>
769 std::enable_if_t<llvm::is_one_of<T, TypeRange, ValueRange>::value, T>
770 castImpl() const {
771 return *reinterpret_cast<T *>(const_cast<void *>(value));
772 }
773 template <typename T>
774 std::enable_if_t<std::is_pointer<T>::value, T> castImpl() const {
775 return reinterpret_cast<T>(const_cast<void *>(value));
776 }
777
778 /// The internal opaque representation of a PDLValue.
779 const void *value{nullptr};
780 /// The kind of the opaque value.
781 Kind kind{Kind::Attribute};
782};
783
784inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) {
785 value.print(os);
786 return os;
787}
788
789inline raw_ostream &operator<<(raw_ostream &os, PDLValue::Kind kind) {
790 PDLValue::print(os, kind);
791 return os;
792}
793
794//===----------------------------------------------------------------------===//
795// PDLResultList
796
797/// The class represents a list of PDL results, returned by a native rewrite
798/// method. It provides the mechanism with which to pass PDLValues back to the
799/// PDL bytecode.
800class PDLResultList {
801public:
802 /// Push a new Attribute value onto the result list.
803 void push_back(Attribute value) { results.push_back(value); }
804
805 /// Push a new Operation onto the result list.
806 void push_back(Operation *value) { results.push_back(value); }
807
808 /// Push a new Type onto the result list.
809 void push_back(Type value) { results.push_back(value); }
810
811 /// Push a new TypeRange onto the result list.
812 void push_back(TypeRange value) {
813 // The lifetime of a TypeRange can't be guaranteed, so we'll need to
814 // allocate a storage for it.
815 llvm::OwningArrayRef<Type> storage(value.size());
816 llvm::copy(value, storage.begin());
817 allocatedTypeRanges.emplace_back(std::move(storage));
818 typeRanges.push_back(allocatedTypeRanges.back());
819 results.push_back(&typeRanges.back());
820 }
821 void push_back(ValueTypeRange<OperandRange> value) {
822 typeRanges.push_back(value);
823 results.push_back(&typeRanges.back());
824 }
825 void push_back(ValueTypeRange<ResultRange> value) {
826 typeRanges.push_back(value);
827 results.push_back(&typeRanges.back());
828 }
829
830 /// Push a new Value onto the result list.
831 void push_back(Value value) { results.push_back(value); }
832
833 /// Push a new ValueRange onto the result list.
834 void push_back(ValueRange value) {
835 // The lifetime of a ValueRange can't be guaranteed, so we'll need to
836 // allocate a storage for it.
837 llvm::OwningArrayRef<Value> storage(value.size());
838 llvm::copy(value, storage.begin());
839 allocatedValueRanges.emplace_back(std::move(storage));
840 valueRanges.push_back(allocatedValueRanges.back());
841 results.push_back(&valueRanges.back());
842 }
843 void push_back(OperandRange value) {
844 valueRanges.push_back(value);
845 results.push_back(&valueRanges.back());
846 }
847 void push_back(ResultRange value) {
848 valueRanges.push_back(value);
849 results.push_back(&valueRanges.back());
850 }
851
852protected:
853 /// Create a new result list with the expected number of results.
854 PDLResultList(unsigned maxNumResults) {
855 // For now just reserve enough space for all of the results. We could do
856 // separate counts per range type, but it isn't really worth it unless there
857 // are a "large" number of results.
858 typeRanges.reserve(maxNumResults);
859 valueRanges.reserve(maxNumResults);
860 }
861
862 /// The PDL results held by this list.
863 SmallVector<PDLValue> results;
864 /// Memory used to store ranges held by the list.
865 SmallVector<TypeRange> typeRanges;
866 SmallVector<ValueRange> valueRanges;
867 /// Memory allocated to store ranges in the result list whose lifetime was
868 /// generated in the native function.
869 SmallVector<llvm::OwningArrayRef<Type>> allocatedTypeRanges;
870 SmallVector<llvm::OwningArrayRef<Value>> allocatedValueRanges;
871};
872
873//===----------------------------------------------------------------------===//
874// PDLPatternConfig
875
876/// An individual configuration for a pattern, which can be accessed by native
877/// functions via the PDLPatternConfigSet. This allows for injecting additional
878/// configuration into PDL patterns that is specific to certain compilation
879/// flows.
880class PDLPatternConfig {
881public:
882 virtual ~PDLPatternConfig() = default;
883
884 /// Hooks that are invoked at the beginning and end of a rewrite of a matched
885 /// pattern. These can be used to setup any specific state necessary for the
886 /// rewrite.
887 virtual void notifyRewriteBegin(PatternRewriter &rewriter) {}
888 virtual void notifyRewriteEnd(PatternRewriter &rewriter) {}
889
890 /// Return the TypeID that represents this configuration.
891 TypeID getTypeID() const { return id; }
892
893protected:
894 PDLPatternConfig(TypeID id) : id(id) {}
895
896private:
897 TypeID id;
898};
899
900/// This class provides a base class for users implementing a type of pattern
901/// configuration.
902template <typename T>
903class PDLPatternConfigBase : public PDLPatternConfig {
904public:
905 /// Support LLVM style casting.
906 static bool classof(const PDLPatternConfig *config) {
907 return config->getTypeID() == getConfigID();
908 }
909
910 /// Return the type id used for this configuration.
911 static TypeID getConfigID() { return TypeID::get<T>(); }
912
913protected:
914 PDLPatternConfigBase() : PDLPatternConfig(getConfigID()) {}
915};
916
917/// This class contains a set of configurations for a specific pattern.
918/// Configurations are uniqued by TypeID, meaning that only one configuration of
919/// each type is allowed.
920class PDLPatternConfigSet {
921public:
922 PDLPatternConfigSet() = default;
923
924 /// Construct a set with the given configurations.
925 template <typename... ConfigsT>
926 PDLPatternConfigSet(ConfigsT &&...configs) {
927 (addConfig(std::forward<ConfigsT>(configs)), ...);
928 }
929
930 /// Get the configuration defined by the given type. Asserts that the
931 /// configuration of the provided type exists.
932 template <typename T>
933 const T &get() const {
934 const T *config = tryGet<T>();
935 assert(config && "configuration not found")(static_cast <bool> (config && "configuration not found"
) ? void (0) : __assert_fail ("config && \"configuration not found\""
, "llvm/../mlir/include/mlir/IR/PatternMatch.h", 935, __extension__
__PRETTY_FUNCTION__))
;
936 return *config;
937 }
938
939 /// Get the configuration defined by the given type, returns nullptr if the
940 /// configuration does not exist.
941 template <typename T>
942 const T *tryGet() const {
943 for (const auto &configIt : configs)
944 if (const T *config = dyn_cast<T>(configIt.get()))
945 return config;
946 return nullptr;
947 }
948
949 /// Notify the configurations within this set at the beginning or end of a
950 /// rewrite of a matched pattern.
951 void notifyRewriteBegin(PatternRewriter &rewriter) {
952 for (const auto &config : configs)
953 config->notifyRewriteBegin(rewriter);
954 }
955 void notifyRewriteEnd(PatternRewriter &rewriter) {
956 for (const auto &config : configs)
957 config->notifyRewriteEnd(rewriter);
958 }
959
960protected:
961 /// Add a configuration to the set.
962 template <typename T>
963 void addConfig(T &&config) {
964 assert(!tryGet<std::decay_t<T>>() && "configuration already exists")(static_cast <bool> (!tryGet<std::decay_t<T>>
() && "configuration already exists") ? void (0) : __assert_fail
("!tryGet<std::decay_t<T>>() && \"configuration already exists\""
, "llvm/../mlir/include/mlir/IR/PatternMatch.h", 964, __extension__
__PRETTY_FUNCTION__))
;
965 configs.emplace_back(
966 std::make_unique<std::decay_t<T>>(std::forward<T>(config)));
967 }
968
969 /// The set of configurations for this pattern. This uses a vector instead of
970 /// a map with the expectation that the number of configurations per set is
971 /// small (<= 1).
972 SmallVector<std::unique_ptr<PDLPatternConfig>> configs;
973};
974
975//===----------------------------------------------------------------------===//
976// PDLPatternModule
977
978/// A generic PDL pattern constraint function. This function applies a
979/// constraint to a given set of opaque PDLValue entities. Returns success if
980/// the constraint successfully held, failure otherwise.
981using PDLConstraintFunction =
982 std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
983/// A native PDL rewrite function. This function performs a rewrite on the
984/// given set of values. Any results from this rewrite that should be passed
985/// back to PDL should be added to the provided result list. This method is only
986/// invoked when the corresponding match was successful. Returns failure if an
987/// invariant of the rewrite was broken (certain rewriters may recover from
988/// partial pattern application).
989using PDLRewriteFunction = std::function<LogicalResult(
990 PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
991
992namespace detail {
993namespace pdl_function_builder {
994/// A utility variable that always resolves to false. This is useful for static
995/// asserts that are always false, but only should fire in certain templated
996/// constructs. For example, if a templated function should never be called, the
997/// function could be defined as:
998///
999/// template <typename T>
1000/// void foo() {
1001/// static_assert(always_false<T>, "This function should never be called");
1002/// }
1003///
1004template <class... T>
1005constexpr bool always_false = false;
1006
1007//===----------------------------------------------------------------------===//
1008// PDL Function Builder: Type Processing
1009//===----------------------------------------------------------------------===//
1010
1011/// This struct provides a convenient way to determine how to process a given
1012/// type as either a PDL parameter, or a result value. This allows for
1013/// supporting complex types in constraint and rewrite functions, without
1014/// requiring the user to hand-write the necessary glue code themselves.
1015/// Specializations of this class should implement the following methods to
1016/// enable support as a PDL argument or result type:
1017///
1018/// static LogicalResult verifyAsArg(
1019/// function_ref<LogicalResult(const Twine &)> errorFn, PDLValue pdlValue,
1020/// size_t argIdx);
1021///
1022/// * This method verifies that the given PDLValue is valid for use as a
1023/// value of `T`.
1024///
1025/// static T processAsArg(PDLValue pdlValue);
1026///
1027/// * This method processes the given PDLValue as a value of `T`.
1028///
1029/// static void processAsResult(PatternRewriter &, PDLResultList &results,
1030/// const T &value);
1031///
1032/// * This method processes the given value of `T` as the result of a
1033/// function invocation. The method should package the value into an
1034/// appropriate form and append it to the given result list.
1035///
1036/// If the type `T` is based on a higher order value, consider using
1037/// `ProcessPDLValueBasedOn` as a base class of the specialization to simplify
1038/// the implementation.
1039///
1040template <typename T, typename Enable = void>
1041struct ProcessPDLValue;
1042
1043/// This struct provides a simplified model for processing types that are based
1044/// on another type, e.g. APInt is based on the handling for IntegerAttr. This
1045/// allows for building the necessary processing functions on top of the base
1046/// value instead of a PDLValue. Derived users should implement the following
1047/// (which subsume the ProcessPDLValue variants):
1048///
1049/// static LogicalResult verifyAsArg(
1050/// function_ref<LogicalResult(const Twine &)> errorFn,
1051/// const BaseT &baseValue, size_t argIdx);
1052///
1053/// * This method verifies that the given PDLValue is valid for use as a
1054/// value of `T`.
1055///
1056/// static T processAsArg(BaseT baseValue);
1057///
1058/// * This method processes the given base value as a value of `T`.
1059///
1060template <typename T, typename BaseT>
1061struct ProcessPDLValueBasedOn {
1062 static LogicalResult
1063 verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
1064 PDLValue pdlValue, size_t argIdx) {
1065 // Verify the base class before continuing.
1066 if (failed(ProcessPDLValue<BaseT>::verifyAsArg(errorFn, pdlValue, argIdx)))
1067 return failure();
1068 return ProcessPDLValue<T>::verifyAsArg(
1069 errorFn, ProcessPDLValue<BaseT>::processAsArg(pdlValue), argIdx);
1070 }
1071 static T processAsArg(PDLValue pdlValue) {
1072 return ProcessPDLValue<T>::processAsArg(
1073 ProcessPDLValue<BaseT>::processAsArg(pdlValue));
1074 }
1075
1076 /// Explicitly add the expected parent API to ensure the parent class
1077 /// implements the necessary API (and doesn't implicitly inherit it from
1078 /// somewhere else).
1079 static LogicalResult
1080 verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn, BaseT value,
1081 size_t argIdx) {
1082 return success();
1083 }
1084 static T processAsArg(BaseT baseValue);
1085};
1086
1087/// This struct provides a simplified model for processing types that have
1088/// "builtin" PDLValue support:
1089/// * Attribute, Operation *, Type, TypeRange, ValueRange
1090template <typename T>
1091struct ProcessBuiltinPDLValue {
1092 static LogicalResult
1093 verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
1094 PDLValue pdlValue, size_t argIdx) {
1095 if (pdlValue)
1096 return success();
1097 return errorFn("expected a non-null value for argument " + Twine(argIdx) +
1098 " of type: " + llvm::getTypeName<T>());
1099 }
1100
1101 static T processAsArg(PDLValue pdlValue) { return pdlValue.cast<T>(); }
1102 static void processAsResult(PatternRewriter &, PDLResultList &results,
1103 T value) {
1104 results.push_back(value);
1105 }
1106};
1107
1108/// This struct provides a simplified model for processing types that inherit
1109/// from builtin PDLValue types. For example, derived attributes like
1110/// IntegerAttr, derived types like IntegerType, derived operations like
1111/// ModuleOp, Interfaces, etc.
1112template <typename T, typename BaseT>
1113struct ProcessDerivedPDLValue : public ProcessPDLValueBasedOn<T, BaseT> {
1114 static LogicalResult
1115 verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
1116 BaseT baseValue, size_t argIdx) {
1117 return TypeSwitch<BaseT, LogicalResult>(baseValue)
1118 .Case([&](T) { return success(); })
1119 .Default([&](BaseT) {
1120 return errorFn("expected argument " + Twine(argIdx) +
1121 " to be of type: " + llvm::getTypeName<T>());
1122 });
1123 }
1124 using ProcessPDLValueBasedOn<T, BaseT>::verifyAsArg;
1125
1126 static T processAsArg(BaseT baseValue) {
1127 return baseValue.template cast<T>();
1128 }
1129 using ProcessPDLValueBasedOn<T, BaseT>::processAsArg;
1130
1131 static void processAsResult(PatternRewriter &, PDLResultList &results,
1132 T value) {
1133 results.push_back(value);
1134 }
1135};
1136
1137//===----------------------------------------------------------------------===//
1138// Attribute
1139
1140template <>
1141struct ProcessPDLValue<Attribute> : public ProcessBuiltinPDLValue<Attribute> {};
1142template <typename T>
1143struct ProcessPDLValue<T,
1144 std::enable_if_t<std::is_base_of<Attribute, T>::value>>
1145 : public ProcessDerivedPDLValue<T, Attribute> {};
1146
1147/// Handling for various Attribute value types.
1148template <>
1149struct ProcessPDLValue<StringRef>
1150 : public ProcessPDLValueBasedOn<StringRef, StringAttr> {
1151 static StringRef processAsArg(StringAttr value) { return value.getValue(); }
1152 using ProcessPDLValueBasedOn<StringRef, StringAttr>::processAsArg;
1153
1154 static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
1155 StringRef value) {
1156 results.push_back(rewriter.getStringAttr(value));
1157 }
1158};
1159template <>
1160struct ProcessPDLValue<std::string>
1161 : public ProcessPDLValueBasedOn<std::string, StringAttr> {
1162 template <typename T>
1163 static std::string processAsArg(T value) {
1164 static_assert(always_false<T>,
1165 "`std::string` arguments require a string copy, use "
1166 "`StringRef` for string-like arguments instead");
1167 return {};
1168 }
1169 static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
1170 StringRef value) {
1171 results.push_back(rewriter.getStringAttr(value));
1172 }
1173};
1174
1175//===----------------------------------------------------------------------===//
1176// Operation
1177
1178template <>
1179struct ProcessPDLValue<Operation *>
1180 : public ProcessBuiltinPDLValue<Operation *> {};
1181template <typename T>
1182struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<OpState, T>::value>>
1183 : public ProcessDerivedPDLValue<T, Operation *> {
1184 static T processAsArg(Operation *value) { return cast<T>(value); }
1185};
1186
1187//===----------------------------------------------------------------------===//
1188// Type
1189
1190template <>
1191struct ProcessPDLValue<Type> : public ProcessBuiltinPDLValue<Type> {};
1192template <typename T>
1193struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<Type, T>::value>>
1194 : public ProcessDerivedPDLValue<T, Type> {};
1195
1196//===----------------------------------------------------------------------===//
1197// TypeRange
1198
1199template <>
1200struct ProcessPDLValue<TypeRange> : public ProcessBuiltinPDLValue<TypeRange> {};
1201template <>
1202struct ProcessPDLValue<ValueTypeRange<OperandRange>> {
1203 static void processAsResult(PatternRewriter &, PDLResultList &results,
1204 ValueTypeRange<OperandRange> types) {
1205 results.push_back(types);
1206 }
1207};
1208template <>
1209struct ProcessPDLValue<ValueTypeRange<ResultRange>> {
1210 static void processAsResult(PatternRewriter &, PDLResultList &results,
1211 ValueTypeRange<ResultRange> types) {
1212 results.push_back(types);
1213 }
1214};
1215template <unsigned N>
1216struct ProcessPDLValue<SmallVector<Type, N>> {
1217 static void processAsResult(PatternRewriter &, PDLResultList &results,
1218 SmallVector<Type, N> values) {
1219 results.push_back(TypeRange(values));
1220 }
1221};
1222
1223//===----------------------------------------------------------------------===//
1224// Value
1225
1226template <>
1227struct ProcessPDLValue<Value> : public ProcessBuiltinPDLValue<Value> {};
1228
1229//===----------------------------------------------------------------------===//
1230// ValueRange
1231
1232template <>
1233struct ProcessPDLValue<ValueRange> : public ProcessBuiltinPDLValue<ValueRange> {
1234};
1235template <>
1236struct ProcessPDLValue<OperandRange> {
1237 static void processAsResult(PatternRewriter &, PDLResultList &results,
1238 OperandRange values) {
1239 results.push_back(values);
1240 }
1241};
1242template <>
1243struct ProcessPDLValue<ResultRange> {
1244 static void processAsResult(PatternRewriter &, PDLResultList &results,
1245 ResultRange values) {
1246 results.push_back(values);
1247 }
1248};
1249template <unsigned N>
1250struct ProcessPDLValue<SmallVector<Value, N>> {
1251 static void processAsResult(PatternRewriter &, PDLResultList &results,
1252 SmallVector<Value, N> values) {
1253 results.push_back(ValueRange(values));
1254 }
1255};
1256
1257//===----------------------------------------------------------------------===//
1258// PDL Function Builder: Argument Handling
1259//===----------------------------------------------------------------------===//
1260
1261/// Validate the given PDLValues match the constraints defined by the argument
1262/// types of the given function. In the case of failure, a match failure
1263/// diagnostic is emitted.
1264/// FIXME: This should be completely removed in favor of `assertArgs`, but PDL
1265/// does not currently preserve Constraint application ordering.
1266template <typename PDLFnT, std::size_t... I>
1267LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
1268 std::index_sequence<I...>) {
1269 using FnTraitsT = llvm::function_traits<PDLFnT>;
1270
1271 auto errorFn = [&](const Twine &msg) {
1272 return rewriter.notifyMatchFailure(rewriter.getUnknownLoc(), msg);
1273 };
1274 return success(
1275 (succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
1276 verifyAsArg(errorFn, values[I], I)) &&
1277 ...));
1278}
1279
1280/// Assert that the given PDLValues match the constraints defined by the
1281/// arguments of the given function. In the case of failure, a fatal error
1282/// is emitted.
1283template <typename PDLFnT, std::size_t... I>
1284void assertArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
1285 std::index_sequence<I...>) {
1286 // We only want to do verification in debug builds, same as with `assert`.
1287#if LLVM_ENABLE_ABI_BREAKING_CHECKS1
1288 using FnTraitsT = llvm::function_traits<PDLFnT>;
1289 auto errorFn = [&](const Twine &msg) -> LogicalResult {
1290 llvm::report_fatal_error(msg);
1291 };
1292 (void)errorFn;
1293 assert((succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::(static_cast <bool> ((succeeded(ProcessPDLValue<typename
FnTraitsT::template arg_t<I + 1>>:: verifyAsArg(errorFn
, values[I], I)) && ...)) ? void (0) : __assert_fail (
"(succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>:: verifyAsArg(errorFn, values[I], I)) && ...)"
, "llvm/../mlir/include/mlir/IR/PatternMatch.h", 1295, __extension__
__PRETTY_FUNCTION__))
1294 verifyAsArg(errorFn, values[I], I)) &&(static_cast <bool> ((succeeded(ProcessPDLValue<typename
FnTraitsT::template arg_t<I + 1>>:: verifyAsArg(errorFn
, values[I], I)) && ...)) ? void (0) : __assert_fail (
"(succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>:: verifyAsArg(errorFn, values[I], I)) && ...)"
, "llvm/../mlir/include/mlir/IR/PatternMatch.h", 1295, __extension__
__PRETTY_FUNCTION__))
1295 ...))(static_cast <bool> ((succeeded(ProcessPDLValue<typename
FnTraitsT::template arg_t<I + 1>>:: verifyAsArg(errorFn
, values[I], I)) && ...)) ? void (0) : __assert_fail (
"(succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>:: verifyAsArg(errorFn, values[I], I)) && ...)"
, "llvm/../mlir/include/mlir/IR/PatternMatch.h", 1295, __extension__
__PRETTY_FUNCTION__))
;
1296#endif
1297 (void)values;
1298}
1299
1300//===----------------------------------------------------------------------===//
1301// PDL Function Builder: Results Handling
1302//===----------------------------------------------------------------------===//
1303
1304/// Store a single result within the result list.
1305template <typename T>
1306static LogicalResult processResults(PatternRewriter &rewriter,
1307 PDLResultList &results, T &&value) {
1308 ProcessPDLValue<T>::processAsResult(rewriter, results,
1309 std::forward<T>(value));
1310 return success();
1311}
1312
1313/// Store a std::pair<> as individual results within the result list.
1314template <typename T1, typename T2>
1315static LogicalResult processResults(PatternRewriter &rewriter,
1316 PDLResultList &results,
1317 std::pair<T1, T2> &&pair) {
1318 if (failed(processResults(rewriter, results, std::move(pair.first))) ||
1319 failed(processResults(rewriter, results, std::move(pair.second))))
1320 return failure();
1321 return success();
1322}
1323
1324/// Store a std::tuple<> as individual results within the result list.
1325template <typename... Ts>
1326static LogicalResult processResults(PatternRewriter &rewriter,
1327 PDLResultList &results,
1328 std::tuple<Ts...> &&tuple) {
1329 auto applyFn = [&](auto &&...args) {
1330 return (succeeded(processResults(rewriter, results, std::move(args))) &&
1331 ...);
1332 };
1333 return success(std::apply(applyFn, std::move(tuple)));
1334}
1335
1336/// Handle LogicalResult propagation.
1337inline LogicalResult processResults(PatternRewriter &rewriter,
1338 PDLResultList &results,
1339 LogicalResult &&result) {
1340 return result;
1341}
1342template <typename T>
1343static LogicalResult processResults(PatternRewriter &rewriter,
1344 PDLResultList &results,
1345 FailureOr<T> &&result) {
1346 if (failed(result))
1347 return failure();
1348 return processResults(rewriter, results, std::move(*result));
1349}
1350
1351//===----------------------------------------------------------------------===//
1352// PDL Constraint Builder
1353//===----------------------------------------------------------------------===//
1354
1355/// Process the arguments of a native constraint and invoke it.
1356template <typename PDLFnT, std::size_t... I,
1357 typename FnTraitsT = llvm::function_traits<PDLFnT>>
1358typename FnTraitsT::result_t
1359processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter,
1360 ArrayRef<PDLValue> values,
1361 std::index_sequence<I...>) {
1362 return fn(
1363 rewriter,
1364 (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
1365 values[I]))...);
1366}
1367
1368/// Build a constraint function from the given function `ConstraintFnT`. This
1369/// allows for enabling the user to define simpler, more direct constraint
1370/// functions without needing to handle the low-level PDL goop.
1371///
1372/// If the constraint function is already in the correct form, we just forward
1373/// it directly.
1374template <typename ConstraintFnT>
1375std::enable_if_t<
1376 std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
1377 PDLConstraintFunction>
1378buildConstraintFn(ConstraintFnT &&constraintFn) {
1379 return std::forward<ConstraintFnT>(constraintFn);
1380}
1381/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
1382/// we desire.
1383template <typename ConstraintFnT>
1384std::enable_if_t<
1385 !std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
1386 PDLConstraintFunction>
1387buildConstraintFn(ConstraintFnT &&constraintFn) {
1388 return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
1389 PatternRewriter &rewriter,
1390 ArrayRef<PDLValue> values) -> LogicalResult {
1391 auto argIndices = std::make_index_sequence<
1392 llvm::function_traits<ConstraintFnT>::num_args - 1>();
1393 if (failed(verifyAsArgs<ConstraintFnT>(rewriter, values, argIndices)))
1394 return failure();
1395 return processArgsAndInvokeConstraint(constraintFn, rewriter, values,
1396 argIndices);
1397 };
1398}
1399
1400//===----------------------------------------------------------------------===//
1401// PDL Rewrite Builder
1402//===----------------------------------------------------------------------===//
1403
1404/// Process the arguments of a native rewrite and invoke it.
1405/// This overload handles the case of no return values.
1406template <typename PDLFnT, std::size_t... I,
1407 typename FnTraitsT = llvm::function_traits<PDLFnT>>
1408std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value,
1409 LogicalResult>
1410processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
1411 PDLResultList &, ArrayRef<PDLValue> values,
1412 std::index_sequence<I...>) {
1413 fn(rewriter,
1414 (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
1415 values[I]))...);
1416 return success();
1417}
1418/// This overload handles the case of return values, which need to be packaged
1419/// into the result list.
1420template <typename PDLFnT, std::size_t... I,
1421 typename FnTraitsT = llvm::function_traits<PDLFnT>>
1422std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value,
1423 LogicalResult>
1424processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
1425 PDLResultList &results, ArrayRef<PDLValue> values,
1426 std::index_sequence<I...>) {
1427 return processResults(
1428 rewriter, results,
1429 fn(rewriter, (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
1430 processAsArg(values[I]))...));
1431 (void)values;
1432}
1433
1434/// Build a rewrite function from the given function `RewriteFnT`. This
1435/// allows for enabling the user to define simpler, more direct rewrite
1436/// functions without needing to handle the low-level PDL goop.
1437///
1438/// If the rewrite function is already in the correct form, we just forward
1439/// it directly.
1440template <typename RewriteFnT>
1441std::enable_if_t<std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
1442 PDLRewriteFunction>
1443buildRewriteFn(RewriteFnT &&rewriteFn) {
1444 return std::forward<RewriteFnT>(rewriteFn);
1445}
1446/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
1447/// we desire.
1448template <typename RewriteFnT>
1449std::enable_if_t<!std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
1450 PDLRewriteFunction>
1451buildRewriteFn(RewriteFnT &&rewriteFn) {
1452 return [rewriteFn = std::forward<RewriteFnT>(rewriteFn)](
1453 PatternRewriter &rewriter, PDLResultList &results,
1454 ArrayRef<PDLValue> values) {
1455 auto argIndices =
1456 std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args -
1457 1>();
1458 assertArgs<RewriteFnT>(rewriter, values, argIndices);
1459 return processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values,
1460 argIndices);
1461 };
1462}
1463
1464} // namespace pdl_function_builder
1465} // namespace detail
1466
1467//===----------------------------------------------------------------------===//
1468// PDLPatternModule
1469
1470/// This class contains all of the necessary data for a set of PDL patterns, or
1471/// pattern rewrites specified in the form of the PDL dialect. This PDL module
1472/// contained by this pattern may contain any number of `pdl.pattern`
1473/// operations.
1474class PDLPatternModule {
1475public:
1476 PDLPatternModule() = default;
1477
1478 /// Construct a PDL pattern with the given module and configurations.
1479 PDLPatternModule(OwningOpRef<ModuleOp> module)
1480 : pdlModule(std::move(module)) {}
1481 template <typename... ConfigsT>
1482 PDLPatternModule(OwningOpRef<ModuleOp> module, ConfigsT &&...patternConfigs)
1483 : PDLPatternModule(std::move(module)) {
1484 auto configSet = std::make_unique<PDLPatternConfigSet>(
1485 std::forward<ConfigsT>(patternConfigs)...);
1486 attachConfigToPatterns(*pdlModule, *configSet);
1487 configs.emplace_back(std::move(configSet));
1488 }
1489
1490 /// Merge the state in `other` into this pattern module.
1491 void mergeIn(PDLPatternModule &&other);
1492
1493 /// Return the internal PDL module of this pattern.
1494 ModuleOp getModule() { return pdlModule.get(); }
1495
1496 //===--------------------------------------------------------------------===//
1497 // Function Registry
1498
1499 /// Register a constraint function with PDL. A constraint function may be
1500 /// specified in one of two ways:
1501 ///
1502 /// * `LogicalResult (PatternRewriter &, ArrayRef<PDLValue>)`
1503 ///
1504 /// In this overload the arguments of the constraint function are passed via
1505 /// the low-level PDLValue form.
1506 ///
1507 /// * `LogicalResult (PatternRewriter &, ValueTs... values)`
1508 ///
1509 /// In this form the arguments of the constraint function are passed via the
1510 /// expected high level C++ type. In this form, the framework will
1511 /// automatically unwrap PDLValues and convert them to the expected ValueTs.
1512 /// For example, if the constraint function accepts a `Operation *`, the
1513 /// framework will automatically cast the input PDLValue. In the case of a
1514 /// `StringRef`, the framework will automatically unwrap the argument as a
1515 /// StringAttr and pass the underlying string value. To see the full list of
1516 /// supported types, or to see how to add handling for custom types, view
1517 /// the definition of `ProcessPDLValue` above.
1518 void registerConstraintFunction(StringRef name,
1519 PDLConstraintFunction constraintFn);
1520 template <typename ConstraintFnT>
1521 void registerConstraintFunction(StringRef name,
1522 ConstraintFnT &&constraintFn) {
1523 registerConstraintFunction(name,
1524 detail::pdl_function_builder::buildConstraintFn(
1525 std::forward<ConstraintFnT>(constraintFn)));
1526 }
1527
1528 /// Register a rewrite function with PDL. A rewrite function may be specified
1529 /// in one of two ways:
1530 ///
1531 /// * `void (PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)`
1532 ///
1533 /// In this overload the arguments of the constraint function are passed via
1534 /// the low-level PDLValue form, and the results are manually appended to
1535 /// the given result list.
1536 ///
1537 /// * `ResultT (PatternRewriter &, ValueTs... values)`
1538 ///
1539 /// In this form the arguments and result of the rewrite function are passed
1540 /// via the expected high level C++ type. In this form, the framework will
1541 /// automatically unwrap the PDLValues arguments and convert them to the
1542 /// expected ValueTs. It will also automatically handle the processing and
1543 /// packaging of the result value to the result list. For example, if the
1544 /// rewrite function takes a `Operation *`, the framework will automatically
1545 /// cast the input PDLValue. In the case of a `StringRef`, the framework
1546 /// will automatically unwrap the argument as a StringAttr and pass the
1547 /// underlying string value. In the reverse case, if the rewrite returns a
1548 /// StringRef or std::string, it will automatically package this as a
1549 /// StringAttr and append it to the result list. To see the full list of
1550 /// supported types, or to see how to add handling for custom types, view
1551 /// the definition of `ProcessPDLValue` above.
1552 void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn);
1553 template <typename RewriteFnT>
1554 void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {
1555 registerRewriteFunction(name, detail::pdl_function_builder::buildRewriteFn(
1556 std::forward<RewriteFnT>(rewriteFn)));
1557 }
1558
1559 /// Return the set of the registered constraint functions.
1560 const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const {
1561 return constraintFunctions;
1562 }
1563 llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() {
1564 return constraintFunctions;
1565 }
1566 /// Return the set of the registered rewrite functions.
1567 const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions() const {
1568 return rewriteFunctions;
1569 }
1570 llvm::StringMap<PDLRewriteFunction> takeRewriteFunctions() {
1571 return rewriteFunctions;
1572 }
1573
1574 /// Return the set of the registered pattern configs.
1575 SmallVector<std::unique_ptr<PDLPatternConfigSet>> takeConfigs() {
1576 return std::move(configs);
1577 }
1578 DenseMap<Operation *, PDLPatternConfigSet *> takeConfigMap() {
1579 return std::move(configMap);
1580 }
1581
1582 /// Clear out the patterns and functions within this module.
1583 void clear() {
1584 pdlModule = nullptr;
1585 constraintFunctions.clear();
1586 rewriteFunctions.clear();
1587 }
1588
1589private:
1590 /// Attach the given pattern config set to the patterns defined within the
1591 /// given module.
1592 void attachConfigToPatterns(ModuleOp module, PDLPatternConfigSet &configSet);
1593
1594 /// The module containing the `pdl.pattern` operations.
1595 OwningOpRef<ModuleOp> pdlModule;
1596
1597 /// The set of configuration sets referenced by patterns within `pdlModule`.
1598 SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs;
1599 DenseMap<Operation *, PDLPatternConfigSet *> configMap;
1600
1601 /// The external functions referenced from within the PDL module.
1602 llvm::StringMap<PDLConstraintFunction> constraintFunctions;
1603 llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
1604};
1605
1606//===----------------------------------------------------------------------===//
1607// RewritePatternSet
1608//===----------------------------------------------------------------------===//
1609
1610class RewritePatternSet {
1611 using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
1612
1613public:
1614 RewritePatternSet(MLIRContext *context) : context(context) {}
1615
1616 /// Construct a RewritePatternSet populated with the given pattern.
1617 RewritePatternSet(MLIRContext *context,
1618 std::unique_ptr<RewritePattern> pattern)
1619 : context(context) {
1620 nativePatterns.emplace_back(std::move(pattern));
1621 }
1622 RewritePatternSet(PDLPatternModule &&pattern)
1623 : context(pattern.getModule()->getContext()),
1624 pdlPatterns(std::move(pattern)) {}
1625
1626 MLIRContext *getContext() const { return context; }
1627
1628 /// Return the native patterns held in this list.
1629 NativePatternListT &getNativePatterns() { return nativePatterns; }
1630
1631 /// Return the PDL patterns held in this list.
1632 PDLPatternModule &getPDLPatterns() { return pdlPatterns; }
1633
1634 /// Clear out all of the held patterns in this list.
1635 void clear() {
1636 nativePatterns.clear();
1637 pdlPatterns.clear();
1638 }
1639
1640 //===--------------------------------------------------------------------===//
1641 // 'add' methods for adding patterns to the set.
1642 //===--------------------------------------------------------------------===//
1643
1644 /// Add an instance of each of the pattern types 'Ts' to the pattern list with
1645 /// the given arguments. Return a reference to `this` for chaining insertions.
1646 /// Note: ConstructorArg is necessary here to separate the two variadic lists.
1647 template <typename... Ts, typename ConstructorArg,
1648 typename... ConstructorArgs,
1649 typename = std::enable_if_t<sizeof...(Ts) != 0>>
1650 RewritePatternSet &add(ConstructorArg &&arg, ConstructorArgs &&...args) {
1651 // The following expands a call to emplace_back for each of the pattern
1652 // types 'Ts'.
1653 (addImpl<Ts>(/*debugLabels=*/std::nullopt,
1654 std::forward<ConstructorArg>(arg),
1655 std::forward<ConstructorArgs>(args)...),
1656 ...);
1657 return *this;
1658 }
1659 /// An overload of the above `add` method that allows for attaching a set
1660 /// of debug labels to the attached patterns. This is useful for labeling
1661 /// groups of patterns that may be shared between multiple different
1662 /// passes/users.
1663 template <typename... Ts, typename ConstructorArg,
1664 typename... ConstructorArgs,
1665 typename = std::enable_if_t<sizeof...(Ts) != 0>>
1666 RewritePatternSet &addWithLabel(ArrayRef<StringRef> debugLabels,
1667 ConstructorArg &&arg,
1668 ConstructorArgs &&...args) {
1669 // The following expands a call to emplace_back for each of the pattern
1670 // types 'Ts'.
1671 (addImpl<Ts>(debugLabels, arg, args...), ...);
1672 return *this;
1673 }
1674
1675 /// Add an instance of each of the pattern types 'Ts'. Return a reference to
1676 /// `this` for chaining insertions.
1677 template <typename... Ts>
1678 RewritePatternSet &add() {
1679 (addImpl<Ts>(), ...);
1680 return *this;
1681 }
1682
1683 /// Add the given native pattern to the pattern list. Return a reference to
1684 /// `this` for chaining insertions.
1685 RewritePatternSet &add(std::unique_ptr<RewritePattern> pattern) {
1686 nativePatterns.emplace_back(std::move(pattern));
1687 return *this;
1688 }
1689
1690 /// Add the given PDL pattern to the pattern list. Return a reference to
1691 /// `this` for chaining insertions.
1692 RewritePatternSet &add(PDLPatternModule &&pattern) {
1693 pdlPatterns.mergeIn(std::move(pattern));
1694 return *this;
1695 }
1696
1697 // Add a matchAndRewrite style pattern represented as a C function pointer.
1698 template <typename OpType>
1699 RewritePatternSet &
1700 add(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
1701 PatternBenefit benefit = 1, ArrayRef<StringRef> generatedNames = {}) {
1702 struct FnPattern final : public OpRewritePattern<OpType> {
1703 FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
1704 MLIRContext *context, PatternBenefit benefit,
1705 ArrayRef<StringRef> generatedNames)
1706 : OpRewritePattern<OpType>(context, benefit, generatedNames),
1707 implFn(implFn) {}
1708
1709 LogicalResult matchAndRewrite(OpType op,
1710 PatternRewriter &rewriter) const override {
1711 return implFn(op, rewriter);
1712 }
1713
1714 private:
1715 LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
1716 };
1717 add(std::make_unique<FnPattern>(std::move(implFn), getContext(), benefit,
1718 generatedNames));
1719 return *this;
1720 }
1721
1722 //===--------------------------------------------------------------------===//
1723 // Pattern Insertion
1724 //===--------------------------------------------------------------------===//
1725
1726 // TODO: These are soft deprecated in favor of the 'add' methods above.
1727
1728 /// Add an instance of each of the pattern types 'Ts' to the pattern list with
1729 /// the given arguments. Return a reference to `this` for chaining insertions.
1730 /// Note: ConstructorArg is necessary here to separate the two variadic lists.
1731 template <typename... Ts, typename ConstructorArg,
1732 typename... ConstructorArgs,
1733 typename = std::enable_if_t<sizeof...(Ts) != 0>>
1734 RewritePatternSet &insert(ConstructorArg &&arg, ConstructorArgs &&...args) {
1735 // The following expands a call to emplace_back for each of the pattern
1736 // types 'Ts'.
1737 (addImpl<Ts>(/*debugLabels=*/std::nullopt, arg, args...), ...);
1738 return *this;
1739 }
1740
1741 /// Add an instance of each of the pattern types 'Ts'. Return a reference to
1742 /// `this` for chaining insertions.
1743 template <typename... Ts>
1744 RewritePatternSet &insert() {
1745 (addImpl<Ts>(), ...);
1746 return *this;
1747 }
1748
1749 /// Add the given native pattern to the pattern list. Return a reference to
1750 /// `this` for chaining insertions.
1751 RewritePatternSet &insert(std::unique_ptr<RewritePattern> pattern) {
1752 nativePatterns.emplace_back(std::move(pattern));
1753 return *this;
1754 }
1755
1756 /// Add the given PDL pattern to the pattern list. Return a reference to
1757 /// `this` for chaining insertions.
1758 RewritePatternSet &insert(PDLPatternModule &&pattern) {
1759 pdlPatterns.mergeIn(std::move(pattern));
1760 return *this;
1761 }
1762
1763 // Add a matchAndRewrite style pattern represented as a C function pointer.
1764 template <typename OpType>
1765 RewritePatternSet &
1766 insert(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter)) {
1767 struct FnPattern final : public OpRewritePattern<OpType> {
1768 FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
1769 MLIRContext *context)
1770 : OpRewritePattern<OpType>(context), implFn(implFn) {
1771 this->setDebugName(llvm::getTypeName<FnPattern>());
1772 }
1773
1774 LogicalResult matchAndRewrite(OpType op,
1775 PatternRewriter &rewriter) const override {
1776 return implFn(op, rewriter);
1777 }
1778
1779 private:
1780 LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
1781 };
1782 add(std::make_unique<FnPattern>(std::move(implFn), getContext()));
1783 return *this;
1784 }
1785
1786private:
1787 /// Add an instance of the pattern type 'T'. Return a reference to `this` for
1788 /// chaining insertions.
1789 template <typename T, typename... Args>
1790 std::enable_if_t<std::is_base_of<RewritePattern, T>::value>
1791 addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
1792 std::unique_ptr<T> pattern =
1793 RewritePattern::create<T>(std::forward<Args>(args)...);
1794 pattern->addDebugLabels(debugLabels);
1795 nativePatterns.emplace_back(std::move(pattern));
1796 }
1797 template <typename T, typename... Args>
1798 std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value>
1799 addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
1800 // TODO: Add the provided labels to the PDL pattern when PDL supports
1801 // labels.
1802 pdlPatterns.mergeIn(T(std::forward<Args>(args)...));
1803 }
1804
1805 MLIRContext *const context;
1806 NativePatternListT nativePatterns;
1807 PDLPatternModule pdlPatterns;
1808};
1809
1810} // namespace mlir
1811
1812#endif // MLIR_IR_PATTERNMATCH_H

/build/source/llvm/../mlir/include/mlir/IR/Builders.h

1//===- Builders.h - Helpers for constructing MLIR Classes -------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_IR_BUILDERS_H
10#define MLIR_IR_BUILDERS_H
11
12#include "mlir/IR/OpDefinition.h"
13#include "llvm/Support/Compiler.h"
14#include <optional>
15
16namespace mlir {
17
18class AffineExpr;
19class IRMapping;
20class UnknownLoc;
21class FileLineColLoc;
22class Type;
23class PrimitiveType;
24class IntegerType;
25class FloatType;
26class FunctionType;
27class IndexType;
28class MemRefType;
29class VectorType;
30class RankedTensorType;
31class UnrankedTensorType;
32class TupleType;
33class NoneType;
34class BoolAttr;
35class IntegerAttr;
36class FloatAttr;
37class StringAttr;
38class TypeAttr;
39class ArrayAttr;
40class SymbolRefAttr;
41class ElementsAttr;
42class DenseElementsAttr;
43class DenseIntElementsAttr;
44class AffineMapAttr;
45class AffineMap;
46class UnitAttr;
47
48/// This class is a general helper class for creating context-global objects
49/// like types, attributes, and affine expressions.
50class Builder {
51public:
52 explicit Builder(MLIRContext *context) : context(context) {}
53 explicit Builder(Operation *op) : Builder(op->getContext()) {}
54
55 MLIRContext *getContext() const { return context; }
56
57 // Locations.
58 Location getUnknownLoc();
59 Location getFusedLoc(ArrayRef<Location> locs,
60 Attribute metadata = Attribute());
61
62 // Types.
63 FloatType getFloat8E5M2Type();
64 FloatType getFloat8E4M3FNType();
65 FloatType getFloat8E5M2FNUZType();
66 FloatType getFloat8E4M3FNUZType();
67 FloatType getFloat8E4M3B11FNUZType();
68 FloatType getBF16Type();
69 FloatType getF16Type();
70 FloatType getF32Type();
71 FloatType getF64Type();
72 FloatType getF80Type();
73 FloatType getF128Type();
74
75 IndexType getIndexType();
76
77 IntegerType getI1Type();
78 IntegerType getI2Type();
79 IntegerType getI4Type();
80 IntegerType getI8Type();
81 IntegerType getI16Type();
82 IntegerType getI32Type();
83 IntegerType getI64Type();
84 IntegerType getIntegerType(unsigned width);
85 IntegerType getIntegerType(unsigned width, bool isSigned);
86 FunctionType getFunctionType(TypeRange inputs, TypeRange results);
87 TupleType getTupleType(TypeRange elementTypes);
88 NoneType getNoneType();
89
90 /// Get or construct an instance of the type `Ty` with provided arguments.
91 template <typename Ty, typename... Args>
92 Ty getType(Args &&...args) {
93 return Ty::get(context, std::forward<Args>(args)...);
94 }
95
96 /// Get or construct an instance of the attribute `Attr` with provided
97 /// arguments.
98 template <typename Attr, typename... Args>
99 Attr getAttr(Args &&...args) {
100 return Attr::get(context, std::forward<Args>(args)...);
101 }
102
103 // Attributes.
104 NamedAttribute getNamedAttr(StringRef name, Attribute val);
105
106 UnitAttr getUnitAttr();
107 BoolAttr getBoolAttr(bool value);
108 DictionaryAttr getDictionaryAttr(ArrayRef<NamedAttribute> value);
109 IntegerAttr getIntegerAttr(Type type, int64_t value);
110 IntegerAttr getIntegerAttr(Type type, const APInt &value);
111 FloatAttr getFloatAttr(Type type, double value);
112 FloatAttr getFloatAttr(Type type, const APFloat &value);
113 StringAttr getStringAttr(const Twine &bytes);
114 ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
115
116 // Returns a 0-valued attribute of the given `type`. This function only
117 // supports boolean, integer, and 16-/32-/64-bit float types, and vector or
118 // ranked tensor of them. Returns null attribute otherwise.
119 TypedAttr getZeroAttr(Type type);
120
121 // Convenience methods for fixed types.
122 FloatAttr getF16FloatAttr(float value);
123 FloatAttr getF32FloatAttr(float value);
124 FloatAttr getF64FloatAttr(double value);
125
126 IntegerAttr getI8IntegerAttr(int8_t value);
127 IntegerAttr getI16IntegerAttr(int16_t value);
128 IntegerAttr getI32IntegerAttr(int32_t value);
129 IntegerAttr getI64IntegerAttr(int64_t value);
130 IntegerAttr getIndexAttr(int64_t value);
131
132 /// Signed and unsigned integer attribute getters.
133 IntegerAttr getSI32IntegerAttr(int32_t value);
134 IntegerAttr getUI32IntegerAttr(uint32_t value);
135
136 /// Vector-typed DenseIntElementsAttr getters. `values` must not be empty.
137 DenseIntElementsAttr getBoolVectorAttr(ArrayRef<bool> values);
138 DenseIntElementsAttr getI32VectorAttr(ArrayRef<int32_t> values);
139 DenseIntElementsAttr getI64VectorAttr(ArrayRef<int64_t> values);
140 DenseIntElementsAttr getIndexVectorAttr(ArrayRef<int64_t> values);
141
142 /// Tensor-typed DenseIntElementsAttr getters. `values` can be empty.
143 /// These are generally preferable for representing general lists of integers
144 /// as attributes.
145 DenseIntElementsAttr getI32TensorAttr(ArrayRef<int32_t> values);
146 DenseIntElementsAttr getI64TensorAttr(ArrayRef<int64_t> values);
147 DenseIntElementsAttr getIndexTensorAttr(ArrayRef<int64_t> values);
148
149 /// Tensor-typed DenseArrayAttr getters.
150 DenseBoolArrayAttr getDenseBoolArrayAttr(ArrayRef<bool> values);
151 DenseI8ArrayAttr getDenseI8ArrayAttr(ArrayRef<int8_t> values);
152 DenseI16ArrayAttr getDenseI16ArrayAttr(ArrayRef<int16_t> values);
153 DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef<int32_t> values);
154 DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef<int64_t> values);
155 DenseF32ArrayAttr getDenseF32ArrayAttr(ArrayRef<float> values);
156 DenseF64ArrayAttr getDenseF64ArrayAttr(ArrayRef<double> values);
157
158 ArrayAttr getAffineMapArrayAttr(ArrayRef<AffineMap> values);
159 ArrayAttr getBoolArrayAttr(ArrayRef<bool> values);
160 ArrayAttr getI32ArrayAttr(ArrayRef<int32_t> values);
161 ArrayAttr getI64ArrayAttr(ArrayRef<int64_t> values);
162 ArrayAttr getIndexArrayAttr(ArrayRef<int64_t> values);
163 ArrayAttr getF32ArrayAttr(ArrayRef<float> values);
164 ArrayAttr getF64ArrayAttr(ArrayRef<double> values);
165 ArrayAttr getStrArrayAttr(ArrayRef<StringRef> values);
166 ArrayAttr getTypeArrayAttr(TypeRange values);
167
168 // Affine expressions and affine maps.
169 AffineExpr getAffineDimExpr(unsigned position);
170 AffineExpr getAffineSymbolExpr(unsigned position);
171 AffineExpr getAffineConstantExpr(int64_t constant);
172
173 // Special cases of affine maps and integer sets
174 /// Returns a zero result affine map with no dimensions or symbols: () -> ().
175 AffineMap getEmptyAffineMap();
176 /// Returns a single constant result affine map with 0 dimensions and 0
177 /// symbols. One constant result: () -> (val).
178 AffineMap getConstantAffineMap(int64_t val);
179 // One dimension id identity map: (i) -> (i).
180 AffineMap getDimIdentityMap();
181 // Multi-dimensional identity map: (d0, d1, d2) -> (d0, d1, d2).
182 AffineMap getMultiDimIdentityMap(unsigned rank);
183 // One symbol identity map: ()[s] -> (s).
184 AffineMap getSymbolIdentityMap();
185
186 /// Returns a map that shifts its (single) input dimension by 'shift'.
187 /// (d0) -> (d0 + shift)
188 AffineMap getSingleDimShiftAffineMap(int64_t shift);
189
190 /// Returns an affine map that is a translation (shift) of all result
191 /// expressions in 'map' by 'shift'.
192 /// Eg: input: (d0, d1)[s0] -> (d0, d1 + s0), shift = 2
193 /// returns: (d0, d1)[s0] -> (d0 + 2, d1 + s0 + 2)
194 AffineMap getShiftedAffineMap(AffineMap map, int64_t shift);
195
196protected:
197 MLIRContext *context;
198};
199
200/// This class helps build Operations. Operations that are created are
201/// automatically inserted at an insertion point. The builder is copyable.
202class OpBuilder : public Builder {
203public:
204 struct Listener;
205
206 /// Create a builder with the given context.
207 explicit OpBuilder(MLIRContext *ctx, Listener *listener = nullptr)
208 : Builder(ctx), listener(listener) {}
209
210 /// Create a builder and set the insertion point to the start of the region.
211 explicit OpBuilder(Region *region, Listener *listener = nullptr)
212 : OpBuilder(region->getContext(), listener) {
213 if (!region->empty())
214 setInsertionPoint(&region->front(), region->front().begin());
215 }
216 explicit OpBuilder(Region &region, Listener *listener = nullptr)
217 : OpBuilder(&region, listener) {}
218
219 /// Create a builder and set insertion point to the given operation, which
220 /// will cause subsequent insertions to go right before it.
221 explicit OpBuilder(Operation *op, Listener *listener = nullptr)
222 : OpBuilder(op->getContext(), listener) {
223 setInsertionPoint(op);
224 }
225
226 OpBuilder(Block *block, Block::iterator insertPoint,
227 Listener *listener = nullptr)
228 : OpBuilder(block->getParent()->getContext(), listener) {
229 setInsertionPoint(block, insertPoint);
230 }
231
232 /// Create a builder and set the insertion point to before the first operation
233 /// in the block but still inside the block.
234 static OpBuilder atBlockBegin(Block *block, Listener *listener = nullptr) {
235 return OpBuilder(block, block->begin(), listener);
236 }
237
238 /// Create a builder and set the insertion point to after the last operation
239 /// in the block but still inside the block.
240 static OpBuilder atBlockEnd(Block *block, Listener *listener = nullptr) {
241 return OpBuilder(block, block->end(), listener);
242 }
243
244 /// Create a builder and set the insertion point to before the block
245 /// terminator.
246 static OpBuilder atBlockTerminator(Block *block,
247 Listener *listener = nullptr) {
248 auto *terminator = block->getTerminator();
249 assert(terminator != nullptr && "the block has no terminator")(static_cast <bool> (terminator != nullptr && "the block has no terminator"
) ? void (0) : __assert_fail ("terminator != nullptr && \"the block has no terminator\""
, "llvm/../mlir/include/mlir/IR/Builders.h", 249, __extension__
__PRETTY_FUNCTION__))
;
250 return OpBuilder(block, Block::iterator(terminator), listener);
251 }
252
253 //===--------------------------------------------------------------------===//
254 // Listeners
255 //===--------------------------------------------------------------------===//
256
257 /// Base class for listeners.
258 struct ListenerBase {
259 /// The kind of listener.
260 enum class Kind {
261 /// OpBuilder::Listener or user-derived class.
262 OpBuilderListener = 0,
263
264 /// RewriterBase::Listener or user-derived class.
265 RewriterBaseListener = 1
266 };
267
268 Kind getKind() const { return kind; }
269
270 protected:
271 ListenerBase(Kind kind) : kind(kind) {}
272
273 private:
274 const Kind kind;
275 };
276
277 /// This class represents a listener that may be used to hook into various
278 /// actions within an OpBuilder.
279 struct Listener : public ListenerBase {
280 Listener() : ListenerBase(ListenerBase::Kind::OpBuilderListener) {}
281
282 virtual ~Listener() = default;
283
284 /// Notification handler for when an operation is inserted into the builder.
285 /// `op` is the operation that was inserted.
286 virtual void notifyOperationInserted(Operation *op) {}
287
288 /// Notification handler for when a block is created using the builder.
289 /// `block` is the block that was created.
290 virtual void notifyBlockCreated(Block *block) {}
291
292 protected:
293 Listener(Kind kind) : ListenerBase(kind) {}
294 };
295
296 /// Sets the listener of this builder to the one provided.
297 void setListener(Listener *newListener) { listener = newListener; }
298
299 /// Returns the current listener of this builder, or nullptr if this builder
300 /// doesn't have a listener.
301 Listener *getListener() const { return listener; }
302
303 //===--------------------------------------------------------------------===//
304 // Insertion Point Management
305 //===--------------------------------------------------------------------===//
306
307 /// This class represents a saved insertion point.
308 class InsertPoint {
309 public:
310 /// Creates a new insertion point which doesn't point to anything.
311 InsertPoint() = default;
312
313 /// Creates a new insertion point at the given location.
314 InsertPoint(Block *insertBlock, Block::iterator insertPt)
315 : block(insertBlock), point(insertPt) {}
316
317 /// Returns true if this insert point is set.
318 bool isSet() const { return (block != nullptr); }
319
320 Block *getBlock() const { return block; }
321 Block::iterator getPoint() const { return point; }
322
323 private:
324 Block *block = nullptr;
325 Block::iterator point;
326 };
327
328 /// RAII guard to reset the insertion point of the builder when destroyed.
329 class InsertionGuard {
330 public:
331 InsertionGuard(OpBuilder &builder)
332 : builder(&builder), ip(builder.saveInsertionPoint()) {}
333
334 ~InsertionGuard() {
335 if (builder)
336 builder->restoreInsertionPoint(ip);
337 }
338
339 InsertionGuard(const InsertionGuard &) = delete;
340 InsertionGuard &operator=(const InsertionGuard &) = delete;
341
342 /// Implement the move constructor to clear the builder field of `other`.
343 /// That way it does not restore the insertion point upon destruction as
344 /// that should be done exclusively by the just constructed InsertionGuard.
345 InsertionGuard(InsertionGuard &&other) noexcept
346 : builder(other.builder), ip(other.ip) {
347 other.builder = nullptr;
348 }
349
350 InsertionGuard &operator=(InsertionGuard &&other) = delete;
351
352 private:
353 OpBuilder *builder;
354 OpBuilder::InsertPoint ip;
355 };
356
357 /// Reset the insertion point to no location. Creating an operation without a
358 /// set insertion point is an error, but this can still be useful when the
359 /// current insertion point a builder refers to is being removed.
360 void clearInsertionPoint() {
361 this->block = nullptr;
362 insertPoint = Block::iterator();
363 }
364
365 /// Return a saved insertion point.
366 InsertPoint saveInsertionPoint() const {
367 return InsertPoint(getInsertionBlock(), getInsertionPoint());
368 }
369
370 /// Restore the insert point to a previously saved point.
371 void restoreInsertionPoint(InsertPoint ip) {
372 if (ip.isSet())
373 setInsertionPoint(ip.getBlock(), ip.getPoint());
374 else
375 clearInsertionPoint();
376 }
377
378 /// Set the insertion point to the specified location.
379 void setInsertionPoint(Block *block, Block::iterator insertPoint) {
380 // TODO: check that insertPoint is in this rather than some other block.
381 this->block = block;
382 this->insertPoint = insertPoint;
383 }
384
385 /// Sets the insertion point to the specified operation, which will cause
386 /// subsequent insertions to go right before it.
387 void setInsertionPoint(Operation *op) {
388 setInsertionPoint(op->getBlock(), Block::iterator(op));
389 }
390
391 /// Sets the insertion point to the node after the specified operation, which
392 /// will cause subsequent insertions to go right after it.
393 void setInsertionPointAfter(Operation *op) {
394 setInsertionPoint(op->getBlock(), ++Block::iterator(op));
395 }
396
397 /// Sets the insertion point to the node after the specified value. If value
398 /// has a defining operation, sets the insertion point to the node after such
399 /// defining operation. This will cause subsequent insertions to go right
400 /// after it. Otherwise, value is a BlockArgument. Sets the insertion point to
401 /// the start of its block.
402 void setInsertionPointAfterValue(Value val) {
403 if (Operation *op = val.getDefiningOp()) {
404 setInsertionPointAfter(op);
405 } else {
406 auto blockArg = val.cast<BlockArgument>();
407 setInsertionPointToStart(blockArg.getOwner());
408 }
409 }
410
411 /// Sets the insertion point to the start of the specified block.
412 void setInsertionPointToStart(Block *block) {
413 setInsertionPoint(block, block->begin());
414 }
415
416 /// Sets the insertion point to the end of the specified block.
417 void setInsertionPointToEnd(Block *block) {
418 setInsertionPoint(block, block->end());
419 }
420
421 /// Return the block the current insertion point belongs to. Note that the
422 /// insertion point is not necessarily the end of the block.
423 Block *getInsertionBlock() const { return block; }
424
425 /// Returns the current insertion point of the builder.
426 Block::iterator getInsertionPoint() const { return insertPoint; }
427
428 /// Returns the current block of the builder.
429 Block *getBlock() const { return block; }
430
431 //===--------------------------------------------------------------------===//
432 // Block Creation
433 //===--------------------------------------------------------------------===//
434
435 /// Add new block with 'argTypes' arguments and set the insertion point to the
436 /// end of it. The block is inserted at the provided insertion point of
437 /// 'parent'. `locs` contains the locations of the inserted arguments, and
438 /// should match the size of `argTypes`.
439 Block *createBlock(Region *parent, Region::iterator insertPt = {},
440 TypeRange argTypes = std::nullopt,
441 ArrayRef<Location> locs = std::nullopt);
442
443 /// Add new block with 'argTypes' arguments and set the insertion point to the
444 /// end of it. The block is placed before 'insertBefore'. `locs` contains the
445 /// locations of the inserted arguments, and should match the size of
446 /// `argTypes`.
447 Block *createBlock(Block *insertBefore, TypeRange argTypes = std::nullopt,
448 ArrayRef<Location> locs = std::nullopt);
449
450 //===--------------------------------------------------------------------===//
451 // Operation Creation
452 //===--------------------------------------------------------------------===//
453
454 /// Insert the given operation at the current insertion point and return it.
455 Operation *insert(Operation *op);
456
457 /// Creates an operation given the fields represented as an OperationState.
458 Operation *create(const OperationState &state);
459
460 /// Creates an operation with the given fields.
461 Operation *create(Location loc, StringAttr opName, ValueRange operands,
462 TypeRange types = {},
463 ArrayRef<NamedAttribute> attributes = {},
464 BlockRange successors = {},
465 MutableArrayRef<std::unique_ptr<Region>> regions = {});
466
467private:
468 /// Helper for sanity checking preconditions for create* methods below.
469 template <typename OpT>
470 RegisteredOperationName getCheckRegisteredInfo(MLIRContext *ctx) {
471 std::optional<RegisteredOperationName> opName =
472 RegisteredOperationName::lookup(OpT::getOperationName(), ctx);
473 if (LLVM_UNLIKELY(!opName)__builtin_expect((bool)(!opName), false)) {
474 llvm::report_fatal_error(
475 "Building op `" + OpT::getOperationName() +
476 "` but it isn't registered in this MLIRContext: the dialect may not "
477 "be loaded or this operation isn't registered by the dialect. See "
478 "also https://mlir.llvm.org/getting_started/Faq/"
479 "#registered-loaded-dependent-whats-up-with-dialects-management");
480 }
481 return *opName;
482 }
483
484public:
485 /// Create an operation of specific op type at the current insertion point.
486 template <typename OpTy, typename... Args>
487 OpTy create(Location location, Args &&...args) {
488 OperationState state(location,
489 getCheckRegisteredInfo<OpTy>(location.getContext()));
490 OpTy::build(*this, state, std::forward<Args>(args)...);
17
Calling 'forward<mlir::Block *&>'
18
Returning from 'forward<mlir::Block *&>'
19
4th function call argument is an uninitialized value
491 auto *op = create(state);
492 auto result = dyn_cast<OpTy>(op);
493 assert(result && "builder didn't return the right type")(static_cast <bool> (result && "builder didn't return the right type"
) ? void (0) : __assert_fail ("result && \"builder didn't return the right type\""
, "llvm/../mlir/include/mlir/IR/Builders.h", 493, __extension__
__PRETTY_FUNCTION__))
;
494 return result;
495 }
496
497 /// Create an operation of specific op type at the current insertion point,
498 /// and immediately try to fold it. This functions populates 'results' with
499 /// the results after folding the operation.
500 template <typename OpTy, typename... Args>
501 void createOrFold(SmallVectorImpl<Value> &results, Location location,
502 Args &&...args) {
503 // Create the operation without using 'create' as we don't want to
504 // insert it yet.
505 OperationState state(location,
506 getCheckRegisteredInfo<OpTy>(location.getContext()));
507 OpTy::build(*this, state, std::forward<Args>(args)...);
508 Operation *op = Operation::create(state);
509
510 // Fold the operation. If successful destroy it, otherwise insert it.
511 if (succeeded(tryFold(op, results)))
512 op->destroy();
513 else
514 insert(op);
515 }
516
517 /// Overload to create or fold a single result operation.
518 template <typename OpTy, typename... Args>
519 std::enable_if_t<OpTy::template hasTrait<OpTrait::OneResult>(), Value>
520 createOrFold(Location location, Args &&...args) {
521 SmallVector<Value, 1> results;
522 createOrFold<OpTy>(results, location, std::forward<Args>(args)...);
523 return results.front();
524 }
525
526 /// Overload to create or fold a zero result operation.
527 template <typename OpTy, typename... Args>
528 std::enable_if_t<OpTy::template hasTrait<OpTrait::ZeroResults>(), OpTy>
529 createOrFold(Location location, Args &&...args) {
530 auto op = create<OpTy>(location, std::forward<Args>(args)...);
531 SmallVector<Value, 0> unused;
532 (void)tryFold(op.getOperation(), unused);
533
534 // Folding cannot remove a zero-result operation, so for convenience we
535 // continue to return it.
536 return op;
537 }
538
539 /// Attempts to fold the given operation and places new results within
540 /// 'results'. Returns success if the operation was folded, failure otherwise.
541 /// Note: This function does not erase the operation on a successful fold.
542 LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results);
543
544 /// Creates a deep copy of the specified operation, remapping any operands
545 /// that use values outside of the operation using the map that is provided
546 /// ( leaving them alone if no entry is present). Replaces references to
547 /// cloned sub-operations to the corresponding operation that is copied,
548 /// and adds those mappings to the map.
549 Operation *clone(Operation &op, IRMapping &mapper);
550 Operation *clone(Operation &op);
551
552 /// Creates a deep copy of this operation but keep the operation regions
553 /// empty. Operands are remapped using `mapper` (if present), and `mapper` is
554 /// updated to contain the results.
555 Operation *cloneWithoutRegions(Operation &op, IRMapping &mapper) {
556 return insert(op.cloneWithoutRegions(mapper));
557 }
558 Operation *cloneWithoutRegions(Operation &op) {
559 return insert(op.cloneWithoutRegions());
560 }
561 template <typename OpT>
562 OpT cloneWithoutRegions(OpT op) {
563 return cast<OpT>(cloneWithoutRegions(*op.getOperation()));
564 }
565
566protected:
567 /// The optional listener for events of this builder.
568 Listener *listener;
569
570private:
571 /// The current block this builder is inserting into.
572 Block *block = nullptr;
573 /// The insertion point within the block that this builder is inserting
574 /// before.
575 Block::iterator insertPoint;
576};
577
578} // namespace mlir
579
580#endif