]> git.ipfire.org Git - thirdparty/gcc.git/blob - libgfortran/m4/matmul.m4
Update copyright years in libgfortran/
[thirdparty/gcc.git] / libgfortran / m4 / matmul.m4
1 `/* Implementation of the MATMUL intrinsic
2 Copyright (C) 2002-2014 Free Software Foundation, Inc.
3 Contributed by Paul Brook <paul@nowt.org>
4
5 This file is part of the GNU Fortran runtime library (libgfortran).
6
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 3 of the License, or (at your option) any later version.
11
12 Libgfortran is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 Under Section 7 of GPL version 3, you are granted additional
18 permissions described in the GCC Runtime Library Exception, version
19 3.1, as published by the Free Software Foundation.
20
21 You should have received a copy of the GNU General Public License and
22 a copy of the GCC Runtime Library Exception along with this program;
23 see the files COPYING3 and COPYING.RUNTIME respectively. If not, see
24 <http://www.gnu.org/licenses/>. */
25
26 #include "libgfortran.h"
27 #include <stdlib.h>
28 #include <string.h>
29 #include <assert.h>'
30
31 include(iparm.m4)dnl
32
33 `#if defined (HAVE_'rtype_name`)
34
35 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
36 passed to us by the front-end, in which case we''`ll call it for large
37 matrices. */
38
39 typedef void (*blas_call)(const char *, const char *, const int *, const int *,
40 const int *, const 'rtype_name` *, const 'rtype_name` *,
41 const int *, const 'rtype_name` *, const int *,
42 const 'rtype_name` *, 'rtype_name` *, const int *,
43 int, int);
44
45 /* The order of loops is different in the case of plain matrix
46 multiplication C=MATMUL(A,B), and in the frequent special case where
47 the argument A is the temporary result of a TRANSPOSE intrinsic:
48 C=MATMUL(TRANSPOSE(A),B). Transposed temporaries are detected by
49 looking at their strides.
50
51 The equivalent Fortran pseudo-code is:
52
53 DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
54 IF (.NOT.IS_TRANSPOSED(A)) THEN
55 C = 0
56 DO J=1,N
57 DO K=1,COUNT
58 DO I=1,M
59 C(I,J) = C(I,J)+A(I,K)*B(K,J)
60 ELSE
61 DO J=1,N
62 DO I=1,M
63 S = 0
64 DO K=1,COUNT
65 S = S+A(I,K)*B(K,J)
66 C(I,J) = S
67 ENDIF
68 */
69
70 /* If try_blas is set to a nonzero value, then the matmul function will
71 see if there is a way to perform the matrix multiplication by a call
72 to the BLAS gemm function. */
73
74 extern void matmul_'rtype_code` ('rtype` * const restrict retarray,
75 'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
76 int blas_limit, blas_call gemm);
77 export_proto(matmul_'rtype_code`);
78
79 void
80 matmul_'rtype_code` ('rtype` * const restrict retarray,
81 'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
82 int blas_limit, blas_call gemm)
83 {
84 const 'rtype_name` * restrict abase;
85 const 'rtype_name` * restrict bbase;
86 'rtype_name` * restrict dest;
87
88 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
89 index_type x, y, n, count, xcount, ycount;
90
91 assert (GFC_DESCRIPTOR_RANK (a) == 2
92 || GFC_DESCRIPTOR_RANK (b) == 2);
93
94 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
95
96 Either A or B (but not both) can be rank 1:
97
98 o One-dimensional argument A is implicitly treated as a row matrix
99 dimensioned [1,count], so xcount=1.
100
101 o One-dimensional argument B is implicitly treated as a column matrix
102 dimensioned [count, 1], so ycount=1.
103 */
104
105 if (retarray->base_addr == NULL)
106 {
107 if (GFC_DESCRIPTOR_RANK (a) == 1)
108 {
109 GFC_DIMENSION_SET(retarray->dim[0], 0,
110 GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
111 }
112 else if (GFC_DESCRIPTOR_RANK (b) == 1)
113 {
114 GFC_DIMENSION_SET(retarray->dim[0], 0,
115 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
116 }
117 else
118 {
119 GFC_DIMENSION_SET(retarray->dim[0], 0,
120 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
121
122 GFC_DIMENSION_SET(retarray->dim[1], 0,
123 GFC_DESCRIPTOR_EXTENT(b,1) - 1,
124 GFC_DESCRIPTOR_EXTENT(retarray,0));
125 }
126
127 retarray->base_addr
128 = xmalloc (sizeof ('rtype_name`) * size0 ((array_t *) retarray));
129 retarray->offset = 0;
130 }
131 else if (unlikely (compile_options.bounds_check))
132 {
133 index_type ret_extent, arg_extent;
134
135 if (GFC_DESCRIPTOR_RANK (a) == 1)
136 {
137 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
138 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
139 if (arg_extent != ret_extent)
140 runtime_error ("Incorrect extent in return array in"
141 " MATMUL intrinsic: is %ld, should be %ld",
142 (long int) ret_extent, (long int) arg_extent);
143 }
144 else if (GFC_DESCRIPTOR_RANK (b) == 1)
145 {
146 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
147 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
148 if (arg_extent != ret_extent)
149 runtime_error ("Incorrect extent in return array in"
150 " MATMUL intrinsic: is %ld, should be %ld",
151 (long int) ret_extent, (long int) arg_extent);
152 }
153 else
154 {
155 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
156 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
157 if (arg_extent != ret_extent)
158 runtime_error ("Incorrect extent in return array in"
159 " MATMUL intrinsic for dimension 1:"
160 " is %ld, should be %ld",
161 (long int) ret_extent, (long int) arg_extent);
162
163 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
164 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
165 if (arg_extent != ret_extent)
166 runtime_error ("Incorrect extent in return array in"
167 " MATMUL intrinsic for dimension 2:"
168 " is %ld, should be %ld",
169 (long int) ret_extent, (long int) arg_extent);
170 }
171 }
172 '
173 sinclude(`matmul_asm_'rtype_code`.m4')dnl
174 `
175 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
176 {
177 /* One-dimensional result may be addressed in the code below
178 either as a row or a column matrix. We want both cases to
179 work. */
180 rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
181 }
182 else
183 {
184 rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
185 rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
186 }
187
188
189 if (GFC_DESCRIPTOR_RANK (a) == 1)
190 {
191 /* Treat it as a a row matrix A[1,count]. */
192 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
193 aystride = 1;
194
195 xcount = 1;
196 count = GFC_DESCRIPTOR_EXTENT(a,0);
197 }
198 else
199 {
200 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
201 aystride = GFC_DESCRIPTOR_STRIDE(a,1);
202
203 count = GFC_DESCRIPTOR_EXTENT(a,1);
204 xcount = GFC_DESCRIPTOR_EXTENT(a,0);
205 }
206
207 if (count != GFC_DESCRIPTOR_EXTENT(b,0))
208 {
209 if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
210 runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
211 }
212
213 if (GFC_DESCRIPTOR_RANK (b) == 1)
214 {
215 /* Treat it as a column matrix B[count,1] */
216 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
217
218 /* bystride should never be used for 1-dimensional b.
219 in case it is we want it to cause a segfault, rather than
220 an incorrect result. */
221 bystride = 0xDEADBEEF;
222 ycount = 1;
223 }
224 else
225 {
226 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
227 bystride = GFC_DESCRIPTOR_STRIDE(b,1);
228 ycount = GFC_DESCRIPTOR_EXTENT(b,1);
229 }
230
231 abase = a->base_addr;
232 bbase = b->base_addr;
233 dest = retarray->base_addr;
234
235
236 /* Now that everything is set up, we''`re performing the multiplication
237 itself. */
238
239 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
240
241 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
242 && (bxstride == 1 || bystride == 1)
243 && (((float) xcount) * ((float) ycount) * ((float) count)
244 > POW3(blas_limit)))
245 {
246 const int m = xcount, n = ycount, k = count, ldc = rystride;
247 const 'rtype_name` one = 1, zero = 0;
248 const int lda = (axstride == 1) ? aystride : axstride,
249 ldb = (bxstride == 1) ? bystride : bxstride;
250
251 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
252 {
253 assert (gemm != NULL);
254 gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m, &n, &k,
255 &one, abase, &lda, bbase, &ldb, &zero, dest, &ldc, 1, 1);
256 return;
257 }
258 }
259
260 if (rxstride == 1 && axstride == 1 && bxstride == 1)
261 {
262 const 'rtype_name` * restrict bbase_y;
263 'rtype_name` * restrict dest_y;
264 const 'rtype_name` * restrict abase_n;
265 'rtype_name` bbase_yn;
266
267 if (rystride == xcount)
268 memset (dest, 0, (sizeof ('rtype_name`) * xcount * ycount));
269 else
270 {
271 for (y = 0; y < ycount; y++)
272 for (x = 0; x < xcount; x++)
273 dest[x + y*rystride] = ('rtype_name`)0;
274 }
275
276 for (y = 0; y < ycount; y++)
277 {
278 bbase_y = bbase + y*bystride;
279 dest_y = dest + y*rystride;
280 for (n = 0; n < count; n++)
281 {
282 abase_n = abase + n*aystride;
283 bbase_yn = bbase_y[n];
284 for (x = 0; x < xcount; x++)
285 {
286 dest_y[x] += abase_n[x] * bbase_yn;
287 }
288 }
289 }
290 }
291 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
292 {
293 if (GFC_DESCRIPTOR_RANK (a) != 1)
294 {
295 const 'rtype_name` *restrict abase_x;
296 const 'rtype_name` *restrict bbase_y;
297 'rtype_name` *restrict dest_y;
298 'rtype_name` s;
299
300 for (y = 0; y < ycount; y++)
301 {
302 bbase_y = &bbase[y*bystride];
303 dest_y = &dest[y*rystride];
304 for (x = 0; x < xcount; x++)
305 {
306 abase_x = &abase[x*axstride];
307 s = ('rtype_name`) 0;
308 for (n = 0; n < count; n++)
309 s += abase_x[n] * bbase_y[n];
310 dest_y[x] = s;
311 }
312 }
313 }
314 else
315 {
316 const 'rtype_name` *restrict bbase_y;
317 'rtype_name` s;
318
319 for (y = 0; y < ycount; y++)
320 {
321 bbase_y = &bbase[y*bystride];
322 s = ('rtype_name`) 0;
323 for (n = 0; n < count; n++)
324 s += abase[n*axstride] * bbase_y[n];
325 dest[y*rystride] = s;
326 }
327 }
328 }
329 else if (axstride < aystride)
330 {
331 for (y = 0; y < ycount; y++)
332 for (x = 0; x < xcount; x++)
333 dest[x*rxstride + y*rystride] = ('rtype_name`)0;
334
335 for (y = 0; y < ycount; y++)
336 for (n = 0; n < count; n++)
337 for (x = 0; x < xcount; x++)
338 /* dest[x,y] += a[x,n] * b[n,y] */
339 dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
340 }
341 else if (GFC_DESCRIPTOR_RANK (a) == 1)
342 {
343 const 'rtype_name` *restrict bbase_y;
344 'rtype_name` s;
345
346 for (y = 0; y < ycount; y++)
347 {
348 bbase_y = &bbase[y*bystride];
349 s = ('rtype_name`) 0;
350 for (n = 0; n < count; n++)
351 s += abase[n*axstride] * bbase_y[n*bxstride];
352 dest[y*rxstride] = s;
353 }
354 }
355 else
356 {
357 const 'rtype_name` *restrict abase_x;
358 const 'rtype_name` *restrict bbase_y;
359 'rtype_name` *restrict dest_y;
360 'rtype_name` s;
361
362 for (y = 0; y < ycount; y++)
363 {
364 bbase_y = &bbase[y*bystride];
365 dest_y = &dest[y*rystride];
366 for (x = 0; x < xcount; x++)
367 {
368 abase_x = &abase[x*axstride];
369 s = ('rtype_name`) 0;
370 for (n = 0; n < count; n++)
371 s += abase_x[n*aystride] * bbase_y[n*bxstride];
372 dest_y[x*rxstride] = s;
373 }
374 }
375 }
376 }
377
378 #endif'