]> git.ipfire.org Git - thirdparty/gcc.git/blob - libgfortran/m4/matmul.m4
Licensing changes to GPLv3 resp. GPLv3 with GCC Runtime Exception.
[thirdparty/gcc.git] / libgfortran / m4 / matmul.m4
1 `/* Implementation of the MATMUL intrinsic
2 Copyright 2002, 2005, 2006, 2007, 2009 Free Software Foundation, Inc.
3 Contributed by Paul Brook <paul@nowt.org>
4
5 This file is part of the GNU Fortran 95 runtime library (libgfortran).
6
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 3 of the License, or (at your option) any later version.
11
12 Libgfortran is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 Under Section 7 of GPL version 3, you are granted additional
18 permissions described in the GCC Runtime Library Exception, version
19 3.1, as published by the Free Software Foundation.
20
21 You should have received a copy of the GNU General Public License and
22 a copy of the GCC Runtime Library Exception along with this program;
23 see the files COPYING3 and COPYING.RUNTIME respectively. If not, see
24 <http://www.gnu.org/licenses/>. */
25
26 #include "libgfortran.h"
27 #include <stdlib.h>
28 #include <string.h>
29 #include <assert.h>'
30
31 include(iparm.m4)dnl
32
33 `#if defined (HAVE_'rtype_name`)
34
35 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
36 passed to us by the front-end, in which case we''`ll call it for large
37 matrices. */
38
39 typedef void (*blas_call)(const char *, const char *, const int *, const int *,
40 const int *, const 'rtype_name` *, const 'rtype_name` *,
41 const int *, const 'rtype_name` *, const int *,
42 const 'rtype_name` *, 'rtype_name` *, const int *,
43 int, int);
44
45 /* The order of loops is different in the case of plain matrix
46 multiplication C=MATMUL(A,B), and in the frequent special case where
47 the argument A is the temporary result of a TRANSPOSE intrinsic:
48 C=MATMUL(TRANSPOSE(A),B). Transposed temporaries are detected by
49 looking at their strides.
50
51 The equivalent Fortran pseudo-code is:
52
53 DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
54 IF (.NOT.IS_TRANSPOSED(A)) THEN
55 C = 0
56 DO J=1,N
57 DO K=1,COUNT
58 DO I=1,M
59 C(I,J) = C(I,J)+A(I,K)*B(K,J)
60 ELSE
61 DO J=1,N
62 DO I=1,M
63 S = 0
64 DO K=1,COUNT
65 S = S+A(I,K)*B(K,J)
66 C(I,J) = S
67 ENDIF
68 */
69
70 /* If try_blas is set to a nonzero value, then the matmul function will
71 see if there is a way to perform the matrix multiplication by a call
72 to the BLAS gemm function. */
73
74 extern void matmul_'rtype_code` ('rtype` * const restrict retarray,
75 'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
76 int blas_limit, blas_call gemm);
77 export_proto(matmul_'rtype_code`);
78
79 void
80 matmul_'rtype_code` ('rtype` * const restrict retarray,
81 'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
82 int blas_limit, blas_call gemm)
83 {
84 const 'rtype_name` * restrict abase;
85 const 'rtype_name` * restrict bbase;
86 'rtype_name` * restrict dest;
87
88 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
89 index_type x, y, n, count, xcount, ycount;
90
91 assert (GFC_DESCRIPTOR_RANK (a) == 2
92 || GFC_DESCRIPTOR_RANK (b) == 2);
93
94 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
95
96 Either A or B (but not both) can be rank 1:
97
98 o One-dimensional argument A is implicitly treated as a row matrix
99 dimensioned [1,count], so xcount=1.
100
101 o One-dimensional argument B is implicitly treated as a column matrix
102 dimensioned [count, 1], so ycount=1.
103 */
104
105 if (retarray->data == NULL)
106 {
107 if (GFC_DESCRIPTOR_RANK (a) == 1)
108 {
109 retarray->dim[0].lbound = 0;
110 retarray->dim[0].ubound = b->dim[1].ubound - b->dim[1].lbound;
111 retarray->dim[0].stride = 1;
112 }
113 else if (GFC_DESCRIPTOR_RANK (b) == 1)
114 {
115 retarray->dim[0].lbound = 0;
116 retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
117 retarray->dim[0].stride = 1;
118 }
119 else
120 {
121 retarray->dim[0].lbound = 0;
122 retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
123 retarray->dim[0].stride = 1;
124
125 retarray->dim[1].lbound = 0;
126 retarray->dim[1].ubound = b->dim[1].ubound - b->dim[1].lbound;
127 retarray->dim[1].stride = retarray->dim[0].ubound+1;
128 }
129
130 retarray->data
131 = internal_malloc_size (sizeof ('rtype_name`) * size0 ((array_t *) retarray));
132 retarray->offset = 0;
133 }
134 else if (unlikely (compile_options.bounds_check))
135 {
136 index_type ret_extent, arg_extent;
137
138 if (GFC_DESCRIPTOR_RANK (a) == 1)
139 {
140 arg_extent = b->dim[1].ubound + 1 - b->dim[1].lbound;
141 ret_extent = retarray->dim[0].ubound + 1 - retarray->dim[0].lbound;
142 if (arg_extent != ret_extent)
143 runtime_error ("Incorrect extent in return array in"
144 " MATMUL intrinsic: is %ld, should be %ld",
145 (long int) ret_extent, (long int) arg_extent);
146 }
147 else if (GFC_DESCRIPTOR_RANK (b) == 1)
148 {
149 arg_extent = a->dim[0].ubound + 1 - a->dim[0].lbound;
150 ret_extent = retarray->dim[0].ubound + 1 - retarray->dim[0].lbound;
151 if (arg_extent != ret_extent)
152 runtime_error ("Incorrect extent in return array in"
153 " MATMUL intrinsic: is %ld, should be %ld",
154 (long int) ret_extent, (long int) arg_extent);
155 }
156 else
157 {
158 arg_extent = a->dim[0].ubound + 1 - a->dim[0].lbound;
159 ret_extent = retarray->dim[0].ubound + 1 - retarray->dim[0].lbound;
160 if (arg_extent != ret_extent)
161 runtime_error ("Incorrect extent in return array in"
162 " MATMUL intrinsic for dimension 1:"
163 " is %ld, should be %ld",
164 (long int) ret_extent, (long int) arg_extent);
165
166 arg_extent = b->dim[1].ubound + 1 - b->dim[1].lbound;
167 ret_extent = retarray->dim[1].ubound + 1 - retarray->dim[1].lbound;
168 if (arg_extent != ret_extent)
169 runtime_error ("Incorrect extent in return array in"
170 " MATMUL intrinsic for dimension 2:"
171 " is %ld, should be %ld",
172 (long int) ret_extent, (long int) arg_extent);
173 }
174 }
175 '
176 sinclude(`matmul_asm_'rtype_code`.m4')dnl
177 `
178 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
179 {
180 /* One-dimensional result may be addressed in the code below
181 either as a row or a column matrix. We want both cases to
182 work. */
183 rxstride = rystride = retarray->dim[0].stride;
184 }
185 else
186 {
187 rxstride = retarray->dim[0].stride;
188 rystride = retarray->dim[1].stride;
189 }
190
191
192 if (GFC_DESCRIPTOR_RANK (a) == 1)
193 {
194 /* Treat it as a a row matrix A[1,count]. */
195 axstride = a->dim[0].stride;
196 aystride = 1;
197
198 xcount = 1;
199 count = a->dim[0].ubound + 1 - a->dim[0].lbound;
200 }
201 else
202 {
203 axstride = a->dim[0].stride;
204 aystride = a->dim[1].stride;
205
206 count = a->dim[1].ubound + 1 - a->dim[1].lbound;
207 xcount = a->dim[0].ubound + 1 - a->dim[0].lbound;
208 }
209
210 if (count != b->dim[0].ubound + 1 - b->dim[0].lbound)
211 {
212 if (count > 0 || b->dim[0].ubound + 1 - b->dim[0].lbound > 0)
213 runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
214 }
215
216 if (GFC_DESCRIPTOR_RANK (b) == 1)
217 {
218 /* Treat it as a column matrix B[count,1] */
219 bxstride = b->dim[0].stride;
220
221 /* bystride should never be used for 1-dimensional b.
222 in case it is we want it to cause a segfault, rather than
223 an incorrect result. */
224 bystride = 0xDEADBEEF;
225 ycount = 1;
226 }
227 else
228 {
229 bxstride = b->dim[0].stride;
230 bystride = b->dim[1].stride;
231 ycount = b->dim[1].ubound + 1 - b->dim[1].lbound;
232 }
233
234 abase = a->data;
235 bbase = b->data;
236 dest = retarray->data;
237
238
239 /* Now that everything is set up, we''`re performing the multiplication
240 itself. */
241
242 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
243
244 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
245 && (bxstride == 1 || bystride == 1)
246 && (((float) xcount) * ((float) ycount) * ((float) count)
247 > POW3(blas_limit)))
248 {
249 const int m = xcount, n = ycount, k = count, ldc = rystride;
250 const 'rtype_name` one = 1, zero = 0;
251 const int lda = (axstride == 1) ? aystride : axstride,
252 ldb = (bxstride == 1) ? bystride : bxstride;
253
254 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
255 {
256 assert (gemm != NULL);
257 gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m, &n, &k,
258 &one, abase, &lda, bbase, &ldb, &zero, dest, &ldc, 1, 1);
259 return;
260 }
261 }
262
263 if (rxstride == 1 && axstride == 1 && bxstride == 1)
264 {
265 const 'rtype_name` * restrict bbase_y;
266 'rtype_name` * restrict dest_y;
267 const 'rtype_name` * restrict abase_n;
268 'rtype_name` bbase_yn;
269
270 if (rystride == xcount)
271 memset (dest, 0, (sizeof ('rtype_name`) * xcount * ycount));
272 else
273 {
274 for (y = 0; y < ycount; y++)
275 for (x = 0; x < xcount; x++)
276 dest[x + y*rystride] = ('rtype_name`)0;
277 }
278
279 for (y = 0; y < ycount; y++)
280 {
281 bbase_y = bbase + y*bystride;
282 dest_y = dest + y*rystride;
283 for (n = 0; n < count; n++)
284 {
285 abase_n = abase + n*aystride;
286 bbase_yn = bbase_y[n];
287 for (x = 0; x < xcount; x++)
288 {
289 dest_y[x] += abase_n[x] * bbase_yn;
290 }
291 }
292 }
293 }
294 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
295 {
296 if (GFC_DESCRIPTOR_RANK (a) != 1)
297 {
298 const 'rtype_name` *restrict abase_x;
299 const 'rtype_name` *restrict bbase_y;
300 'rtype_name` *restrict dest_y;
301 'rtype_name` s;
302
303 for (y = 0; y < ycount; y++)
304 {
305 bbase_y = &bbase[y*bystride];
306 dest_y = &dest[y*rystride];
307 for (x = 0; x < xcount; x++)
308 {
309 abase_x = &abase[x*axstride];
310 s = ('rtype_name`) 0;
311 for (n = 0; n < count; n++)
312 s += abase_x[n] * bbase_y[n];
313 dest_y[x] = s;
314 }
315 }
316 }
317 else
318 {
319 const 'rtype_name` *restrict bbase_y;
320 'rtype_name` s;
321
322 for (y = 0; y < ycount; y++)
323 {
324 bbase_y = &bbase[y*bystride];
325 s = ('rtype_name`) 0;
326 for (n = 0; n < count; n++)
327 s += abase[n*axstride] * bbase_y[n];
328 dest[y*rystride] = s;
329 }
330 }
331 }
332 else if (axstride < aystride)
333 {
334 for (y = 0; y < ycount; y++)
335 for (x = 0; x < xcount; x++)
336 dest[x*rxstride + y*rystride] = ('rtype_name`)0;
337
338 for (y = 0; y < ycount; y++)
339 for (n = 0; n < count; n++)
340 for (x = 0; x < xcount; x++)
341 /* dest[x,y] += a[x,n] * b[n,y] */
342 dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
343 }
344 else if (GFC_DESCRIPTOR_RANK (a) == 1)
345 {
346 const 'rtype_name` *restrict bbase_y;
347 'rtype_name` s;
348
349 for (y = 0; y < ycount; y++)
350 {
351 bbase_y = &bbase[y*bystride];
352 s = ('rtype_name`) 0;
353 for (n = 0; n < count; n++)
354 s += abase[n*axstride] * bbase_y[n*bxstride];
355 dest[y*rxstride] = s;
356 }
357 }
358 else
359 {
360 const 'rtype_name` *restrict abase_x;
361 const 'rtype_name` *restrict bbase_y;
362 'rtype_name` *restrict dest_y;
363 'rtype_name` s;
364
365 for (y = 0; y < ycount; y++)
366 {
367 bbase_y = &bbase[y*bystride];
368 dest_y = &dest[y*rystride];
369 for (x = 0; x < xcount; x++)
370 {
371 abase_x = &abase[x*axstride];
372 s = ('rtype_name`) 0;
373 for (n = 0; n < count; n++)
374 s += abase_x[n*aystride] * bbase_y[n*bxstride];
375 dest_y[x*rxstride] = s;
376 }
377 }
378 }
379 }
380
381 #endif'