]> git.ipfire.org Git - thirdparty/gcc.git/blob - libgfortran/m4/matmul.m4
re PR fortran/28947 (Double MATMUL() uses wrong array elements)
[thirdparty/gcc.git] / libgfortran / m4 / matmul.m4
1 `/* Implementation of the MATMUL intrinsic
2 Copyright 2002, 2005, 2006 Free Software Foundation, Inc.
3 Contributed by Paul Brook <paul@nowt.org>
4
5 This file is part of the GNU Fortran 95 runtime library (libgfortran).
6
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 2 of the License, or (at your option) any later version.
11
12 In addition to the permissions in the GNU General Public License, the
13 Free Software Foundation gives you unlimited permission to link the
14 compiled version of this file into combinations with other programs,
15 and to distribute those combinations without any restriction coming
16 from the use of this file. (The General Public License restrictions
17 do apply in other respects; for example, they cover modification of
18 the file, and distribution when not linked into a combine
19 executable.)
20
21 Libgfortran is distributed in the hope that it will be useful,
22 but WITHOUT ANY WARRANTY; without even the implied warranty of
23 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24 GNU General Public License for more details.
25
26 You should have received a copy of the GNU General Public
27 License along with libgfortran; see the file COPYING. If not,
28 write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
29 Boston, MA 02110-1301, USA. */
30
31 #include "config.h"
32 #include <stdlib.h>
33 #include <string.h>
34 #include <assert.h>
35 #include "libgfortran.h"'
36 include(iparm.m4)dnl
37
38 `#if defined (HAVE_'rtype_name`)'
39
40 /* The order of loops is different in the case of plain matrix
41 multiplication C=MATMUL(A,B), and in the frequent special case where
42 the argument A is the temporary result of a TRANSPOSE intrinsic:
43 C=MATMUL(TRANSPOSE(A),B). Transposed temporaries are detected by
44 looking at their strides.
45
46 The equivalent Fortran pseudo-code is:
47
48 DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
49 IF (.NOT.IS_TRANSPOSED(A)) THEN
50 C = 0
51 DO J=1,N
52 DO K=1,COUNT
53 DO I=1,M
54 C(I,J) = C(I,J)+A(I,K)*B(K,J)
55 ELSE
56 DO J=1,N
57 DO I=1,M
58 S = 0
59 DO K=1,COUNT
60 S = S+A(I,K)+B(K,J)
61 C(I,J) = S
62 ENDIF
63 */
64
65 extern void matmul_`'rtype_code (rtype * const restrict retarray,
66 rtype * const restrict a, rtype * const restrict b);
67 export_proto(matmul_`'rtype_code);
68
69 void
70 matmul_`'rtype_code (rtype * const restrict retarray,
71 rtype * const restrict a, rtype * const restrict b)
72 {
73 const rtype_name * restrict abase;
74 const rtype_name * restrict bbase;
75 rtype_name * restrict dest;
76
77 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
78 index_type x, y, n, count, xcount, ycount;
79
80 assert (GFC_DESCRIPTOR_RANK (a) == 2
81 || GFC_DESCRIPTOR_RANK (b) == 2);
82
83 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
84
85 Either A or B (but not both) can be rank 1:
86
87 o One-dimensional argument A is implicitly treated as a row matrix
88 dimensioned [1,count], so xcount=1.
89
90 o One-dimensional argument B is implicitly treated as a column matrix
91 dimensioned [count, 1], so ycount=1.
92 */
93
94 if (retarray->data == NULL)
95 {
96 if (GFC_DESCRIPTOR_RANK (a) == 1)
97 {
98 retarray->dim[0].lbound = 0;
99 retarray->dim[0].ubound = b->dim[1].ubound - b->dim[1].lbound;
100 retarray->dim[0].stride = 1;
101 }
102 else if (GFC_DESCRIPTOR_RANK (b) == 1)
103 {
104 retarray->dim[0].lbound = 0;
105 retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
106 retarray->dim[0].stride = 1;
107 }
108 else
109 {
110 retarray->dim[0].lbound = 0;
111 retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
112 retarray->dim[0].stride = 1;
113
114 retarray->dim[1].lbound = 0;
115 retarray->dim[1].ubound = b->dim[1].ubound - b->dim[1].lbound;
116 retarray->dim[1].stride = retarray->dim[0].ubound+1;
117 }
118
119 retarray->data
120 = internal_malloc_size (sizeof (rtype_name) * size0 ((array_t *) retarray));
121 retarray->offset = 0;
122 }
123
124 sinclude(`matmul_asm_'rtype_code`.m4')dnl
125
126 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
127 {
128 /* One-dimensional result may be addressed in the code below
129 either as a row or a column matrix. We want both cases to
130 work. */
131 rxstride = rystride = retarray->dim[0].stride;
132 }
133 else
134 {
135 rxstride = retarray->dim[0].stride;
136 rystride = retarray->dim[1].stride;
137 }
138
139
140 if (GFC_DESCRIPTOR_RANK (a) == 1)
141 {
142 /* Treat it as a a row matrix A[1,count]. */
143 axstride = a->dim[0].stride;
144 aystride = 1;
145
146 xcount = 1;
147 count = a->dim[0].ubound + 1 - a->dim[0].lbound;
148 }
149 else
150 {
151 axstride = a->dim[0].stride;
152 aystride = a->dim[1].stride;
153
154 count = a->dim[1].ubound + 1 - a->dim[1].lbound;
155 xcount = a->dim[0].ubound + 1 - a->dim[0].lbound;
156 }
157
158 assert(count == b->dim[0].ubound + 1 - b->dim[0].lbound);
159
160 if (GFC_DESCRIPTOR_RANK (b) == 1)
161 {
162 /* Treat it as a column matrix B[count,1] */
163 bxstride = b->dim[0].stride;
164
165 /* bystride should never be used for 1-dimensional b.
166 in case it is we want it to cause a segfault, rather than
167 an incorrect result. */
168 bystride = 0xDEADBEEF;
169 ycount = 1;
170 }
171 else
172 {
173 bxstride = b->dim[0].stride;
174 bystride = b->dim[1].stride;
175 ycount = b->dim[1].ubound + 1 - b->dim[1].lbound;
176 }
177
178 abase = a->data;
179 bbase = b->data;
180 dest = retarray->data;
181
182 if (rxstride == 1 && axstride == 1 && bxstride == 1)
183 {
184 const rtype_name * restrict bbase_y;
185 rtype_name * restrict dest_y;
186 const rtype_name * restrict abase_n;
187 rtype_name bbase_yn;
188
189 if (rystride == xcount)
190 memset (dest, 0, (sizeof (rtype_name) * xcount * ycount));
191 else
192 {
193 for (y = 0; y < ycount; y++)
194 for (x = 0; x < xcount; x++)
195 dest[x + y*rystride] = (rtype_name)0;
196 }
197
198 for (y = 0; y < ycount; y++)
199 {
200 bbase_y = bbase + y*bystride;
201 dest_y = dest + y*rystride;
202 for (n = 0; n < count; n++)
203 {
204 abase_n = abase + n*aystride;
205 bbase_yn = bbase_y[n];
206 for (x = 0; x < xcount; x++)
207 {
208 dest_y[x] += abase_n[x] * bbase_yn;
209 }
210 }
211 }
212 }
213 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
214 {
215 if (GFC_DESCRIPTOR_RANK (a) != 1)
216 {
217 const rtype_name *restrict abase_x;
218 const rtype_name *restrict bbase_y;
219 rtype_name *restrict dest_y;
220 rtype_name s;
221
222 for (y = 0; y < ycount; y++)
223 {
224 bbase_y = &bbase[y*bystride];
225 dest_y = &dest[y*rystride];
226 for (x = 0; x < xcount; x++)
227 {
228 abase_x = &abase[x*axstride];
229 s = (rtype_name) 0;
230 for (n = 0; n < count; n++)
231 s += abase_x[n] * bbase_y[n];
232 dest_y[x] = s;
233 }
234 }
235 }
236 else
237 {
238 const rtype_name *restrict bbase_y;
239 rtype_name s;
240
241 for (y = 0; y < ycount; y++)
242 {
243 bbase_y = &bbase[y*bystride];
244 s = (rtype_name) 0;
245 for (n = 0; n < count; n++)
246 s += abase[n*axstride] * bbase_y[n];
247 dest[y*rystride] = s;
248 }
249 }
250 }
251 else if (axstride < aystride)
252 {
253 for (y = 0; y < ycount; y++)
254 for (x = 0; x < xcount; x++)
255 dest[x*rxstride + y*rystride] = (rtype_name)0;
256
257 for (y = 0; y < ycount; y++)
258 for (n = 0; n < count; n++)
259 for (x = 0; x < xcount; x++)
260 /* dest[x,y] += a[x,n] * b[n,y] */
261 dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
262 }
263 else if (GFC_DESCRIPTOR_RANK (a) == 1)
264 {
265 const rtype_name *restrict bbase_y;
266 rtype_name s;
267
268 for (y = 0; y < ycount; y++)
269 {
270 bbase_y = &bbase[y*bystride];
271 s = (rtype_name) 0;
272 for (n = 0; n < count; n++)
273 s += abase[n*axstride] * bbase_y[n*bxstride];
274 dest[y*rxstride] = s;
275 }
276 }
277 else
278 {
279 const rtype_name *restrict abase_x;
280 const rtype_name *restrict bbase_y;
281 rtype_name *restrict dest_y;
282 rtype_name s;
283
284 for (y = 0; y < ycount; y++)
285 {
286 bbase_y = &bbase[y*bystride];
287 dest_y = &dest[y*rystride];
288 for (x = 0; x < xcount; x++)
289 {
290 abase_x = &abase[x*axstride];
291 s = (rtype_name) 0;
292 for (n = 0; n < count; n++)
293 s += abase_x[n*aystride] * bbase_y[n*bxstride];
294 dest_y[x*rxstride] = s;
295 }
296 }
297 }
298 }
299
300 #endif