1 `/* Implementation of the MATMUL intrinsic
2 Copyright 2002, 2005, 2006, 2007, 2009 Free Software Foundation, Inc.
3 Contributed by Paul Brook <paul@nowt.org>
5 This file is part of the GNU Fortran 95 runtime library (libgfortran).
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 3 of the License, or (at your option) any later version.
12 Libgfortran is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
17 Under Section 7 of GPL version 3, you are granted additional
18 permissions described in the GCC Runtime Library Exception, version
19 3.1, as published by the Free Software Foundation.
21 You should have received a copy of the GNU General Public License and
22 a copy of the GCC Runtime Library Exception along with this program;
23 see the files COPYING3 and COPYING.RUNTIME respectively. If not, see
24 <http://www.gnu.org/licenses/>. */
26 #include "libgfortran.h"
33 `#if defined (HAVE_'rtype_name`)
35 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
36 passed to us by the front-end, in which case we''`ll call it for large
39 typedef void (*blas_call)(const char *, const char *, const int *, const int *,
40 const int *, const 'rtype_name` *, const 'rtype_name` *,
41 const int *, const 'rtype_name` *, const int *,
42 const 'rtype_name` *, 'rtype_name` *, const int *,
45 /* The order of loops is different in the case of plain matrix
46 multiplication C=MATMUL(A,B), and in the frequent special case where
47 the argument A is the temporary result of a TRANSPOSE intrinsic:
48 C=MATMUL(TRANSPOSE(A),B). Transposed temporaries are detected by
49 looking at their strides.
51 The equivalent Fortran pseudo-code is:
53 DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
54 IF (.NOT.IS_TRANSPOSED(A)) THEN
59 C(I,J) = C(I,J)+A(I,K)*B(K,J)
70 /* If try_blas is set to a nonzero value, then the matmul function will
71 see if there is a way to perform the matrix multiplication by a call
72 to the BLAS gemm function. */
74 extern void matmul_'rtype_code` ('rtype` * const restrict retarray,
75 'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
76 int blas_limit, blas_call gemm);
77 export_proto(matmul_'rtype_code`);
80 matmul_'rtype_code` ('rtype` * const restrict retarray,
81 'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
82 int blas_limit, blas_call gemm)
84 const 'rtype_name` * restrict abase;
85 const 'rtype_name` * restrict bbase;
86 'rtype_name` * restrict dest;
88 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
89 index_type x, y, n, count, xcount, ycount;
91 assert (GFC_DESCRIPTOR_RANK (a) == 2
92 || GFC_DESCRIPTOR_RANK (b) == 2);
94 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
96 Either A or B (but not both) can be rank 1:
98 o One-dimensional argument A is implicitly treated as a row matrix
99 dimensioned [1,count], so xcount=1.
101 o One-dimensional argument B is implicitly treated as a column matrix
102 dimensioned [count, 1], so ycount=1.
105 if (retarray->data == NULL)
107 if (GFC_DESCRIPTOR_RANK (a) == 1)
109 retarray->dim[0].lbound = 0;
110 retarray->dim[0].ubound = b->dim[1].ubound - b->dim[1].lbound;
111 retarray->dim[0].stride = 1;
113 else if (GFC_DESCRIPTOR_RANK (b) == 1)
115 retarray->dim[0].lbound = 0;
116 retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
117 retarray->dim[0].stride = 1;
121 retarray->dim[0].lbound = 0;
122 retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
123 retarray->dim[0].stride = 1;
125 retarray->dim[1].lbound = 0;
126 retarray->dim[1].ubound = b->dim[1].ubound - b->dim[1].lbound;
127 retarray->dim[1].stride = retarray->dim[0].ubound+1;
131 = internal_malloc_size (sizeof ('rtype_name`) * size0 ((array_t *) retarray));
132 retarray->offset = 0;
134 else if (unlikely (compile_options.bounds_check))
136 index_type ret_extent, arg_extent;
138 if (GFC_DESCRIPTOR_RANK (a) == 1)
140 arg_extent = b->dim[1].ubound + 1 - b->dim[1].lbound;
141 ret_extent = retarray->dim[0].ubound + 1 - retarray->dim[0].lbound;
142 if (arg_extent != ret_extent)
143 runtime_error ("Incorrect extent in return array in"
144 " MATMUL intrinsic: is %ld, should be %ld",
145 (long int) ret_extent, (long int) arg_extent);
147 else if (GFC_DESCRIPTOR_RANK (b) == 1)
149 arg_extent = a->dim[0].ubound + 1 - a->dim[0].lbound;
150 ret_extent = retarray->dim[0].ubound + 1 - retarray->dim[0].lbound;
151 if (arg_extent != ret_extent)
152 runtime_error ("Incorrect extent in return array in"
153 " MATMUL intrinsic: is %ld, should be %ld",
154 (long int) ret_extent, (long int) arg_extent);
158 arg_extent = a->dim[0].ubound + 1 - a->dim[0].lbound;
159 ret_extent = retarray->dim[0].ubound + 1 - retarray->dim[0].lbound;
160 if (arg_extent != ret_extent)
161 runtime_error ("Incorrect extent in return array in"
162 " MATMUL intrinsic for dimension 1:"
163 " is %ld, should be %ld",
164 (long int) ret_extent, (long int) arg_extent);
166 arg_extent = b->dim[1].ubound + 1 - b->dim[1].lbound;
167 ret_extent = retarray->dim[1].ubound + 1 - retarray->dim[1].lbound;
168 if (arg_extent != ret_extent)
169 runtime_error ("Incorrect extent in return array in"
170 " MATMUL intrinsic for dimension 2:"
171 " is %ld, should be %ld",
172 (long int) ret_extent, (long int) arg_extent);
176 sinclude(`matmul_asm_'rtype_code`.m4')dnl
178 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
180 /* One-dimensional result may be addressed in the code below
181 either as a row or a column matrix. We want both cases to
183 rxstride = rystride = retarray->dim[0].stride;
187 rxstride = retarray->dim[0].stride;
188 rystride = retarray->dim[1].stride;
192 if (GFC_DESCRIPTOR_RANK (a) == 1)
194 /* Treat it as a a row matrix A[1,count]. */
195 axstride = a->dim[0].stride;
199 count = a->dim[0].ubound + 1 - a->dim[0].lbound;
203 axstride = a->dim[0].stride;
204 aystride = a->dim[1].stride;
206 count = a->dim[1].ubound + 1 - a->dim[1].lbound;
207 xcount = a->dim[0].ubound + 1 - a->dim[0].lbound;
210 if (count != b->dim[0].ubound + 1 - b->dim[0].lbound)
212 if (count > 0 || b->dim[0].ubound + 1 - b->dim[0].lbound > 0)
213 runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
216 if (GFC_DESCRIPTOR_RANK (b) == 1)
218 /* Treat it as a column matrix B[count,1] */
219 bxstride = b->dim[0].stride;
221 /* bystride should never be used for 1-dimensional b.
222 in case it is we want it to cause a segfault, rather than
223 an incorrect result. */
224 bystride = 0xDEADBEEF;
229 bxstride = b->dim[0].stride;
230 bystride = b->dim[1].stride;
231 ycount = b->dim[1].ubound + 1 - b->dim[1].lbound;
236 dest = retarray->data;
239 /* Now that everything is set up, we''`re performing the multiplication
242 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
244 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
245 && (bxstride == 1 || bystride == 1)
246 && (((float) xcount) * ((float) ycount) * ((float) count)
249 const int m = xcount, n = ycount, k = count, ldc = rystride;
250 const 'rtype_name` one = 1, zero = 0;
251 const int lda = (axstride == 1) ? aystride : axstride,
252 ldb = (bxstride == 1) ? bystride : bxstride;
254 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
256 assert (gemm != NULL);
257 gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m, &n, &k,
258 &one, abase, &lda, bbase, &ldb, &zero, dest, &ldc, 1, 1);
263 if (rxstride == 1 && axstride == 1 && bxstride == 1)
265 const 'rtype_name` * restrict bbase_y;
266 'rtype_name` * restrict dest_y;
267 const 'rtype_name` * restrict abase_n;
268 'rtype_name` bbase_yn;
270 if (rystride == xcount)
271 memset (dest, 0, (sizeof ('rtype_name`) * xcount * ycount));
274 for (y = 0; y < ycount; y++)
275 for (x = 0; x < xcount; x++)
276 dest[x + y*rystride] = ('rtype_name`)0;
279 for (y = 0; y < ycount; y++)
281 bbase_y = bbase + y*bystride;
282 dest_y = dest + y*rystride;
283 for (n = 0; n < count; n++)
285 abase_n = abase + n*aystride;
286 bbase_yn = bbase_y[n];
287 for (x = 0; x < xcount; x++)
289 dest_y[x] += abase_n[x] * bbase_yn;
294 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
296 if (GFC_DESCRIPTOR_RANK (a) != 1)
298 const 'rtype_name` *restrict abase_x;
299 const 'rtype_name` *restrict bbase_y;
300 'rtype_name` *restrict dest_y;
303 for (y = 0; y < ycount; y++)
305 bbase_y = &bbase[y*bystride];
306 dest_y = &dest[y*rystride];
307 for (x = 0; x < xcount; x++)
309 abase_x = &abase[x*axstride];
310 s = ('rtype_name`) 0;
311 for (n = 0; n < count; n++)
312 s += abase_x[n] * bbase_y[n];
319 const 'rtype_name` *restrict bbase_y;
322 for (y = 0; y < ycount; y++)
324 bbase_y = &bbase[y*bystride];
325 s = ('rtype_name`) 0;
326 for (n = 0; n < count; n++)
327 s += abase[n*axstride] * bbase_y[n];
328 dest[y*rystride] = s;
332 else if (axstride < aystride)
334 for (y = 0; y < ycount; y++)
335 for (x = 0; x < xcount; x++)
336 dest[x*rxstride + y*rystride] = ('rtype_name`)0;
338 for (y = 0; y < ycount; y++)
339 for (n = 0; n < count; n++)
340 for (x = 0; x < xcount; x++)
341 /* dest[x,y] += a[x,n] * b[n,y] */
342 dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
344 else if (GFC_DESCRIPTOR_RANK (a) == 1)
346 const 'rtype_name` *restrict bbase_y;
349 for (y = 0; y < ycount; y++)
351 bbase_y = &bbase[y*bystride];
352 s = ('rtype_name`) 0;
353 for (n = 0; n < count; n++)
354 s += abase[n*axstride] * bbase_y[n*bxstride];
355 dest[y*rxstride] = s;
360 const 'rtype_name` *restrict abase_x;
361 const 'rtype_name` *restrict bbase_y;
362 'rtype_name` *restrict dest_y;
365 for (y = 0; y < ycount; y++)
367 bbase_y = &bbase[y*bystride];
368 dest_y = &dest[y*rystride];
369 for (x = 0; x < xcount; x++)
371 abase_x = &abase[x*axstride];
372 s = ('rtype_name`) 0;
373 for (n = 0; n < count; n++)
374 s += abase_x[n*aystride] * bbase_y[n*bxstride];
375 dest_y[x*rxstride] = s;