]> git.ipfire.org Git - thirdparty/gcc.git/blame - libgfortran/generated/matmul_c4.c
re PR fortran/26025 (Optionally use BLAS for matmul)
[thirdparty/gcc.git] / libgfortran / generated / matmul_c4.c
CommitLineData
6de9cd9a 1/* Implementation of the MATMUL intrinsic
6ff24d45 2 Copyright 2002, 2005, 2006 Free Software Foundation, Inc.
6de9cd9a
DN
3 Contributed by Paul Brook <paul@nowt.org>
4
883c9d4d 5This file is part of the GNU Fortran 95 runtime library (libgfortran).
6de9cd9a
DN
6
7Libgfortran is free software; you can redistribute it and/or
57dea9f6 8modify it under the terms of the GNU General Public
6de9cd9a 9License as published by the Free Software Foundation; either
57dea9f6
TM
10version 2 of the License, or (at your option) any later version.
11
12In addition to the permissions in the GNU General Public License, the
13Free Software Foundation gives you unlimited permission to link the
14compiled version of this file into combinations with other programs,
15and to distribute those combinations without any restriction coming
16from the use of this file. (The General Public License restrictions
17do apply in other respects; for example, they cover modification of
18the file, and distribution when not linked into a combine
19executable.)
6de9cd9a
DN
20
21Libgfortran is distributed in the hope that it will be useful,
22but WITHOUT ANY WARRANTY; without even the implied warranty of
23MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
57dea9f6 24GNU General Public License for more details.
6de9cd9a 25
57dea9f6
TM
26You should have received a copy of the GNU General Public
27License along with libgfortran; see the file COPYING. If not,
fe2ae685
KC
28write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
29Boston, MA 02110-1301, USA. */
6de9cd9a
DN
30
31#include "config.h"
32#include <stdlib.h>
410d3bba 33#include <string.h>
6de9cd9a
DN
34#include <assert.h>
35#include "libgfortran.h"
36
644cb69f
FXC
37#if defined (HAVE_GFC_COMPLEX_4)
38
5a0aad31
FXC
39/* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
40 passed to us by the front-end, in which case we'll call it for large
41 matrices. */
42
43typedef void (*blas_call)(const char *, const char *, const int *, const int *,
44 const int *, const GFC_COMPLEX_4 *, const GFC_COMPLEX_4 *,
45 const int *, const GFC_COMPLEX_4 *, const int *,
46 const GFC_COMPLEX_4 *, GFC_COMPLEX_4 *, const int *,
47 int, int);
48
1524f80b
RS
49/* The order of loops is different in the case of plain matrix
50 multiplication C=MATMUL(A,B), and in the frequent special case where
51 the argument A is the temporary result of a TRANSPOSE intrinsic:
52 C=MATMUL(TRANSPOSE(A),B). Transposed temporaries are detected by
53 looking at their strides.
54
55 The equivalent Fortran pseudo-code is:
410d3bba
VL
56
57 DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
1524f80b
RS
58 IF (.NOT.IS_TRANSPOSED(A)) THEN
59 C = 0
60 DO J=1,N
61 DO K=1,COUNT
62 DO I=1,M
63 C(I,J) = C(I,J)+A(I,K)*B(K,J)
64 ELSE
65 DO J=1,N
410d3bba 66 DO I=1,M
1524f80b
RS
67 S = 0
68 DO K=1,COUNT
5a0aad31 69 S = S+A(I,K)*B(K,J)
1524f80b
RS
70 C(I,J) = S
71 ENDIF
410d3bba
VL
72*/
73
5a0aad31
FXC
74/* If try_blas is set to a nonzero value, then the matmul function will
75 see if there is a way to perform the matrix multiplication by a call
76 to the BLAS gemm function. */
77
85206901 78extern void matmul_c4 (gfc_array_c4 * const restrict retarray,
5a0aad31
FXC
79 gfc_array_c4 * const restrict a, gfc_array_c4 * const restrict b, int try_blas,
80 int blas_limit, blas_call gemm);
7f68c75f 81export_proto(matmul_c4);
7d7b8bfe 82
6de9cd9a 83void
85206901 84matmul_c4 (gfc_array_c4 * const restrict retarray,
5a0aad31
FXC
85 gfc_array_c4 * const restrict a, gfc_array_c4 * const restrict b, int try_blas,
86 int blas_limit, blas_call gemm)
6de9cd9a 87{
85206901
JB
88 const GFC_COMPLEX_4 * restrict abase;
89 const GFC_COMPLEX_4 * restrict bbase;
90 GFC_COMPLEX_4 * restrict dest;
410d3bba
VL
91
92 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
93 index_type x, y, n, count, xcount, ycount;
6de9cd9a
DN
94
95 assert (GFC_DESCRIPTOR_RANK (a) == 2
96 || GFC_DESCRIPTOR_RANK (b) == 2);
883c9d4d 97
410d3bba
VL
98/* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
99
100 Either A or B (but not both) can be rank 1:
101
102 o One-dimensional argument A is implicitly treated as a row matrix
103 dimensioned [1,count], so xcount=1.
104
105 o One-dimensional argument B is implicitly treated as a column matrix
106 dimensioned [count, 1], so ycount=1.
107 */
108
883c9d4d
VL
109 if (retarray->data == NULL)
110 {
111 if (GFC_DESCRIPTOR_RANK (a) == 1)
112 {
113 retarray->dim[0].lbound = 0;
114 retarray->dim[0].ubound = b->dim[1].ubound - b->dim[1].lbound;
115 retarray->dim[0].stride = 1;
116 }
117 else if (GFC_DESCRIPTOR_RANK (b) == 1)
118 {
119 retarray->dim[0].lbound = 0;
120 retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
121 retarray->dim[0].stride = 1;
122 }
123 else
124 {
125 retarray->dim[0].lbound = 0;
126 retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
127 retarray->dim[0].stride = 1;
420aa7b8 128
883c9d4d
VL
129 retarray->dim[1].lbound = 0;
130 retarray->dim[1].ubound = b->dim[1].ubound - b->dim[1].lbound;
131 retarray->dim[1].stride = retarray->dim[0].ubound+1;
132 }
420aa7b8 133
07d3cebe 134 retarray->data
4b6903ec 135 = internal_malloc_size (sizeof (GFC_COMPLEX_4) * size0 ((array_t *) retarray));
efd4dc1a 136 retarray->offset = 0;
883c9d4d
VL
137 }
138
6de9cd9a
DN
139
140 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
141 {
410d3bba
VL
142 /* One-dimensional result may be addressed in the code below
143 either as a row or a column matrix. We want both cases to
144 work. */
145 rxstride = rystride = retarray->dim[0].stride;
6de9cd9a
DN
146 }
147 else
148 {
149 rxstride = retarray->dim[0].stride;
150 rystride = retarray->dim[1].stride;
151 }
152
410d3bba 153
6de9cd9a
DN
154 if (GFC_DESCRIPTOR_RANK (a) == 1)
155 {
410d3bba
VL
156 /* Treat it as a a row matrix A[1,count]. */
157 axstride = a->dim[0].stride;
158 aystride = 1;
159
6de9cd9a 160 xcount = 1;
410d3bba 161 count = a->dim[0].ubound + 1 - a->dim[0].lbound;
6de9cd9a
DN
162 }
163 else
164 {
410d3bba
VL
165 axstride = a->dim[0].stride;
166 aystride = a->dim[1].stride;
167
6de9cd9a 168 count = a->dim[1].ubound + 1 - a->dim[1].lbound;
6de9cd9a
DN
169 xcount = a->dim[0].ubound + 1 - a->dim[0].lbound;
170 }
410d3bba
VL
171
172 assert(count == b->dim[0].ubound + 1 - b->dim[0].lbound);
173
6de9cd9a
DN
174 if (GFC_DESCRIPTOR_RANK (b) == 1)
175 {
410d3bba
VL
176 /* Treat it as a column matrix B[count,1] */
177 bxstride = b->dim[0].stride;
178
179 /* bystride should never be used for 1-dimensional b.
180 in case it is we want it to cause a segfault, rather than
181 an incorrect result. */
420aa7b8 182 bystride = 0xDEADBEEF;
6de9cd9a
DN
183 ycount = 1;
184 }
185 else
186 {
410d3bba
VL
187 bxstride = b->dim[0].stride;
188 bystride = b->dim[1].stride;
6de9cd9a
DN
189 ycount = b->dim[1].ubound + 1 - b->dim[1].lbound;
190 }
191
410d3bba
VL
192 abase = a->data;
193 bbase = b->data;
194 dest = retarray->data;
195
5a0aad31
FXC
196
197 /* Now that everything is set up, we're performing the multiplication
198 itself. */
199
200#define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
201
202 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
203 && (bxstride == 1 || bystride == 1)
204 && (((float) xcount) * ((float) ycount) * ((float) count)
205 > POW3(blas_limit)))
206 {
207 const int m = xcount, n = ycount, k = count, ldc = rystride;
208 const GFC_COMPLEX_4 one = 1, zero = 0;
209 const int lda = (axstride == 1) ? aystride : axstride,
210 ldb = (bxstride == 1) ? bystride : bxstride;
211
212 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
213 {
214 assert (gemm != NULL);
215 gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m, &n, &k,
216 &one, abase, &lda, bbase, &ldb, &zero, dest, &ldc, 1, 1);
217 return;
218 }
219 }
220
410d3bba 221 if (rxstride == 1 && axstride == 1 && bxstride == 1)
6de9cd9a 222 {
85206901
JB
223 const GFC_COMPLEX_4 * restrict bbase_y;
224 GFC_COMPLEX_4 * restrict dest_y;
225 const GFC_COMPLEX_4 * restrict abase_n;
410d3bba
VL
226 GFC_COMPLEX_4 bbase_yn;
227
1633cb7c
FXC
228 if (rystride == xcount)
229 memset (dest, 0, (sizeof (GFC_COMPLEX_4) * xcount * ycount));
ae740cce
TK
230 else
231 {
232 for (y = 0; y < ycount; y++)
233 for (x = 0; x < xcount; x++)
234 dest[x + y*rystride] = (GFC_COMPLEX_4)0;
235 }
410d3bba
VL
236
237 for (y = 0; y < ycount; y++)
238 {
239 bbase_y = bbase + y*bystride;
240 dest_y = dest + y*rystride;
241 for (n = 0; n < count; n++)
242 {
243 abase_n = abase + n*aystride;
244 bbase_yn = bbase_y[n];
245 for (x = 0; x < xcount; x++)
246 {
247 dest_y[x] += abase_n[x] * bbase_yn;
248 }
249 }
250 }
251 }
1524f80b
RS
252 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
253 {
a4a11197
PT
254 if (GFC_DESCRIPTOR_RANK (a) != 1)
255 {
256 const GFC_COMPLEX_4 *restrict abase_x;
257 const GFC_COMPLEX_4 *restrict bbase_y;
258 GFC_COMPLEX_4 *restrict dest_y;
259 GFC_COMPLEX_4 s;
1524f80b 260
a4a11197
PT
261 for (y = 0; y < ycount; y++)
262 {
263 bbase_y = &bbase[y*bystride];
264 dest_y = &dest[y*rystride];
265 for (x = 0; x < xcount; x++)
266 {
267 abase_x = &abase[x*axstride];
268 s = (GFC_COMPLEX_4) 0;
269 for (n = 0; n < count; n++)
270 s += abase_x[n] * bbase_y[n];
271 dest_y[x] = s;
272 }
273 }
274 }
275 else
1524f80b 276 {
a4a11197
PT
277 const GFC_COMPLEX_4 *restrict bbase_y;
278 GFC_COMPLEX_4 s;
279
280 for (y = 0; y < ycount; y++)
1524f80b 281 {
a4a11197 282 bbase_y = &bbase[y*bystride];
1524f80b
RS
283 s = (GFC_COMPLEX_4) 0;
284 for (n = 0; n < count; n++)
a4a11197
PT
285 s += abase[n*axstride] * bbase_y[n];
286 dest[y*rystride] = s;
1524f80b
RS
287 }
288 }
289 }
290 else if (axstride < aystride)
410d3bba
VL
291 {
292 for (y = 0; y < ycount; y++)
293 for (x = 0; x < xcount; x++)
294 dest[x*rxstride + y*rystride] = (GFC_COMPLEX_4)0;
295
296 for (y = 0; y < ycount; y++)
297 for (n = 0; n < count; n++)
298 for (x = 0; x < xcount; x++)
299 /* dest[x,y] += a[x,n] * b[n,y] */
300 dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
6de9cd9a 301 }
f0e871d6
PT
302 else if (GFC_DESCRIPTOR_RANK (a) == 1)
303 {
304 const GFC_COMPLEX_4 *restrict bbase_y;
305 GFC_COMPLEX_4 s;
306
307 for (y = 0; y < ycount; y++)
308 {
309 bbase_y = &bbase[y*bystride];
310 s = (GFC_COMPLEX_4) 0;
311 for (n = 0; n < count; n++)
312 s += abase[n*axstride] * bbase_y[n*bxstride];
313 dest[y*rxstride] = s;
314 }
315 }
1524f80b
RS
316 else
317 {
318 const GFC_COMPLEX_4 *restrict abase_x;
319 const GFC_COMPLEX_4 *restrict bbase_y;
320 GFC_COMPLEX_4 *restrict dest_y;
321 GFC_COMPLEX_4 s;
322
323 for (y = 0; y < ycount; y++)
324 {
325 bbase_y = &bbase[y*bystride];
326 dest_y = &dest[y*rystride];
327 for (x = 0; x < xcount; x++)
328 {
329 abase_x = &abase[x*axstride];
330 s = (GFC_COMPLEX_4) 0;
331 for (n = 0; n < count; n++)
332 s += abase_x[n*aystride] * bbase_y[n*bxstride];
333 dest_y[x*rxstride] = s;
334 }
335 }
336 }
6de9cd9a 337}
644cb69f
FXC
338
339#endif