]> git.ipfire.org Git - thirdparty/gcc.git/blob - libgfortran/m4/matmul.m4
re PR fortran/36341 (MATMUL: Bounds check missing)
[thirdparty/gcc.git] / libgfortran / m4 / matmul.m4
1 `/* Implementation of the MATMUL intrinsic
2 Copyright 2002, 2005, 2006, 2007 Free Software Foundation, Inc.
3 Contributed by Paul Brook <paul@nowt.org>
4
5 This file is part of the GNU Fortran 95 runtime library (libgfortran).
6
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 2 of the License, or (at your option) any later version.
11
12 In addition to the permissions in the GNU General Public License, the
13 Free Software Foundation gives you unlimited permission to link the
14 compiled version of this file into combinations with other programs,
15 and to distribute those combinations without any restriction coming
16 from the use of this file. (The General Public License restrictions
17 do apply in other respects; for example, they cover modification of
18 the file, and distribution when not linked into a combine
19 executable.)
20
21 Libgfortran is distributed in the hope that it will be useful,
22 but WITHOUT ANY WARRANTY; without even the implied warranty of
23 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24 GNU General Public License for more details.
25
26 You should have received a copy of the GNU General Public
27 License along with libgfortran; see the file COPYING. If not,
28 write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
29 Boston, MA 02110-1301, USA. */
30
31 #include "libgfortran.h"
32 #include <stdlib.h>
33 #include <string.h>
34 #include <assert.h>'
35
36 include(iparm.m4)dnl
37
38 `#if defined (HAVE_'rtype_name`)
39
40 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
41 passed to us by the front-end, in which case we''`ll call it for large
42 matrices. */
43
44 typedef void (*blas_call)(const char *, const char *, const int *, const int *,
45 const int *, const 'rtype_name` *, const 'rtype_name` *,
46 const int *, const 'rtype_name` *, const int *,
47 const 'rtype_name` *, 'rtype_name` *, const int *,
48 int, int);
49
50 /* The order of loops is different in the case of plain matrix
51 multiplication C=MATMUL(A,B), and in the frequent special case where
52 the argument A is the temporary result of a TRANSPOSE intrinsic:
53 C=MATMUL(TRANSPOSE(A),B). Transposed temporaries are detected by
54 looking at their strides.
55
56 The equivalent Fortran pseudo-code is:
57
58 DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
59 IF (.NOT.IS_TRANSPOSED(A)) THEN
60 C = 0
61 DO J=1,N
62 DO K=1,COUNT
63 DO I=1,M
64 C(I,J) = C(I,J)+A(I,K)*B(K,J)
65 ELSE
66 DO J=1,N
67 DO I=1,M
68 S = 0
69 DO K=1,COUNT
70 S = S+A(I,K)*B(K,J)
71 C(I,J) = S
72 ENDIF
73 */
74
75 /* If try_blas is set to a nonzero value, then the matmul function will
76 see if there is a way to perform the matrix multiplication by a call
77 to the BLAS gemm function. */
78
79 extern void matmul_'rtype_code` ('rtype` * const restrict retarray,
80 'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
81 int blas_limit, blas_call gemm);
82 export_proto(matmul_'rtype_code`);
83
84 void
85 matmul_'rtype_code` ('rtype` * const restrict retarray,
86 'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
87 int blas_limit, blas_call gemm)
88 {
89 const 'rtype_name` * restrict abase;
90 const 'rtype_name` * restrict bbase;
91 'rtype_name` * restrict dest;
92
93 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
94 index_type x, y, n, count, xcount, ycount;
95
96 assert (GFC_DESCRIPTOR_RANK (a) == 2
97 || GFC_DESCRIPTOR_RANK (b) == 2);
98
99 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
100
101 Either A or B (but not both) can be rank 1:
102
103 o One-dimensional argument A is implicitly treated as a row matrix
104 dimensioned [1,count], so xcount=1.
105
106 o One-dimensional argument B is implicitly treated as a column matrix
107 dimensioned [count, 1], so ycount=1.
108 */
109
110 if (retarray->data == NULL)
111 {
112 if (GFC_DESCRIPTOR_RANK (a) == 1)
113 {
114 retarray->dim[0].lbound = 0;
115 retarray->dim[0].ubound = b->dim[1].ubound - b->dim[1].lbound;
116 retarray->dim[0].stride = 1;
117 }
118 else if (GFC_DESCRIPTOR_RANK (b) == 1)
119 {
120 retarray->dim[0].lbound = 0;
121 retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
122 retarray->dim[0].stride = 1;
123 }
124 else
125 {
126 retarray->dim[0].lbound = 0;
127 retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
128 retarray->dim[0].stride = 1;
129
130 retarray->dim[1].lbound = 0;
131 retarray->dim[1].ubound = b->dim[1].ubound - b->dim[1].lbound;
132 retarray->dim[1].stride = retarray->dim[0].ubound+1;
133 }
134
135 retarray->data
136 = internal_malloc_size (sizeof ('rtype_name`) * size0 ((array_t *) retarray));
137 retarray->offset = 0;
138 }
139 else if (compile_options.bounds_check)
140 {
141 index_type ret_extent, arg_extent;
142
143 if (GFC_DESCRIPTOR_RANK (a) == 1)
144 {
145 arg_extent = b->dim[1].ubound + 1 - b->dim[1].lbound;
146 ret_extent = retarray->dim[0].ubound + 1 - retarray->dim[0].lbound;
147 if (arg_extent != ret_extent)
148 runtime_error ("Incorrect extent in return array in"
149 " MATMUL intrinsic: is %ld, should be %ld",
150 (long int) ret_extent, (long int) arg_extent);
151 }
152 else if (GFC_DESCRIPTOR_RANK (b) == 1)
153 {
154 arg_extent = a->dim[0].ubound + 1 - a->dim[0].lbound;
155 ret_extent = retarray->dim[0].ubound + 1 - retarray->dim[0].lbound;
156 if (arg_extent != ret_extent)
157 runtime_error ("Incorrect extent in return array in"
158 " MATMUL intrinsic: is %ld, should be %ld",
159 (long int) ret_extent, (long int) arg_extent);
160 }
161 else
162 {
163 arg_extent = a->dim[0].ubound + 1 - a->dim[0].lbound;
164 ret_extent = retarray->dim[0].ubound + 1 - retarray->dim[0].lbound;
165 if (arg_extent != ret_extent)
166 runtime_error ("Incorrect extent in return array in"
167 " MATMUL intrinsic for dimension 1:"
168 " is %ld, should be %ld",
169 (long int) ret_extent, (long int) arg_extent);
170
171 arg_extent = b->dim[1].ubound + 1 - b->dim[1].lbound;
172 ret_extent = retarray->dim[1].ubound + 1 - retarray->dim[1].lbound;
173 if (arg_extent != ret_extent)
174 runtime_error ("Incorrect extent in return array in"
175 " MATMUL intrinsic for dimension 2:"
176 " is %ld, should be %ld",
177 (long int) ret_extent, (long int) arg_extent);
178 }
179 }
180 '
181 sinclude(`matmul_asm_'rtype_code`.m4')dnl
182 `
183 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
184 {
185 /* One-dimensional result may be addressed in the code below
186 either as a row or a column matrix. We want both cases to
187 work. */
188 rxstride = rystride = retarray->dim[0].stride;
189 }
190 else
191 {
192 rxstride = retarray->dim[0].stride;
193 rystride = retarray->dim[1].stride;
194 }
195
196
197 if (GFC_DESCRIPTOR_RANK (a) == 1)
198 {
199 /* Treat it as a a row matrix A[1,count]. */
200 axstride = a->dim[0].stride;
201 aystride = 1;
202
203 xcount = 1;
204 count = a->dim[0].ubound + 1 - a->dim[0].lbound;
205 }
206 else
207 {
208 axstride = a->dim[0].stride;
209 aystride = a->dim[1].stride;
210
211 count = a->dim[1].ubound + 1 - a->dim[1].lbound;
212 xcount = a->dim[0].ubound + 1 - a->dim[0].lbound;
213 }
214
215 if (count != b->dim[0].ubound + 1 - b->dim[0].lbound)
216 {
217 if (count > 0 || b->dim[0].ubound + 1 - b->dim[0].lbound > 0)
218 runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
219 }
220
221 if (GFC_DESCRIPTOR_RANK (b) == 1)
222 {
223 /* Treat it as a column matrix B[count,1] */
224 bxstride = b->dim[0].stride;
225
226 /* bystride should never be used for 1-dimensional b.
227 in case it is we want it to cause a segfault, rather than
228 an incorrect result. */
229 bystride = 0xDEADBEEF;
230 ycount = 1;
231 }
232 else
233 {
234 bxstride = b->dim[0].stride;
235 bystride = b->dim[1].stride;
236 ycount = b->dim[1].ubound + 1 - b->dim[1].lbound;
237 }
238
239 abase = a->data;
240 bbase = b->data;
241 dest = retarray->data;
242
243
244 /* Now that everything is set up, we''`re performing the multiplication
245 itself. */
246
247 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
248
249 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
250 && (bxstride == 1 || bystride == 1)
251 && (((float) xcount) * ((float) ycount) * ((float) count)
252 > POW3(blas_limit)))
253 {
254 const int m = xcount, n = ycount, k = count, ldc = rystride;
255 const 'rtype_name` one = 1, zero = 0;
256 const int lda = (axstride == 1) ? aystride : axstride,
257 ldb = (bxstride == 1) ? bystride : bxstride;
258
259 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
260 {
261 assert (gemm != NULL);
262 gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m, &n, &k,
263 &one, abase, &lda, bbase, &ldb, &zero, dest, &ldc, 1, 1);
264 return;
265 }
266 }
267
268 if (rxstride == 1 && axstride == 1 && bxstride == 1)
269 {
270 const 'rtype_name` * restrict bbase_y;
271 'rtype_name` * restrict dest_y;
272 const 'rtype_name` * restrict abase_n;
273 'rtype_name` bbase_yn;
274
275 if (rystride == xcount)
276 memset (dest, 0, (sizeof ('rtype_name`) * xcount * ycount));
277 else
278 {
279 for (y = 0; y < ycount; y++)
280 for (x = 0; x < xcount; x++)
281 dest[x + y*rystride] = ('rtype_name`)0;
282 }
283
284 for (y = 0; y < ycount; y++)
285 {
286 bbase_y = bbase + y*bystride;
287 dest_y = dest + y*rystride;
288 for (n = 0; n < count; n++)
289 {
290 abase_n = abase + n*aystride;
291 bbase_yn = bbase_y[n];
292 for (x = 0; x < xcount; x++)
293 {
294 dest_y[x] += abase_n[x] * bbase_yn;
295 }
296 }
297 }
298 }
299 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
300 {
301 if (GFC_DESCRIPTOR_RANK (a) != 1)
302 {
303 const 'rtype_name` *restrict abase_x;
304 const 'rtype_name` *restrict bbase_y;
305 'rtype_name` *restrict dest_y;
306 'rtype_name` s;
307
308 for (y = 0; y < ycount; y++)
309 {
310 bbase_y = &bbase[y*bystride];
311 dest_y = &dest[y*rystride];
312 for (x = 0; x < xcount; x++)
313 {
314 abase_x = &abase[x*axstride];
315 s = ('rtype_name`) 0;
316 for (n = 0; n < count; n++)
317 s += abase_x[n] * bbase_y[n];
318 dest_y[x] = s;
319 }
320 }
321 }
322 else
323 {
324 const 'rtype_name` *restrict bbase_y;
325 'rtype_name` s;
326
327 for (y = 0; y < ycount; y++)
328 {
329 bbase_y = &bbase[y*bystride];
330 s = ('rtype_name`) 0;
331 for (n = 0; n < count; n++)
332 s += abase[n*axstride] * bbase_y[n];
333 dest[y*rystride] = s;
334 }
335 }
336 }
337 else if (axstride < aystride)
338 {
339 for (y = 0; y < ycount; y++)
340 for (x = 0; x < xcount; x++)
341 dest[x*rxstride + y*rystride] = ('rtype_name`)0;
342
343 for (y = 0; y < ycount; y++)
344 for (n = 0; n < count; n++)
345 for (x = 0; x < xcount; x++)
346 /* dest[x,y] += a[x,n] * b[n,y] */
347 dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
348 }
349 else if (GFC_DESCRIPTOR_RANK (a) == 1)
350 {
351 const 'rtype_name` *restrict bbase_y;
352 'rtype_name` s;
353
354 for (y = 0; y < ycount; y++)
355 {
356 bbase_y = &bbase[y*bystride];
357 s = ('rtype_name`) 0;
358 for (n = 0; n < count; n++)
359 s += abase[n*axstride] * bbase_y[n*bxstride];
360 dest[y*rxstride] = s;
361 }
362 }
363 else
364 {
365 const 'rtype_name` *restrict abase_x;
366 const 'rtype_name` *restrict bbase_y;
367 'rtype_name` *restrict dest_y;
368 'rtype_name` s;
369
370 for (y = 0; y < ycount; y++)
371 {
372 bbase_y = &bbase[y*bystride];
373 dest_y = &dest[y*rystride];
374 for (x = 0; x < xcount; x++)
375 {
376 abase_x = &abase[x*axstride];
377 s = ('rtype_name`) 0;
378 for (n = 0; n < count; n++)
379 s += abase_x[n*aystride] * bbase_y[n*bxstride];
380 dest_y[x*rxstride] = s;
381 }
382 }
383 }
384 }
385
386 #endif'