1 `/* Implementation of the MATMUL intrinsic
2 Copyright (C) 2002-2016 Free Software Foundation, Inc.
3 Contributed by Paul Brook <paul@nowt.org>
5 This file is part of the GNU Fortran runtime library (libgfortran).
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 3 of the License, or (at your option) any later version.
12 Libgfortran is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
17 Under Section 7 of GPL version 3, you are granted additional
18 permissions described in the GCC Runtime Library Exception, version
19 3.1, as published by the Free Software Foundation.
21 You should have received a copy of the GNU General Public License and
22 a copy of the GCC Runtime Library Exception along with this program;
23 see the files COPYING3 and COPYING.RUNTIME respectively. If not, see
24 <http://www.gnu.org/licenses/>. */
26 #include "libgfortran.h"
33 `#if defined (HAVE_'rtype_name`)
35 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
36 passed to us by the front-end, in which case we call it for large
39 typedef void (*blas_call)(const char *, const char *, const int *, const int *,
40 const int *, const 'rtype_name` *, const 'rtype_name` *,
41 const int *, const 'rtype_name` *, const int *,
42 const 'rtype_name` *, 'rtype_name` *, const int *,
45 /* The order of loops is different in the case of plain matrix
46 multiplication C=MATMUL(A,B), and in the frequent special case where
47 the argument A is the temporary result of a TRANSPOSE intrinsic:
48 C=MATMUL(TRANSPOSE(A),B). Transposed temporaries are detected by
49 looking at their strides.
51 The equivalent Fortran pseudo-code is:
53 DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
54 IF (.NOT.IS_TRANSPOSED(A)) THEN
59 C(I,J) = C(I,J)+A(I,K)*B(K,J)
70 /* If try_blas is set to a nonzero value, then the matmul function will
71 see if there is a way to perform the matrix multiplication by a call
72 to the BLAS gemm function. */
74 extern void matmul_'rtype_code` ('rtype` * const restrict retarray,
75 'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
76 int blas_limit, blas_call gemm);
77 export_proto(matmul_'rtype_code`);
80 matmul_'rtype_code` ('rtype` * const restrict retarray,
81 'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
82 int blas_limit, blas_call gemm)
84 const 'rtype_name` * restrict abase;
85 const 'rtype_name` * restrict bbase;
86 'rtype_name` * restrict dest;
88 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
89 index_type x, y, n, count, xcount, ycount;
91 assert (GFC_DESCRIPTOR_RANK (a) == 2
92 || GFC_DESCRIPTOR_RANK (b) == 2);
94 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
96 Either A or B (but not both) can be rank 1:
98 o One-dimensional argument A is implicitly treated as a row matrix
99 dimensioned [1,count], so xcount=1.
101 o One-dimensional argument B is implicitly treated as a column matrix
102 dimensioned [count, 1], so ycount=1.
105 if (retarray->base_addr == NULL)
107 if (GFC_DESCRIPTOR_RANK (a) == 1)
109 GFC_DIMENSION_SET(retarray->dim[0], 0,
110 GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
112 else if (GFC_DESCRIPTOR_RANK (b) == 1)
114 GFC_DIMENSION_SET(retarray->dim[0], 0,
115 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
119 GFC_DIMENSION_SET(retarray->dim[0], 0,
120 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
122 GFC_DIMENSION_SET(retarray->dim[1], 0,
123 GFC_DESCRIPTOR_EXTENT(b,1) - 1,
124 GFC_DESCRIPTOR_EXTENT(retarray,0));
128 = xmallocarray (size0 ((array_t *) retarray), sizeof ('rtype_name`));
129 retarray->offset = 0;
131 else if (unlikely (compile_options.bounds_check))
133 index_type ret_extent, arg_extent;
135 if (GFC_DESCRIPTOR_RANK (a) == 1)
137 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
138 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
139 if (arg_extent != ret_extent)
140 runtime_error ("Incorrect extent in return array in"
141 " MATMUL intrinsic: is %ld, should be %ld",
142 (long int) ret_extent, (long int) arg_extent);
144 else if (GFC_DESCRIPTOR_RANK (b) == 1)
146 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
147 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
148 if (arg_extent != ret_extent)
149 runtime_error ("Incorrect extent in return array in"
150 " MATMUL intrinsic: is %ld, should be %ld",
151 (long int) ret_extent, (long int) arg_extent);
155 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
156 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
157 if (arg_extent != ret_extent)
158 runtime_error ("Incorrect extent in return array in"
159 " MATMUL intrinsic for dimension 1:"
160 " is %ld, should be %ld",
161 (long int) ret_extent, (long int) arg_extent);
163 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
164 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
165 if (arg_extent != ret_extent)
166 runtime_error ("Incorrect extent in return array in"
167 " MATMUL intrinsic for dimension 2:"
168 " is %ld, should be %ld",
169 (long int) ret_extent, (long int) arg_extent);
173 sinclude(`matmul_asm_'rtype_code`.m4')dnl
175 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
177 /* One-dimensional result may be addressed in the code below
178 either as a row or a column matrix. We want both cases to
180 rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
184 rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
185 rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
189 if (GFC_DESCRIPTOR_RANK (a) == 1)
191 /* Treat it as a a row matrix A[1,count]. */
192 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
196 count = GFC_DESCRIPTOR_EXTENT(a,0);
200 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
201 aystride = GFC_DESCRIPTOR_STRIDE(a,1);
203 count = GFC_DESCRIPTOR_EXTENT(a,1);
204 xcount = GFC_DESCRIPTOR_EXTENT(a,0);
207 if (count != GFC_DESCRIPTOR_EXTENT(b,0))
209 if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
210 runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
213 if (GFC_DESCRIPTOR_RANK (b) == 1)
215 /* Treat it as a column matrix B[count,1] */
216 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
218 /* bystride should never be used for 1-dimensional b.
219 in case it is we want it to cause a segfault, rather than
220 an incorrect result. */
221 bystride = 0xDEADBEEF;
226 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
227 bystride = GFC_DESCRIPTOR_STRIDE(b,1);
228 ycount = GFC_DESCRIPTOR_EXTENT(b,1);
231 abase = a->base_addr;
232 bbase = b->base_addr;
233 dest = retarray->base_addr;
235 /* Now that everything is set up, we perform the multiplication
238 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
239 #define min(a,b) ((a) <= (b) ? (a) : (b))
240 #define max(a,b) ((a) >= (b) ? (a) : (b))
242 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
243 && (bxstride == 1 || bystride == 1)
244 && (((float) xcount) * ((float) ycount) * ((float) count)
247 const int m = xcount, n = ycount, k = count, ldc = rystride;
248 const 'rtype_name` one = 1, zero = 0;
249 const int lda = (axstride == 1) ? aystride : axstride,
250 ldb = (bxstride == 1) ? bystride : bxstride;
252 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
254 assert (gemm != NULL);
255 gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m,
256 &n, &k, &one, abase, &lda, bbase, &ldb, &zero, dest,
262 if (rxstride == 1 && axstride == 1 && bxstride == 1)
264 /* This block of code implements a tuned matmul, derived from
265 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
267 Bo Kagstrom and Per Ling
268 Department of Computing Science
270 S-901 87 Umea, Sweden
272 from netlib.org, translated to C, and modified for matmul.m4. */
274 const 'rtype_name` *a, *b;
276 const index_type m = xcount, n = ycount, k = count;
278 /* System generated locals */
279 index_type a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset,
280 i1, i2, i3, i4, i5, i6;
282 /* Local variables */
283 'rtype_name` t1[65536], /* was [256][256] */
284 f11, f12, f21, f22, f31, f32, f41, f42,
285 f13, f14, f23, f24, f33, f34, f43, f44;
286 index_type i, j, l, ii, jj, ll;
287 index_type isec, jsec, lsec, uisec, ujsec, ulsec;
291 c = retarray->base_addr;
293 /* Parameter adjustments */
295 c_offset = 1 + c_dim1;
298 a_offset = 1 + a_dim1;
301 b_offset = 1 + b_dim1;
304 /* Early exit if possible */
305 if (m == 0 || n == 0 || k == 0)
311 c[i + j * c_dim1] = ('rtype_name`)0;
313 /* Start turning the crank. */
315 for (jj = 1; jj <= i1; jj += 512)
321 ujsec = jsec - jsec % 4;
323 for (ll = 1; ll <= i2; ll += 256)
329 ulsec = lsec - lsec % 2;
332 for (ii = 1; ii <= i3; ii += 256)
338 uisec = isec - isec % 2;
340 for (l = ll; l <= i4; l += 2)
343 for (i = ii; i <= i5; i += 2)
345 t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] =
347 t1[l - ll + 2 + ((i - ii + 1) << 8) - 257] =
348 a[i + (l + 1) * a_dim1];
349 t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] =
350 a[i + 1 + l * a_dim1];
351 t1[l - ll + 2 + ((i - ii + 2) << 8) - 257] =
352 a[i + 1 + (l + 1) * a_dim1];
356 t1[l - ll + 1 + (isec << 8) - 257] =
357 a[ii + isec - 1 + l * a_dim1];
358 t1[l - ll + 2 + (isec << 8) - 257] =
359 a[ii + isec - 1 + (l + 1) * a_dim1];
365 for (i = ii; i<= i4; ++i)
367 t1[lsec + ((i - ii + 1) << 8) - 257] =
368 a[i + (ll + lsec - 1) * a_dim1];
372 uisec = isec - isec % 4;
374 for (j = jj; j <= i4; j += 4)
377 for (i = ii; i <= i5; i += 4)
379 f11 = c[i + j * c_dim1];
380 f21 = c[i + 1 + j * c_dim1];
381 f12 = c[i + (j + 1) * c_dim1];
382 f22 = c[i + 1 + (j + 1) * c_dim1];
383 f13 = c[i + (j + 2) * c_dim1];
384 f23 = c[i + 1 + (j + 2) * c_dim1];
385 f14 = c[i + (j + 3) * c_dim1];
386 f24 = c[i + 1 + (j + 3) * c_dim1];
387 f31 = c[i + 2 + j * c_dim1];
388 f41 = c[i + 3 + j * c_dim1];
389 f32 = c[i + 2 + (j + 1) * c_dim1];
390 f42 = c[i + 3 + (j + 1) * c_dim1];
391 f33 = c[i + 2 + (j + 2) * c_dim1];
392 f43 = c[i + 3 + (j + 2) * c_dim1];
393 f34 = c[i + 2 + (j + 3) * c_dim1];
394 f44 = c[i + 3 + (j + 3) * c_dim1];
396 for (l = ll; l <= i6; ++l)
398 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
400 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
402 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
403 * b[l + (j + 1) * b_dim1];
404 f22 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
405 * b[l + (j + 1) * b_dim1];
406 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
407 * b[l + (j + 2) * b_dim1];
408 f23 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
409 * b[l + (j + 2) * b_dim1];
410 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
411 * b[l + (j + 3) * b_dim1];
412 f24 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
413 * b[l + (j + 3) * b_dim1];
414 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
416 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
418 f32 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
419 * b[l + (j + 1) * b_dim1];
420 f42 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
421 * b[l + (j + 1) * b_dim1];
422 f33 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
423 * b[l + (j + 2) * b_dim1];
424 f43 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
425 * b[l + (j + 2) * b_dim1];
426 f34 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
427 * b[l + (j + 3) * b_dim1];
428 f44 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
429 * b[l + (j + 3) * b_dim1];
431 c[i + j * c_dim1] = f11;
432 c[i + 1 + j * c_dim1] = f21;
433 c[i + (j + 1) * c_dim1] = f12;
434 c[i + 1 + (j + 1) * c_dim1] = f22;
435 c[i + (j + 2) * c_dim1] = f13;
436 c[i + 1 + (j + 2) * c_dim1] = f23;
437 c[i + (j + 3) * c_dim1] = f14;
438 c[i + 1 + (j + 3) * c_dim1] = f24;
439 c[i + 2 + j * c_dim1] = f31;
440 c[i + 3 + j * c_dim1] = f41;
441 c[i + 2 + (j + 1) * c_dim1] = f32;
442 c[i + 3 + (j + 1) * c_dim1] = f42;
443 c[i + 2 + (j + 2) * c_dim1] = f33;
444 c[i + 3 + (j + 2) * c_dim1] = f43;
445 c[i + 2 + (j + 3) * c_dim1] = f34;
446 c[i + 3 + (j + 3) * c_dim1] = f44;
451 for (i = ii + uisec; i <= i5; ++i)
453 f11 = c[i + j * c_dim1];
454 f12 = c[i + (j + 1) * c_dim1];
455 f13 = c[i + (j + 2) * c_dim1];
456 f14 = c[i + (j + 3) * c_dim1];
458 for (l = ll; l <= i6; ++l)
460 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
461 257] * b[l + j * b_dim1];
462 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
463 257] * b[l + (j + 1) * b_dim1];
464 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
465 257] * b[l + (j + 2) * b_dim1];
466 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
467 257] * b[l + (j + 3) * b_dim1];
469 c[i + j * c_dim1] = f11;
470 c[i + (j + 1) * c_dim1] = f12;
471 c[i + (j + 2) * c_dim1] = f13;
472 c[i + (j + 3) * c_dim1] = f14;
479 for (j = jj + ujsec; j <= i4; ++j)
482 for (i = ii; i <= i5; i += 4)
484 f11 = c[i + j * c_dim1];
485 f21 = c[i + 1 + j * c_dim1];
486 f31 = c[i + 2 + j * c_dim1];
487 f41 = c[i + 3 + j * c_dim1];
489 for (l = ll; l <= i6; ++l)
491 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
492 257] * b[l + j * b_dim1];
493 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) -
494 257] * b[l + j * b_dim1];
495 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) -
496 257] * b[l + j * b_dim1];
497 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) -
498 257] * b[l + j * b_dim1];
500 c[i + j * c_dim1] = f11;
501 c[i + 1 + j * c_dim1] = f21;
502 c[i + 2 + j * c_dim1] = f31;
503 c[i + 3 + j * c_dim1] = f41;
506 for (i = ii + uisec; i <= i5; ++i)
508 f11 = c[i + j * c_dim1];
510 for (l = ll; l <= i6; ++l)
512 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
513 257] * b[l + j * b_dim1];
515 c[i + j * c_dim1] = f11;
524 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
526 if (GFC_DESCRIPTOR_RANK (a) != 1)
528 const 'rtype_name` *restrict abase_x;
529 const 'rtype_name` *restrict bbase_y;
530 'rtype_name` *restrict dest_y;
533 for (y = 0; y < ycount; y++)
535 bbase_y = &bbase[y*bystride];
536 dest_y = &dest[y*rystride];
537 for (x = 0; x < xcount; x++)
539 abase_x = &abase[x*axstride];
540 s = ('rtype_name`) 0;
541 for (n = 0; n < count; n++)
542 s += abase_x[n] * bbase_y[n];
549 const 'rtype_name` *restrict bbase_y;
552 for (y = 0; y < ycount; y++)
554 bbase_y = &bbase[y*bystride];
555 s = ('rtype_name`) 0;
556 for (n = 0; n < count; n++)
557 s += abase[n*axstride] * bbase_y[n];
558 dest[y*rystride] = s;
562 else if (axstride < aystride)
564 for (y = 0; y < ycount; y++)
565 for (x = 0; x < xcount; x++)
566 dest[x*rxstride + y*rystride] = ('rtype_name`)0;
568 for (y = 0; y < ycount; y++)
569 for (n = 0; n < count; n++)
570 for (x = 0; x < xcount; x++)
571 /* dest[x,y] += a[x,n] * b[n,y] */
572 dest[x*rxstride + y*rystride] +=
573 abase[x*axstride + n*aystride] *
574 bbase[n*bxstride + y*bystride];
576 else if (GFC_DESCRIPTOR_RANK (a) == 1)
578 const 'rtype_name` *restrict bbase_y;
581 for (y = 0; y < ycount; y++)
583 bbase_y = &bbase[y*bystride];
584 s = ('rtype_name`) 0;
585 for (n = 0; n < count; n++)
586 s += abase[n*axstride] * bbase_y[n*bxstride];
587 dest[y*rxstride] = s;
592 const 'rtype_name` *restrict abase_x;
593 const 'rtype_name` *restrict bbase_y;
594 'rtype_name` *restrict dest_y;
597 for (y = 0; y < ycount; y++)
599 bbase_y = &bbase[y*bystride];
600 dest_y = &dest[y*rystride];
601 for (x = 0; x < xcount; x++)
603 abase_x = &abase[x*axstride];
604 s = ('rtype_name`) 0;
605 for (n = 0; n < count; n++)
606 s += abase_x[n*aystride] * bbase_y[n*bxstride];
607 dest_y[x*rxstride] = s;