]> git.ipfire.org Git - thirdparty/gcc.git/blame - libgfortran/generated/matmul_i8.c
Daily bump.
[thirdparty/gcc.git] / libgfortran / generated / matmul_i8.c
CommitLineData
6de9cd9a 1/* Implementation of the MATMUL intrinsic
36ae8a61 2 Copyright 2002, 2005, 2006, 2007 Free Software Foundation, Inc.
6de9cd9a
DN
3 Contributed by Paul Brook <paul@nowt.org>
4
883c9d4d 5This file is part of the GNU Fortran 95 runtime library (libgfortran).
6de9cd9a
DN
6
7Libgfortran is free software; you can redistribute it and/or
57dea9f6 8modify it under the terms of the GNU General Public
6de9cd9a 9License as published by the Free Software Foundation; either
57dea9f6
TM
10version 2 of the License, or (at your option) any later version.
11
12In addition to the permissions in the GNU General Public License, the
13Free Software Foundation gives you unlimited permission to link the
14compiled version of this file into combinations with other programs,
15and to distribute those combinations without any restriction coming
16from the use of this file. (The General Public License restrictions
17do apply in other respects; for example, they cover modification of
18the file, and distribution when not linked into a combine
19executable.)
6de9cd9a
DN
20
21Libgfortran is distributed in the hope that it will be useful,
22but WITHOUT ANY WARRANTY; without even the implied warranty of
23MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
57dea9f6 24GNU General Public License for more details.
6de9cd9a 25
57dea9f6
TM
26You should have received a copy of the GNU General Public
27License along with libgfortran; see the file COPYING. If not,
fe2ae685
KC
28write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
29Boston, MA 02110-1301, USA. */
6de9cd9a 30
36ae8a61 31#include "libgfortran.h"
6de9cd9a 32#include <stdlib.h>
410d3bba 33#include <string.h>
6de9cd9a 34#include <assert.h>
36ae8a61 35
6de9cd9a 36
644cb69f
FXC
37#if defined (HAVE_GFC_INTEGER_8)
38
5a0aad31
FXC
39/* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
40 passed to us by the front-end, in which case we'll call it for large
41 matrices. */
42
43typedef void (*blas_call)(const char *, const char *, const int *, const int *,
44 const int *, const GFC_INTEGER_8 *, const GFC_INTEGER_8 *,
45 const int *, const GFC_INTEGER_8 *, const int *,
46 const GFC_INTEGER_8 *, GFC_INTEGER_8 *, const int *,
47 int, int);
48
1524f80b
RS
49/* The order of loops is different in the case of plain matrix
50 multiplication C=MATMUL(A,B), and in the frequent special case where
51 the argument A is the temporary result of a TRANSPOSE intrinsic:
52 C=MATMUL(TRANSPOSE(A),B). Transposed temporaries are detected by
53 looking at their strides.
54
55 The equivalent Fortran pseudo-code is:
410d3bba
VL
56
57 DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
1524f80b
RS
58 IF (.NOT.IS_TRANSPOSED(A)) THEN
59 C = 0
60 DO J=1,N
61 DO K=1,COUNT
62 DO I=1,M
63 C(I,J) = C(I,J)+A(I,K)*B(K,J)
64 ELSE
65 DO J=1,N
410d3bba 66 DO I=1,M
1524f80b
RS
67 S = 0
68 DO K=1,COUNT
5a0aad31 69 S = S+A(I,K)*B(K,J)
1524f80b
RS
70 C(I,J) = S
71 ENDIF
410d3bba
VL
72*/
73
5a0aad31
FXC
74/* If try_blas is set to a nonzero value, then the matmul function will
75 see if there is a way to perform the matrix multiplication by a call
76 to the BLAS gemm function. */
77
85206901 78extern void matmul_i8 (gfc_array_i8 * const restrict retarray,
5a0aad31
FXC
79 gfc_array_i8 * const restrict a, gfc_array_i8 * const restrict b, int try_blas,
80 int blas_limit, blas_call gemm);
7f68c75f 81export_proto(matmul_i8);
7d7b8bfe 82
6de9cd9a 83void
85206901 84matmul_i8 (gfc_array_i8 * const restrict retarray,
5a0aad31
FXC
85 gfc_array_i8 * const restrict a, gfc_array_i8 * const restrict b, int try_blas,
86 int blas_limit, blas_call gemm)
6de9cd9a 87{
85206901
JB
88 const GFC_INTEGER_8 * restrict abase;
89 const GFC_INTEGER_8 * restrict bbase;
90 GFC_INTEGER_8 * restrict dest;
410d3bba
VL
91
92 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
93 index_type x, y, n, count, xcount, ycount;
6de9cd9a
DN
94
95 assert (GFC_DESCRIPTOR_RANK (a) == 2
96 || GFC_DESCRIPTOR_RANK (b) == 2);
883c9d4d 97
410d3bba
VL
98/* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
99
100 Either A or B (but not both) can be rank 1:
101
102 o One-dimensional argument A is implicitly treated as a row matrix
103 dimensioned [1,count], so xcount=1.
104
105 o One-dimensional argument B is implicitly treated as a column matrix
106 dimensioned [count, 1], so ycount=1.
107 */
108
883c9d4d
VL
109 if (retarray->data == NULL)
110 {
111 if (GFC_DESCRIPTOR_RANK (a) == 1)
112 {
113 retarray->dim[0].lbound = 0;
114 retarray->dim[0].ubound = b->dim[1].ubound - b->dim[1].lbound;
115 retarray->dim[0].stride = 1;
116 }
117 else if (GFC_DESCRIPTOR_RANK (b) == 1)
118 {
119 retarray->dim[0].lbound = 0;
120 retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
121 retarray->dim[0].stride = 1;
122 }
123 else
124 {
125 retarray->dim[0].lbound = 0;
126 retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
127 retarray->dim[0].stride = 1;
420aa7b8 128
883c9d4d
VL
129 retarray->dim[1].lbound = 0;
130 retarray->dim[1].ubound = b->dim[1].ubound - b->dim[1].lbound;
131 retarray->dim[1].stride = retarray->dim[0].ubound+1;
132 }
420aa7b8 133
07d3cebe 134 retarray->data
4b6903ec 135 = internal_malloc_size (sizeof (GFC_INTEGER_8) * size0 ((array_t *) retarray));
efd4dc1a 136 retarray->offset = 0;
883c9d4d
VL
137 }
138
6de9cd9a
DN
139
140 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
141 {
410d3bba
VL
142 /* One-dimensional result may be addressed in the code below
143 either as a row or a column matrix. We want both cases to
144 work. */
145 rxstride = rystride = retarray->dim[0].stride;
6de9cd9a
DN
146 }
147 else
148 {
149 rxstride = retarray->dim[0].stride;
150 rystride = retarray->dim[1].stride;
151 }
152
410d3bba 153
6de9cd9a
DN
154 if (GFC_DESCRIPTOR_RANK (a) == 1)
155 {
410d3bba
VL
156 /* Treat it as a a row matrix A[1,count]. */
157 axstride = a->dim[0].stride;
158 aystride = 1;
159
6de9cd9a 160 xcount = 1;
410d3bba 161 count = a->dim[0].ubound + 1 - a->dim[0].lbound;
6de9cd9a
DN
162 }
163 else
164 {
410d3bba
VL
165 axstride = a->dim[0].stride;
166 aystride = a->dim[1].stride;
167
6de9cd9a 168 count = a->dim[1].ubound + 1 - a->dim[1].lbound;
6de9cd9a
DN
169 xcount = a->dim[0].ubound + 1 - a->dim[0].lbound;
170 }
410d3bba 171
18c492a9
TK
172 if (count != b->dim[0].ubound + 1 - b->dim[0].lbound)
173 runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
410d3bba 174
6de9cd9a
DN
175 if (GFC_DESCRIPTOR_RANK (b) == 1)
176 {
410d3bba
VL
177 /* Treat it as a column matrix B[count,1] */
178 bxstride = b->dim[0].stride;
179
180 /* bystride should never be used for 1-dimensional b.
181 in case it is we want it to cause a segfault, rather than
182 an incorrect result. */
420aa7b8 183 bystride = 0xDEADBEEF;
6de9cd9a
DN
184 ycount = 1;
185 }
186 else
187 {
410d3bba
VL
188 bxstride = b->dim[0].stride;
189 bystride = b->dim[1].stride;
6de9cd9a
DN
190 ycount = b->dim[1].ubound + 1 - b->dim[1].lbound;
191 }
192
410d3bba
VL
193 abase = a->data;
194 bbase = b->data;
195 dest = retarray->data;
196
5a0aad31
FXC
197
198 /* Now that everything is set up, we're performing the multiplication
199 itself. */
200
201#define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
202
203 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
204 && (bxstride == 1 || bystride == 1)
205 && (((float) xcount) * ((float) ycount) * ((float) count)
206 > POW3(blas_limit)))
207 {
208 const int m = xcount, n = ycount, k = count, ldc = rystride;
209 const GFC_INTEGER_8 one = 1, zero = 0;
210 const int lda = (axstride == 1) ? aystride : axstride,
211 ldb = (bxstride == 1) ? bystride : bxstride;
212
213 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
214 {
215 assert (gemm != NULL);
216 gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m, &n, &k,
217 &one, abase, &lda, bbase, &ldb, &zero, dest, &ldc, 1, 1);
218 return;
219 }
220 }
221
410d3bba 222 if (rxstride == 1 && axstride == 1 && bxstride == 1)
6de9cd9a 223 {
85206901
JB
224 const GFC_INTEGER_8 * restrict bbase_y;
225 GFC_INTEGER_8 * restrict dest_y;
226 const GFC_INTEGER_8 * restrict abase_n;
410d3bba
VL
227 GFC_INTEGER_8 bbase_yn;
228
1633cb7c
FXC
229 if (rystride == xcount)
230 memset (dest, 0, (sizeof (GFC_INTEGER_8) * xcount * ycount));
ae740cce
TK
231 else
232 {
233 for (y = 0; y < ycount; y++)
234 for (x = 0; x < xcount; x++)
235 dest[x + y*rystride] = (GFC_INTEGER_8)0;
236 }
410d3bba
VL
237
238 for (y = 0; y < ycount; y++)
239 {
240 bbase_y = bbase + y*bystride;
241 dest_y = dest + y*rystride;
242 for (n = 0; n < count; n++)
243 {
244 abase_n = abase + n*aystride;
245 bbase_yn = bbase_y[n];
246 for (x = 0; x < xcount; x++)
247 {
248 dest_y[x] += abase_n[x] * bbase_yn;
249 }
250 }
251 }
252 }
1524f80b
RS
253 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
254 {
a4a11197
PT
255 if (GFC_DESCRIPTOR_RANK (a) != 1)
256 {
257 const GFC_INTEGER_8 *restrict abase_x;
258 const GFC_INTEGER_8 *restrict bbase_y;
259 GFC_INTEGER_8 *restrict dest_y;
260 GFC_INTEGER_8 s;
1524f80b 261
a4a11197
PT
262 for (y = 0; y < ycount; y++)
263 {
264 bbase_y = &bbase[y*bystride];
265 dest_y = &dest[y*rystride];
266 for (x = 0; x < xcount; x++)
267 {
268 abase_x = &abase[x*axstride];
269 s = (GFC_INTEGER_8) 0;
270 for (n = 0; n < count; n++)
271 s += abase_x[n] * bbase_y[n];
272 dest_y[x] = s;
273 }
274 }
275 }
276 else
1524f80b 277 {
a4a11197
PT
278 const GFC_INTEGER_8 *restrict bbase_y;
279 GFC_INTEGER_8 s;
280
281 for (y = 0; y < ycount; y++)
1524f80b 282 {
a4a11197 283 bbase_y = &bbase[y*bystride];
1524f80b
RS
284 s = (GFC_INTEGER_8) 0;
285 for (n = 0; n < count; n++)
a4a11197
PT
286 s += abase[n*axstride] * bbase_y[n];
287 dest[y*rystride] = s;
1524f80b
RS
288 }
289 }
290 }
291 else if (axstride < aystride)
410d3bba
VL
292 {
293 for (y = 0; y < ycount; y++)
294 for (x = 0; x < xcount; x++)
295 dest[x*rxstride + y*rystride] = (GFC_INTEGER_8)0;
296
297 for (y = 0; y < ycount; y++)
298 for (n = 0; n < count; n++)
299 for (x = 0; x < xcount; x++)
300 /* dest[x,y] += a[x,n] * b[n,y] */
301 dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
6de9cd9a 302 }
f0e871d6
PT
303 else if (GFC_DESCRIPTOR_RANK (a) == 1)
304 {
305 const GFC_INTEGER_8 *restrict bbase_y;
306 GFC_INTEGER_8 s;
307
308 for (y = 0; y < ycount; y++)
309 {
310 bbase_y = &bbase[y*bystride];
311 s = (GFC_INTEGER_8) 0;
312 for (n = 0; n < count; n++)
313 s += abase[n*axstride] * bbase_y[n*bxstride];
314 dest[y*rxstride] = s;
315 }
316 }
1524f80b
RS
317 else
318 {
319 const GFC_INTEGER_8 *restrict abase_x;
320 const GFC_INTEGER_8 *restrict bbase_y;
321 GFC_INTEGER_8 *restrict dest_y;
322 GFC_INTEGER_8 s;
323
324 for (y = 0; y < ycount; y++)
325 {
326 bbase_y = &bbase[y*bystride];
327 dest_y = &dest[y*rystride];
328 for (x = 0; x < xcount; x++)
329 {
330 abase_x = &abase[x*axstride];
331 s = (GFC_INTEGER_8) 0;
332 for (n = 0; n < count; n++)
333 s += abase_x[n*aystride] * bbase_y[n*bxstride];
334 dest_y[x*rxstride] = s;
335 }
336 }
337 }
6de9cd9a 338}
644cb69f
FXC
339
340#endif