]> git.ipfire.org Git - thirdparty/gcc.git/blob - libgfortran/generated/matmul_i2.c
Licensing changes to GPLv3 resp. GPLv3 with GCC Runtime Exception.
[thirdparty/gcc.git] / libgfortran / generated / matmul_i2.c
1 /* Implementation of the MATMUL intrinsic
2 Copyright 2002, 2005, 2006, 2007, 2009 Free Software Foundation, Inc.
3 Contributed by Paul Brook <paul@nowt.org>
4
5 This file is part of the GNU Fortran 95 runtime library (libgfortran).
6
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 3 of the License, or (at your option) any later version.
11
12 Libgfortran is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 Under Section 7 of GPL version 3, you are granted additional
18 permissions described in the GCC Runtime Library Exception, version
19 3.1, as published by the Free Software Foundation.
20
21 You should have received a copy of the GNU General Public License and
22 a copy of the GCC Runtime Library Exception along with this program;
23 see the files COPYING3 and COPYING.RUNTIME respectively. If not, see
24 <http://www.gnu.org/licenses/>. */
25
26 #include "libgfortran.h"
27 #include <stdlib.h>
28 #include <string.h>
29 #include <assert.h>
30
31
32 #if defined (HAVE_GFC_INTEGER_2)
33
34 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
35 passed to us by the front-end, in which case we'll call it for large
36 matrices. */
37
38 typedef void (*blas_call)(const char *, const char *, const int *, const int *,
39 const int *, const GFC_INTEGER_2 *, const GFC_INTEGER_2 *,
40 const int *, const GFC_INTEGER_2 *, const int *,
41 const GFC_INTEGER_2 *, GFC_INTEGER_2 *, const int *,
42 int, int);
43
44 /* The order of loops is different in the case of plain matrix
45 multiplication C=MATMUL(A,B), and in the frequent special case where
46 the argument A is the temporary result of a TRANSPOSE intrinsic:
47 C=MATMUL(TRANSPOSE(A),B). Transposed temporaries are detected by
48 looking at their strides.
49
50 The equivalent Fortran pseudo-code is:
51
52 DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
53 IF (.NOT.IS_TRANSPOSED(A)) THEN
54 C = 0
55 DO J=1,N
56 DO K=1,COUNT
57 DO I=1,M
58 C(I,J) = C(I,J)+A(I,K)*B(K,J)
59 ELSE
60 DO J=1,N
61 DO I=1,M
62 S = 0
63 DO K=1,COUNT
64 S = S+A(I,K)*B(K,J)
65 C(I,J) = S
66 ENDIF
67 */
68
69 /* If try_blas is set to a nonzero value, then the matmul function will
70 see if there is a way to perform the matrix multiplication by a call
71 to the BLAS gemm function. */
72
73 extern void matmul_i2 (gfc_array_i2 * const restrict retarray,
74 gfc_array_i2 * const restrict a, gfc_array_i2 * const restrict b, int try_blas,
75 int blas_limit, blas_call gemm);
76 export_proto(matmul_i2);
77
78 void
79 matmul_i2 (gfc_array_i2 * const restrict retarray,
80 gfc_array_i2 * const restrict a, gfc_array_i2 * const restrict b, int try_blas,
81 int blas_limit, blas_call gemm)
82 {
83 const GFC_INTEGER_2 * restrict abase;
84 const GFC_INTEGER_2 * restrict bbase;
85 GFC_INTEGER_2 * restrict dest;
86
87 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
88 index_type x, y, n, count, xcount, ycount;
89
90 assert (GFC_DESCRIPTOR_RANK (a) == 2
91 || GFC_DESCRIPTOR_RANK (b) == 2);
92
93 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
94
95 Either A or B (but not both) can be rank 1:
96
97 o One-dimensional argument A is implicitly treated as a row matrix
98 dimensioned [1,count], so xcount=1.
99
100 o One-dimensional argument B is implicitly treated as a column matrix
101 dimensioned [count, 1], so ycount=1.
102 */
103
104 if (retarray->data == NULL)
105 {
106 if (GFC_DESCRIPTOR_RANK (a) == 1)
107 {
108 retarray->dim[0].lbound = 0;
109 retarray->dim[0].ubound = b->dim[1].ubound - b->dim[1].lbound;
110 retarray->dim[0].stride = 1;
111 }
112 else if (GFC_DESCRIPTOR_RANK (b) == 1)
113 {
114 retarray->dim[0].lbound = 0;
115 retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
116 retarray->dim[0].stride = 1;
117 }
118 else
119 {
120 retarray->dim[0].lbound = 0;
121 retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
122 retarray->dim[0].stride = 1;
123
124 retarray->dim[1].lbound = 0;
125 retarray->dim[1].ubound = b->dim[1].ubound - b->dim[1].lbound;
126 retarray->dim[1].stride = retarray->dim[0].ubound+1;
127 }
128
129 retarray->data
130 = internal_malloc_size (sizeof (GFC_INTEGER_2) * size0 ((array_t *) retarray));
131 retarray->offset = 0;
132 }
133 else if (unlikely (compile_options.bounds_check))
134 {
135 index_type ret_extent, arg_extent;
136
137 if (GFC_DESCRIPTOR_RANK (a) == 1)
138 {
139 arg_extent = b->dim[1].ubound + 1 - b->dim[1].lbound;
140 ret_extent = retarray->dim[0].ubound + 1 - retarray->dim[0].lbound;
141 if (arg_extent != ret_extent)
142 runtime_error ("Incorrect extent in return array in"
143 " MATMUL intrinsic: is %ld, should be %ld",
144 (long int) ret_extent, (long int) arg_extent);
145 }
146 else if (GFC_DESCRIPTOR_RANK (b) == 1)
147 {
148 arg_extent = a->dim[0].ubound + 1 - a->dim[0].lbound;
149 ret_extent = retarray->dim[0].ubound + 1 - retarray->dim[0].lbound;
150 if (arg_extent != ret_extent)
151 runtime_error ("Incorrect extent in return array in"
152 " MATMUL intrinsic: is %ld, should be %ld",
153 (long int) ret_extent, (long int) arg_extent);
154 }
155 else
156 {
157 arg_extent = a->dim[0].ubound + 1 - a->dim[0].lbound;
158 ret_extent = retarray->dim[0].ubound + 1 - retarray->dim[0].lbound;
159 if (arg_extent != ret_extent)
160 runtime_error ("Incorrect extent in return array in"
161 " MATMUL intrinsic for dimension 1:"
162 " is %ld, should be %ld",
163 (long int) ret_extent, (long int) arg_extent);
164
165 arg_extent = b->dim[1].ubound + 1 - b->dim[1].lbound;
166 ret_extent = retarray->dim[1].ubound + 1 - retarray->dim[1].lbound;
167 if (arg_extent != ret_extent)
168 runtime_error ("Incorrect extent in return array in"
169 " MATMUL intrinsic for dimension 2:"
170 " is %ld, should be %ld",
171 (long int) ret_extent, (long int) arg_extent);
172 }
173 }
174
175
176 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
177 {
178 /* One-dimensional result may be addressed in the code below
179 either as a row or a column matrix. We want both cases to
180 work. */
181 rxstride = rystride = retarray->dim[0].stride;
182 }
183 else
184 {
185 rxstride = retarray->dim[0].stride;
186 rystride = retarray->dim[1].stride;
187 }
188
189
190 if (GFC_DESCRIPTOR_RANK (a) == 1)
191 {
192 /* Treat it as a a row matrix A[1,count]. */
193 axstride = a->dim[0].stride;
194 aystride = 1;
195
196 xcount = 1;
197 count = a->dim[0].ubound + 1 - a->dim[0].lbound;
198 }
199 else
200 {
201 axstride = a->dim[0].stride;
202 aystride = a->dim[1].stride;
203
204 count = a->dim[1].ubound + 1 - a->dim[1].lbound;
205 xcount = a->dim[0].ubound + 1 - a->dim[0].lbound;
206 }
207
208 if (count != b->dim[0].ubound + 1 - b->dim[0].lbound)
209 {
210 if (count > 0 || b->dim[0].ubound + 1 - b->dim[0].lbound > 0)
211 runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
212 }
213
214 if (GFC_DESCRIPTOR_RANK (b) == 1)
215 {
216 /* Treat it as a column matrix B[count,1] */
217 bxstride = b->dim[0].stride;
218
219 /* bystride should never be used for 1-dimensional b.
220 in case it is we want it to cause a segfault, rather than
221 an incorrect result. */
222 bystride = 0xDEADBEEF;
223 ycount = 1;
224 }
225 else
226 {
227 bxstride = b->dim[0].stride;
228 bystride = b->dim[1].stride;
229 ycount = b->dim[1].ubound + 1 - b->dim[1].lbound;
230 }
231
232 abase = a->data;
233 bbase = b->data;
234 dest = retarray->data;
235
236
237 /* Now that everything is set up, we're performing the multiplication
238 itself. */
239
240 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
241
242 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
243 && (bxstride == 1 || bystride == 1)
244 && (((float) xcount) * ((float) ycount) * ((float) count)
245 > POW3(blas_limit)))
246 {
247 const int m = xcount, n = ycount, k = count, ldc = rystride;
248 const GFC_INTEGER_2 one = 1, zero = 0;
249 const int lda = (axstride == 1) ? aystride : axstride,
250 ldb = (bxstride == 1) ? bystride : bxstride;
251
252 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
253 {
254 assert (gemm != NULL);
255 gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m, &n, &k,
256 &one, abase, &lda, bbase, &ldb, &zero, dest, &ldc, 1, 1);
257 return;
258 }
259 }
260
261 if (rxstride == 1 && axstride == 1 && bxstride == 1)
262 {
263 const GFC_INTEGER_2 * restrict bbase_y;
264 GFC_INTEGER_2 * restrict dest_y;
265 const GFC_INTEGER_2 * restrict abase_n;
266 GFC_INTEGER_2 bbase_yn;
267
268 if (rystride == xcount)
269 memset (dest, 0, (sizeof (GFC_INTEGER_2) * xcount * ycount));
270 else
271 {
272 for (y = 0; y < ycount; y++)
273 for (x = 0; x < xcount; x++)
274 dest[x + y*rystride] = (GFC_INTEGER_2)0;
275 }
276
277 for (y = 0; y < ycount; y++)
278 {
279 bbase_y = bbase + y*bystride;
280 dest_y = dest + y*rystride;
281 for (n = 0; n < count; n++)
282 {
283 abase_n = abase + n*aystride;
284 bbase_yn = bbase_y[n];
285 for (x = 0; x < xcount; x++)
286 {
287 dest_y[x] += abase_n[x] * bbase_yn;
288 }
289 }
290 }
291 }
292 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
293 {
294 if (GFC_DESCRIPTOR_RANK (a) != 1)
295 {
296 const GFC_INTEGER_2 *restrict abase_x;
297 const GFC_INTEGER_2 *restrict bbase_y;
298 GFC_INTEGER_2 *restrict dest_y;
299 GFC_INTEGER_2 s;
300
301 for (y = 0; y < ycount; y++)
302 {
303 bbase_y = &bbase[y*bystride];
304 dest_y = &dest[y*rystride];
305 for (x = 0; x < xcount; x++)
306 {
307 abase_x = &abase[x*axstride];
308 s = (GFC_INTEGER_2) 0;
309 for (n = 0; n < count; n++)
310 s += abase_x[n] * bbase_y[n];
311 dest_y[x] = s;
312 }
313 }
314 }
315 else
316 {
317 const GFC_INTEGER_2 *restrict bbase_y;
318 GFC_INTEGER_2 s;
319
320 for (y = 0; y < ycount; y++)
321 {
322 bbase_y = &bbase[y*bystride];
323 s = (GFC_INTEGER_2) 0;
324 for (n = 0; n < count; n++)
325 s += abase[n*axstride] * bbase_y[n];
326 dest[y*rystride] = s;
327 }
328 }
329 }
330 else if (axstride < aystride)
331 {
332 for (y = 0; y < ycount; y++)
333 for (x = 0; x < xcount; x++)
334 dest[x*rxstride + y*rystride] = (GFC_INTEGER_2)0;
335
336 for (y = 0; y < ycount; y++)
337 for (n = 0; n < count; n++)
338 for (x = 0; x < xcount; x++)
339 /* dest[x,y] += a[x,n] * b[n,y] */
340 dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
341 }
342 else if (GFC_DESCRIPTOR_RANK (a) == 1)
343 {
344 const GFC_INTEGER_2 *restrict bbase_y;
345 GFC_INTEGER_2 s;
346
347 for (y = 0; y < ycount; y++)
348 {
349 bbase_y = &bbase[y*bystride];
350 s = (GFC_INTEGER_2) 0;
351 for (n = 0; n < count; n++)
352 s += abase[n*axstride] * bbase_y[n*bxstride];
353 dest[y*rxstride] = s;
354 }
355 }
356 else
357 {
358 const GFC_INTEGER_2 *restrict abase_x;
359 const GFC_INTEGER_2 *restrict bbase_y;
360 GFC_INTEGER_2 *restrict dest_y;
361 GFC_INTEGER_2 s;
362
363 for (y = 0; y < ycount; y++)
364 {
365 bbase_y = &bbase[y*bystride];
366 dest_y = &dest[y*rystride];
367 for (x = 0; x < xcount; x++)
368 {
369 abase_x = &abase[x*axstride];
370 s = (GFC_INTEGER_2) 0;
371 for (n = 0; n < count; n++)
372 s += abase_x[n*aystride] * bbase_y[n*bxstride];
373 dest_y[x*rxstride] = s;
374 }
375 }
376 }
377 }
378
379 #endif