]> git.ipfire.org Git - thirdparty/gcc.git/blame - libgfortran/generated/matmul_i8.c
re PR fortran/35988 (run-time abort for MATMUL of run-time zero sized array)
[thirdparty/gcc.git] / libgfortran / generated / matmul_i8.c
CommitLineData
6de9cd9a 1/* Implementation of the MATMUL intrinsic
36ae8a61 2 Copyright 2002, 2005, 2006, 2007 Free Software Foundation, Inc.
6de9cd9a
DN
3 Contributed by Paul Brook <paul@nowt.org>
4
883c9d4d 5This file is part of the GNU Fortran 95 runtime library (libgfortran).
6de9cd9a
DN
6
7Libgfortran is free software; you can redistribute it and/or
57dea9f6 8modify it under the terms of the GNU General Public
6de9cd9a 9License as published by the Free Software Foundation; either
57dea9f6
TM
10version 2 of the License, or (at your option) any later version.
11
12In addition to the permissions in the GNU General Public License, the
13Free Software Foundation gives you unlimited permission to link the
14compiled version of this file into combinations with other programs,
15and to distribute those combinations without any restriction coming
16from the use of this file. (The General Public License restrictions
17do apply in other respects; for example, they cover modification of
18the file, and distribution when not linked into a combine
19executable.)
6de9cd9a
DN
20
21Libgfortran is distributed in the hope that it will be useful,
22but WITHOUT ANY WARRANTY; without even the implied warranty of
23MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
57dea9f6 24GNU General Public License for more details.
6de9cd9a 25
57dea9f6
TM
26You should have received a copy of the GNU General Public
27License along with libgfortran; see the file COPYING. If not,
fe2ae685
KC
28write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
29Boston, MA 02110-1301, USA. */
6de9cd9a 30
36ae8a61 31#include "libgfortran.h"
6de9cd9a 32#include <stdlib.h>
410d3bba 33#include <string.h>
6de9cd9a 34#include <assert.h>
36ae8a61 35
6de9cd9a 36
644cb69f
FXC
37#if defined (HAVE_GFC_INTEGER_8)
38
5a0aad31
FXC
39/* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
40 passed to us by the front-end, in which case we'll call it for large
41 matrices. */
42
43typedef void (*blas_call)(const char *, const char *, const int *, const int *,
44 const int *, const GFC_INTEGER_8 *, const GFC_INTEGER_8 *,
45 const int *, const GFC_INTEGER_8 *, const int *,
46 const GFC_INTEGER_8 *, GFC_INTEGER_8 *, const int *,
47 int, int);
48
1524f80b
RS
49/* The order of loops is different in the case of plain matrix
50 multiplication C=MATMUL(A,B), and in the frequent special case where
51 the argument A is the temporary result of a TRANSPOSE intrinsic:
52 C=MATMUL(TRANSPOSE(A),B). Transposed temporaries are detected by
53 looking at their strides.
54
55 The equivalent Fortran pseudo-code is:
410d3bba
VL
56
57 DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
1524f80b
RS
58 IF (.NOT.IS_TRANSPOSED(A)) THEN
59 C = 0
60 DO J=1,N
61 DO K=1,COUNT
62 DO I=1,M
63 C(I,J) = C(I,J)+A(I,K)*B(K,J)
64 ELSE
65 DO J=1,N
410d3bba 66 DO I=1,M
1524f80b
RS
67 S = 0
68 DO K=1,COUNT
5a0aad31 69 S = S+A(I,K)*B(K,J)
1524f80b
RS
70 C(I,J) = S
71 ENDIF
410d3bba
VL
72*/
73
5a0aad31
FXC
74/* If try_blas is set to a nonzero value, then the matmul function will
75 see if there is a way to perform the matrix multiplication by a call
76 to the BLAS gemm function. */
77
85206901 78extern void matmul_i8 (gfc_array_i8 * const restrict retarray,
5a0aad31
FXC
79 gfc_array_i8 * const restrict a, gfc_array_i8 * const restrict b, int try_blas,
80 int blas_limit, blas_call gemm);
7f68c75f 81export_proto(matmul_i8);
7d7b8bfe 82
6de9cd9a 83void
85206901 84matmul_i8 (gfc_array_i8 * const restrict retarray,
5a0aad31
FXC
85 gfc_array_i8 * const restrict a, gfc_array_i8 * const restrict b, int try_blas,
86 int blas_limit, blas_call gemm)
6de9cd9a 87{
85206901
JB
88 const GFC_INTEGER_8 * restrict abase;
89 const GFC_INTEGER_8 * restrict bbase;
90 GFC_INTEGER_8 * restrict dest;
410d3bba
VL
91
92 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
93 index_type x, y, n, count, xcount, ycount;
6de9cd9a
DN
94
95 assert (GFC_DESCRIPTOR_RANK (a) == 2
96 || GFC_DESCRIPTOR_RANK (b) == 2);
883c9d4d 97
410d3bba
VL
98/* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
99
100 Either A or B (but not both) can be rank 1:
101
102 o One-dimensional argument A is implicitly treated as a row matrix
103 dimensioned [1,count], so xcount=1.
104
105 o One-dimensional argument B is implicitly treated as a column matrix
106 dimensioned [count, 1], so ycount=1.
107 */
108
883c9d4d
VL
109 if (retarray->data == NULL)
110 {
111 if (GFC_DESCRIPTOR_RANK (a) == 1)
112 {
113 retarray->dim[0].lbound = 0;
114 retarray->dim[0].ubound = b->dim[1].ubound - b->dim[1].lbound;
115 retarray->dim[0].stride = 1;
116 }
117 else if (GFC_DESCRIPTOR_RANK (b) == 1)
118 {
119 retarray->dim[0].lbound = 0;
120 retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
121 retarray->dim[0].stride = 1;
122 }
123 else
124 {
125 retarray->dim[0].lbound = 0;
126 retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
127 retarray->dim[0].stride = 1;
420aa7b8 128
883c9d4d
VL
129 retarray->dim[1].lbound = 0;
130 retarray->dim[1].ubound = b->dim[1].ubound - b->dim[1].lbound;
131 retarray->dim[1].stride = retarray->dim[0].ubound+1;
132 }
420aa7b8 133
07d3cebe 134 retarray->data
4b6903ec 135 = internal_malloc_size (sizeof (GFC_INTEGER_8) * size0 ((array_t *) retarray));
efd4dc1a 136 retarray->offset = 0;
883c9d4d
VL
137 }
138
6de9cd9a
DN
139
140 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
141 {
410d3bba
VL
142 /* One-dimensional result may be addressed in the code below
143 either as a row or a column matrix. We want both cases to
144 work. */
145 rxstride = rystride = retarray->dim[0].stride;
6de9cd9a
DN
146 }
147 else
148 {
149 rxstride = retarray->dim[0].stride;
150 rystride = retarray->dim[1].stride;
151 }
152
410d3bba 153
6de9cd9a
DN
154 if (GFC_DESCRIPTOR_RANK (a) == 1)
155 {
410d3bba
VL
156 /* Treat it as a a row matrix A[1,count]. */
157 axstride = a->dim[0].stride;
158 aystride = 1;
159
6de9cd9a 160 xcount = 1;
410d3bba 161 count = a->dim[0].ubound + 1 - a->dim[0].lbound;
6de9cd9a
DN
162 }
163 else
164 {
410d3bba
VL
165 axstride = a->dim[0].stride;
166 aystride = a->dim[1].stride;
167
6de9cd9a 168 count = a->dim[1].ubound + 1 - a->dim[1].lbound;
6de9cd9a
DN
169 xcount = a->dim[0].ubound + 1 - a->dim[0].lbound;
170 }
410d3bba 171
18c492a9 172 if (count != b->dim[0].ubound + 1 - b->dim[0].lbound)
7edc89d4
TK
173 {
174 if (count > 0 || b->dim[0].ubound + 1 - b->dim[0].lbound > 0)
175 runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
176 }
410d3bba 177
6de9cd9a
DN
178 if (GFC_DESCRIPTOR_RANK (b) == 1)
179 {
410d3bba
VL
180 /* Treat it as a column matrix B[count,1] */
181 bxstride = b->dim[0].stride;
182
183 /* bystride should never be used for 1-dimensional b.
184 in case it is we want it to cause a segfault, rather than
185 an incorrect result. */
420aa7b8 186 bystride = 0xDEADBEEF;
6de9cd9a
DN
187 ycount = 1;
188 }
189 else
190 {
410d3bba
VL
191 bxstride = b->dim[0].stride;
192 bystride = b->dim[1].stride;
6de9cd9a
DN
193 ycount = b->dim[1].ubound + 1 - b->dim[1].lbound;
194 }
195
410d3bba
VL
196 abase = a->data;
197 bbase = b->data;
198 dest = retarray->data;
199
5a0aad31
FXC
200
201 /* Now that everything is set up, we're performing the multiplication
202 itself. */
203
204#define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
205
206 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
207 && (bxstride == 1 || bystride == 1)
208 && (((float) xcount) * ((float) ycount) * ((float) count)
209 > POW3(blas_limit)))
210 {
211 const int m = xcount, n = ycount, k = count, ldc = rystride;
212 const GFC_INTEGER_8 one = 1, zero = 0;
213 const int lda = (axstride == 1) ? aystride : axstride,
214 ldb = (bxstride == 1) ? bystride : bxstride;
215
216 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
217 {
218 assert (gemm != NULL);
219 gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m, &n, &k,
220 &one, abase, &lda, bbase, &ldb, &zero, dest, &ldc, 1, 1);
221 return;
222 }
223 }
224
410d3bba 225 if (rxstride == 1 && axstride == 1 && bxstride == 1)
6de9cd9a 226 {
85206901
JB
227 const GFC_INTEGER_8 * restrict bbase_y;
228 GFC_INTEGER_8 * restrict dest_y;
229 const GFC_INTEGER_8 * restrict abase_n;
410d3bba
VL
230 GFC_INTEGER_8 bbase_yn;
231
1633cb7c
FXC
232 if (rystride == xcount)
233 memset (dest, 0, (sizeof (GFC_INTEGER_8) * xcount * ycount));
ae740cce
TK
234 else
235 {
236 for (y = 0; y < ycount; y++)
237 for (x = 0; x < xcount; x++)
238 dest[x + y*rystride] = (GFC_INTEGER_8)0;
239 }
410d3bba
VL
240
241 for (y = 0; y < ycount; y++)
242 {
243 bbase_y = bbase + y*bystride;
244 dest_y = dest + y*rystride;
245 for (n = 0; n < count; n++)
246 {
247 abase_n = abase + n*aystride;
248 bbase_yn = bbase_y[n];
249 for (x = 0; x < xcount; x++)
250 {
251 dest_y[x] += abase_n[x] * bbase_yn;
252 }
253 }
254 }
255 }
1524f80b
RS
256 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
257 {
a4a11197
PT
258 if (GFC_DESCRIPTOR_RANK (a) != 1)
259 {
260 const GFC_INTEGER_8 *restrict abase_x;
261 const GFC_INTEGER_8 *restrict bbase_y;
262 GFC_INTEGER_8 *restrict dest_y;
263 GFC_INTEGER_8 s;
1524f80b 264
a4a11197
PT
265 for (y = 0; y < ycount; y++)
266 {
267 bbase_y = &bbase[y*bystride];
268 dest_y = &dest[y*rystride];
269 for (x = 0; x < xcount; x++)
270 {
271 abase_x = &abase[x*axstride];
272 s = (GFC_INTEGER_8) 0;
273 for (n = 0; n < count; n++)
274 s += abase_x[n] * bbase_y[n];
275 dest_y[x] = s;
276 }
277 }
278 }
279 else
1524f80b 280 {
a4a11197
PT
281 const GFC_INTEGER_8 *restrict bbase_y;
282 GFC_INTEGER_8 s;
283
284 for (y = 0; y < ycount; y++)
1524f80b 285 {
a4a11197 286 bbase_y = &bbase[y*bystride];
1524f80b
RS
287 s = (GFC_INTEGER_8) 0;
288 for (n = 0; n < count; n++)
a4a11197
PT
289 s += abase[n*axstride] * bbase_y[n];
290 dest[y*rystride] = s;
1524f80b
RS
291 }
292 }
293 }
294 else if (axstride < aystride)
410d3bba
VL
295 {
296 for (y = 0; y < ycount; y++)
297 for (x = 0; x < xcount; x++)
298 dest[x*rxstride + y*rystride] = (GFC_INTEGER_8)0;
299
300 for (y = 0; y < ycount; y++)
301 for (n = 0; n < count; n++)
302 for (x = 0; x < xcount; x++)
303 /* dest[x,y] += a[x,n] * b[n,y] */
304 dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
6de9cd9a 305 }
f0e871d6
PT
306 else if (GFC_DESCRIPTOR_RANK (a) == 1)
307 {
308 const GFC_INTEGER_8 *restrict bbase_y;
309 GFC_INTEGER_8 s;
310
311 for (y = 0; y < ycount; y++)
312 {
313 bbase_y = &bbase[y*bystride];
314 s = (GFC_INTEGER_8) 0;
315 for (n = 0; n < count; n++)
316 s += abase[n*axstride] * bbase_y[n*bxstride];
317 dest[y*rxstride] = s;
318 }
319 }
1524f80b
RS
320 else
321 {
322 const GFC_INTEGER_8 *restrict abase_x;
323 const GFC_INTEGER_8 *restrict bbase_y;
324 GFC_INTEGER_8 *restrict dest_y;
325 GFC_INTEGER_8 s;
326
327 for (y = 0; y < ycount; y++)
328 {
329 bbase_y = &bbase[y*bystride];
330 dest_y = &dest[y*rystride];
331 for (x = 0; x < xcount; x++)
332 {
333 abase_x = &abase[x*axstride];
334 s = (GFC_INTEGER_8) 0;
335 for (n = 0; n < count; n++)
336 s += abase_x[n*aystride] * bbase_y[n*bxstride];
337 dest_y[x*rxstride] = s;
338 }
339 }
340 }
6de9cd9a 341}
644cb69f
FXC
342
343#endif