1 `/* Implementation of the MATMUL intrinsic
2 Copyright 2002, 2005, 2006 Free Software Foundation, Inc.
3 Contributed by Paul Brook <paul@nowt.org>
5 This file is part of the GNU Fortran 95 runtime library (libgfortran).
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 2 of the License, or (at your option) any later version.
12 In addition to the permissions in the GNU General Public License, the
13 Free Software Foundation gives you unlimited permission to link the
14 compiled version of this file into combinations with other programs,
15 and to distribute those combinations without any restriction coming
16 from the use of this file. (The General Public License restrictions
17 do apply in other respects; for example, they cover modification of
18 the file, and distribution when not linked into a combine
21 Libgfortran is distributed in the hope that it will be useful,
22 but WITHOUT ANY WARRANTY; without even the implied warranty of
23 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24 GNU General Public License for more details.
26 You should have received a copy of the GNU General Public
27 License along with libgfortran; see the file COPYING. If not,
28 write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
29 Boston, MA 02110-1301, USA. */
35 #include "libgfortran.h"'
38 `#if defined (HAVE_'rtype_name`)'
40 /* The order of loops is different in the case of plain matrix
41 multiplication C=MATMUL(A,B), and in the frequent special case where
42 the argument A is the temporary result of a TRANSPOSE intrinsic:
43 C=MATMUL(TRANSPOSE(A),B). Transposed temporaries are detected by
44 looking at their strides.
46 The equivalent Fortran pseudo-code is:
48 DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
49 IF (.NOT.IS_TRANSPOSED(A)) THEN
54 C(I,J) = C(I,J)+A(I,K)*B(K,J)
65 extern void matmul_`'rtype_code (rtype * const restrict retarray,
66 rtype * const restrict a, rtype * const restrict b);
67 export_proto(matmul_`'rtype_code);
70 matmul_`'rtype_code (rtype * const restrict retarray,
71 rtype * const restrict a, rtype * const restrict b)
73 const rtype_name * restrict abase;
74 const rtype_name * restrict bbase;
75 rtype_name * restrict dest;
77 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
78 index_type x, y, n, count, xcount, ycount;
80 assert (GFC_DESCRIPTOR_RANK (a) == 2
81 || GFC_DESCRIPTOR_RANK (b) == 2);
83 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
85 Either A or B (but not both) can be rank 1:
87 o One-dimensional argument A is implicitly treated as a row matrix
88 dimensioned [1,count], so xcount=1.
90 o One-dimensional argument B is implicitly treated as a column matrix
91 dimensioned [count, 1], so ycount=1.
94 if (retarray->data == NULL)
96 if (GFC_DESCRIPTOR_RANK (a) == 1)
98 retarray->dim[0].lbound = 0;
99 retarray->dim[0].ubound = b->dim[1].ubound - b->dim[1].lbound;
100 retarray->dim[0].stride = 1;
102 else if (GFC_DESCRIPTOR_RANK (b) == 1)
104 retarray->dim[0].lbound = 0;
105 retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
106 retarray->dim[0].stride = 1;
110 retarray->dim[0].lbound = 0;
111 retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
112 retarray->dim[0].stride = 1;
114 retarray->dim[1].lbound = 0;
115 retarray->dim[1].ubound = b->dim[1].ubound - b->dim[1].lbound;
116 retarray->dim[1].stride = retarray->dim[0].ubound+1;
120 = internal_malloc_size (sizeof (rtype_name) * size0 ((array_t *) retarray));
121 retarray->offset = 0;
124 sinclude(`matmul_asm_'rtype_code`.m4')dnl
126 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
128 /* One-dimensional result may be addressed in the code below
129 either as a row or a column matrix. We want both cases to
131 rxstride = rystride = retarray->dim[0].stride;
135 rxstride = retarray->dim[0].stride;
136 rystride = retarray->dim[1].stride;
140 if (GFC_DESCRIPTOR_RANK (a) == 1)
142 /* Treat it as a a row matrix A[1,count]. */
143 axstride = a->dim[0].stride;
147 count = a->dim[0].ubound + 1 - a->dim[0].lbound;
151 axstride = a->dim[0].stride;
152 aystride = a->dim[1].stride;
154 count = a->dim[1].ubound + 1 - a->dim[1].lbound;
155 xcount = a->dim[0].ubound + 1 - a->dim[0].lbound;
158 assert(count == b->dim[0].ubound + 1 - b->dim[0].lbound);
160 if (GFC_DESCRIPTOR_RANK (b) == 1)
162 /* Treat it as a column matrix B[count,1] */
163 bxstride = b->dim[0].stride;
165 /* bystride should never be used for 1-dimensional b.
166 in case it is we want it to cause a segfault, rather than
167 an incorrect result. */
168 bystride = 0xDEADBEEF;
173 bxstride = b->dim[0].stride;
174 bystride = b->dim[1].stride;
175 ycount = b->dim[1].ubound + 1 - b->dim[1].lbound;
180 dest = retarray->data;
182 if (rxstride == 1 && axstride == 1 && bxstride == 1)
184 const rtype_name * restrict bbase_y;
185 rtype_name * restrict dest_y;
186 const rtype_name * restrict abase_n;
189 if (rystride == xcount)
190 memset (dest, 0, (sizeof (rtype_name) * xcount * ycount));
193 for (y = 0; y < ycount; y++)
194 for (x = 0; x < xcount; x++)
195 dest[x + y*rystride] = (rtype_name)0;
198 for (y = 0; y < ycount; y++)
200 bbase_y = bbase + y*bystride;
201 dest_y = dest + y*rystride;
202 for (n = 0; n < count; n++)
204 abase_n = abase + n*aystride;
205 bbase_yn = bbase_y[n];
206 for (x = 0; x < xcount; x++)
208 dest_y[x] += abase_n[x] * bbase_yn;
213 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
215 if (GFC_DESCRIPTOR_RANK (a) != 1)
217 const rtype_name *restrict abase_x;
218 const rtype_name *restrict bbase_y;
219 rtype_name *restrict dest_y;
222 for (y = 0; y < ycount; y++)
224 bbase_y = &bbase[y*bystride];
225 dest_y = &dest[y*rystride];
226 for (x = 0; x < xcount; x++)
228 abase_x = &abase[x*axstride];
230 for (n = 0; n < count; n++)
231 s += abase_x[n] * bbase_y[n];
238 const rtype_name *restrict bbase_y;
241 for (y = 0; y < ycount; y++)
243 bbase_y = &bbase[y*bystride];
245 for (n = 0; n < count; n++)
246 s += abase[n*axstride] * bbase_y[n];
247 dest[y*rystride] = s;
251 else if (axstride < aystride)
253 for (y = 0; y < ycount; y++)
254 for (x = 0; x < xcount; x++)
255 dest[x*rxstride + y*rystride] = (rtype_name)0;
257 for (y = 0; y < ycount; y++)
258 for (n = 0; n < count; n++)
259 for (x = 0; x < xcount; x++)
260 /* dest[x,y] += a[x,n] * b[n,y] */
261 dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
263 else if (GFC_DESCRIPTOR_RANK (a) == 1)
265 const rtype_name *restrict bbase_y;
268 for (y = 0; y < ycount; y++)
270 bbase_y = &bbase[y*bystride];
272 for (n = 0; n < count; n++)
273 s += abase[n*axstride] * bbase_y[n*bxstride];
274 dest[y*rxstride] = s;
279 const rtype_name *restrict abase_x;
280 const rtype_name *restrict bbase_y;
281 rtype_name *restrict dest_y;
284 for (y = 0; y < ycount; y++)
286 bbase_y = &bbase[y*bystride];
287 dest_y = &dest[y*rystride];
288 for (x = 0; x < xcount; x++)
290 abase_x = &abase[x*axstride];
292 for (n = 0; n < count; n++)
293 s += abase_x[n*aystride] * bbase_y[n*bxstride];
294 dest_y[x*rxstride] = s;