File: | build/source/flang/runtime/matmul.cpp |
Warning: | line 171, column 7 2nd function call argument is an uninitialized value |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===-- runtime/matmul.cpp ------------------------------------------------===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | ||||
9 | // Implements all forms of MATMUL (Fortran 2018 16.9.124) | |||
10 | // | |||
11 | // There are two main entry points; one establishes a descriptor for the | |||
12 | // result and allocates it, and the other expects a result descriptor that | |||
13 | // points to existing storage. | |||
14 | // | |||
15 | // This implementation must handle all combinations of numeric types and | |||
16 | // kinds (100 - 165 cases depending on the target), plus all combinations | |||
17 | // of logical kinds (16). A single template undergoes many instantiations | |||
18 | // to cover all of the valid possibilities. | |||
19 | // | |||
20 | // Places where BLAS routines could be called are marked as TODO items. | |||
21 | ||||
22 | #include "flang/Runtime/matmul.h" | |||
23 | #include "terminator.h" | |||
24 | #include "tools.h" | |||
25 | #include "flang/Runtime/c-or-cpp.h" | |||
26 | #include "flang/Runtime/cpp-type.h" | |||
27 | #include "flang/Runtime/descriptor.h" | |||
28 | #include <cstring> | |||
29 | ||||
30 | namespace Fortran::runtime { | |||
31 | ||||
32 | // General accumulator for any type and stride; this is not used for | |||
33 | // contiguous numeric cases. | |||
34 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT> | |||
35 | class Accumulator { | |||
36 | public: | |||
37 | using Result = AccumulationType<RCAT, RKIND>; | |||
38 | Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {} | |||
39 | void Accumulate(const SubscriptValue xAt[], const SubscriptValue yAt[]) { | |||
40 | if constexpr (RCAT == TypeCategory::Logical) { | |||
41 | sum_ = sum_ || | |||
42 | (IsLogicalElementTrue(x_, xAt) && IsLogicalElementTrue(y_, yAt)); | |||
43 | } else { | |||
44 | sum_ += static_cast<Result>(*x_.Element<XT>(xAt)) * | |||
45 | static_cast<Result>(*y_.Element<YT>(yAt)); | |||
46 | } | |||
47 | } | |||
48 | Result GetResult() const { return sum_; } | |||
49 | ||||
50 | private: | |||
51 | const Descriptor &x_, &y_; | |||
52 | Result sum_{}; | |||
53 | }; | |||
54 | ||||
55 | // Contiguous numeric matrix*matrix multiplication | |||
56 | // matrix(rows,n) * matrix(n,cols) -> matrix(rows,cols) | |||
57 | // Straightforward algorithm: | |||
58 | // DO 1 I = 1, NROWS | |||
59 | // DO 1 J = 1, NCOLS | |||
60 | // RES(I,J) = 0 | |||
61 | // DO 1 K = 1, N | |||
62 | // 1 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) | |||
63 | // With loop distribution and transposition to avoid the inner sum | |||
64 | // reduction and to avoid non-unit strides: | |||
65 | // DO 1 I = 1, NROWS | |||
66 | // DO 1 J = 1, NCOLS | |||
67 | // 1 RES(I,J) = 0 | |||
68 | // DO 2 K = 1, N | |||
69 | // DO 2 J = 1, NCOLS | |||
70 | // DO 2 I = 1, NROWS | |||
71 | // 2 RES(I,J) = RES(I,J) + X(I,K)*Y(K,J) ! loop-invariant last term | |||
72 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT> | |||
73 | inline void MatrixTimesMatrix(CppTypeFor<RCAT, RKIND> *RESTRICT__restrict product, | |||
74 | SubscriptValue rows, SubscriptValue cols, const XT *RESTRICT__restrict x, | |||
75 | const YT *RESTRICT__restrict y, SubscriptValue n) { | |||
76 | using ResultType = CppTypeFor<RCAT, RKIND>; | |||
77 | std::memset(product, 0, rows * cols * sizeof *product); | |||
78 | const XT *RESTRICT__restrict xp0{x}; | |||
79 | for (SubscriptValue k{0}; k < n; ++k) { | |||
80 | ResultType *RESTRICT__restrict p{product}; | |||
81 | for (SubscriptValue j{0}; j < cols; ++j) { | |||
82 | const XT *RESTRICT__restrict xp{xp0}; | |||
83 | auto yv{static_cast<ResultType>(y[k + j * n])}; | |||
84 | for (SubscriptValue i{0}; i < rows; ++i) { | |||
85 | *p++ += static_cast<ResultType>(*xp++) * yv; | |||
86 | } | |||
87 | } | |||
88 | xp0 += rows; | |||
89 | } | |||
90 | } | |||
91 | ||||
92 | // Contiguous numeric matrix*vector multiplication | |||
93 | // matrix(rows,n) * column vector(n) -> column vector(rows) | |||
94 | // Straightforward algorithm: | |||
95 | // DO 1 J = 1, NROWS | |||
96 | // RES(J) = 0 | |||
97 | // DO 1 K = 1, N | |||
98 | // 1 RES(J) = RES(J) + X(J,K)*Y(K) | |||
99 | // With loop distribution and transposition to avoid the inner | |||
100 | // sum reduction and to avoid non-unit strides: | |||
101 | // DO 1 J = 1, NROWS | |||
102 | // 1 RES(J) = 0 | |||
103 | // DO 2 K = 1, N | |||
104 | // DO 2 J = 1, NROWS | |||
105 | // 2 RES(J) = RES(J) + X(J,K)*Y(K) | |||
106 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT> | |||
107 | inline void MatrixTimesVector(CppTypeFor<RCAT, RKIND> *RESTRICT__restrict product, | |||
108 | SubscriptValue rows, SubscriptValue n, const XT *RESTRICT__restrict x, | |||
109 | const YT *RESTRICT__restrict y) { | |||
110 | using ResultType = CppTypeFor<RCAT, RKIND>; | |||
111 | std::memset(product, 0, rows * sizeof *product); | |||
112 | for (SubscriptValue k{0}; k < n; ++k) { | |||
113 | ResultType *RESTRICT__restrict p{product}; | |||
114 | auto yv{static_cast<ResultType>(*y++)}; | |||
115 | for (SubscriptValue j{0}; j < rows; ++j) { | |||
116 | *p++ += static_cast<ResultType>(*x++) * yv; | |||
117 | } | |||
118 | } | |||
119 | } | |||
120 | ||||
121 | // Contiguous numeric vector*matrix multiplication | |||
122 | // row vector(n) * matrix(n,cols) -> row vector(cols) | |||
123 | // Straightforward algorithm: | |||
124 | // DO 1 J = 1, NCOLS | |||
125 | // RES(J) = 0 | |||
126 | // DO 1 K = 1, N | |||
127 | // 1 RES(J) = RES(J) + X(K)*Y(K,J) | |||
128 | // With loop distribution and transposition to avoid the inner | |||
129 | // sum reduction and one non-unit stride (the other remains): | |||
130 | // DO 1 J = 1, NCOLS | |||
131 | // 1 RES(J) = 0 | |||
132 | // DO 2 K = 1, N | |||
133 | // DO 2 J = 1, NCOLS | |||
134 | // 2 RES(J) = RES(J) + X(K)*Y(K,J) | |||
135 | template <TypeCategory RCAT, int RKIND, typename XT, typename YT> | |||
136 | inline void VectorTimesMatrix(CppTypeFor<RCAT, RKIND> *RESTRICT__restrict product, | |||
137 | SubscriptValue n, SubscriptValue cols, const XT *RESTRICT__restrict x, | |||
138 | const YT *RESTRICT__restrict y) { | |||
139 | using ResultType = CppTypeFor<RCAT, RKIND>; | |||
140 | std::memset(product, 0, cols * sizeof *product); | |||
141 | for (SubscriptValue k{0}; k < n; ++k) { | |||
142 | ResultType *RESTRICT__restrict p{product}; | |||
143 | auto xv{static_cast<ResultType>(*x++)}; | |||
144 | const YT *RESTRICT__restrict yp{&y[k]}; | |||
145 | for (SubscriptValue j{0}; j < cols; ++j) { | |||
146 | *p++ += xv * static_cast<ResultType>(*yp); | |||
147 | yp += n; | |||
148 | } | |||
149 | } | |||
150 | } | |||
151 | ||||
152 | // Implements an instance of MATMUL for given argument types. | |||
153 | template <bool IS_ALLOCATING, TypeCategory RCAT, int RKIND, typename XT, | |||
154 | typename YT> | |||
155 | static inline void DoMatmul( | |||
156 | std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor> &result, | |||
157 | const Descriptor &x, const Descriptor &y, Terminator &terminator) { | |||
158 | int xRank{x.rank()}; | |||
159 | int yRank{y.rank()}; | |||
160 | int resRank{xRank + yRank - 2}; | |||
161 | if (xRank * yRank != 2 * resRank) { | |||
| ||||
162 | terminator.Crash("MATMUL: bad argument ranks (%d * %d)", xRank, yRank); | |||
163 | } | |||
164 | SubscriptValue extent[2]{ | |||
165 | xRank == 2 ? x.GetDimension(0).Extent() : y.GetDimension(1).Extent(), | |||
166 | resRank == 2 ? y.GetDimension(1).Extent() : 0}; | |||
167 | if constexpr (IS_ALLOCATING) { | |||
168 | result.Establish( | |||
169 | RCAT, RKIND, nullptr, resRank, extent, CFI_attribute_allocatable2); | |||
170 | for (int j{0}; j
| |||
171 | result.GetDimension(j).SetBounds(1, extent[j]); | |||
| ||||
172 | } | |||
173 | if (int stat{result.Allocate()}) { | |||
174 | terminator.Crash( | |||
175 | "MATMUL: could not allocate memory for result; STAT=%d", stat); | |||
176 | } | |||
177 | } else { | |||
178 | RUNTIME_CHECK(terminator, resRank == result.rank())if (resRank == result.rank()) ; else (terminator).CheckFailed ("resRank == result.rank()", "flang/runtime/matmul.cpp", 178); | |||
179 | RUNTIME_CHECK(if (result.ElementBytes() == static_cast<std::size_t>(RKIND )) ; else (terminator).CheckFailed("result.ElementBytes() == static_cast<std::size_t>(RKIND)" , "flang/runtime/matmul.cpp", 180) | |||
180 | terminator, result.ElementBytes() == static_cast<std::size_t>(RKIND))if (result.ElementBytes() == static_cast<std::size_t>(RKIND )) ; else (terminator).CheckFailed("result.ElementBytes() == static_cast<std::size_t>(RKIND)" , "flang/runtime/matmul.cpp", 180); | |||
181 | RUNTIME_CHECK(terminator, result.GetDimension(0).Extent() == extent[0])if (result.GetDimension(0).Extent() == extent[0]) ; else (terminator ).CheckFailed("result.GetDimension(0).Extent() == extent[0]", "flang/runtime/matmul.cpp", 181); | |||
182 | RUNTIME_CHECK(terminator,if (resRank == 1 || result.GetDimension(1).Extent() == extent [1]) ; else (terminator).CheckFailed("resRank == 1 || result.GetDimension(1).Extent() == extent[1]" , "flang/runtime/matmul.cpp", 183) | |||
183 | resRank == 1 || result.GetDimension(1).Extent() == extent[1])if (resRank == 1 || result.GetDimension(1).Extent() == extent [1]) ; else (terminator).CheckFailed("resRank == 1 || result.GetDimension(1).Extent() == extent[1]" , "flang/runtime/matmul.cpp", 183); | |||
184 | } | |||
185 | SubscriptValue n{x.GetDimension(xRank - 1).Extent()}; | |||
186 | if (n != y.GetDimension(0).Extent()) { | |||
187 | terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)", | |||
188 | static_cast<std::intmax_t>(x.GetDimension(0).Extent()), | |||
189 | static_cast<std::intmax_t>(n), | |||
190 | static_cast<std::intmax_t>(y.GetDimension(0).Extent()), | |||
191 | static_cast<std::intmax_t>(y.GetDimension(1).Extent())); | |||
192 | } | |||
193 | using WriteResult = | |||
194 | CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT, | |||
195 | RKIND>; | |||
196 | if constexpr (RCAT != TypeCategory::Logical) { | |||
197 | if (x.IsContiguous() && y.IsContiguous() && | |||
198 | (IS_ALLOCATING || result.IsContiguous())) { | |||
199 | // Contiguous numeric matrices | |||
200 | if (resRank == 2) { // M*M -> M | |||
201 | if (std::is_same_v<XT, YT>) { | |||
202 | if constexpr (std::is_same_v<XT, float>) { | |||
203 | // TODO: call BLAS-3 SGEMM | |||
204 | } else if constexpr (std::is_same_v<XT, double>) { | |||
205 | // TODO: call BLAS-3 DGEMM | |||
206 | } else if constexpr (std::is_same_v<XT, std::complex<float>>) { | |||
207 | // TODO: call BLAS-3 CGEMM | |||
208 | } else if constexpr (std::is_same_v<XT, std::complex<double>>) { | |||
209 | // TODO: call BLAS-3 ZGEMM | |||
210 | } | |||
211 | } | |||
212 | MatrixTimesMatrix<RCAT, RKIND, XT, YT>( | |||
213 | result.template OffsetElement<WriteResult>(), extent[0], extent[1], | |||
214 | x.OffsetElement<XT>(), y.OffsetElement<YT>(), n); | |||
215 | return; | |||
216 | } else if (xRank == 2) { // M*V -> V | |||
217 | if (std::is_same_v<XT, YT>) { | |||
218 | if constexpr (std::is_same_v<XT, float>) { | |||
219 | // TODO: call BLAS-2 SGEMV(x,y) | |||
220 | } else if constexpr (std::is_same_v<XT, double>) { | |||
221 | // TODO: call BLAS-2 DGEMV(x,y) | |||
222 | } else if constexpr (std::is_same_v<XT, std::complex<float>>) { | |||
223 | // TODO: call BLAS-2 CGEMV(x,y) | |||
224 | } else if constexpr (std::is_same_v<XT, std::complex<double>>) { | |||
225 | // TODO: call BLAS-2 ZGEMV(x,y) | |||
226 | } | |||
227 | } | |||
228 | MatrixTimesVector<RCAT, RKIND, XT, YT>( | |||
229 | result.template OffsetElement<WriteResult>(), extent[0], n, | |||
230 | x.OffsetElement<XT>(), y.OffsetElement<YT>()); | |||
231 | return; | |||
232 | } else { // V*M -> V | |||
233 | if (std::is_same_v<XT, YT>) { | |||
234 | if constexpr (std::is_same_v<XT, float>) { | |||
235 | // TODO: call BLAS-2 SGEMV(y,x) | |||
236 | } else if constexpr (std::is_same_v<XT, double>) { | |||
237 | // TODO: call BLAS-2 DGEMV(y,x) | |||
238 | } else if constexpr (std::is_same_v<XT, std::complex<float>>) { | |||
239 | // TODO: call BLAS-2 CGEMV(y,x) | |||
240 | } else if constexpr (std::is_same_v<XT, std::complex<double>>) { | |||
241 | // TODO: call BLAS-2 ZGEMV(y,x) | |||
242 | } | |||
243 | } | |||
244 | VectorTimesMatrix<RCAT, RKIND, XT, YT>( | |||
245 | result.template OffsetElement<WriteResult>(), n, extent[0], | |||
246 | x.OffsetElement<XT>(), y.OffsetElement<YT>()); | |||
247 | return; | |||
248 | } | |||
249 | } | |||
250 | } | |||
251 | // General algorithms for LOGICAL and noncontiguity | |||
252 | SubscriptValue xAt[2], yAt[2], resAt[2]; | |||
253 | x.GetLowerBounds(xAt); | |||
254 | y.GetLowerBounds(yAt); | |||
255 | result.GetLowerBounds(resAt); | |||
256 | if (resRank == 2) { // M*M -> M | |||
257 | SubscriptValue x1{xAt[1]}, y0{yAt[0]}, y1{yAt[1]}, res1{resAt[1]}; | |||
258 | for (SubscriptValue i{0}; i < extent[0]; ++i) { | |||
259 | for (SubscriptValue j{0}; j < extent[1]; ++j) { | |||
260 | Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; | |||
261 | yAt[1] = y1 + j; | |||
262 | for (SubscriptValue k{0}; k < n; ++k) { | |||
263 | xAt[1] = x1 + k; | |||
264 | yAt[0] = y0 + k; | |||
265 | accumulator.Accumulate(xAt, yAt); | |||
266 | } | |||
267 | resAt[1] = res1 + j; | |||
268 | *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); | |||
269 | } | |||
270 | ++resAt[0]; | |||
271 | ++xAt[0]; | |||
272 | } | |||
273 | } else if (xRank == 2) { // M*V -> V | |||
274 | SubscriptValue x1{xAt[1]}, y0{yAt[0]}; | |||
275 | for (SubscriptValue j{0}; j < extent[0]; ++j) { | |||
276 | Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; | |||
277 | for (SubscriptValue k{0}; k < n; ++k) { | |||
278 | xAt[1] = x1 + k; | |||
279 | yAt[0] = y0 + k; | |||
280 | accumulator.Accumulate(xAt, yAt); | |||
281 | } | |||
282 | *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); | |||
283 | ++resAt[0]; | |||
284 | ++xAt[0]; | |||
285 | } | |||
286 | } else { // V*M -> V | |||
287 | SubscriptValue x0{xAt[0]}, y0{yAt[0]}; | |||
288 | for (SubscriptValue j{0}; j < extent[0]; ++j) { | |||
289 | Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; | |||
290 | for (SubscriptValue k{0}; k < n; ++k) { | |||
291 | xAt[0] = x0 + k; | |||
292 | yAt[0] = y0 + k; | |||
293 | accumulator.Accumulate(xAt, yAt); | |||
294 | } | |||
295 | *result.template Element<WriteResult>(resAt) = accumulator.GetResult(); | |||
296 | ++resAt[0]; | |||
297 | ++yAt[1]; | |||
298 | } | |||
299 | } | |||
300 | } | |||
301 | ||||
302 | // Maps the dynamic type information from the arguments' descriptors | |||
303 | // to the right instantiation of DoMatmul() for valid combinations of | |||
304 | // types. | |||
305 | template <bool IS_ALLOCATING> struct Matmul { | |||
306 | using ResultDescriptor = | |||
307 | std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>; | |||
308 | template <TypeCategory XCAT, int XKIND> struct MM1 { | |||
309 | template <TypeCategory YCAT, int YKIND> struct MM2 { | |||
310 | void operator()(ResultDescriptor &result, const Descriptor &x, | |||
311 | const Descriptor &y, Terminator &terminator) const { | |||
312 | if constexpr (constexpr auto resultType{ | |||
313 | GetResultType(XCAT, XKIND, YCAT, YKIND)}) { | |||
314 | if constexpr (common::IsNumericTypeCategory(resultType->first) || | |||
315 | resultType->first == TypeCategory::Logical) { | |||
316 | return DoMatmul<IS_ALLOCATING, resultType->first, | |||
317 | resultType->second, CppTypeFor<XCAT, XKIND>, | |||
318 | CppTypeFor<YCAT, YKIND>>(result, x, y, terminator); | |||
319 | } | |||
320 | } | |||
321 | terminator.Crash("MATMUL: bad operand types (%d(%d), %d(%d))", | |||
322 | static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND); | |||
323 | } | |||
324 | }; | |||
325 | void operator()(ResultDescriptor &result, const Descriptor &x, | |||
326 | const Descriptor &y, Terminator &terminator, TypeCategory yCat, | |||
327 | int yKind) const { | |||
328 | ApplyType<MM2, void>(yCat, yKind, terminator, result, x, y, terminator); | |||
329 | } | |||
330 | }; | |||
331 | void operator()(ResultDescriptor &result, const Descriptor &x, | |||
332 | const Descriptor &y, const char *sourceFile, int line) const { | |||
333 | Terminator terminator{sourceFile, line}; | |||
334 | auto xCatKind{x.type().GetCategoryAndKind()}; | |||
335 | auto yCatKind{y.type().GetCategoryAndKind()}; | |||
336 | RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value())if (xCatKind.has_value() && yCatKind.has_value()) ; else (terminator).CheckFailed("xCatKind.has_value() && yCatKind.has_value()" , "flang/runtime/matmul.cpp", 336); | |||
337 | ApplyType<MM1, void>(xCatKind->first, xCatKind->second, terminator, result, | |||
338 | x, y, terminator, yCatKind->first, yCatKind->second); | |||
339 | } | |||
340 | }; | |||
341 | ||||
342 | extern "C" { | |||
343 | void RTNAME(Matmul)_FortranAMatmul(Descriptor &result, const Descriptor &x, | |||
344 | const Descriptor &y, const char *sourceFile, int line) { | |||
345 | Matmul<true>{}(result, x, y, sourceFile, line); | |||
346 | } | |||
347 | void RTNAME(MatmulDirect)_FortranAMatmulDirect(const Descriptor &result, const Descriptor &x, | |||
348 | const Descriptor &y, const char *sourceFile, int line) { | |||
349 | Matmul<false>{}(result, x, y, sourceFile, line); | |||
350 | } | |||
351 | } // extern "C" | |||
352 | } // namespace Fortran::runtime |