]>
Commit | Line | Data |
---|---|---|
644cb69f | 1 | /* Implementation of the MATMUL intrinsic |
f0bcf628 | 2 | Copyright (C) 2002-2014 Free Software Foundation, Inc. |
644cb69f FXC |
3 | Contributed by Paul Brook <paul@nowt.org> |
4 | ||
21d1335b | 5 | This file is part of the GNU Fortran runtime library (libgfortran). |
644cb69f FXC |
6 | |
7 | Libgfortran is free software; you can redistribute it and/or | |
8 | modify it under the terms of the GNU General Public | |
9 | License as published by the Free Software Foundation; either | |
748086b7 | 10 | version 3 of the License, or (at your option) any later version. |
644cb69f FXC |
11 | |
12 | Libgfortran is distributed in the hope that it will be useful, | |
13 | but WITHOUT ANY WARRANTY; without even the implied warranty of | |
14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
15 | GNU General Public License for more details. | |
16 | ||
748086b7 JJ |
17 | Under Section 7 of GPL version 3, you are granted additional |
18 | permissions described in the GCC Runtime Library Exception, version | |
19 | 3.1, as published by the Free Software Foundation. | |
20 | ||
21 | You should have received a copy of the GNU General Public License and | |
22 | a copy of the GCC Runtime Library Exception along with this program; | |
23 | see the files COPYING3 and COPYING.RUNTIME respectively. If not, see | |
24 | <http://www.gnu.org/licenses/>. */ | |
644cb69f | 25 | |
36ae8a61 | 26 | #include "libgfortran.h" |
644cb69f FXC |
27 | #include <stdlib.h> |
28 | #include <string.h> | |
29 | #include <assert.h> | |
36ae8a61 | 30 | |
644cb69f FXC |
31 | |
32 | #if defined (HAVE_GFC_COMPLEX_16) | |
33 | ||
5a0aad31 FXC |
34 | /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be |
35 | passed to us by the front-end, in which case we'll call it for large | |
36 | matrices. */ | |
37 | ||
38 | typedef void (*blas_call)(const char *, const char *, const int *, const int *, | |
39 | const int *, const GFC_COMPLEX_16 *, const GFC_COMPLEX_16 *, | |
40 | const int *, const GFC_COMPLEX_16 *, const int *, | |
41 | const GFC_COMPLEX_16 *, GFC_COMPLEX_16 *, const int *, | |
42 | int, int); | |
43 | ||
1524f80b RS |
44 | /* The order of loops is different in the case of plain matrix |
45 | multiplication C=MATMUL(A,B), and in the frequent special case where | |
46 | the argument A is the temporary result of a TRANSPOSE intrinsic: | |
47 | C=MATMUL(TRANSPOSE(A),B). Transposed temporaries are detected by | |
48 | looking at their strides. | |
49 | ||
50 | The equivalent Fortran pseudo-code is: | |
644cb69f FXC |
51 | |
52 | DIMENSION A(M,COUNT), B(COUNT,N), C(M,N) | |
1524f80b RS |
53 | IF (.NOT.IS_TRANSPOSED(A)) THEN |
54 | C = 0 | |
55 | DO J=1,N | |
56 | DO K=1,COUNT | |
57 | DO I=1,M | |
58 | C(I,J) = C(I,J)+A(I,K)*B(K,J) | |
59 | ELSE | |
60 | DO J=1,N | |
644cb69f | 61 | DO I=1,M |
1524f80b RS |
62 | S = 0 |
63 | DO K=1,COUNT | |
5a0aad31 | 64 | S = S+A(I,K)*B(K,J) |
1524f80b RS |
65 | C(I,J) = S |
66 | ENDIF | |
644cb69f FXC |
67 | */ |
68 | ||
5a0aad31 FXC |
69 | /* If try_blas is set to a nonzero value, then the matmul function will |
70 | see if there is a way to perform the matrix multiplication by a call | |
71 | to the BLAS gemm function. */ | |
72 | ||
85206901 | 73 | extern void matmul_c16 (gfc_array_c16 * const restrict retarray, |
5a0aad31 FXC |
74 | gfc_array_c16 * const restrict a, gfc_array_c16 * const restrict b, int try_blas, |
75 | int blas_limit, blas_call gemm); | |
644cb69f FXC |
76 | export_proto(matmul_c16); |
77 | ||
78 | void | |
85206901 | 79 | matmul_c16 (gfc_array_c16 * const restrict retarray, |
5a0aad31 FXC |
80 | gfc_array_c16 * const restrict a, gfc_array_c16 * const restrict b, int try_blas, |
81 | int blas_limit, blas_call gemm) | |
644cb69f | 82 | { |
85206901 JB |
83 | const GFC_COMPLEX_16 * restrict abase; |
84 | const GFC_COMPLEX_16 * restrict bbase; | |
85 | GFC_COMPLEX_16 * restrict dest; | |
644cb69f FXC |
86 | |
87 | index_type rxstride, rystride, axstride, aystride, bxstride, bystride; | |
88 | index_type x, y, n, count, xcount, ycount; | |
89 | ||
90 | assert (GFC_DESCRIPTOR_RANK (a) == 2 | |
91 | || GFC_DESCRIPTOR_RANK (b) == 2); | |
92 | ||
93 | /* C[xcount,ycount] = A[xcount, count] * B[count,ycount] | |
94 | ||
95 | Either A or B (but not both) can be rank 1: | |
96 | ||
97 | o One-dimensional argument A is implicitly treated as a row matrix | |
98 | dimensioned [1,count], so xcount=1. | |
99 | ||
100 | o One-dimensional argument B is implicitly treated as a column matrix | |
101 | dimensioned [count, 1], so ycount=1. | |
102 | */ | |
103 | ||
21d1335b | 104 | if (retarray->base_addr == NULL) |
644cb69f FXC |
105 | { |
106 | if (GFC_DESCRIPTOR_RANK (a) == 1) | |
107 | { | |
dfb55fdc TK |
108 | GFC_DIMENSION_SET(retarray->dim[0], 0, |
109 | GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1); | |
644cb69f FXC |
110 | } |
111 | else if (GFC_DESCRIPTOR_RANK (b) == 1) | |
112 | { | |
dfb55fdc TK |
113 | GFC_DIMENSION_SET(retarray->dim[0], 0, |
114 | GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1); | |
644cb69f FXC |
115 | } |
116 | else | |
117 | { | |
dfb55fdc TK |
118 | GFC_DIMENSION_SET(retarray->dim[0], 0, |
119 | GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1); | |
644cb69f | 120 | |
dfb55fdc TK |
121 | GFC_DIMENSION_SET(retarray->dim[1], 0, |
122 | GFC_DESCRIPTOR_EXTENT(b,1) - 1, | |
123 | GFC_DESCRIPTOR_EXTENT(retarray,0)); | |
644cb69f FXC |
124 | } |
125 | ||
21d1335b | 126 | retarray->base_addr |
92e6f3a4 | 127 | = xmallocarray (size0 ((array_t *) retarray), sizeof (GFC_COMPLEX_16)); |
644cb69f FXC |
128 | retarray->offset = 0; |
129 | } | |
9731c4a3 | 130 | else if (unlikely (compile_options.bounds_check)) |
9ad13e91 TK |
131 | { |
132 | index_type ret_extent, arg_extent; | |
133 | ||
134 | if (GFC_DESCRIPTOR_RANK (a) == 1) | |
135 | { | |
dfb55fdc TK |
136 | arg_extent = GFC_DESCRIPTOR_EXTENT(b,1); |
137 | ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0); | |
9ad13e91 TK |
138 | if (arg_extent != ret_extent) |
139 | runtime_error ("Incorrect extent in return array in" | |
140 | " MATMUL intrinsic: is %ld, should be %ld", | |
141 | (long int) ret_extent, (long int) arg_extent); | |
142 | } | |
143 | else if (GFC_DESCRIPTOR_RANK (b) == 1) | |
144 | { | |
dfb55fdc TK |
145 | arg_extent = GFC_DESCRIPTOR_EXTENT(a,0); |
146 | ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0); | |
9ad13e91 TK |
147 | if (arg_extent != ret_extent) |
148 | runtime_error ("Incorrect extent in return array in" | |
149 | " MATMUL intrinsic: is %ld, should be %ld", | |
150 | (long int) ret_extent, (long int) arg_extent); | |
151 | } | |
152 | else | |
153 | { | |
dfb55fdc TK |
154 | arg_extent = GFC_DESCRIPTOR_EXTENT(a,0); |
155 | ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0); | |
9ad13e91 TK |
156 | if (arg_extent != ret_extent) |
157 | runtime_error ("Incorrect extent in return array in" | |
158 | " MATMUL intrinsic for dimension 1:" | |
159 | " is %ld, should be %ld", | |
160 | (long int) ret_extent, (long int) arg_extent); | |
161 | ||
dfb55fdc TK |
162 | arg_extent = GFC_DESCRIPTOR_EXTENT(b,1); |
163 | ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1); | |
9ad13e91 TK |
164 | if (arg_extent != ret_extent) |
165 | runtime_error ("Incorrect extent in return array in" | |
166 | " MATMUL intrinsic for dimension 2:" | |
167 | " is %ld, should be %ld", | |
168 | (long int) ret_extent, (long int) arg_extent); | |
169 | } | |
170 | } | |
644cb69f | 171 | |
644cb69f FXC |
172 | |
173 | if (GFC_DESCRIPTOR_RANK (retarray) == 1) | |
174 | { | |
175 | /* One-dimensional result may be addressed in the code below | |
176 | either as a row or a column matrix. We want both cases to | |
177 | work. */ | |
dfb55fdc | 178 | rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0); |
644cb69f FXC |
179 | } |
180 | else | |
181 | { | |
dfb55fdc TK |
182 | rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0); |
183 | rystride = GFC_DESCRIPTOR_STRIDE(retarray,1); | |
644cb69f FXC |
184 | } |
185 | ||
186 | ||
187 | if (GFC_DESCRIPTOR_RANK (a) == 1) | |
188 | { | |
189 | /* Treat it as a a row matrix A[1,count]. */ | |
dfb55fdc | 190 | axstride = GFC_DESCRIPTOR_STRIDE(a,0); |
644cb69f FXC |
191 | aystride = 1; |
192 | ||
193 | xcount = 1; | |
dfb55fdc | 194 | count = GFC_DESCRIPTOR_EXTENT(a,0); |
644cb69f FXC |
195 | } |
196 | else | |
197 | { | |
dfb55fdc TK |
198 | axstride = GFC_DESCRIPTOR_STRIDE(a,0); |
199 | aystride = GFC_DESCRIPTOR_STRIDE(a,1); | |
644cb69f | 200 | |
dfb55fdc TK |
201 | count = GFC_DESCRIPTOR_EXTENT(a,1); |
202 | xcount = GFC_DESCRIPTOR_EXTENT(a,0); | |
644cb69f FXC |
203 | } |
204 | ||
dfb55fdc | 205 | if (count != GFC_DESCRIPTOR_EXTENT(b,0)) |
7edc89d4 | 206 | { |
dfb55fdc | 207 | if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0) |
7edc89d4 TK |
208 | runtime_error ("dimension of array B incorrect in MATMUL intrinsic"); |
209 | } | |
644cb69f FXC |
210 | |
211 | if (GFC_DESCRIPTOR_RANK (b) == 1) | |
212 | { | |
213 | /* Treat it as a column matrix B[count,1] */ | |
dfb55fdc | 214 | bxstride = GFC_DESCRIPTOR_STRIDE(b,0); |
644cb69f FXC |
215 | |
216 | /* bystride should never be used for 1-dimensional b. | |
217 | in case it is we want it to cause a segfault, rather than | |
218 | an incorrect result. */ | |
219 | bystride = 0xDEADBEEF; | |
220 | ycount = 1; | |
221 | } | |
222 | else | |
223 | { | |
dfb55fdc TK |
224 | bxstride = GFC_DESCRIPTOR_STRIDE(b,0); |
225 | bystride = GFC_DESCRIPTOR_STRIDE(b,1); | |
226 | ycount = GFC_DESCRIPTOR_EXTENT(b,1); | |
644cb69f FXC |
227 | } |
228 | ||
21d1335b TB |
229 | abase = a->base_addr; |
230 | bbase = b->base_addr; | |
231 | dest = retarray->base_addr; | |
644cb69f | 232 | |
5a0aad31 FXC |
233 | |
234 | /* Now that everything is set up, we're performing the multiplication | |
235 | itself. */ | |
236 | ||
237 | #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x))) | |
238 | ||
239 | if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1) | |
240 | && (bxstride == 1 || bystride == 1) | |
241 | && (((float) xcount) * ((float) ycount) * ((float) count) | |
242 | > POW3(blas_limit))) | |
243 | { | |
244 | const int m = xcount, n = ycount, k = count, ldc = rystride; | |
245 | const GFC_COMPLEX_16 one = 1, zero = 0; | |
246 | const int lda = (axstride == 1) ? aystride : axstride, | |
247 | ldb = (bxstride == 1) ? bystride : bxstride; | |
248 | ||
249 | if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1) | |
250 | { | |
251 | assert (gemm != NULL); | |
252 | gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m, &n, &k, | |
253 | &one, abase, &lda, bbase, &ldb, &zero, dest, &ldc, 1, 1); | |
254 | return; | |
255 | } | |
256 | } | |
257 | ||
644cb69f FXC |
258 | if (rxstride == 1 && axstride == 1 && bxstride == 1) |
259 | { | |
85206901 JB |
260 | const GFC_COMPLEX_16 * restrict bbase_y; |
261 | GFC_COMPLEX_16 * restrict dest_y; | |
262 | const GFC_COMPLEX_16 * restrict abase_n; | |
644cb69f FXC |
263 | GFC_COMPLEX_16 bbase_yn; |
264 | ||
1633cb7c FXC |
265 | if (rystride == xcount) |
266 | memset (dest, 0, (sizeof (GFC_COMPLEX_16) * xcount * ycount)); | |
644cb69f FXC |
267 | else |
268 | { | |
269 | for (y = 0; y < ycount; y++) | |
270 | for (x = 0; x < xcount; x++) | |
271 | dest[x + y*rystride] = (GFC_COMPLEX_16)0; | |
272 | } | |
273 | ||
274 | for (y = 0; y < ycount; y++) | |
275 | { | |
276 | bbase_y = bbase + y*bystride; | |
277 | dest_y = dest + y*rystride; | |
278 | for (n = 0; n < count; n++) | |
279 | { | |
280 | abase_n = abase + n*aystride; | |
281 | bbase_yn = bbase_y[n]; | |
282 | for (x = 0; x < xcount; x++) | |
283 | { | |
284 | dest_y[x] += abase_n[x] * bbase_yn; | |
285 | } | |
286 | } | |
287 | } | |
288 | } | |
1524f80b RS |
289 | else if (rxstride == 1 && aystride == 1 && bxstride == 1) |
290 | { | |
a4a11197 PT |
291 | if (GFC_DESCRIPTOR_RANK (a) != 1) |
292 | { | |
293 | const GFC_COMPLEX_16 *restrict abase_x; | |
294 | const GFC_COMPLEX_16 *restrict bbase_y; | |
295 | GFC_COMPLEX_16 *restrict dest_y; | |
296 | GFC_COMPLEX_16 s; | |
1524f80b | 297 | |
a4a11197 PT |
298 | for (y = 0; y < ycount; y++) |
299 | { | |
300 | bbase_y = &bbase[y*bystride]; | |
301 | dest_y = &dest[y*rystride]; | |
302 | for (x = 0; x < xcount; x++) | |
303 | { | |
304 | abase_x = &abase[x*axstride]; | |
305 | s = (GFC_COMPLEX_16) 0; | |
306 | for (n = 0; n < count; n++) | |
307 | s += abase_x[n] * bbase_y[n]; | |
308 | dest_y[x] = s; | |
309 | } | |
310 | } | |
311 | } | |
312 | else | |
1524f80b | 313 | { |
a4a11197 PT |
314 | const GFC_COMPLEX_16 *restrict bbase_y; |
315 | GFC_COMPLEX_16 s; | |
316 | ||
317 | for (y = 0; y < ycount; y++) | |
1524f80b | 318 | { |
a4a11197 | 319 | bbase_y = &bbase[y*bystride]; |
1524f80b RS |
320 | s = (GFC_COMPLEX_16) 0; |
321 | for (n = 0; n < count; n++) | |
a4a11197 PT |
322 | s += abase[n*axstride] * bbase_y[n]; |
323 | dest[y*rystride] = s; | |
1524f80b RS |
324 | } |
325 | } | |
326 | } | |
327 | else if (axstride < aystride) | |
644cb69f FXC |
328 | { |
329 | for (y = 0; y < ycount; y++) | |
330 | for (x = 0; x < xcount; x++) | |
331 | dest[x*rxstride + y*rystride] = (GFC_COMPLEX_16)0; | |
332 | ||
333 | for (y = 0; y < ycount; y++) | |
334 | for (n = 0; n < count; n++) | |
335 | for (x = 0; x < xcount; x++) | |
336 | /* dest[x,y] += a[x,n] * b[n,y] */ | |
337 | dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride]; | |
338 | } | |
f0e871d6 PT |
339 | else if (GFC_DESCRIPTOR_RANK (a) == 1) |
340 | { | |
341 | const GFC_COMPLEX_16 *restrict bbase_y; | |
342 | GFC_COMPLEX_16 s; | |
343 | ||
344 | for (y = 0; y < ycount; y++) | |
345 | { | |
346 | bbase_y = &bbase[y*bystride]; | |
347 | s = (GFC_COMPLEX_16) 0; | |
348 | for (n = 0; n < count; n++) | |
349 | s += abase[n*axstride] * bbase_y[n*bxstride]; | |
350 | dest[y*rxstride] = s; | |
351 | } | |
352 | } | |
1524f80b RS |
353 | else |
354 | { | |
355 | const GFC_COMPLEX_16 *restrict abase_x; | |
356 | const GFC_COMPLEX_16 *restrict bbase_y; | |
357 | GFC_COMPLEX_16 *restrict dest_y; | |
358 | GFC_COMPLEX_16 s; | |
359 | ||
360 | for (y = 0; y < ycount; y++) | |
361 | { | |
362 | bbase_y = &bbase[y*bystride]; | |
363 | dest_y = &dest[y*rystride]; | |
364 | for (x = 0; x < xcount; x++) | |
365 | { | |
366 | abase_x = &abase[x*axstride]; | |
367 | s = (GFC_COMPLEX_16) 0; | |
368 | for (n = 0; n < count; n++) | |
369 | s += abase_x[n*aystride] * bbase_y[n*bxstride]; | |
370 | dest_y[x*rxstride] = s; | |
371 | } | |
372 | } | |
373 | } | |
644cb69f FXC |
374 | } |
375 | ||
376 | #endif |