File: | build/source/flang/runtime/matmul.cpp |
Warning: | line 287, column 20 Assigned value is garbage or undefined |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===-- runtime/matmul.cpp ------------------------------------------------===// | ||||
2 | // | ||||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||||
4 | // See https://llvm.org/LICENSE.txt for license information. | ||||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||||
6 | // | ||||
7 | //===----------------------------------------------------------------------===// | ||||
8 | |||||
9 | // Implements all forms of MATMUL (Fortran 2018 16.9.124) | ||||
10 | // | ||||
11 | // There are two main entry points; one establishes a descriptor for the | ||||
12 | // result and allocates it, and the other expects a result descriptor that | ||||
13 | // points to existing storage. | ||||
14 | // | ||||
15 | // This implementation must handle all combinations of numeric types and | ||||
16 | // kinds (100 - 165 cases depending on the target), plus all combinations | ||||
17 | // of logical kinds (16). A single template undergoes many instantiations | ||||
18 | // to cover all of the valid possibilities. | ||||
19 | // | ||||
20 | // Places where BLAS routines could be called are marked as TODO items. | ||||
21 | |||||
22 | #include "flang/Runtime/matmul.h" | ||||
23 | #include "terminator.h" | ||||
24 | #include "tools.h" | ||||
25 | #include "flang/Runtime/c-or-cpp.h" | ||||
26 | #include "flang/Runtime/cpp-type.h" | ||||
27 | #include "flang/Runtime/descriptor.h" | ||||
28 | #include <cstring> | ||||
29 | |||||
30 | namespace Fortran::runtime { | ||||
31 | |||||
32 | // General accumulator for any type and stride; this is not used for | ||||
33 | // contiguous numeric cases. | ||||
34 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT> | ||||
35 | class Accumulator { | ||||
36 | public: | ||||
37 | using Result = AccumulationType<RCAT, RKIND>; | ||||
38 | Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {} | ||||
39 | void Accumulate(const SubscriptValue xAt[], const SubscriptValue yAt[]) { | ||||
40 | if constexpr (RCAT == TypeCategory::Logical) { | ||||
41 | sum_ = sum_ || | ||||
42 | (IsLogicalElementTrue(x_, xAt) && IsLogicalElementTrue(y_, yAt)); | ||||
43 | } else { | ||||
44 | sum_ += static_cast<Result>(*x_.Element<XT>(xAt)) * | ||||
45 | static_cast<Result>(*y_.Element<YT>(yAt)); | ||||
46 | } | ||||
47 | } | ||||
48 | Result GetResult() const { return sum_; } | ||||
49 | |||||
50 | private: | ||||
51 | const Descriptor &x_, &y_; | ||||
52 | Result sum_{}; | ||||
53 | }; | ||||
54 | |||||
55 | // Contiguous numeric matrix*matrix multiplication | ||||
56 | // matrix(rows,n) * matrix(n,cols) -> matrix(rows,cols) | ||||
57 | // Straightforward algorithm: | ||||
58 | // DO 1 I = 1, NROWS | ||||
59 | // DO 1 J = 1, NCOLS | ||||
60 | // RES(I,J) = 0 | ||||
61 | // DO 1 K = 1, N | ||||
62 | // 1 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) | ||||
63 | // With loop distribution and transposition to avoid the inner sum | ||||
64 | // reduction and to avoid non-unit strides: | ||||
65 | // DO 1 I = 1, NROWS | ||||
66 | // DO 1 J = 1, NCOLS | ||||
67 | // 1 RES(I,J) = 0 | ||||
68 | // DO 2 K = 1, N | ||||
69 | // DO 2 J = 1, NCOLS | ||||
70 | // DO 2 I = 1, NROWS | ||||
71 | // 2 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) ! loop-invariant last term | ||||
72 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT> | ||||
73 | inline void MatrixTimesMatrix(CppTypeFor<RCAT, RKIND> *RESTRICT__restrict product, | ||||
74 | SubscriptValue rows, SubscriptValue cols, const XT *RESTRICT__restrict x, | ||||
75 | const YT *RESTRICT__restrict y, SubscriptValue n) { | ||||
76 | using ResultType = CppTypeFor<RCAT, RKIND>; | ||||
77 | std::memset(product, 0, rows * cols * sizeof *product); | ||||
78 | const XT *RESTRICT__restrict xp0{x}; | ||||
79 | for (SubscriptValue k{0}; k < n; ++k) { | ||||
80 | ResultType *RESTRICT__restrict p{product}; | ||||
81 | for (SubscriptValue j{0}; j < cols; ++j) { | ||||
82 | const XT *RESTRICT__restrict xp{xp0}; | ||||
83 | auto yv{static_cast<ResultType>(y[k + j * n])}; | ||||
84 | for (SubscriptValue i{0}; i < rows; ++i) { | ||||
85 | *p++ += static_cast<ResultType>(*xp++) * yv; | ||||
86 | } | ||||
87 | } | ||||
88 | xp0 += rows; | ||||
89 | } | ||||
90 | } | ||||
91 | |||||
92 | // Contiguous numeric matrix*vector multiplication | ||||
93 | // matrix(rows,n) * column vector(n) -> column vector(rows) | ||||
94 | // Straightforward algorithm: | ||||
95 | // DO 1 J = 1, NROWS | ||||
96 | // RES(J) = 0 | ||||
97 | // DO 1 K = 1, N | ||||
98 | // 1 RES(J) = RES(J) + X(J,K)*Y(K) | ||||
99 | // With loop distribution and transposition to avoid the inner | ||||
100 | // sum reduction and to avoid non-unit strides: | ||||
101 | // DO 1 J = 1, NROWS | ||||
102 | // 1 RES(J) = 0 | ||||
103 | // DO 2 K = 1, N | ||||
104 | // DO 2 J = 1, NROWS | ||||
105 | // 2 RES(J) = RES(J) + X(J,K)*Y(K) | ||||
106 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT> | ||||
107 | inline void MatrixTimesVector(CppTypeFor<RCAT, RKIND> *RESTRICT__restrict product, | ||||
108 | SubscriptValue rows, SubscriptValue n, const XT *RESTRICT__restrict x, | ||||
109 | const YT *RESTRICT__restrict y) { | ||||
110 | using ResultType = CppTypeFor<RCAT, RKIND>; | ||||
111 | std::memset(product, 0, rows * sizeof *product); | ||||
112 | for (SubscriptValue k{0}; k < n; ++k) { | ||||
113 | ResultType *RESTRICT__restrict p{product}; | ||||
114 | auto yv{static_cast<ResultType>(*y++)}; | ||||
115 | for (SubscriptValue j{0}; j < rows; ++j) { | ||||
116 | *p++ += static_cast<ResultType>(*x++) * yv; | ||||
117 | } | ||||
118 | } | ||||
119 | } | ||||
120 | |||||
121 | // Contiguous numeric vector*matrix multiplication | ||||
122 | // row vector(n) * matrix(n,cols) -> row vector(cols) | ||||
123 | // Straightforward algorithm: | ||||
124 | // DO 1 J = 1, NCOLS | ||||
125 | // RES(J) = 0 | ||||
126 | // DO 1 K = 1, N | ||||
127 | // 1 RES(J) = RES(J) + X(K)*Y(K,J) | ||||
128 | // With loop distribution and transposition to avoid the inner | ||||
129 | // sum reduction and one non-unit stride (the other remains): | ||||
130 | // DO 1 J = 1, NCOLS | ||||
131 | // 1 RES(J) = 0 | ||||
132 | // DO 2 K = 1, N | ||||
133 | // DO 2 J = 1, NCOLS | ||||
134 | // 2 RES(J) = RES(J) + X(K)*Y(K,J) | ||||
135 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT> | ||||
136 | inline void VectorTimesMatrix(CppTypeFor<RCAT, RKIND> *RESTRICT__restrict product, | ||||
137 | SubscriptValue n, SubscriptValue cols, const XT *RESTRICT__restrict x, | ||||
138 | const YT *RESTRICT__restrict y) { | ||||
139 | using ResultType = CppTypeFor<RCAT, RKIND>; | ||||
140 | std::memset(product, 0, cols * sizeof *product); | ||||
141 | for (SubscriptValue k{0}; k < n; ++k) { | ||||
142 | ResultType *RESTRICT__restrict p{product}; | ||||
143 | auto xv{static_cast<ResultType>(*x++)}; | ||||
144 | const YT *RESTRICT__restrict yp{&y[k]}; | ||||
145 | for (SubscriptValue j{0}; j < cols; ++j) { | ||||
146 | *p++ += xv * static_cast<ResultType>(*yp); | ||||
147 | yp += n; | ||||
148 | } | ||||
149 | } | ||||
150 | } | ||||
151 | |||||
152 | // Implements an instance of MATMUL for given argument types. | ||||
153 | template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT, | ||||
154 | typename YT> | ||||
155 | static inline void DoMatmul( | ||||
156 | std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor> &result, | ||||
157 | const Descriptor &x, const Descriptor &y, Terminator &terminator) { | ||||
158 | int xRank{x.rank()}; | ||||
159 | int yRank{y.rank()}; | ||||
160 | int resRank{xRank + yRank - 2}; | ||||
161 | if (xRank * yRank != 2 * resRank) { | ||||
| |||||
162 | terminator.Crash("MATMUL: bad argument ranks (%d * %d)", xRank, yRank); | ||||
163 | } | ||||
164 | SubscriptValue extent[2]{ | ||||
165 | xRank == 2 ? x.GetDimension(0).Extent() : y.GetDimension(1).Extent(), | ||||
166 | resRank == 2 ? y.GetDimension(1).Extent() : 0}; | ||||
167 | if constexpr (IS_ALLOCATING) { | ||||
168 | result.Establish( | ||||
169 | RCAT, RKIND, nullptr, resRank, extent, CFI_attribute_allocatable2); | ||||
170 | for (int j{0}; j < resRank; ++j) { | ||||
171 | result.GetDimension(j).SetBounds(1, extent[j]); | ||||
172 | } | ||||
173 | if (int stat{result.Allocate()}) { | ||||
174 | terminator.Crash( | ||||
175 | "MATMUL: could not allocate memory for result; STAT=%d", stat); | ||||
176 | } | ||||
177 | } else { | ||||
178 | RUNTIME_CHECK(terminator, resRank == result.rank())if (resRank == result.rank()) ; else (terminator).CheckFailed ("resRank == result.rank()", "flang/runtime/matmul.cpp", 178); | ||||
179 | RUNTIME_CHECK(if (result.ElementBytes() == static_cast<std::size_t>(RKIND )) ; else (terminator).CheckFailed("result.ElementBytes() == static_cast<std::size_t>(RKIND)" , "flang/runtime/matmul.cpp", 180) | ||||
180 | terminator, result.ElementBytes() == static_cast<std::size_t>(RKIND))if (result.ElementBytes() == static_cast<std::size_t>(RKIND )) ; else (terminator).CheckFailed("result.ElementBytes() == static_cast<std::size_t>(RKIND)" , "flang/runtime/matmul.cpp", 180); | ||||
181 | RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0])if (result.GetDimension(0).Extent() == extent[0]) ; else (terminator ).CheckFailed("result.GetDimension(0).Extent() == extent[0]", "flang/runtime/matmul.cpp", 181); | ||||
182 | RUNTIME_CHECK(terminator,if (resRank == 1 || result.GetDimension(1).Extent() == extent [1]) ; else (terminator).CheckFailed("resRank == 1 || result.GetDimension(1).Extent() == extent[1]" , "flang/runtime/matmul.cpp", 183) | ||||
183 | resRank == 1 || result.GetDimension(1).Extent() == extent[1])if (resRank == 1 || result.GetDimension(1).Extent() == extent [1]) ; else (terminator).CheckFailed("resRank == 1 || result.GetDimension(1).Extent() == extent[1]" , "flang/runtime/matmul.cpp", 183); | ||||
184 | } | ||||
185 | SubscriptValue n{x.GetDimension(xRank - 1).Extent()}; | ||||
186 | if (n != y.GetDimension(0).Extent()) { | ||||
187 | terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)", | ||||
188 | static_cast<std::intmax_t>(x.GetDimension(0).Extent()), | ||||
189 | static_cast<std::intmax_t>(n), | ||||
190 | static_cast<std::intmax_t>(y.GetDimension(0).Extent()), | ||||
191 | static_cast<std::intmax_t>(y.GetDimension(1).Extent())); | ||||
192 | } | ||||
193 | using WriteResult = | ||||
194 | CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT, | ||||
195 | RKIND>; | ||||
196 | if constexpr (RCAT
| ||||
197 | if (x.IsContiguous() && y.IsContiguous() && | ||||
198 | (IS_ALLOCATING || result.IsContiguous())) { | ||||
199 | // Contiguous numeric matrices | ||||
200 | if (resRank == 2) { // M*M -> M | ||||
201 | if (std::is_same_v<XT, YT>) { | ||||
202 | if constexpr (std::is_same_v<XT, float>) { | ||||
203 | // TODO: call BLAS-3 SGEMM | ||||
204 | } else if constexpr (std::is_same_v<XT, double>) { | ||||
205 | // TODO: call BLAS-3 DGEMM | ||||
206 | } else if constexpr (std::is_same_v<XT, std::complex<float>>) { | ||||
207 | // TODO: call BLAS-3 CGEMM | ||||
208 | } else if constexpr (std::is_same_v<XT, std::complex<double>>) { | ||||
209 | // TODO: call BLAS-3 ZGEMM | ||||
210 | } | ||||
211 | } | ||||
212 | MatrixTimesMatrix<RCAT, RKIND, XT, YT>( | ||||
213 | result.template OffsetElement<WriteResult>(), extent[0], extent[1], | ||||
214 | x.OffsetElement<XT>(), y.OffsetElement<YT>(), n); | ||||
215 | return; | ||||
216 | } else if (xRank == 2) { // M*V -> V | ||||
217 | if (std::is_same_v<XT, YT>) { | ||||
218 | if constexpr (std::is_same_v<XT, float>) { | ||||
219 | // TODO: call BLAS-2 SGEMV(x,y) | ||||
220 | } else if constexpr (std::is_same_v<XT, double>) { | ||||
221 | // TODO: call BLAS-2 DGEMV(x,y) | ||||
222 | } else if constexpr (std::is_same_v<XT, std::complex<float>>) { | ||||
223 | // TODO: call BLAS-2 CGEMV(x,y) | ||||
224 | } else if constexpr (std::is_same_v<XT, std::complex<double>>) { | ||||
225 | // TODO: call BLAS-2 ZGEMV(x,y) | ||||
226 | } | ||||
227 | } | ||||
228 | MatrixTimesVector<RCAT, RKIND, XT, YT>( | ||||
229 | result.template OffsetElement<WriteResult>(), extent[0], n, | ||||
230 | x.OffsetElement<XT>(), y.OffsetElement<YT>()); | ||||
231 | return; | ||||
232 | } else { // V*M -> V | ||||
233 | if (std::is_same_v<XT, YT>) { | ||||
234 | if constexpr (std::is_same_v<XT, float>) { | ||||
235 | // TODO: call BLAS-2 SGEMV(y,x) | ||||
236 | } else if constexpr (std::is_same_v<XT, double>) { | ||||
237 | // TODO: call BLAS-2 DGEMV(y,x) | ||||
238 | } else if constexpr (std::is_same_v<XT, std::complex<float>>) { | ||||
239 | // TODO: call BLAS-2 CGEMV(y,x) | ||||
240 | } else if constexpr (std::is_same_v<XT, std::complex<double>>) { | ||||
241 | // TODO: call BLAS-2 ZGEMV(y,x) | ||||
242 | } | ||||
243 | } | ||||
244 | VectorTimesMatrix<RCAT, RKIND, XT, YT>( | ||||
245 | result.template OffsetElement<WriteResult>(), n, extent[0], | ||||
246 | x.OffsetElement<XT>(), y.OffsetElement<YT>()); | ||||
247 | return; | ||||
248 | } | ||||
249 | } | ||||
250 | } | ||||
251 | // General algorithms for LOGICAL and noncontiguity | ||||
252 | SubscriptValue xAt[2], yAt[2], resAt[2]; | ||||
253 | x.GetLowerBounds(xAt); | ||||
254 | y.GetLowerBounds(yAt); | ||||
255 | result.GetLowerBounds(resAt); | ||||
256 | if (resRank == 2) { // M*M -> M | ||||
257 | SubscriptValue x1{xAt[1]}, y0{yAt[0]}, y1{yAt[1]}, res1{resAt[1]}; | ||||
258 | for (SubscriptValue i{0}; i < extent[0]; ++i) { | ||||
259 | for (SubscriptValue j{0}; j < extent[1]; ++j) { | ||||
260 | Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; | ||||
261 | yAt[1] = y1 + j; | ||||
262 | for (SubscriptValue k{0}; k < n; ++k) { | ||||
263 | xAt[1] = x1 + k; | ||||
264 | yAt[0] = y0 + k; | ||||
265 | accumulator.Accumulate(xAt, yAt); | ||||
266 | } | ||||
267 | resAt[1] = res1 + j; | ||||
268 | *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); | ||||
269 | } | ||||
270 | ++resAt[0]; | ||||
271 | ++xAt[0]; | ||||
272 | } | ||||
273 | } else if (xRank
| ||||
274 | SubscriptValue x1{xAt[1]}, y0{yAt[0]}; | ||||
275 | for (SubscriptValue j{0}; j < extent[0]; ++j) { | ||||
276 | Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; | ||||
277 | for (SubscriptValue k{0}; k < n; ++k) { | ||||
278 | xAt[1] = x1 + k; | ||||
279 | yAt[0] = y0 + k; | ||||
280 | accumulator.Accumulate(xAt, yAt); | ||||
281 | } | ||||
282 | *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); | ||||
283 | ++resAt[0]; | ||||
284 | ++xAt[0]; | ||||
285 | } | ||||
286 | } else { // V*M -> V | ||||
287 | SubscriptValue x0{xAt[0]}, y0{yAt[0]}; | ||||
| |||||
288 | for (SubscriptValue j{0}; j < extent[0]; ++j) { | ||||
289 | Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; | ||||
290 | for (SubscriptValue k{0}; k < n; ++k) { | ||||
291 | xAt[0] = x0 + k; | ||||
292 | yAt[0] = y0 + k; | ||||
293 | accumulator.Accumulate(xAt, yAt); | ||||
294 | } | ||||
295 | *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); | ||||
296 | ++resAt[0]; | ||||
297 | ++yAt[1]; | ||||
298 | } | ||||
299 | } | ||||
300 | } | ||||
301 | |||||
302 | // Maps the dynamic type information from the arguments' descriptors | ||||
303 | // to the right instantiation of DoMatmul() for valid combinations of | ||||
304 | // types. | ||||
305 | template <bool IS_ALLOCATING> struct Matmul { | ||||
306 | using ResultDescriptor = | ||||
307 | std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>; | ||||
308 | template <TypeCategory XCAT, int XKIND> struct MM1 { | ||||
309 | template <TypeCategory YCAT, int YKIND> struct MM2 { | ||||
310 | void operator()(ResultDescriptor &result, const Descriptor &x, | ||||
311 | const Descriptor &y, Terminator &terminator) const { | ||||
312 | if constexpr (constexpr auto resultType{ | ||||
313 | GetResultType(XCAT, XKIND, YCAT, YKIND)}) { | ||||
314 | if constexpr (common::IsNumericTypeCategory(resultType->first) || | ||||
315 | resultType->first == TypeCategory::Logical) { | ||||
316 | return DoMatmul<IS_ALLOCATING, resultType->first, | ||||
317 | resultType->second, CppTypeFor<XCAT, XKIND>, | ||||
318 | CppTypeFor<YCAT, YKIND>>(result, x, y, terminator); | ||||
319 | } | ||||
320 | } | ||||
321 | terminator.Crash("MATMUL: bad operand types (%d(%d), %d(%d))", | ||||
322 | static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND); | ||||
323 | } | ||||
324 | }; | ||||
325 | void operator()(ResultDescriptor &result, const Descriptor &x, | ||||
326 | const Descriptor &y, Terminator &terminator, TypeCategory yCat, | ||||
327 | int yKind) const { | ||||
328 | ApplyType<MM2, void>(yCat, yKind, terminator, result, x, y, terminator); | ||||
329 | } | ||||
330 | }; | ||||
331 | void operator()(ResultDescriptor &result, const Descriptor &x, | ||||
332 | const Descriptor &y, const char *sourceFile, int line) const { | ||||
333 | Terminator terminator{sourceFile, line}; | ||||
334 | auto xCatKind{x.type().GetCategoryAndKind()}; | ||||
335 | auto yCatKind{y.type().GetCategoryAndKind()}; | ||||
336 | RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value())if (xCatKind.has_value() && yCatKind.has_value()) ; else (terminator).CheckFailed("xCatKind.has_value() && yCatKind.has_value()" , "flang/runtime/matmul.cpp", 336); | ||||
337 | ApplyType<MM1, void>(xCatKind->first, xCatKind->second, terminator, result, | ||||
338 | x, y, terminator, yCatKind->first, yCatKind->second); | ||||
339 | } | ||||
340 | }; | ||||
341 | |||||
342 | extern "C" { | ||||
343 | void RTNAME(Matmul)_FortranAMatmul(Descriptor &result, const Descriptor &x, | ||||
344 | const Descriptor &y, const char *sourceFile, int line) { | ||||
345 | Matmul<true>{}(result, x, y, sourceFile, line); | ||||
346 | } | ||||
347 | void RTNAME(MatmulDirect)_FortranAMatmulDirect(const Descriptor &result, const Descriptor &x, | ||||
348 | const Descriptor &y, const char *sourceFile, int line) { | ||||
349 | Matmul<false>{}(result, x, y, sourceFile, line); | ||||
350 | } | ||||
351 | } // extern "C" | ||||
352 | } // namespace Fortran::runtime |
1 | //===-- include/flang/Runtime/descriptor.h ----------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef FORTRAN_RUNTIME_DESCRIPTOR_H_ |
10 | #define FORTRAN_RUNTIME_DESCRIPTOR_H_ |
11 | |
12 | // Defines data structures used during execution of a Fortran program |
13 | // to implement nontrivial dummy arguments, pointers, allocatables, |
14 | // function results, and the special behaviors of instances of derived types. |
15 | // This header file includes and extends the published language |
16 | // interoperability header that is required by the Fortran 2018 standard |
17 | // as a subset of definitions suitable for exposure to user C/C++ code. |
18 | // User C code is welcome to depend on that ISO_Fortran_binding.h file, |
19 | // but should never reference this internal header. |
20 | |
21 | #include "flang/ISO_Fortran_binding.h" |
22 | #include "flang/Runtime/memory.h" |
23 | #include "flang/Runtime/type-code.h" |
24 | #include <algorithm> |
25 | #include <cassert> |
26 | #include <cinttypes> |
27 | #include <cstddef> |
28 | #include <cstdio> |
29 | #include <cstring> |
30 | |
31 | namespace Fortran::runtime::typeInfo { |
32 | using TypeParameterValue = std::int64_t; |
33 | class DerivedType; |
34 | } // namespace Fortran::runtime::typeInfo |
35 | |
36 | namespace Fortran::runtime { |
37 | |
38 | using SubscriptValue = ISO::CFI_index_t; |
39 | |
40 | static constexpr int maxRank{CFI_MAX_RANK15}; |
41 | |
42 | // A C++ view of the sole interoperable standard descriptor (ISO::CFI_cdesc_t) |
43 | // and its type and per-dimension information. |
44 | |
45 | class Dimension { |
46 | public: |
47 | SubscriptValue LowerBound() const { return raw_.lower_bound; } |
48 | SubscriptValue Extent() const { return raw_.extent; } |
49 | SubscriptValue UpperBound() const { return LowerBound() + Extent() - 1; } |
50 | SubscriptValue ByteStride() const { return raw_.sm; } |
51 | |
52 | Dimension &SetBounds(SubscriptValue lower, SubscriptValue upper) { |
53 | if (upper >= lower) { |
54 | raw_.lower_bound = lower; |
55 | raw_.extent = upper - lower + 1; |
56 | } else { |
57 | raw_.lower_bound = 1; |
58 | raw_.extent = 0; |
59 | } |
60 | return *this; |
61 | } |
62 | // Do not use this API to cause the LB of an empty dimension |
63 | // to be anything other than 1. Use SetBounds() instead if you can. |
64 | Dimension &SetLowerBound(SubscriptValue lower) { |
65 | raw_.lower_bound = lower; |
66 | return *this; |
67 | } |
68 | Dimension &SetUpperBound(SubscriptValue upper) { |
69 | auto lower{raw_.lower_bound}; |
70 | raw_.extent = upper >= lower ? upper - lower + 1 : 0; |
71 | return *this; |
72 | } |
73 | Dimension &SetExtent(SubscriptValue extent) { |
74 | raw_.extent = extent; |
75 | return *this; |
76 | } |
77 | Dimension &SetByteStride(SubscriptValue bytes) { |
78 | raw_.sm = bytes; |
79 | return *this; |
80 | } |
81 | |
82 | private: |
83 | ISO::CFI_dim_t raw_; |
84 | }; |
85 | |
86 | // The storage for this object follows the last used dim[] entry in a |
87 | // Descriptor (CFI_cdesc_t) generic descriptor. Space matters here, since |
88 | // descriptors serve as POINTER and ALLOCATABLE components of derived type |
89 | // instances. The presence of this structure is implied by the flag |
90 | // CFI_cdesc_t.f18Addendum, and the number of elements in the len_[] |
91 | // array is determined by derivedType_->LenParameters(). |
92 | class DescriptorAddendum { |
93 | public: |
94 | explicit DescriptorAddendum(const typeInfo::DerivedType *dt = nullptr) |
95 | : derivedType_{dt} {} |
96 | DescriptorAddendum &operator=(const DescriptorAddendum &); |
97 | |
98 | const typeInfo::DerivedType *derivedType() const { return derivedType_; } |
99 | DescriptorAddendum &set_derivedType(const typeInfo::DerivedType *dt) { |
100 | derivedType_ = dt; |
101 | return *this; |
102 | } |
103 | |
104 | std::size_t LenParameters() const; |
105 | |
106 | typeInfo::TypeParameterValue LenParameterValue(int which) const { |
107 | return len_[which]; |
108 | } |
109 | static constexpr std::size_t SizeInBytes(int lenParameters) { |
110 | // TODO: Don't waste that last word if lenParameters == 0 |
111 | return sizeof(DescriptorAddendum) + |
112 | std::max(lenParameters - 1, 0) * sizeof(typeInfo::TypeParameterValue); |
113 | } |
114 | std::size_t SizeInBytes() const; |
115 | |
116 | void SetLenParameterValue(int which, typeInfo::TypeParameterValue x) { |
117 | len_[which] = x; |
118 | } |
119 | |
120 | void Dump(FILE * = stdoutstdout) const; |
121 | |
122 | private: |
123 | const typeInfo::DerivedType *derivedType_; |
124 | typeInfo::TypeParameterValue len_[1]; // must be the last component |
125 | // The LEN type parameter values can also include captured values of |
126 | // specification expressions that were used for bounds and for LEN type |
127 | // parameters of components. The values have been truncated to the LEN |
128 | // type parameter's type, if shorter than 64 bits, then sign-extended. |
129 | }; |
130 | |
131 | // A C++ view of a standard descriptor object. |
132 | class Descriptor { |
133 | public: |
134 | // Be advised: this class type is not suitable for use when allocating |
135 | // a descriptor -- it is a dynamic view of the common descriptor format. |
136 | // If used in a simple declaration of a local variable or dynamic allocation, |
137 | // the size is going to be correct only by accident, since the true size of |
138 | // a descriptor depends on the number of its dimensions and the presence and |
139 | // size of an addendum, which depends on the type of the data. |
140 | // Use the class template StaticDescriptor (below) to declare a descriptor |
141 | // whose type and rank are fixed and known at compilation time. Use the |
142 | // Create() static member functions otherwise to dynamically allocate a |
143 | // descriptor. |
144 | |
145 | Descriptor(const Descriptor &); |
146 | Descriptor &operator=(const Descriptor &); |
147 | |
148 | // Returns the number of bytes occupied by an element of the given |
149 | // category and kind including any alignment padding required |
150 | // between adjacent elements. |
151 | static std::size_t BytesFor(TypeCategory category, int kind); |
152 | |
153 | void Establish(TypeCode t, std::size_t elementBytes, void *p = nullptr, |
154 | int rank = maxRank, const SubscriptValue *extent = nullptr, |
155 | ISO::CFI_attribute_t attribute = CFI_attribute_other0, |
156 | bool addendum = false); |
157 | void Establish(TypeCategory, int kind, void *p = nullptr, int rank = maxRank, |
158 | const SubscriptValue *extent = nullptr, |
159 | ISO::CFI_attribute_t attribute = CFI_attribute_other0, |
160 | bool addendum = false); |
161 | void Establish(int characterKind, std::size_t characters, void *p = nullptr, |
162 | int rank = maxRank, const SubscriptValue *extent = nullptr, |
163 | ISO::CFI_attribute_t attribute = CFI_attribute_other0, |
164 | bool addendum = false); |
165 | void Establish(const typeInfo::DerivedType &dt, void *p = nullptr, |
166 | int rank = maxRank, const SubscriptValue *extent = nullptr, |
167 | ISO::CFI_attribute_t attribute = CFI_attribute_other0); |
168 | |
169 | static OwningPtr<Descriptor> Create(TypeCode t, std::size_t elementBytes, |
170 | void *p = nullptr, int rank = maxRank, |
171 | const SubscriptValue *extent = nullptr, |
172 | ISO::CFI_attribute_t attribute = CFI_attribute_other0, |
173 | int derivedTypeLenParameters = 0); |
174 | static OwningPtr<Descriptor> Create(TypeCategory, int kind, void *p = nullptr, |
175 | int rank = maxRank, const SubscriptValue *extent = nullptr, |
176 | ISO::CFI_attribute_t attribute = CFI_attribute_other0); |
177 | static OwningPtr<Descriptor> Create(int characterKind, |
178 | SubscriptValue characters, void *p = nullptr, int rank = maxRank, |
179 | const SubscriptValue *extent = nullptr, |
180 | ISO::CFI_attribute_t attribute = CFI_attribute_other0); |
181 | static OwningPtr<Descriptor> Create(const typeInfo::DerivedType &dt, |
182 | void *p = nullptr, int rank = maxRank, |
183 | const SubscriptValue *extent = nullptr, |
184 | ISO::CFI_attribute_t attribute = CFI_attribute_other0); |
185 | |
186 | ISO::CFI_cdesc_t &raw() { return raw_; } |
187 | const ISO::CFI_cdesc_t &raw() const { return raw_; } |
188 | std::size_t ElementBytes() const { return raw_.elem_len; } |
189 | int rank() const { return raw_.rank; } |
190 | TypeCode type() const { return TypeCode{raw_.type}; } |
191 | |
192 | Descriptor &set_base_addr(void *p) { |
193 | raw_.base_addr = p; |
194 | return *this; |
195 | } |
196 | |
197 | bool IsPointer() const { return raw_.attribute == CFI_attribute_pointer1; } |
198 | bool IsAllocatable() const { |
199 | return raw_.attribute == CFI_attribute_allocatable2; |
200 | } |
201 | bool IsAllocated() const { return raw_.base_addr != nullptr; } |
202 | |
203 | Dimension &GetDimension(int dim) { |
204 | return *reinterpret_cast<Dimension *>(&raw_.dim[dim]); |
205 | } |
206 | const Dimension &GetDimension(int dim) const { |
207 | return *reinterpret_cast<const Dimension *>(&raw_.dim[dim]); |
208 | } |
209 | |
210 | std::size_t SubscriptByteOffset( |
211 | int dim, SubscriptValue subscriptValue) const { |
212 | const Dimension &dimension{GetDimension(dim)}; |
213 | return (subscriptValue - dimension.LowerBound()) * dimension.ByteStride(); |
214 | } |
215 | |
216 | std::size_t SubscriptsToByteOffset(const SubscriptValue subscript[]) const { |
217 | std::size_t offset{0}; |
218 | for (int j{0}; j < raw_.rank; ++j) { |
219 | offset += SubscriptByteOffset(j, subscript[j]); |
220 | } |
221 | return offset; |
222 | } |
223 | |
224 | template <typename A = char> A *OffsetElement(std::size_t offset = 0) const { |
225 | return reinterpret_cast<A *>( |
226 | reinterpret_cast<char *>(raw_.base_addr) + offset); |
227 | } |
228 | |
229 | template <typename A> A *Element(const SubscriptValue subscript[]) const { |
230 | return OffsetElement<A>(SubscriptsToByteOffset(subscript)); |
231 | } |
232 | |
233 | template <typename A> A *ZeroBasedIndexedElement(std::size_t n) const { |
234 | SubscriptValue at[maxRank]; |
235 | if (SubscriptsForZeroBasedElementNumber(at, n)) { |
236 | return Element<A>(at); |
237 | } |
238 | return nullptr; |
239 | } |
240 | |
241 | int GetLowerBounds(SubscriptValue subscript[]) const { |
242 | for (int j{0}; j < raw_.rank; ++j) { |
243 | subscript[j] = GetDimension(j).LowerBound(); |
244 | } |
245 | return raw_.rank; |
246 | } |
247 | |
248 | int GetShape(SubscriptValue subscript[]) const { |
249 | for (int j{0}; j < raw_.rank; ++j) { |
250 | subscript[j] = GetDimension(j).Extent(); |
251 | } |
252 | return raw_.rank; |
253 | } |
254 | |
255 | // When the passed subscript vector contains the last (or first) |
256 | // subscripts of the array, these wrap the subscripts around to |
257 | // their first (or last) values and return false. |
258 | bool IncrementSubscripts( |
259 | SubscriptValue subscript[], const int *permutation = nullptr) const { |
260 | for (int j{0}; j < raw_.rank; ++j) { |
261 | int k{permutation ? permutation[j] : j}; |
262 | const Dimension &dim{GetDimension(k)}; |
263 | if (subscript[k]++ < dim.UpperBound()) { |
264 | return true; |
265 | } |
266 | subscript[k] = dim.LowerBound(); |
267 | } |
268 | return false; |
269 | } |
270 | |
271 | bool DecrementSubscripts( |
272 | SubscriptValue[], const int *permutation = nullptr) const; |
273 | |
274 | // False when out of range. |
275 | bool SubscriptsForZeroBasedElementNumber(SubscriptValue subscript[], |
276 | std::size_t elementNumber, const int *permutation = nullptr) const { |
277 | if (raw_.rank == 0) { |
278 | return elementNumber == 0; |
279 | } |
280 | std::size_t dimCoefficient[maxRank]; |
281 | int k0{permutation ? permutation[0] : 0}; |
282 | dimCoefficient[0] = 1; |
283 | auto coefficient{static_cast<std::size_t>(GetDimension(k0).Extent())}; |
284 | for (int j{1}; j < raw_.rank; ++j) { |
285 | int k{permutation ? permutation[j] : j}; |
286 | const Dimension &dim{GetDimension(k)}; |
287 | dimCoefficient[j] = coefficient; |
288 | coefficient *= dim.Extent(); |
289 | } |
290 | if (elementNumber >= coefficient) { |
291 | return false; // out of range |
292 | } |
293 | for (int j{raw_.rank - 1}; j > 0; --j) { |
294 | int k{permutation ? permutation[j] : j}; |
295 | const Dimension &dim{GetDimension(k)}; |
296 | std::size_t quotient{elementNumber / dimCoefficient[j]}; |
297 | subscript[k] = quotient + dim.LowerBound(); |
298 | elementNumber -= quotient * dimCoefficient[j]; |
299 | } |
300 | subscript[k0] = elementNumber + GetDimension(k0).LowerBound(); |
301 | return true; |
302 | } |
303 | |
304 | std::size_t ZeroBasedElementNumber( |
305 | const SubscriptValue *, const int *permutation = nullptr) const; |
306 | |
307 | DescriptorAddendum *Addendum() { |
308 | if (raw_.f18Addendum != 0) { |
309 | return reinterpret_cast<DescriptorAddendum *>(&GetDimension(rank())); |
310 | } else { |
311 | return nullptr; |
312 | } |
313 | } |
314 | const DescriptorAddendum *Addendum() const { |
315 | if (raw_.f18Addendum != 0) { |
316 | return reinterpret_cast<const DescriptorAddendum *>( |
317 | &GetDimension(rank())); |
318 | } else { |
319 | return nullptr; |
320 | } |
321 | } |
322 | |
323 | // Returns size in bytes of the descriptor (not the data) |
324 | static constexpr std::size_t SizeInBytes( |
325 | int rank, bool addendum = false, int lengthTypeParameters = 0) { |
326 | std::size_t bytes{sizeof(Descriptor) - sizeof(Dimension)}; |
327 | bytes += rank * sizeof(Dimension); |
328 | if (addendum || lengthTypeParameters > 0) { |
329 | bytes += DescriptorAddendum::SizeInBytes(lengthTypeParameters); |
330 | } |
331 | return bytes; |
332 | } |
333 | |
334 | std::size_t SizeInBytes() const; |
335 | |
336 | std::size_t Elements() const; |
337 | |
338 | // Allocate() assumes Elements() and ElementBytes() work; |
339 | // define the extents of the dimensions and the element length |
340 | // before calling. It (re)computes the byte strides after |
341 | // allocation. Does not allocate automatic components or |
342 | // perform default component initialization. |
343 | int Allocate(); |
344 | |
345 | // Deallocates storage; does not call FINAL subroutines or |
346 | // deallocate allocatable/automatic components. |
347 | int Deallocate(); |
348 | |
349 | // Deallocates storage, including allocatable and automatic |
350 | // components. Optionally invokes FINAL subroutines. |
351 | int Destroy(bool finalize = false, bool destroyPointers = false); |
352 | |
353 | bool IsContiguous(int leadingDimensions = maxRank) const { |
354 | auto bytes{static_cast<SubscriptValue>(ElementBytes())}; |
355 | if (leadingDimensions > raw_.rank) { |
356 | leadingDimensions = raw_.rank; |
357 | } |
358 | for (int j{0}; j < leadingDimensions; ++j) { |
359 | const Dimension &dim{GetDimension(j)}; |
360 | if (bytes != dim.ByteStride()) { |
361 | return false; |
362 | } |
363 | bytes *= dim.Extent(); |
364 | } |
365 | return true; |
366 | } |
367 | |
368 | // Establishes a pointer to a section or element. |
369 | bool EstablishPointerSection(const Descriptor &source, |
370 | const SubscriptValue *lower = nullptr, |
371 | const SubscriptValue *upper = nullptr, |
372 | const SubscriptValue *stride = nullptr); |
373 | |
374 | void Check() const; |
375 | |
376 | void Dump(FILE * = stdoutstdout) const; |
377 | |
378 | private: |
379 | ISO::CFI_cdesc_t raw_; |
380 | }; |
381 | static_assert(sizeof(Descriptor) == sizeof(ISO::CFI_cdesc_t)); |
382 | |
383 | // Properly configured instances of StaticDescriptor will occupy the |
384 | // exact amount of storage required for the descriptor, its dimensional |
385 | // information, and possible addendum. To build such a static descriptor, |
386 | // declare an instance of StaticDescriptor<>, extract a reference to its |
387 | // descriptor via the descriptor() accessor, and then built a Descriptor |
388 | // therein via descriptor.Establish(), e.g.: |
389 | // StaticDescriptor<R,A,LP> statDesc; |
390 | // Descriptor &descriptor{statDesc.descriptor()}; |
391 | // descriptor.Establish( ... ); |
392 | template <int MAX_RANK = maxRank, bool ADDENDUM = false, int MAX_LEN_PARMS = 0> |
393 | class alignas(Descriptor) StaticDescriptor { |
394 | public: |
395 | static constexpr int maxRank{MAX_RANK}; |
396 | static constexpr int maxLengthTypeParameters{MAX_LEN_PARMS}; |
397 | static constexpr bool hasAddendum{ADDENDUM || MAX_LEN_PARMS > 0}; |
398 | static constexpr std::size_t byteSize{ |
399 | Descriptor::SizeInBytes(maxRank, hasAddendum, maxLengthTypeParameters)}; |
400 | |
401 | Descriptor &descriptor() { return *reinterpret_cast<Descriptor *>(storage_); } |
402 | const Descriptor &descriptor() const { |
403 | return *reinterpret_cast<const Descriptor *>(storage_); |
404 | } |
405 | |
406 | void Check() { |
407 | assert(descriptor().rank() <= maxRank)(static_cast <bool> (descriptor().rank() <= maxRank) ? void (0) : __assert_fail ("descriptor().rank() <= maxRank" , "flang/include/flang/Runtime/descriptor.h", 407, __extension__ __PRETTY_FUNCTION__)); |
408 | assert(descriptor().SizeInBytes() <= byteSize)(static_cast <bool> (descriptor().SizeInBytes() <= byteSize ) ? void (0) : __assert_fail ("descriptor().SizeInBytes() <= byteSize" , "flang/include/flang/Runtime/descriptor.h", 408, __extension__ __PRETTY_FUNCTION__)); |
409 | if (DescriptorAddendum * addendum{descriptor().Addendum()}) { |
410 | assert(hasAddendum)(static_cast <bool> (hasAddendum) ? void (0) : __assert_fail ("hasAddendum", "flang/include/flang/Runtime/descriptor.h", 410 , __extension__ __PRETTY_FUNCTION__)); |
411 | assert(addendum->LenParameters() <= maxLengthTypeParameters)(static_cast <bool> (addendum->LenParameters() <= maxLengthTypeParameters) ? void (0) : __assert_fail ("addendum->LenParameters() <= maxLengthTypeParameters" , "flang/include/flang/Runtime/descriptor.h", 411, __extension__ __PRETTY_FUNCTION__)); |
412 | } else { |
413 | assert(!hasAddendum)(static_cast <bool> (!hasAddendum) ? void (0) : __assert_fail ("!hasAddendum", "flang/include/flang/Runtime/descriptor.h", 413, __extension__ __PRETTY_FUNCTION__)); |
414 | assert(maxLengthTypeParameters == 0)(static_cast <bool> (maxLengthTypeParameters == 0) ? void (0) : __assert_fail ("maxLengthTypeParameters == 0", "flang/include/flang/Runtime/descriptor.h" , 414, __extension__ __PRETTY_FUNCTION__)); |
415 | } |
416 | descriptor().Check(); |
417 | } |
418 | |
419 | private: |
420 | char storage_[byteSize]{}; |
421 | }; |
422 | } // namespace Fortran::runtime |
423 | #endif // FORTRAN_RUNTIME_DESCRIPTOR_H_ |