]> git.ipfire.org Git - thirdparty/gcc.git/blob - libgfortran/generated/matmul_c8.c
in_pack.m4: Add TODO comment about detecting temporaries...
[thirdparty/gcc.git] / libgfortran / generated / matmul_c8.c
1 /* Implementation of the MATMUL intrinsic
2 Copyright 2002, 2005, 2006 Free Software Foundation, Inc.
3 Contributed by Paul Brook <paul@nowt.org>
4
5 This file is part of the GNU Fortran 95 runtime library (libgfortran).
6
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 2 of the License, or (at your option) any later version.
11
12 In addition to the permissions in the GNU General Public License, the
13 Free Software Foundation gives you unlimited permission to link the
14 compiled version of this file into combinations with other programs,
15 and to distribute those combinations without any restriction coming
16 from the use of this file. (The General Public License restrictions
17 do apply in other respects; for example, they cover modification of
18 the file, and distribution when not linked into a combine
19 executable.)
20
21 Libgfortran is distributed in the hope that it will be useful,
22 but WITHOUT ANY WARRANTY; without even the implied warranty of
23 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24 GNU General Public License for more details.
25
26 You should have received a copy of the GNU General Public
27 License along with libgfortran; see the file COPYING. If not,
28 write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
29 Boston, MA 02110-1301, USA. */
30
31 #include "config.h"
32 #include <stdlib.h>
33 #include <string.h>
34 #include <assert.h>
35 #include "libgfortran.h"
36
37 #if defined (HAVE_GFC_COMPLEX_8)
38
39 /* The order of loops is different in the case of plain matrix
40 multiplication C=MATMUL(A,B), and in the frequent special case where
41 the argument A is the temporary result of a TRANSPOSE intrinsic:
42 C=MATMUL(TRANSPOSE(A),B). Transposed temporaries are detected by
43 looking at their strides.
44
45 The equivalent Fortran pseudo-code is:
46
47 DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
48 IF (.NOT.IS_TRANSPOSED(A)) THEN
49 C = 0
50 DO J=1,N
51 DO K=1,COUNT
52 DO I=1,M
53 C(I,J) = C(I,J)+A(I,K)*B(K,J)
54 ELSE
55 DO J=1,N
56 DO I=1,M
57 S = 0
58 DO K=1,COUNT
59 S = S+A(I,K)+B(K,J)
60 C(I,J) = S
61 ENDIF
62 */
63
64 extern void matmul_c8 (gfc_array_c8 * const restrict retarray,
65 gfc_array_c8 * const restrict a, gfc_array_c8 * const restrict b);
66 export_proto(matmul_c8);
67
68 void
69 matmul_c8 (gfc_array_c8 * const restrict retarray,
70 gfc_array_c8 * const restrict a, gfc_array_c8 * const restrict b)
71 {
72 const GFC_COMPLEX_8 * restrict abase;
73 const GFC_COMPLEX_8 * restrict bbase;
74 GFC_COMPLEX_8 * restrict dest;
75
76 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
77 index_type x, y, n, count, xcount, ycount;
78
79 assert (GFC_DESCRIPTOR_RANK (a) == 2
80 || GFC_DESCRIPTOR_RANK (b) == 2);
81
82 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
83
84 Either A or B (but not both) can be rank 1:
85
86 o One-dimensional argument A is implicitly treated as a row matrix
87 dimensioned [1,count], so xcount=1.
88
89 o One-dimensional argument B is implicitly treated as a column matrix
90 dimensioned [count, 1], so ycount=1.
91 */
92
93 if (retarray->data == NULL)
94 {
95 if (GFC_DESCRIPTOR_RANK (a) == 1)
96 {
97 retarray->dim[0].lbound = 0;
98 retarray->dim[0].ubound = b->dim[1].ubound - b->dim[1].lbound;
99 retarray->dim[0].stride = 1;
100 }
101 else if (GFC_DESCRIPTOR_RANK (b) == 1)
102 {
103 retarray->dim[0].lbound = 0;
104 retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
105 retarray->dim[0].stride = 1;
106 }
107 else
108 {
109 retarray->dim[0].lbound = 0;
110 retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
111 retarray->dim[0].stride = 1;
112
113 retarray->dim[1].lbound = 0;
114 retarray->dim[1].ubound = b->dim[1].ubound - b->dim[1].lbound;
115 retarray->dim[1].stride = retarray->dim[0].ubound+1;
116 }
117
118 retarray->data
119 = internal_malloc_size (sizeof (GFC_COMPLEX_8) * size0 ((array_t *) retarray));
120 retarray->offset = 0;
121 }
122
123
124 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
125 {
126 /* One-dimensional result may be addressed in the code below
127 either as a row or a column matrix. We want both cases to
128 work. */
129 rxstride = rystride = retarray->dim[0].stride;
130 }
131 else
132 {
133 rxstride = retarray->dim[0].stride;
134 rystride = retarray->dim[1].stride;
135 }
136
137
138 if (GFC_DESCRIPTOR_RANK (a) == 1)
139 {
140 /* Treat it as a a row matrix A[1,count]. */
141 axstride = a->dim[0].stride;
142 aystride = 1;
143
144 xcount = 1;
145 count = a->dim[0].ubound + 1 - a->dim[0].lbound;
146 }
147 else
148 {
149 axstride = a->dim[0].stride;
150 aystride = a->dim[1].stride;
151
152 count = a->dim[1].ubound + 1 - a->dim[1].lbound;
153 xcount = a->dim[0].ubound + 1 - a->dim[0].lbound;
154 }
155
156 assert(count == b->dim[0].ubound + 1 - b->dim[0].lbound);
157
158 if (GFC_DESCRIPTOR_RANK (b) == 1)
159 {
160 /* Treat it as a column matrix B[count,1] */
161 bxstride = b->dim[0].stride;
162
163 /* bystride should never be used for 1-dimensional b.
164 in case it is we want it to cause a segfault, rather than
165 an incorrect result. */
166 bystride = 0xDEADBEEF;
167 ycount = 1;
168 }
169 else
170 {
171 bxstride = b->dim[0].stride;
172 bystride = b->dim[1].stride;
173 ycount = b->dim[1].ubound + 1 - b->dim[1].lbound;
174 }
175
176 abase = a->data;
177 bbase = b->data;
178 dest = retarray->data;
179
180 if (rxstride == 1 && axstride == 1 && bxstride == 1)
181 {
182 const GFC_COMPLEX_8 * restrict bbase_y;
183 GFC_COMPLEX_8 * restrict dest_y;
184 const GFC_COMPLEX_8 * restrict abase_n;
185 GFC_COMPLEX_8 bbase_yn;
186
187 if (rystride == xcount)
188 memset (dest, 0, (sizeof (GFC_COMPLEX_8) * xcount * ycount));
189 else
190 {
191 for (y = 0; y < ycount; y++)
192 for (x = 0; x < xcount; x++)
193 dest[x + y*rystride] = (GFC_COMPLEX_8)0;
194 }
195
196 for (y = 0; y < ycount; y++)
197 {
198 bbase_y = bbase + y*bystride;
199 dest_y = dest + y*rystride;
200 for (n = 0; n < count; n++)
201 {
202 abase_n = abase + n*aystride;
203 bbase_yn = bbase_y[n];
204 for (x = 0; x < xcount; x++)
205 {
206 dest_y[x] += abase_n[x] * bbase_yn;
207 }
208 }
209 }
210 }
211 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
212 {
213 const GFC_COMPLEX_8 *restrict abase_x;
214 const GFC_COMPLEX_8 *restrict bbase_y;
215 GFC_COMPLEX_8 *restrict dest_y;
216 GFC_COMPLEX_8 s;
217
218 for (y = 0; y < ycount; y++)
219 {
220 bbase_y = &bbase[y*bystride];
221 dest_y = &dest[y*rystride];
222 for (x = 0; x < xcount; x++)
223 {
224 abase_x = &abase[x*axstride];
225 s = (GFC_COMPLEX_8) 0;
226 for (n = 0; n < count; n++)
227 s += abase_x[n] * bbase_y[n];
228 dest_y[x] = s;
229 }
230 }
231 }
232 else if (axstride < aystride)
233 {
234 for (y = 0; y < ycount; y++)
235 for (x = 0; x < xcount; x++)
236 dest[x*rxstride + y*rystride] = (GFC_COMPLEX_8)0;
237
238 for (y = 0; y < ycount; y++)
239 for (n = 0; n < count; n++)
240 for (x = 0; x < xcount; x++)
241 /* dest[x,y] += a[x,n] * b[n,y] */
242 dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
243 }
244 else
245 {
246 const GFC_COMPLEX_8 *restrict abase_x;
247 const GFC_COMPLEX_8 *restrict bbase_y;
248 GFC_COMPLEX_8 *restrict dest_y;
249 GFC_COMPLEX_8 s;
250
251 for (y = 0; y < ycount; y++)
252 {
253 bbase_y = &bbase[y*bystride];
254 dest_y = &dest[y*rystride];
255 for (x = 0; x < xcount; x++)
256 {
257 abase_x = &abase[x*axstride];
258 s = (GFC_COMPLEX_8) 0;
259 for (n = 0; n < count; n++)
260 s += abase_x[n*aystride] * bbase_y[n*bxstride];
261 dest_y[x*rxstride] = s;
262 }
263 }
264 }
265 }
266
267 #endif