]>
git.ipfire.org Git - thirdparty/gcc.git/blob - libgfortran/generated/matmul_r8.c
1 /* Implementation of the MATMUL intrinsic
2 Copyright (C) 2002-2021 Free Software Foundation, Inc.
3 Contributed by Paul Brook <paul@nowt.org>
5 This file is part of the GNU Fortran runtime library (libgfortran).
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 3 of the License, or (at your option) any later version.
12 Libgfortran is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
17 Under Section 7 of GPL version 3, you are granted additional
18 permissions described in the GCC Runtime Library Exception, version
19 3.1, as published by the Free Software Foundation.
21 You should have received a copy of the GNU General Public License and
22 a copy of the GCC Runtime Library Exception along with this program;
23 see the files COPYING3 and COPYING.RUNTIME respectively. If not, see
24 <http://www.gnu.org/licenses/>. */
26 #include "libgfortran.h"
31 #if defined (HAVE_GFC_REAL_8)
33 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
34 passed to us by the front-end, in which case we call it for large
37 typedef void (*blas_call
)(const char *, const char *, const int *, const int *,
38 const int *, const GFC_REAL_8
*, const GFC_REAL_8
*,
39 const int *, const GFC_REAL_8
*, const int *,
40 const GFC_REAL_8
*, GFC_REAL_8
*, const int *,
43 /* The order of loops is different in the case of plain matrix
44 multiplication C=MATMUL(A,B), and in the frequent special case where
45 the argument A is the temporary result of a TRANSPOSE intrinsic:
46 C=MATMUL(TRANSPOSE(A),B). Transposed temporaries are detected by
47 looking at their strides.
49 The equivalent Fortran pseudo-code is:
51 DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
52 IF (.NOT.IS_TRANSPOSED(A)) THEN
57 C(I,J) = C(I,J)+A(I,K)*B(K,J)
68 /* If try_blas is set to a nonzero value, then the matmul function will
69 see if there is a way to perform the matrix multiplication by a call
70 to the BLAS gemm function. */
72 extern void matmul_r8 (gfc_array_r8
* const restrict retarray
,
73 gfc_array_r8
* const restrict a
, gfc_array_r8
* const restrict b
, int try_blas
,
74 int blas_limit
, blas_call gemm
);
75 export_proto(matmul_r8
);
77 /* Put exhaustive list of possible architectures here here, ORed together. */
79 #if defined(HAVE_AVX) || defined(HAVE_AVX2) || defined(HAVE_AVX512F)
83 matmul_r8_avx (gfc_array_r8
* const restrict retarray
,
84 gfc_array_r8
* const restrict a
, gfc_array_r8
* const restrict b
, int try_blas
,
85 int blas_limit
, blas_call gemm
) __attribute__((__target__("avx")));
87 matmul_r8_avx (gfc_array_r8
* const restrict retarray
,
88 gfc_array_r8
* const restrict a
, gfc_array_r8
* const restrict b
, int try_blas
,
89 int blas_limit
, blas_call gemm
)
91 const GFC_REAL_8
* restrict abase
;
92 const GFC_REAL_8
* restrict bbase
;
93 GFC_REAL_8
* restrict dest
;
95 index_type rxstride
, rystride
, axstride
, aystride
, bxstride
, bystride
;
96 index_type x
, y
, n
, count
, xcount
, ycount
;
98 assert (GFC_DESCRIPTOR_RANK (a
) == 2
99 || GFC_DESCRIPTOR_RANK (b
) == 2);
101 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
103 Either A or B (but not both) can be rank 1:
105 o One-dimensional argument A is implicitly treated as a row matrix
106 dimensioned [1,count], so xcount=1.
108 o One-dimensional argument B is implicitly treated as a column matrix
109 dimensioned [count, 1], so ycount=1.
112 if (retarray
->base_addr
== NULL
)
114 if (GFC_DESCRIPTOR_RANK (a
) == 1)
116 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
117 GFC_DESCRIPTOR_EXTENT(b
,1) - 1, 1);
119 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
121 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
122 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
126 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
127 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
129 GFC_DIMENSION_SET(retarray
->dim
[1], 0,
130 GFC_DESCRIPTOR_EXTENT(b
,1) - 1,
131 GFC_DESCRIPTOR_EXTENT(retarray
,0));
135 = xmallocarray (size0 ((array_t
*) retarray
), sizeof (GFC_REAL_8
));
136 retarray
->offset
= 0;
138 else if (unlikely (compile_options
.bounds_check
))
140 index_type ret_extent
, arg_extent
;
142 if (GFC_DESCRIPTOR_RANK (a
) == 1)
144 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
145 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
146 if (arg_extent
!= ret_extent
)
147 runtime_error ("Array bound mismatch for dimension 1 of "
149 (long int) ret_extent
, (long int) arg_extent
);
151 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
153 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
154 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
155 if (arg_extent
!= ret_extent
)
156 runtime_error ("Array bound mismatch for dimension 1 of "
158 (long int) ret_extent
, (long int) arg_extent
);
162 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
163 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
164 if (arg_extent
!= ret_extent
)
165 runtime_error ("Array bound mismatch for dimension 1 of "
167 (long int) ret_extent
, (long int) arg_extent
);
169 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
170 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,1);
171 if (arg_extent
!= ret_extent
)
172 runtime_error ("Array bound mismatch for dimension 2 of "
174 (long int) ret_extent
, (long int) arg_extent
);
179 if (GFC_DESCRIPTOR_RANK (retarray
) == 1)
181 /* One-dimensional result may be addressed in the code below
182 either as a row or a column matrix. We want both cases to
184 rxstride
= rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
188 rxstride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
189 rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,1);
193 if (GFC_DESCRIPTOR_RANK (a
) == 1)
195 /* Treat it as a a row matrix A[1,count]. */
196 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
200 count
= GFC_DESCRIPTOR_EXTENT(a
,0);
204 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
205 aystride
= GFC_DESCRIPTOR_STRIDE(a
,1);
207 count
= GFC_DESCRIPTOR_EXTENT(a
,1);
208 xcount
= GFC_DESCRIPTOR_EXTENT(a
,0);
211 if (count
!= GFC_DESCRIPTOR_EXTENT(b
,0))
213 if (count
> 0 || GFC_DESCRIPTOR_EXTENT(b
,0) > 0)
214 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
215 "in dimension 1: is %ld, should be %ld",
216 (long int) GFC_DESCRIPTOR_EXTENT(b
,0), (long int) count
);
219 if (GFC_DESCRIPTOR_RANK (b
) == 1)
221 /* Treat it as a column matrix B[count,1] */
222 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
224 /* bystride should never be used for 1-dimensional b.
225 The value is only used for calculation of the
226 memory by the buffer. */
232 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
233 bystride
= GFC_DESCRIPTOR_STRIDE(b
,1);
234 ycount
= GFC_DESCRIPTOR_EXTENT(b
,1);
237 abase
= a
->base_addr
;
238 bbase
= b
->base_addr
;
239 dest
= retarray
->base_addr
;
241 /* Now that everything is set up, we perform the multiplication
244 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
245 #define min(a,b) ((a) <= (b) ? (a) : (b))
246 #define max(a,b) ((a) >= (b) ? (a) : (b))
248 if (try_blas
&& rxstride
== 1 && (axstride
== 1 || aystride
== 1)
249 && (bxstride
== 1 || bystride
== 1)
250 && (((float) xcount
) * ((float) ycount
) * ((float) count
)
253 const int m
= xcount
, n
= ycount
, k
= count
, ldc
= rystride
;
254 const GFC_REAL_8 one
= 1, zero
= 0;
255 const int lda
= (axstride
== 1) ? aystride
: axstride
,
256 ldb
= (bxstride
== 1) ? bystride
: bxstride
;
258 if (lda
> 0 && ldb
> 0 && ldc
> 0 && m
> 1 && n
> 1 && k
> 1)
260 assert (gemm
!= NULL
);
261 const char *transa
, *transb
;
265 transa
= axstride
== 1 ? "N" : "T";
270 transb
= bxstride
== 1 ? "N" : "T";
272 gemm (transa
, transb
, &m
,
273 &n
, &k
, &one
, abase
, &lda
, bbase
, &ldb
, &zero
, dest
,
279 if (rxstride
== 1 && axstride
== 1 && bxstride
== 1)
281 /* This block of code implements a tuned matmul, derived from
282 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
284 Bo Kagstrom and Per Ling
285 Department of Computing Science
287 S-901 87 Umea, Sweden
289 from netlib.org, translated to C, and modified for matmul.m4. */
291 const GFC_REAL_8
*a
, *b
;
293 const index_type m
= xcount
, n
= ycount
, k
= count
;
295 /* System generated locals */
296 index_type a_dim1
, a_offset
, b_dim1
, b_offset
, c_dim1
, c_offset
,
297 i1
, i2
, i3
, i4
, i5
, i6
;
299 /* Local variables */
300 GFC_REAL_8 f11
, f12
, f21
, f22
, f31
, f32
, f41
, f42
,
301 f13
, f14
, f23
, f24
, f33
, f34
, f43
, f44
;
302 index_type i
, j
, l
, ii
, jj
, ll
;
303 index_type isec
, jsec
, lsec
, uisec
, ujsec
, ulsec
;
308 c
= retarray
->base_addr
;
310 /* Parameter adjustments */
312 c_offset
= 1 + c_dim1
;
315 a_offset
= 1 + a_dim1
;
318 b_offset
= 1 + b_dim1
;
324 c
[i
+ j
* c_dim1
] = (GFC_REAL_8
)0;
326 /* Early exit if possible */
327 if (m
== 0 || n
== 0 || k
== 0)
330 /* Adjust size of t1 to what is needed. */
331 index_type t1_dim
, a_sz
;
337 t1_dim
= a_sz
* 256 + b_dim1
;
341 t1
= malloc (t1_dim
* sizeof(GFC_REAL_8
));
343 /* Start turning the crank. */
345 for (jj
= 1; jj
<= i1
; jj
+= 512)
351 ujsec
= jsec
- jsec
% 4;
353 for (ll
= 1; ll
<= i2
; ll
+= 256)
359 ulsec
= lsec
- lsec
% 2;
362 for (ii
= 1; ii
<= i3
; ii
+= 256)
368 uisec
= isec
- isec
% 2;
370 for (l
= ll
; l
<= i4
; l
+= 2)
373 for (i
= ii
; i
<= i5
; i
+= 2)
375 t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257] =
377 t1
[l
- ll
+ 2 + ((i
- ii
+ 1) << 8) - 257] =
378 a
[i
+ (l
+ 1) * a_dim1
];
379 t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257] =
380 a
[i
+ 1 + l
* a_dim1
];
381 t1
[l
- ll
+ 2 + ((i
- ii
+ 2) << 8) - 257] =
382 a
[i
+ 1 + (l
+ 1) * a_dim1
];
386 t1
[l
- ll
+ 1 + (isec
<< 8) - 257] =
387 a
[ii
+ isec
- 1 + l
* a_dim1
];
388 t1
[l
- ll
+ 2 + (isec
<< 8) - 257] =
389 a
[ii
+ isec
- 1 + (l
+ 1) * a_dim1
];
395 for (i
= ii
; i
<= i4
; ++i
)
397 t1
[lsec
+ ((i
- ii
+ 1) << 8) - 257] =
398 a
[i
+ (ll
+ lsec
- 1) * a_dim1
];
402 uisec
= isec
- isec
% 4;
404 for (j
= jj
; j
<= i4
; j
+= 4)
407 for (i
= ii
; i
<= i5
; i
+= 4)
409 f11
= c
[i
+ j
* c_dim1
];
410 f21
= c
[i
+ 1 + j
* c_dim1
];
411 f12
= c
[i
+ (j
+ 1) * c_dim1
];
412 f22
= c
[i
+ 1 + (j
+ 1) * c_dim1
];
413 f13
= c
[i
+ (j
+ 2) * c_dim1
];
414 f23
= c
[i
+ 1 + (j
+ 2) * c_dim1
];
415 f14
= c
[i
+ (j
+ 3) * c_dim1
];
416 f24
= c
[i
+ 1 + (j
+ 3) * c_dim1
];
417 f31
= c
[i
+ 2 + j
* c_dim1
];
418 f41
= c
[i
+ 3 + j
* c_dim1
];
419 f32
= c
[i
+ 2 + (j
+ 1) * c_dim1
];
420 f42
= c
[i
+ 3 + (j
+ 1) * c_dim1
];
421 f33
= c
[i
+ 2 + (j
+ 2) * c_dim1
];
422 f43
= c
[i
+ 3 + (j
+ 2) * c_dim1
];
423 f34
= c
[i
+ 2 + (j
+ 3) * c_dim1
];
424 f44
= c
[i
+ 3 + (j
+ 3) * c_dim1
];
426 for (l
= ll
; l
<= i6
; ++l
)
428 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
430 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
432 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
433 * b
[l
+ (j
+ 1) * b_dim1
];
434 f22
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
435 * b
[l
+ (j
+ 1) * b_dim1
];
436 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
437 * b
[l
+ (j
+ 2) * b_dim1
];
438 f23
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
439 * b
[l
+ (j
+ 2) * b_dim1
];
440 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
441 * b
[l
+ (j
+ 3) * b_dim1
];
442 f24
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
443 * b
[l
+ (j
+ 3) * b_dim1
];
444 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
446 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
448 f32
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
449 * b
[l
+ (j
+ 1) * b_dim1
];
450 f42
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
451 * b
[l
+ (j
+ 1) * b_dim1
];
452 f33
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
453 * b
[l
+ (j
+ 2) * b_dim1
];
454 f43
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
455 * b
[l
+ (j
+ 2) * b_dim1
];
456 f34
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
457 * b
[l
+ (j
+ 3) * b_dim1
];
458 f44
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
459 * b
[l
+ (j
+ 3) * b_dim1
];
461 c
[i
+ j
* c_dim1
] = f11
;
462 c
[i
+ 1 + j
* c_dim1
] = f21
;
463 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
464 c
[i
+ 1 + (j
+ 1) * c_dim1
] = f22
;
465 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
466 c
[i
+ 1 + (j
+ 2) * c_dim1
] = f23
;
467 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
468 c
[i
+ 1 + (j
+ 3) * c_dim1
] = f24
;
469 c
[i
+ 2 + j
* c_dim1
] = f31
;
470 c
[i
+ 3 + j
* c_dim1
] = f41
;
471 c
[i
+ 2 + (j
+ 1) * c_dim1
] = f32
;
472 c
[i
+ 3 + (j
+ 1) * c_dim1
] = f42
;
473 c
[i
+ 2 + (j
+ 2) * c_dim1
] = f33
;
474 c
[i
+ 3 + (j
+ 2) * c_dim1
] = f43
;
475 c
[i
+ 2 + (j
+ 3) * c_dim1
] = f34
;
476 c
[i
+ 3 + (j
+ 3) * c_dim1
] = f44
;
481 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
483 f11
= c
[i
+ j
* c_dim1
];
484 f12
= c
[i
+ (j
+ 1) * c_dim1
];
485 f13
= c
[i
+ (j
+ 2) * c_dim1
];
486 f14
= c
[i
+ (j
+ 3) * c_dim1
];
488 for (l
= ll
; l
<= i6
; ++l
)
490 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
491 257] * b
[l
+ j
* b_dim1
];
492 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
493 257] * b
[l
+ (j
+ 1) * b_dim1
];
494 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
495 257] * b
[l
+ (j
+ 2) * b_dim1
];
496 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
497 257] * b
[l
+ (j
+ 3) * b_dim1
];
499 c
[i
+ j
* c_dim1
] = f11
;
500 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
501 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
502 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
509 for (j
= jj
+ ujsec
; j
<= i4
; ++j
)
512 for (i
= ii
; i
<= i5
; i
+= 4)
514 f11
= c
[i
+ j
* c_dim1
];
515 f21
= c
[i
+ 1 + j
* c_dim1
];
516 f31
= c
[i
+ 2 + j
* c_dim1
];
517 f41
= c
[i
+ 3 + j
* c_dim1
];
519 for (l
= ll
; l
<= i6
; ++l
)
521 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
522 257] * b
[l
+ j
* b_dim1
];
523 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) -
524 257] * b
[l
+ j
* b_dim1
];
525 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) -
526 257] * b
[l
+ j
* b_dim1
];
527 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) -
528 257] * b
[l
+ j
* b_dim1
];
530 c
[i
+ j
* c_dim1
] = f11
;
531 c
[i
+ 1 + j
* c_dim1
] = f21
;
532 c
[i
+ 2 + j
* c_dim1
] = f31
;
533 c
[i
+ 3 + j
* c_dim1
] = f41
;
536 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
538 f11
= c
[i
+ j
* c_dim1
];
540 for (l
= ll
; l
<= i6
; ++l
)
542 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
543 257] * b
[l
+ j
* b_dim1
];
545 c
[i
+ j
* c_dim1
] = f11
;
555 else if (rxstride
== 1 && aystride
== 1 && bxstride
== 1)
557 if (GFC_DESCRIPTOR_RANK (a
) != 1)
559 const GFC_REAL_8
*restrict abase_x
;
560 const GFC_REAL_8
*restrict bbase_y
;
561 GFC_REAL_8
*restrict dest_y
;
564 for (y
= 0; y
< ycount
; y
++)
566 bbase_y
= &bbase
[y
*bystride
];
567 dest_y
= &dest
[y
*rystride
];
568 for (x
= 0; x
< xcount
; x
++)
570 abase_x
= &abase
[x
*axstride
];
572 for (n
= 0; n
< count
; n
++)
573 s
+= abase_x
[n
] * bbase_y
[n
];
580 const GFC_REAL_8
*restrict bbase_y
;
583 for (y
= 0; y
< ycount
; y
++)
585 bbase_y
= &bbase
[y
*bystride
];
587 for (n
= 0; n
< count
; n
++)
588 s
+= abase
[n
*axstride
] * bbase_y
[n
];
589 dest
[y
*rystride
] = s
;
593 else if (GFC_DESCRIPTOR_RANK (a
) == 1)
595 const GFC_REAL_8
*restrict bbase_y
;
598 for (y
= 0; y
< ycount
; y
++)
600 bbase_y
= &bbase
[y
*bystride
];
602 for (n
= 0; n
< count
; n
++)
603 s
+= abase
[n
*axstride
] * bbase_y
[n
*bxstride
];
604 dest
[y
*rxstride
] = s
;
607 else if (axstride
< aystride
)
609 for (y
= 0; y
< ycount
; y
++)
610 for (x
= 0; x
< xcount
; x
++)
611 dest
[x
*rxstride
+ y
*rystride
] = (GFC_REAL_8
)0;
613 for (y
= 0; y
< ycount
; y
++)
614 for (n
= 0; n
< count
; n
++)
615 for (x
= 0; x
< xcount
; x
++)
616 /* dest[x,y] += a[x,n] * b[n,y] */
617 dest
[x
*rxstride
+ y
*rystride
] +=
618 abase
[x
*axstride
+ n
*aystride
] *
619 bbase
[n
*bxstride
+ y
*bystride
];
623 const GFC_REAL_8
*restrict abase_x
;
624 const GFC_REAL_8
*restrict bbase_y
;
625 GFC_REAL_8
*restrict dest_y
;
628 for (y
= 0; y
< ycount
; y
++)
630 bbase_y
= &bbase
[y
*bystride
];
631 dest_y
= &dest
[y
*rystride
];
632 for (x
= 0; x
< xcount
; x
++)
634 abase_x
= &abase
[x
*axstride
];
636 for (n
= 0; n
< count
; n
++)
637 s
+= abase_x
[n
*aystride
] * bbase_y
[n
*bxstride
];
638 dest_y
[x
*rxstride
] = s
;
647 #endif /* HAVE_AVX */
651 matmul_r8_avx2 (gfc_array_r8
* const restrict retarray
,
652 gfc_array_r8
* const restrict a
, gfc_array_r8
* const restrict b
, int try_blas
,
653 int blas_limit
, blas_call gemm
) __attribute__((__target__("avx2,fma")));
655 matmul_r8_avx2 (gfc_array_r8
* const restrict retarray
,
656 gfc_array_r8
* const restrict a
, gfc_array_r8
* const restrict b
, int try_blas
,
657 int blas_limit
, blas_call gemm
)
659 const GFC_REAL_8
* restrict abase
;
660 const GFC_REAL_8
* restrict bbase
;
661 GFC_REAL_8
* restrict dest
;
663 index_type rxstride
, rystride
, axstride
, aystride
, bxstride
, bystride
;
664 index_type x
, y
, n
, count
, xcount
, ycount
;
666 assert (GFC_DESCRIPTOR_RANK (a
) == 2
667 || GFC_DESCRIPTOR_RANK (b
) == 2);
669 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
671 Either A or B (but not both) can be rank 1:
673 o One-dimensional argument A is implicitly treated as a row matrix
674 dimensioned [1,count], so xcount=1.
676 o One-dimensional argument B is implicitly treated as a column matrix
677 dimensioned [count, 1], so ycount=1.
680 if (retarray
->base_addr
== NULL
)
682 if (GFC_DESCRIPTOR_RANK (a
) == 1)
684 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
685 GFC_DESCRIPTOR_EXTENT(b
,1) - 1, 1);
687 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
689 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
690 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
694 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
695 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
697 GFC_DIMENSION_SET(retarray
->dim
[1], 0,
698 GFC_DESCRIPTOR_EXTENT(b
,1) - 1,
699 GFC_DESCRIPTOR_EXTENT(retarray
,0));
703 = xmallocarray (size0 ((array_t
*) retarray
), sizeof (GFC_REAL_8
));
704 retarray
->offset
= 0;
706 else if (unlikely (compile_options
.bounds_check
))
708 index_type ret_extent
, arg_extent
;
710 if (GFC_DESCRIPTOR_RANK (a
) == 1)
712 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
713 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
714 if (arg_extent
!= ret_extent
)
715 runtime_error ("Array bound mismatch for dimension 1 of "
717 (long int) ret_extent
, (long int) arg_extent
);
719 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
721 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
722 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
723 if (arg_extent
!= ret_extent
)
724 runtime_error ("Array bound mismatch for dimension 1 of "
726 (long int) ret_extent
, (long int) arg_extent
);
730 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
731 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
732 if (arg_extent
!= ret_extent
)
733 runtime_error ("Array bound mismatch for dimension 1 of "
735 (long int) ret_extent
, (long int) arg_extent
);
737 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
738 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,1);
739 if (arg_extent
!= ret_extent
)
740 runtime_error ("Array bound mismatch for dimension 2 of "
742 (long int) ret_extent
, (long int) arg_extent
);
747 if (GFC_DESCRIPTOR_RANK (retarray
) == 1)
749 /* One-dimensional result may be addressed in the code below
750 either as a row or a column matrix. We want both cases to
752 rxstride
= rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
756 rxstride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
757 rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,1);
761 if (GFC_DESCRIPTOR_RANK (a
) == 1)
763 /* Treat it as a a row matrix A[1,count]. */
764 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
768 count
= GFC_DESCRIPTOR_EXTENT(a
,0);
772 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
773 aystride
= GFC_DESCRIPTOR_STRIDE(a
,1);
775 count
= GFC_DESCRIPTOR_EXTENT(a
,1);
776 xcount
= GFC_DESCRIPTOR_EXTENT(a
,0);
779 if (count
!= GFC_DESCRIPTOR_EXTENT(b
,0))
781 if (count
> 0 || GFC_DESCRIPTOR_EXTENT(b
,0) > 0)
782 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
783 "in dimension 1: is %ld, should be %ld",
784 (long int) GFC_DESCRIPTOR_EXTENT(b
,0), (long int) count
);
787 if (GFC_DESCRIPTOR_RANK (b
) == 1)
789 /* Treat it as a column matrix B[count,1] */
790 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
792 /* bystride should never be used for 1-dimensional b.
793 The value is only used for calculation of the
794 memory by the buffer. */
800 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
801 bystride
= GFC_DESCRIPTOR_STRIDE(b
,1);
802 ycount
= GFC_DESCRIPTOR_EXTENT(b
,1);
805 abase
= a
->base_addr
;
806 bbase
= b
->base_addr
;
807 dest
= retarray
->base_addr
;
809 /* Now that everything is set up, we perform the multiplication
812 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
813 #define min(a,b) ((a) <= (b) ? (a) : (b))
814 #define max(a,b) ((a) >= (b) ? (a) : (b))
816 if (try_blas
&& rxstride
== 1 && (axstride
== 1 || aystride
== 1)
817 && (bxstride
== 1 || bystride
== 1)
818 && (((float) xcount
) * ((float) ycount
) * ((float) count
)
821 const int m
= xcount
, n
= ycount
, k
= count
, ldc
= rystride
;
822 const GFC_REAL_8 one
= 1, zero
= 0;
823 const int lda
= (axstride
== 1) ? aystride
: axstride
,
824 ldb
= (bxstride
== 1) ? bystride
: bxstride
;
826 if (lda
> 0 && ldb
> 0 && ldc
> 0 && m
> 1 && n
> 1 && k
> 1)
828 assert (gemm
!= NULL
);
829 const char *transa
, *transb
;
833 transa
= axstride
== 1 ? "N" : "T";
838 transb
= bxstride
== 1 ? "N" : "T";
840 gemm (transa
, transb
, &m
,
841 &n
, &k
, &one
, abase
, &lda
, bbase
, &ldb
, &zero
, dest
,
847 if (rxstride
== 1 && axstride
== 1 && bxstride
== 1)
849 /* This block of code implements a tuned matmul, derived from
850 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
852 Bo Kagstrom and Per Ling
853 Department of Computing Science
855 S-901 87 Umea, Sweden
857 from netlib.org, translated to C, and modified for matmul.m4. */
859 const GFC_REAL_8
*a
, *b
;
861 const index_type m
= xcount
, n
= ycount
, k
= count
;
863 /* System generated locals */
864 index_type a_dim1
, a_offset
, b_dim1
, b_offset
, c_dim1
, c_offset
,
865 i1
, i2
, i3
, i4
, i5
, i6
;
867 /* Local variables */
868 GFC_REAL_8 f11
, f12
, f21
, f22
, f31
, f32
, f41
, f42
,
869 f13
, f14
, f23
, f24
, f33
, f34
, f43
, f44
;
870 index_type i
, j
, l
, ii
, jj
, ll
;
871 index_type isec
, jsec
, lsec
, uisec
, ujsec
, ulsec
;
876 c
= retarray
->base_addr
;
878 /* Parameter adjustments */
880 c_offset
= 1 + c_dim1
;
883 a_offset
= 1 + a_dim1
;
886 b_offset
= 1 + b_dim1
;
892 c
[i
+ j
* c_dim1
] = (GFC_REAL_8
)0;
894 /* Early exit if possible */
895 if (m
== 0 || n
== 0 || k
== 0)
898 /* Adjust size of t1 to what is needed. */
899 index_type t1_dim
, a_sz
;
905 t1_dim
= a_sz
* 256 + b_dim1
;
909 t1
= malloc (t1_dim
* sizeof(GFC_REAL_8
));
911 /* Start turning the crank. */
913 for (jj
= 1; jj
<= i1
; jj
+= 512)
919 ujsec
= jsec
- jsec
% 4;
921 for (ll
= 1; ll
<= i2
; ll
+= 256)
927 ulsec
= lsec
- lsec
% 2;
930 for (ii
= 1; ii
<= i3
; ii
+= 256)
936 uisec
= isec
- isec
% 2;
938 for (l
= ll
; l
<= i4
; l
+= 2)
941 for (i
= ii
; i
<= i5
; i
+= 2)
943 t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257] =
945 t1
[l
- ll
+ 2 + ((i
- ii
+ 1) << 8) - 257] =
946 a
[i
+ (l
+ 1) * a_dim1
];
947 t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257] =
948 a
[i
+ 1 + l
* a_dim1
];
949 t1
[l
- ll
+ 2 + ((i
- ii
+ 2) << 8) - 257] =
950 a
[i
+ 1 + (l
+ 1) * a_dim1
];
954 t1
[l
- ll
+ 1 + (isec
<< 8) - 257] =
955 a
[ii
+ isec
- 1 + l
* a_dim1
];
956 t1
[l
- ll
+ 2 + (isec
<< 8) - 257] =
957 a
[ii
+ isec
- 1 + (l
+ 1) * a_dim1
];
963 for (i
= ii
; i
<= i4
; ++i
)
965 t1
[lsec
+ ((i
- ii
+ 1) << 8) - 257] =
966 a
[i
+ (ll
+ lsec
- 1) * a_dim1
];
970 uisec
= isec
- isec
% 4;
972 for (j
= jj
; j
<= i4
; j
+= 4)
975 for (i
= ii
; i
<= i5
; i
+= 4)
977 f11
= c
[i
+ j
* c_dim1
];
978 f21
= c
[i
+ 1 + j
* c_dim1
];
979 f12
= c
[i
+ (j
+ 1) * c_dim1
];
980 f22
= c
[i
+ 1 + (j
+ 1) * c_dim1
];
981 f13
= c
[i
+ (j
+ 2) * c_dim1
];
982 f23
= c
[i
+ 1 + (j
+ 2) * c_dim1
];
983 f14
= c
[i
+ (j
+ 3) * c_dim1
];
984 f24
= c
[i
+ 1 + (j
+ 3) * c_dim1
];
985 f31
= c
[i
+ 2 + j
* c_dim1
];
986 f41
= c
[i
+ 3 + j
* c_dim1
];
987 f32
= c
[i
+ 2 + (j
+ 1) * c_dim1
];
988 f42
= c
[i
+ 3 + (j
+ 1) * c_dim1
];
989 f33
= c
[i
+ 2 + (j
+ 2) * c_dim1
];
990 f43
= c
[i
+ 3 + (j
+ 2) * c_dim1
];
991 f34
= c
[i
+ 2 + (j
+ 3) * c_dim1
];
992 f44
= c
[i
+ 3 + (j
+ 3) * c_dim1
];
994 for (l
= ll
; l
<= i6
; ++l
)
996 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
998 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
1000 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
1001 * b
[l
+ (j
+ 1) * b_dim1
];
1002 f22
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
1003 * b
[l
+ (j
+ 1) * b_dim1
];
1004 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
1005 * b
[l
+ (j
+ 2) * b_dim1
];
1006 f23
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
1007 * b
[l
+ (j
+ 2) * b_dim1
];
1008 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
1009 * b
[l
+ (j
+ 3) * b_dim1
];
1010 f24
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
1011 * b
[l
+ (j
+ 3) * b_dim1
];
1012 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
1013 * b
[l
+ j
* b_dim1
];
1014 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
1015 * b
[l
+ j
* b_dim1
];
1016 f32
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
1017 * b
[l
+ (j
+ 1) * b_dim1
];
1018 f42
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
1019 * b
[l
+ (j
+ 1) * b_dim1
];
1020 f33
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
1021 * b
[l
+ (j
+ 2) * b_dim1
];
1022 f43
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
1023 * b
[l
+ (j
+ 2) * b_dim1
];
1024 f34
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
1025 * b
[l
+ (j
+ 3) * b_dim1
];
1026 f44
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
1027 * b
[l
+ (j
+ 3) * b_dim1
];
1029 c
[i
+ j
* c_dim1
] = f11
;
1030 c
[i
+ 1 + j
* c_dim1
] = f21
;
1031 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
1032 c
[i
+ 1 + (j
+ 1) * c_dim1
] = f22
;
1033 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
1034 c
[i
+ 1 + (j
+ 2) * c_dim1
] = f23
;
1035 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
1036 c
[i
+ 1 + (j
+ 3) * c_dim1
] = f24
;
1037 c
[i
+ 2 + j
* c_dim1
] = f31
;
1038 c
[i
+ 3 + j
* c_dim1
] = f41
;
1039 c
[i
+ 2 + (j
+ 1) * c_dim1
] = f32
;
1040 c
[i
+ 3 + (j
+ 1) * c_dim1
] = f42
;
1041 c
[i
+ 2 + (j
+ 2) * c_dim1
] = f33
;
1042 c
[i
+ 3 + (j
+ 2) * c_dim1
] = f43
;
1043 c
[i
+ 2 + (j
+ 3) * c_dim1
] = f34
;
1044 c
[i
+ 3 + (j
+ 3) * c_dim1
] = f44
;
1049 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
1051 f11
= c
[i
+ j
* c_dim1
];
1052 f12
= c
[i
+ (j
+ 1) * c_dim1
];
1053 f13
= c
[i
+ (j
+ 2) * c_dim1
];
1054 f14
= c
[i
+ (j
+ 3) * c_dim1
];
1056 for (l
= ll
; l
<= i6
; ++l
)
1058 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1059 257] * b
[l
+ j
* b_dim1
];
1060 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1061 257] * b
[l
+ (j
+ 1) * b_dim1
];
1062 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1063 257] * b
[l
+ (j
+ 2) * b_dim1
];
1064 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1065 257] * b
[l
+ (j
+ 3) * b_dim1
];
1067 c
[i
+ j
* c_dim1
] = f11
;
1068 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
1069 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
1070 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
1077 for (j
= jj
+ ujsec
; j
<= i4
; ++j
)
1079 i5
= ii
+ uisec
- 1;
1080 for (i
= ii
; i
<= i5
; i
+= 4)
1082 f11
= c
[i
+ j
* c_dim1
];
1083 f21
= c
[i
+ 1 + j
* c_dim1
];
1084 f31
= c
[i
+ 2 + j
* c_dim1
];
1085 f41
= c
[i
+ 3 + j
* c_dim1
];
1087 for (l
= ll
; l
<= i6
; ++l
)
1089 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1090 257] * b
[l
+ j
* b_dim1
];
1091 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) -
1092 257] * b
[l
+ j
* b_dim1
];
1093 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) -
1094 257] * b
[l
+ j
* b_dim1
];
1095 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) -
1096 257] * b
[l
+ j
* b_dim1
];
1098 c
[i
+ j
* c_dim1
] = f11
;
1099 c
[i
+ 1 + j
* c_dim1
] = f21
;
1100 c
[i
+ 2 + j
* c_dim1
] = f31
;
1101 c
[i
+ 3 + j
* c_dim1
] = f41
;
1104 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
1106 f11
= c
[i
+ j
* c_dim1
];
1108 for (l
= ll
; l
<= i6
; ++l
)
1110 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1111 257] * b
[l
+ j
* b_dim1
];
1113 c
[i
+ j
* c_dim1
] = f11
;
1123 else if (rxstride
== 1 && aystride
== 1 && bxstride
== 1)
1125 if (GFC_DESCRIPTOR_RANK (a
) != 1)
1127 const GFC_REAL_8
*restrict abase_x
;
1128 const GFC_REAL_8
*restrict bbase_y
;
1129 GFC_REAL_8
*restrict dest_y
;
1132 for (y
= 0; y
< ycount
; y
++)
1134 bbase_y
= &bbase
[y
*bystride
];
1135 dest_y
= &dest
[y
*rystride
];
1136 for (x
= 0; x
< xcount
; x
++)
1138 abase_x
= &abase
[x
*axstride
];
1140 for (n
= 0; n
< count
; n
++)
1141 s
+= abase_x
[n
] * bbase_y
[n
];
1148 const GFC_REAL_8
*restrict bbase_y
;
1151 for (y
= 0; y
< ycount
; y
++)
1153 bbase_y
= &bbase
[y
*bystride
];
1155 for (n
= 0; n
< count
; n
++)
1156 s
+= abase
[n
*axstride
] * bbase_y
[n
];
1157 dest
[y
*rystride
] = s
;
1161 else if (GFC_DESCRIPTOR_RANK (a
) == 1)
1163 const GFC_REAL_8
*restrict bbase_y
;
1166 for (y
= 0; y
< ycount
; y
++)
1168 bbase_y
= &bbase
[y
*bystride
];
1170 for (n
= 0; n
< count
; n
++)
1171 s
+= abase
[n
*axstride
] * bbase_y
[n
*bxstride
];
1172 dest
[y
*rxstride
] = s
;
1175 else if (axstride
< aystride
)
1177 for (y
= 0; y
< ycount
; y
++)
1178 for (x
= 0; x
< xcount
; x
++)
1179 dest
[x
*rxstride
+ y
*rystride
] = (GFC_REAL_8
)0;
1181 for (y
= 0; y
< ycount
; y
++)
1182 for (n
= 0; n
< count
; n
++)
1183 for (x
= 0; x
< xcount
; x
++)
1184 /* dest[x,y] += a[x,n] * b[n,y] */
1185 dest
[x
*rxstride
+ y
*rystride
] +=
1186 abase
[x
*axstride
+ n
*aystride
] *
1187 bbase
[n
*bxstride
+ y
*bystride
];
1191 const GFC_REAL_8
*restrict abase_x
;
1192 const GFC_REAL_8
*restrict bbase_y
;
1193 GFC_REAL_8
*restrict dest_y
;
1196 for (y
= 0; y
< ycount
; y
++)
1198 bbase_y
= &bbase
[y
*bystride
];
1199 dest_y
= &dest
[y
*rystride
];
1200 for (x
= 0; x
< xcount
; x
++)
1202 abase_x
= &abase
[x
*axstride
];
1204 for (n
= 0; n
< count
; n
++)
1205 s
+= abase_x
[n
*aystride
] * bbase_y
[n
*bxstride
];
1206 dest_y
[x
*rxstride
] = s
;
1215 #endif /* HAVE_AVX2 */
1219 matmul_r8_avx512f (gfc_array_r8
* const restrict retarray
,
1220 gfc_array_r8
* const restrict a
, gfc_array_r8
* const restrict b
, int try_blas
,
1221 int blas_limit
, blas_call gemm
) __attribute__((__target__("avx512f")));
1223 matmul_r8_avx512f (gfc_array_r8
* const restrict retarray
,
1224 gfc_array_r8
* const restrict a
, gfc_array_r8
* const restrict b
, int try_blas
,
1225 int blas_limit
, blas_call gemm
)
1227 const GFC_REAL_8
* restrict abase
;
1228 const GFC_REAL_8
* restrict bbase
;
1229 GFC_REAL_8
* restrict dest
;
1231 index_type rxstride
, rystride
, axstride
, aystride
, bxstride
, bystride
;
1232 index_type x
, y
, n
, count
, xcount
, ycount
;
1234 assert (GFC_DESCRIPTOR_RANK (a
) == 2
1235 || GFC_DESCRIPTOR_RANK (b
) == 2);
1237 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
1239 Either A or B (but not both) can be rank 1:
1241 o One-dimensional argument A is implicitly treated as a row matrix
1242 dimensioned [1,count], so xcount=1.
1244 o One-dimensional argument B is implicitly treated as a column matrix
1245 dimensioned [count, 1], so ycount=1.
1248 if (retarray
->base_addr
== NULL
)
1250 if (GFC_DESCRIPTOR_RANK (a
) == 1)
1252 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
1253 GFC_DESCRIPTOR_EXTENT(b
,1) - 1, 1);
1255 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
1257 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
1258 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
1262 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
1263 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
1265 GFC_DIMENSION_SET(retarray
->dim
[1], 0,
1266 GFC_DESCRIPTOR_EXTENT(b
,1) - 1,
1267 GFC_DESCRIPTOR_EXTENT(retarray
,0));
1271 = xmallocarray (size0 ((array_t
*) retarray
), sizeof (GFC_REAL_8
));
1272 retarray
->offset
= 0;
1274 else if (unlikely (compile_options
.bounds_check
))
1276 index_type ret_extent
, arg_extent
;
1278 if (GFC_DESCRIPTOR_RANK (a
) == 1)
1280 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
1281 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
1282 if (arg_extent
!= ret_extent
)
1283 runtime_error ("Array bound mismatch for dimension 1 of "
1285 (long int) ret_extent
, (long int) arg_extent
);
1287 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
1289 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
1290 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
1291 if (arg_extent
!= ret_extent
)
1292 runtime_error ("Array bound mismatch for dimension 1 of "
1294 (long int) ret_extent
, (long int) arg_extent
);
1298 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
1299 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
1300 if (arg_extent
!= ret_extent
)
1301 runtime_error ("Array bound mismatch for dimension 1 of "
1303 (long int) ret_extent
, (long int) arg_extent
);
1305 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
1306 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,1);
1307 if (arg_extent
!= ret_extent
)
1308 runtime_error ("Array bound mismatch for dimension 2 of "
1310 (long int) ret_extent
, (long int) arg_extent
);
1315 if (GFC_DESCRIPTOR_RANK (retarray
) == 1)
1317 /* One-dimensional result may be addressed in the code below
1318 either as a row or a column matrix. We want both cases to
1320 rxstride
= rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
1324 rxstride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
1325 rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,1);
1329 if (GFC_DESCRIPTOR_RANK (a
) == 1)
1331 /* Treat it as a a row matrix A[1,count]. */
1332 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
1336 count
= GFC_DESCRIPTOR_EXTENT(a
,0);
1340 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
1341 aystride
= GFC_DESCRIPTOR_STRIDE(a
,1);
1343 count
= GFC_DESCRIPTOR_EXTENT(a
,1);
1344 xcount
= GFC_DESCRIPTOR_EXTENT(a
,0);
1347 if (count
!= GFC_DESCRIPTOR_EXTENT(b
,0))
1349 if (count
> 0 || GFC_DESCRIPTOR_EXTENT(b
,0) > 0)
1350 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
1351 "in dimension 1: is %ld, should be %ld",
1352 (long int) GFC_DESCRIPTOR_EXTENT(b
,0), (long int) count
);
1355 if (GFC_DESCRIPTOR_RANK (b
) == 1)
1357 /* Treat it as a column matrix B[count,1] */
1358 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
1360 /* bystride should never be used for 1-dimensional b.
1361 The value is only used for calculation of the
1362 memory by the buffer. */
1368 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
1369 bystride
= GFC_DESCRIPTOR_STRIDE(b
,1);
1370 ycount
= GFC_DESCRIPTOR_EXTENT(b
,1);
1373 abase
= a
->base_addr
;
1374 bbase
= b
->base_addr
;
1375 dest
= retarray
->base_addr
;
1377 /* Now that everything is set up, we perform the multiplication
1380 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
1381 #define min(a,b) ((a) <= (b) ? (a) : (b))
1382 #define max(a,b) ((a) >= (b) ? (a) : (b))
1384 if (try_blas
&& rxstride
== 1 && (axstride
== 1 || aystride
== 1)
1385 && (bxstride
== 1 || bystride
== 1)
1386 && (((float) xcount
) * ((float) ycount
) * ((float) count
)
1387 > POW3(blas_limit
)))
1389 const int m
= xcount
, n
= ycount
, k
= count
, ldc
= rystride
;
1390 const GFC_REAL_8 one
= 1, zero
= 0;
1391 const int lda
= (axstride
== 1) ? aystride
: axstride
,
1392 ldb
= (bxstride
== 1) ? bystride
: bxstride
;
1394 if (lda
> 0 && ldb
> 0 && ldc
> 0 && m
> 1 && n
> 1 && k
> 1)
1396 assert (gemm
!= NULL
);
1397 const char *transa
, *transb
;
1401 transa
= axstride
== 1 ? "N" : "T";
1406 transb
= bxstride
== 1 ? "N" : "T";
1408 gemm (transa
, transb
, &m
,
1409 &n
, &k
, &one
, abase
, &lda
, bbase
, &ldb
, &zero
, dest
,
1415 if (rxstride
== 1 && axstride
== 1 && bxstride
== 1)
1417 /* This block of code implements a tuned matmul, derived from
1418 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
1420 Bo Kagstrom and Per Ling
1421 Department of Computing Science
1423 S-901 87 Umea, Sweden
1425 from netlib.org, translated to C, and modified for matmul.m4. */
1427 const GFC_REAL_8
*a
, *b
;
1429 const index_type m
= xcount
, n
= ycount
, k
= count
;
1431 /* System generated locals */
1432 index_type a_dim1
, a_offset
, b_dim1
, b_offset
, c_dim1
, c_offset
,
1433 i1
, i2
, i3
, i4
, i5
, i6
;
1435 /* Local variables */
1436 GFC_REAL_8 f11
, f12
, f21
, f22
, f31
, f32
, f41
, f42
,
1437 f13
, f14
, f23
, f24
, f33
, f34
, f43
, f44
;
1438 index_type i
, j
, l
, ii
, jj
, ll
;
1439 index_type isec
, jsec
, lsec
, uisec
, ujsec
, ulsec
;
1444 c
= retarray
->base_addr
;
1446 /* Parameter adjustments */
1448 c_offset
= 1 + c_dim1
;
1451 a_offset
= 1 + a_dim1
;
1454 b_offset
= 1 + b_dim1
;
1457 /* Empty c first. */
1458 for (j
=1; j
<=n
; j
++)
1459 for (i
=1; i
<=m
; i
++)
1460 c
[i
+ j
* c_dim1
] = (GFC_REAL_8
)0;
1462 /* Early exit if possible */
1463 if (m
== 0 || n
== 0 || k
== 0)
1466 /* Adjust size of t1 to what is needed. */
1467 index_type t1_dim
, a_sz
;
1473 t1_dim
= a_sz
* 256 + b_dim1
;
1477 t1
= malloc (t1_dim
* sizeof(GFC_REAL_8
));
1479 /* Start turning the crank. */
1481 for (jj
= 1; jj
<= i1
; jj
+= 512)
1487 ujsec
= jsec
- jsec
% 4;
1489 for (ll
= 1; ll
<= i2
; ll
+= 256)
1495 ulsec
= lsec
- lsec
% 2;
1498 for (ii
= 1; ii
<= i3
; ii
+= 256)
1504 uisec
= isec
- isec
% 2;
1505 i4
= ll
+ ulsec
- 1;
1506 for (l
= ll
; l
<= i4
; l
+= 2)
1508 i5
= ii
+ uisec
- 1;
1509 for (i
= ii
; i
<= i5
; i
+= 2)
1511 t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257] =
1513 t1
[l
- ll
+ 2 + ((i
- ii
+ 1) << 8) - 257] =
1514 a
[i
+ (l
+ 1) * a_dim1
];
1515 t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257] =
1516 a
[i
+ 1 + l
* a_dim1
];
1517 t1
[l
- ll
+ 2 + ((i
- ii
+ 2) << 8) - 257] =
1518 a
[i
+ 1 + (l
+ 1) * a_dim1
];
1522 t1
[l
- ll
+ 1 + (isec
<< 8) - 257] =
1523 a
[ii
+ isec
- 1 + l
* a_dim1
];
1524 t1
[l
- ll
+ 2 + (isec
<< 8) - 257] =
1525 a
[ii
+ isec
- 1 + (l
+ 1) * a_dim1
];
1531 for (i
= ii
; i
<= i4
; ++i
)
1533 t1
[lsec
+ ((i
- ii
+ 1) << 8) - 257] =
1534 a
[i
+ (ll
+ lsec
- 1) * a_dim1
];
1538 uisec
= isec
- isec
% 4;
1539 i4
= jj
+ ujsec
- 1;
1540 for (j
= jj
; j
<= i4
; j
+= 4)
1542 i5
= ii
+ uisec
- 1;
1543 for (i
= ii
; i
<= i5
; i
+= 4)
1545 f11
= c
[i
+ j
* c_dim1
];
1546 f21
= c
[i
+ 1 + j
* c_dim1
];
1547 f12
= c
[i
+ (j
+ 1) * c_dim1
];
1548 f22
= c
[i
+ 1 + (j
+ 1) * c_dim1
];
1549 f13
= c
[i
+ (j
+ 2) * c_dim1
];
1550 f23
= c
[i
+ 1 + (j
+ 2) * c_dim1
];
1551 f14
= c
[i
+ (j
+ 3) * c_dim1
];
1552 f24
= c
[i
+ 1 + (j
+ 3) * c_dim1
];
1553 f31
= c
[i
+ 2 + j
* c_dim1
];
1554 f41
= c
[i
+ 3 + j
* c_dim1
];
1555 f32
= c
[i
+ 2 + (j
+ 1) * c_dim1
];
1556 f42
= c
[i
+ 3 + (j
+ 1) * c_dim1
];
1557 f33
= c
[i
+ 2 + (j
+ 2) * c_dim1
];
1558 f43
= c
[i
+ 3 + (j
+ 2) * c_dim1
];
1559 f34
= c
[i
+ 2 + (j
+ 3) * c_dim1
];
1560 f44
= c
[i
+ 3 + (j
+ 3) * c_dim1
];
1562 for (l
= ll
; l
<= i6
; ++l
)
1564 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
1565 * b
[l
+ j
* b_dim1
];
1566 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
1567 * b
[l
+ j
* b_dim1
];
1568 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
1569 * b
[l
+ (j
+ 1) * b_dim1
];
1570 f22
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
1571 * b
[l
+ (j
+ 1) * b_dim1
];
1572 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
1573 * b
[l
+ (j
+ 2) * b_dim1
];
1574 f23
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
1575 * b
[l
+ (j
+ 2) * b_dim1
];
1576 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
1577 * b
[l
+ (j
+ 3) * b_dim1
];
1578 f24
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
1579 * b
[l
+ (j
+ 3) * b_dim1
];
1580 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
1581 * b
[l
+ j
* b_dim1
];
1582 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
1583 * b
[l
+ j
* b_dim1
];
1584 f32
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
1585 * b
[l
+ (j
+ 1) * b_dim1
];
1586 f42
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
1587 * b
[l
+ (j
+ 1) * b_dim1
];
1588 f33
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
1589 * b
[l
+ (j
+ 2) * b_dim1
];
1590 f43
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
1591 * b
[l
+ (j
+ 2) * b_dim1
];
1592 f34
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
1593 * b
[l
+ (j
+ 3) * b_dim1
];
1594 f44
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
1595 * b
[l
+ (j
+ 3) * b_dim1
];
1597 c
[i
+ j
* c_dim1
] = f11
;
1598 c
[i
+ 1 + j
* c_dim1
] = f21
;
1599 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
1600 c
[i
+ 1 + (j
+ 1) * c_dim1
] = f22
;
1601 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
1602 c
[i
+ 1 + (j
+ 2) * c_dim1
] = f23
;
1603 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
1604 c
[i
+ 1 + (j
+ 3) * c_dim1
] = f24
;
1605 c
[i
+ 2 + j
* c_dim1
] = f31
;
1606 c
[i
+ 3 + j
* c_dim1
] = f41
;
1607 c
[i
+ 2 + (j
+ 1) * c_dim1
] = f32
;
1608 c
[i
+ 3 + (j
+ 1) * c_dim1
] = f42
;
1609 c
[i
+ 2 + (j
+ 2) * c_dim1
] = f33
;
1610 c
[i
+ 3 + (j
+ 2) * c_dim1
] = f43
;
1611 c
[i
+ 2 + (j
+ 3) * c_dim1
] = f34
;
1612 c
[i
+ 3 + (j
+ 3) * c_dim1
] = f44
;
1617 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
1619 f11
= c
[i
+ j
* c_dim1
];
1620 f12
= c
[i
+ (j
+ 1) * c_dim1
];
1621 f13
= c
[i
+ (j
+ 2) * c_dim1
];
1622 f14
= c
[i
+ (j
+ 3) * c_dim1
];
1624 for (l
= ll
; l
<= i6
; ++l
)
1626 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1627 257] * b
[l
+ j
* b_dim1
];
1628 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1629 257] * b
[l
+ (j
+ 1) * b_dim1
];
1630 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1631 257] * b
[l
+ (j
+ 2) * b_dim1
];
1632 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1633 257] * b
[l
+ (j
+ 3) * b_dim1
];
1635 c
[i
+ j
* c_dim1
] = f11
;
1636 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
1637 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
1638 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
1645 for (j
= jj
+ ujsec
; j
<= i4
; ++j
)
1647 i5
= ii
+ uisec
- 1;
1648 for (i
= ii
; i
<= i5
; i
+= 4)
1650 f11
= c
[i
+ j
* c_dim1
];
1651 f21
= c
[i
+ 1 + j
* c_dim1
];
1652 f31
= c
[i
+ 2 + j
* c_dim1
];
1653 f41
= c
[i
+ 3 + j
* c_dim1
];
1655 for (l
= ll
; l
<= i6
; ++l
)
1657 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1658 257] * b
[l
+ j
* b_dim1
];
1659 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) -
1660 257] * b
[l
+ j
* b_dim1
];
1661 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) -
1662 257] * b
[l
+ j
* b_dim1
];
1663 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) -
1664 257] * b
[l
+ j
* b_dim1
];
1666 c
[i
+ j
* c_dim1
] = f11
;
1667 c
[i
+ 1 + j
* c_dim1
] = f21
;
1668 c
[i
+ 2 + j
* c_dim1
] = f31
;
1669 c
[i
+ 3 + j
* c_dim1
] = f41
;
1672 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
1674 f11
= c
[i
+ j
* c_dim1
];
1676 for (l
= ll
; l
<= i6
; ++l
)
1678 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1679 257] * b
[l
+ j
* b_dim1
];
1681 c
[i
+ j
* c_dim1
] = f11
;
1691 else if (rxstride
== 1 && aystride
== 1 && bxstride
== 1)
1693 if (GFC_DESCRIPTOR_RANK (a
) != 1)
1695 const GFC_REAL_8
*restrict abase_x
;
1696 const GFC_REAL_8
*restrict bbase_y
;
1697 GFC_REAL_8
*restrict dest_y
;
1700 for (y
= 0; y
< ycount
; y
++)
1702 bbase_y
= &bbase
[y
*bystride
];
1703 dest_y
= &dest
[y
*rystride
];
1704 for (x
= 0; x
< xcount
; x
++)
1706 abase_x
= &abase
[x
*axstride
];
1708 for (n
= 0; n
< count
; n
++)
1709 s
+= abase_x
[n
] * bbase_y
[n
];
1716 const GFC_REAL_8
*restrict bbase_y
;
1719 for (y
= 0; y
< ycount
; y
++)
1721 bbase_y
= &bbase
[y
*bystride
];
1723 for (n
= 0; n
< count
; n
++)
1724 s
+= abase
[n
*axstride
] * bbase_y
[n
];
1725 dest
[y
*rystride
] = s
;
1729 else if (GFC_DESCRIPTOR_RANK (a
) == 1)
1731 const GFC_REAL_8
*restrict bbase_y
;
1734 for (y
= 0; y
< ycount
; y
++)
1736 bbase_y
= &bbase
[y
*bystride
];
1738 for (n
= 0; n
< count
; n
++)
1739 s
+= abase
[n
*axstride
] * bbase_y
[n
*bxstride
];
1740 dest
[y
*rxstride
] = s
;
1743 else if (axstride
< aystride
)
1745 for (y
= 0; y
< ycount
; y
++)
1746 for (x
= 0; x
< xcount
; x
++)
1747 dest
[x
*rxstride
+ y
*rystride
] = (GFC_REAL_8
)0;
1749 for (y
= 0; y
< ycount
; y
++)
1750 for (n
= 0; n
< count
; n
++)
1751 for (x
= 0; x
< xcount
; x
++)
1752 /* dest[x,y] += a[x,n] * b[n,y] */
1753 dest
[x
*rxstride
+ y
*rystride
] +=
1754 abase
[x
*axstride
+ n
*aystride
] *
1755 bbase
[n
*bxstride
+ y
*bystride
];
1759 const GFC_REAL_8
*restrict abase_x
;
1760 const GFC_REAL_8
*restrict bbase_y
;
1761 GFC_REAL_8
*restrict dest_y
;
1764 for (y
= 0; y
< ycount
; y
++)
1766 bbase_y
= &bbase
[y
*bystride
];
1767 dest_y
= &dest
[y
*rystride
];
1768 for (x
= 0; x
< xcount
; x
++)
1770 abase_x
= &abase
[x
*axstride
];
1772 for (n
= 0; n
< count
; n
++)
1773 s
+= abase_x
[n
*aystride
] * bbase_y
[n
*bxstride
];
1774 dest_y
[x
*rxstride
] = s
;
1783 #endif /* HAVE_AVX512F */
1785 /* AMD-specifix funtions with AVX128 and FMA3/FMA4. */
1787 #if defined(HAVE_AVX) && defined(HAVE_FMA3) && defined(HAVE_AVX128)
1789 matmul_r8_avx128_fma3 (gfc_array_r8
* const restrict retarray
,
1790 gfc_array_r8
* const restrict a
, gfc_array_r8
* const restrict b
, int try_blas
,
1791 int blas_limit
, blas_call gemm
) __attribute__((__target__("avx,fma")));
1792 internal_proto(matmul_r8_avx128_fma3
);
1795 #if defined(HAVE_AVX) && defined(HAVE_FMA4) && defined(HAVE_AVX128)
1797 matmul_r8_avx128_fma4 (gfc_array_r8
* const restrict retarray
,
1798 gfc_array_r8
* const restrict a
, gfc_array_r8
* const restrict b
, int try_blas
,
1799 int blas_limit
, blas_call gemm
) __attribute__((__target__("avx,fma4")));
1800 internal_proto(matmul_r8_avx128_fma4
);
1803 /* Function to fall back to if there is no special processor-specific version. */
1805 matmul_r8_vanilla (gfc_array_r8
* const restrict retarray
,
1806 gfc_array_r8
* const restrict a
, gfc_array_r8
* const restrict b
, int try_blas
,
1807 int blas_limit
, blas_call gemm
)
1809 const GFC_REAL_8
* restrict abase
;
1810 const GFC_REAL_8
* restrict bbase
;
1811 GFC_REAL_8
* restrict dest
;
1813 index_type rxstride
, rystride
, axstride
, aystride
, bxstride
, bystride
;
1814 index_type x
, y
, n
, count
, xcount
, ycount
;
1816 assert (GFC_DESCRIPTOR_RANK (a
) == 2
1817 || GFC_DESCRIPTOR_RANK (b
) == 2);
1819 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
1821 Either A or B (but not both) can be rank 1:
1823 o One-dimensional argument A is implicitly treated as a row matrix
1824 dimensioned [1,count], so xcount=1.
1826 o One-dimensional argument B is implicitly treated as a column matrix
1827 dimensioned [count, 1], so ycount=1.
1830 if (retarray
->base_addr
== NULL
)
1832 if (GFC_DESCRIPTOR_RANK (a
) == 1)
1834 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
1835 GFC_DESCRIPTOR_EXTENT(b
,1) - 1, 1);
1837 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
1839 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
1840 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
1844 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
1845 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
1847 GFC_DIMENSION_SET(retarray
->dim
[1], 0,
1848 GFC_DESCRIPTOR_EXTENT(b
,1) - 1,
1849 GFC_DESCRIPTOR_EXTENT(retarray
,0));
1853 = xmallocarray (size0 ((array_t
*) retarray
), sizeof (GFC_REAL_8
));
1854 retarray
->offset
= 0;
1856 else if (unlikely (compile_options
.bounds_check
))
1858 index_type ret_extent
, arg_extent
;
1860 if (GFC_DESCRIPTOR_RANK (a
) == 1)
1862 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
1863 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
1864 if (arg_extent
!= ret_extent
)
1865 runtime_error ("Array bound mismatch for dimension 1 of "
1867 (long int) ret_extent
, (long int) arg_extent
);
1869 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
1871 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
1872 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
1873 if (arg_extent
!= ret_extent
)
1874 runtime_error ("Array bound mismatch for dimension 1 of "
1876 (long int) ret_extent
, (long int) arg_extent
);
1880 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
1881 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
1882 if (arg_extent
!= ret_extent
)
1883 runtime_error ("Array bound mismatch for dimension 1 of "
1885 (long int) ret_extent
, (long int) arg_extent
);
1887 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
1888 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,1);
1889 if (arg_extent
!= ret_extent
)
1890 runtime_error ("Array bound mismatch for dimension 2 of "
1892 (long int) ret_extent
, (long int) arg_extent
);
1897 if (GFC_DESCRIPTOR_RANK (retarray
) == 1)
1899 /* One-dimensional result may be addressed in the code below
1900 either as a row or a column matrix. We want both cases to
1902 rxstride
= rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
1906 rxstride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
1907 rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,1);
1911 if (GFC_DESCRIPTOR_RANK (a
) == 1)
1913 /* Treat it as a a row matrix A[1,count]. */
1914 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
1918 count
= GFC_DESCRIPTOR_EXTENT(a
,0);
1922 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
1923 aystride
= GFC_DESCRIPTOR_STRIDE(a
,1);
1925 count
= GFC_DESCRIPTOR_EXTENT(a
,1);
1926 xcount
= GFC_DESCRIPTOR_EXTENT(a
,0);
1929 if (count
!= GFC_DESCRIPTOR_EXTENT(b
,0))
1931 if (count
> 0 || GFC_DESCRIPTOR_EXTENT(b
,0) > 0)
1932 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
1933 "in dimension 1: is %ld, should be %ld",
1934 (long int) GFC_DESCRIPTOR_EXTENT(b
,0), (long int) count
);
1937 if (GFC_DESCRIPTOR_RANK (b
) == 1)
1939 /* Treat it as a column matrix B[count,1] */
1940 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
1942 /* bystride should never be used for 1-dimensional b.
1943 The value is only used for calculation of the
1944 memory by the buffer. */
1950 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
1951 bystride
= GFC_DESCRIPTOR_STRIDE(b
,1);
1952 ycount
= GFC_DESCRIPTOR_EXTENT(b
,1);
1955 abase
= a
->base_addr
;
1956 bbase
= b
->base_addr
;
1957 dest
= retarray
->base_addr
;
1959 /* Now that everything is set up, we perform the multiplication
1962 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
1963 #define min(a,b) ((a) <= (b) ? (a) : (b))
1964 #define max(a,b) ((a) >= (b) ? (a) : (b))
1966 if (try_blas
&& rxstride
== 1 && (axstride
== 1 || aystride
== 1)
1967 && (bxstride
== 1 || bystride
== 1)
1968 && (((float) xcount
) * ((float) ycount
) * ((float) count
)
1969 > POW3(blas_limit
)))
1971 const int m
= xcount
, n
= ycount
, k
= count
, ldc
= rystride
;
1972 const GFC_REAL_8 one
= 1, zero
= 0;
1973 const int lda
= (axstride
== 1) ? aystride
: axstride
,
1974 ldb
= (bxstride
== 1) ? bystride
: bxstride
;
1976 if (lda
> 0 && ldb
> 0 && ldc
> 0 && m
> 1 && n
> 1 && k
> 1)
1978 assert (gemm
!= NULL
);
1979 const char *transa
, *transb
;
1983 transa
= axstride
== 1 ? "N" : "T";
1988 transb
= bxstride
== 1 ? "N" : "T";
1990 gemm (transa
, transb
, &m
,
1991 &n
, &k
, &one
, abase
, &lda
, bbase
, &ldb
, &zero
, dest
,
1997 if (rxstride
== 1 && axstride
== 1 && bxstride
== 1)
1999 /* This block of code implements a tuned matmul, derived from
2000 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
2002 Bo Kagstrom and Per Ling
2003 Department of Computing Science
2005 S-901 87 Umea, Sweden
2007 from netlib.org, translated to C, and modified for matmul.m4. */
2009 const GFC_REAL_8
*a
, *b
;
2011 const index_type m
= xcount
, n
= ycount
, k
= count
;
2013 /* System generated locals */
2014 index_type a_dim1
, a_offset
, b_dim1
, b_offset
, c_dim1
, c_offset
,
2015 i1
, i2
, i3
, i4
, i5
, i6
;
2017 /* Local variables */
2018 GFC_REAL_8 f11
, f12
, f21
, f22
, f31
, f32
, f41
, f42
,
2019 f13
, f14
, f23
, f24
, f33
, f34
, f43
, f44
;
2020 index_type i
, j
, l
, ii
, jj
, ll
;
2021 index_type isec
, jsec
, lsec
, uisec
, ujsec
, ulsec
;
2026 c
= retarray
->base_addr
;
2028 /* Parameter adjustments */
2030 c_offset
= 1 + c_dim1
;
2033 a_offset
= 1 + a_dim1
;
2036 b_offset
= 1 + b_dim1
;
2039 /* Empty c first. */
2040 for (j
=1; j
<=n
; j
++)
2041 for (i
=1; i
<=m
; i
++)
2042 c
[i
+ j
* c_dim1
] = (GFC_REAL_8
)0;
2044 /* Early exit if possible */
2045 if (m
== 0 || n
== 0 || k
== 0)
2048 /* Adjust size of t1 to what is needed. */
2049 index_type t1_dim
, a_sz
;
2055 t1_dim
= a_sz
* 256 + b_dim1
;
2059 t1
= malloc (t1_dim
* sizeof(GFC_REAL_8
));
2061 /* Start turning the crank. */
2063 for (jj
= 1; jj
<= i1
; jj
+= 512)
2069 ujsec
= jsec
- jsec
% 4;
2071 for (ll
= 1; ll
<= i2
; ll
+= 256)
2077 ulsec
= lsec
- lsec
% 2;
2080 for (ii
= 1; ii
<= i3
; ii
+= 256)
2086 uisec
= isec
- isec
% 2;
2087 i4
= ll
+ ulsec
- 1;
2088 for (l
= ll
; l
<= i4
; l
+= 2)
2090 i5
= ii
+ uisec
- 1;
2091 for (i
= ii
; i
<= i5
; i
+= 2)
2093 t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257] =
2095 t1
[l
- ll
+ 2 + ((i
- ii
+ 1) << 8) - 257] =
2096 a
[i
+ (l
+ 1) * a_dim1
];
2097 t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257] =
2098 a
[i
+ 1 + l
* a_dim1
];
2099 t1
[l
- ll
+ 2 + ((i
- ii
+ 2) << 8) - 257] =
2100 a
[i
+ 1 + (l
+ 1) * a_dim1
];
2104 t1
[l
- ll
+ 1 + (isec
<< 8) - 257] =
2105 a
[ii
+ isec
- 1 + l
* a_dim1
];
2106 t1
[l
- ll
+ 2 + (isec
<< 8) - 257] =
2107 a
[ii
+ isec
- 1 + (l
+ 1) * a_dim1
];
2113 for (i
= ii
; i
<= i4
; ++i
)
2115 t1
[lsec
+ ((i
- ii
+ 1) << 8) - 257] =
2116 a
[i
+ (ll
+ lsec
- 1) * a_dim1
];
2120 uisec
= isec
- isec
% 4;
2121 i4
= jj
+ ujsec
- 1;
2122 for (j
= jj
; j
<= i4
; j
+= 4)
2124 i5
= ii
+ uisec
- 1;
2125 for (i
= ii
; i
<= i5
; i
+= 4)
2127 f11
= c
[i
+ j
* c_dim1
];
2128 f21
= c
[i
+ 1 + j
* c_dim1
];
2129 f12
= c
[i
+ (j
+ 1) * c_dim1
];
2130 f22
= c
[i
+ 1 + (j
+ 1) * c_dim1
];
2131 f13
= c
[i
+ (j
+ 2) * c_dim1
];
2132 f23
= c
[i
+ 1 + (j
+ 2) * c_dim1
];
2133 f14
= c
[i
+ (j
+ 3) * c_dim1
];
2134 f24
= c
[i
+ 1 + (j
+ 3) * c_dim1
];
2135 f31
= c
[i
+ 2 + j
* c_dim1
];
2136 f41
= c
[i
+ 3 + j
* c_dim1
];
2137 f32
= c
[i
+ 2 + (j
+ 1) * c_dim1
];
2138 f42
= c
[i
+ 3 + (j
+ 1) * c_dim1
];
2139 f33
= c
[i
+ 2 + (j
+ 2) * c_dim1
];
2140 f43
= c
[i
+ 3 + (j
+ 2) * c_dim1
];
2141 f34
= c
[i
+ 2 + (j
+ 3) * c_dim1
];
2142 f44
= c
[i
+ 3 + (j
+ 3) * c_dim1
];
2144 for (l
= ll
; l
<= i6
; ++l
)
2146 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
2147 * b
[l
+ j
* b_dim1
];
2148 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
2149 * b
[l
+ j
* b_dim1
];
2150 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
2151 * b
[l
+ (j
+ 1) * b_dim1
];
2152 f22
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
2153 * b
[l
+ (j
+ 1) * b_dim1
];
2154 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
2155 * b
[l
+ (j
+ 2) * b_dim1
];
2156 f23
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
2157 * b
[l
+ (j
+ 2) * b_dim1
];
2158 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
2159 * b
[l
+ (j
+ 3) * b_dim1
];
2160 f24
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
2161 * b
[l
+ (j
+ 3) * b_dim1
];
2162 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
2163 * b
[l
+ j
* b_dim1
];
2164 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
2165 * b
[l
+ j
* b_dim1
];
2166 f32
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
2167 * b
[l
+ (j
+ 1) * b_dim1
];
2168 f42
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
2169 * b
[l
+ (j
+ 1) * b_dim1
];
2170 f33
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
2171 * b
[l
+ (j
+ 2) * b_dim1
];
2172 f43
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
2173 * b
[l
+ (j
+ 2) * b_dim1
];
2174 f34
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
2175 * b
[l
+ (j
+ 3) * b_dim1
];
2176 f44
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
2177 * b
[l
+ (j
+ 3) * b_dim1
];
2179 c
[i
+ j
* c_dim1
] = f11
;
2180 c
[i
+ 1 + j
* c_dim1
] = f21
;
2181 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
2182 c
[i
+ 1 + (j
+ 1) * c_dim1
] = f22
;
2183 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
2184 c
[i
+ 1 + (j
+ 2) * c_dim1
] = f23
;
2185 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
2186 c
[i
+ 1 + (j
+ 3) * c_dim1
] = f24
;
2187 c
[i
+ 2 + j
* c_dim1
] = f31
;
2188 c
[i
+ 3 + j
* c_dim1
] = f41
;
2189 c
[i
+ 2 + (j
+ 1) * c_dim1
] = f32
;
2190 c
[i
+ 3 + (j
+ 1) * c_dim1
] = f42
;
2191 c
[i
+ 2 + (j
+ 2) * c_dim1
] = f33
;
2192 c
[i
+ 3 + (j
+ 2) * c_dim1
] = f43
;
2193 c
[i
+ 2 + (j
+ 3) * c_dim1
] = f34
;
2194 c
[i
+ 3 + (j
+ 3) * c_dim1
] = f44
;
2199 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
2201 f11
= c
[i
+ j
* c_dim1
];
2202 f12
= c
[i
+ (j
+ 1) * c_dim1
];
2203 f13
= c
[i
+ (j
+ 2) * c_dim1
];
2204 f14
= c
[i
+ (j
+ 3) * c_dim1
];
2206 for (l
= ll
; l
<= i6
; ++l
)
2208 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2209 257] * b
[l
+ j
* b_dim1
];
2210 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2211 257] * b
[l
+ (j
+ 1) * b_dim1
];
2212 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2213 257] * b
[l
+ (j
+ 2) * b_dim1
];
2214 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2215 257] * b
[l
+ (j
+ 3) * b_dim1
];
2217 c
[i
+ j
* c_dim1
] = f11
;
2218 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
2219 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
2220 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
2227 for (j
= jj
+ ujsec
; j
<= i4
; ++j
)
2229 i5
= ii
+ uisec
- 1;
2230 for (i
= ii
; i
<= i5
; i
+= 4)
2232 f11
= c
[i
+ j
* c_dim1
];
2233 f21
= c
[i
+ 1 + j
* c_dim1
];
2234 f31
= c
[i
+ 2 + j
* c_dim1
];
2235 f41
= c
[i
+ 3 + j
* c_dim1
];
2237 for (l
= ll
; l
<= i6
; ++l
)
2239 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2240 257] * b
[l
+ j
* b_dim1
];
2241 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) -
2242 257] * b
[l
+ j
* b_dim1
];
2243 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) -
2244 257] * b
[l
+ j
* b_dim1
];
2245 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) -
2246 257] * b
[l
+ j
* b_dim1
];
2248 c
[i
+ j
* c_dim1
] = f11
;
2249 c
[i
+ 1 + j
* c_dim1
] = f21
;
2250 c
[i
+ 2 + j
* c_dim1
] = f31
;
2251 c
[i
+ 3 + j
* c_dim1
] = f41
;
2254 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
2256 f11
= c
[i
+ j
* c_dim1
];
2258 for (l
= ll
; l
<= i6
; ++l
)
2260 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2261 257] * b
[l
+ j
* b_dim1
];
2263 c
[i
+ j
* c_dim1
] = f11
;
2273 else if (rxstride
== 1 && aystride
== 1 && bxstride
== 1)
2275 if (GFC_DESCRIPTOR_RANK (a
) != 1)
2277 const GFC_REAL_8
*restrict abase_x
;
2278 const GFC_REAL_8
*restrict bbase_y
;
2279 GFC_REAL_8
*restrict dest_y
;
2282 for (y
= 0; y
< ycount
; y
++)
2284 bbase_y
= &bbase
[y
*bystride
];
2285 dest_y
= &dest
[y
*rystride
];
2286 for (x
= 0; x
< xcount
; x
++)
2288 abase_x
= &abase
[x
*axstride
];
2290 for (n
= 0; n
< count
; n
++)
2291 s
+= abase_x
[n
] * bbase_y
[n
];
2298 const GFC_REAL_8
*restrict bbase_y
;
2301 for (y
= 0; y
< ycount
; y
++)
2303 bbase_y
= &bbase
[y
*bystride
];
2305 for (n
= 0; n
< count
; n
++)
2306 s
+= abase
[n
*axstride
] * bbase_y
[n
];
2307 dest
[y
*rystride
] = s
;
2311 else if (GFC_DESCRIPTOR_RANK (a
) == 1)
2313 const GFC_REAL_8
*restrict bbase_y
;
2316 for (y
= 0; y
< ycount
; y
++)
2318 bbase_y
= &bbase
[y
*bystride
];
2320 for (n
= 0; n
< count
; n
++)
2321 s
+= abase
[n
*axstride
] * bbase_y
[n
*bxstride
];
2322 dest
[y
*rxstride
] = s
;
2325 else if (axstride
< aystride
)
2327 for (y
= 0; y
< ycount
; y
++)
2328 for (x
= 0; x
< xcount
; x
++)
2329 dest
[x
*rxstride
+ y
*rystride
] = (GFC_REAL_8
)0;
2331 for (y
= 0; y
< ycount
; y
++)
2332 for (n
= 0; n
< count
; n
++)
2333 for (x
= 0; x
< xcount
; x
++)
2334 /* dest[x,y] += a[x,n] * b[n,y] */
2335 dest
[x
*rxstride
+ y
*rystride
] +=
2336 abase
[x
*axstride
+ n
*aystride
] *
2337 bbase
[n
*bxstride
+ y
*bystride
];
2341 const GFC_REAL_8
*restrict abase_x
;
2342 const GFC_REAL_8
*restrict bbase_y
;
2343 GFC_REAL_8
*restrict dest_y
;
2346 for (y
= 0; y
< ycount
; y
++)
2348 bbase_y
= &bbase
[y
*bystride
];
2349 dest_y
= &dest
[y
*rystride
];
2350 for (x
= 0; x
< xcount
; x
++)
2352 abase_x
= &abase
[x
*axstride
];
2354 for (n
= 0; n
< count
; n
++)
2355 s
+= abase_x
[n
*aystride
] * bbase_y
[n
*bxstride
];
2356 dest_y
[x
*rxstride
] = s
;
2366 /* Compiling main function, with selection code for the processor. */
2368 /* Currently, this is i386 only. Adjust for other architectures. */
2370 void matmul_r8 (gfc_array_r8
* const restrict retarray
,
2371 gfc_array_r8
* const restrict a
, gfc_array_r8
* const restrict b
, int try_blas
,
2372 int blas_limit
, blas_call gemm
)
2374 static void (*matmul_p
) (gfc_array_r8
* const restrict retarray
,
2375 gfc_array_r8
* const restrict a
, gfc_array_r8
* const restrict b
, int try_blas
,
2376 int blas_limit
, blas_call gemm
);
2378 void (*matmul_fn
) (gfc_array_r8
* const restrict retarray
,
2379 gfc_array_r8
* const restrict a
, gfc_array_r8
* const restrict b
, int try_blas
,
2380 int blas_limit
, blas_call gemm
);
2382 matmul_fn
= __atomic_load_n (&matmul_p
, __ATOMIC_RELAXED
);
2383 if (matmul_fn
== NULL
)
2385 matmul_fn
= matmul_r8_vanilla
;
2386 if (__builtin_cpu_is ("intel"))
2388 /* Run down the available processors in order of preference. */
2390 if (__builtin_cpu_supports ("avx512f"))
2392 matmul_fn
= matmul_r8_avx512f
;
2396 #endif /* HAVE_AVX512F */
2399 if (__builtin_cpu_supports ("avx2")
2400 && __builtin_cpu_supports ("fma"))
2402 matmul_fn
= matmul_r8_avx2
;
2409 if (__builtin_cpu_supports ("avx"))
2411 matmul_fn
= matmul_r8_avx
;
2414 #endif /* HAVE_AVX */
2416 else if (__builtin_cpu_is ("amd"))
2418 #if defined(HAVE_AVX) && defined(HAVE_FMA3) && defined(HAVE_AVX128)
2419 if (__builtin_cpu_supports ("avx")
2420 && __builtin_cpu_supports ("fma"))
2422 matmul_fn
= matmul_r8_avx128_fma3
;
2426 #if defined(HAVE_AVX) && defined(HAVE_FMA4) && defined(HAVE_AVX128)
2427 if (__builtin_cpu_supports ("avx")
2428 && __builtin_cpu_supports ("fma4"))
2430 matmul_fn
= matmul_r8_avx128_fma4
;
2437 __atomic_store_n (&matmul_p
, matmul_fn
, __ATOMIC_RELAXED
);
2440 (*matmul_fn
) (retarray
, a
, b
, try_blas
, blas_limit
, gemm
);
2443 #else /* Just the vanilla function. */
2446 matmul_r8 (gfc_array_r8
* const restrict retarray
,
2447 gfc_array_r8
* const restrict a
, gfc_array_r8
* const restrict b
, int try_blas
,
2448 int blas_limit
, blas_call gemm
)
2450 const GFC_REAL_8
* restrict abase
;
2451 const GFC_REAL_8
* restrict bbase
;
2452 GFC_REAL_8
* restrict dest
;
2454 index_type rxstride
, rystride
, axstride
, aystride
, bxstride
, bystride
;
2455 index_type x
, y
, n
, count
, xcount
, ycount
;
2457 assert (GFC_DESCRIPTOR_RANK (a
) == 2
2458 || GFC_DESCRIPTOR_RANK (b
) == 2);
2460 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
2462 Either A or B (but not both) can be rank 1:
2464 o One-dimensional argument A is implicitly treated as a row matrix
2465 dimensioned [1,count], so xcount=1.
2467 o One-dimensional argument B is implicitly treated as a column matrix
2468 dimensioned [count, 1], so ycount=1.
2471 if (retarray
->base_addr
== NULL
)
2473 if (GFC_DESCRIPTOR_RANK (a
) == 1)
2475 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
2476 GFC_DESCRIPTOR_EXTENT(b
,1) - 1, 1);
2478 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
2480 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
2481 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
2485 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
2486 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
2488 GFC_DIMENSION_SET(retarray
->dim
[1], 0,
2489 GFC_DESCRIPTOR_EXTENT(b
,1) - 1,
2490 GFC_DESCRIPTOR_EXTENT(retarray
,0));
2494 = xmallocarray (size0 ((array_t
*) retarray
), sizeof (GFC_REAL_8
));
2495 retarray
->offset
= 0;
2497 else if (unlikely (compile_options
.bounds_check
))
2499 index_type ret_extent
, arg_extent
;
2501 if (GFC_DESCRIPTOR_RANK (a
) == 1)
2503 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
2504 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
2505 if (arg_extent
!= ret_extent
)
2506 runtime_error ("Array bound mismatch for dimension 1 of "
2508 (long int) ret_extent
, (long int) arg_extent
);
2510 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
2512 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
2513 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
2514 if (arg_extent
!= ret_extent
)
2515 runtime_error ("Array bound mismatch for dimension 1 of "
2517 (long int) ret_extent
, (long int) arg_extent
);
2521 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
2522 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
2523 if (arg_extent
!= ret_extent
)
2524 runtime_error ("Array bound mismatch for dimension 1 of "
2526 (long int) ret_extent
, (long int) arg_extent
);
2528 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
2529 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,1);
2530 if (arg_extent
!= ret_extent
)
2531 runtime_error ("Array bound mismatch for dimension 2 of "
2533 (long int) ret_extent
, (long int) arg_extent
);
2538 if (GFC_DESCRIPTOR_RANK (retarray
) == 1)
2540 /* One-dimensional result may be addressed in the code below
2541 either as a row or a column matrix. We want both cases to
2543 rxstride
= rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
2547 rxstride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
2548 rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,1);
2552 if (GFC_DESCRIPTOR_RANK (a
) == 1)
2554 /* Treat it as a a row matrix A[1,count]. */
2555 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
2559 count
= GFC_DESCRIPTOR_EXTENT(a
,0);
2563 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
2564 aystride
= GFC_DESCRIPTOR_STRIDE(a
,1);
2566 count
= GFC_DESCRIPTOR_EXTENT(a
,1);
2567 xcount
= GFC_DESCRIPTOR_EXTENT(a
,0);
2570 if (count
!= GFC_DESCRIPTOR_EXTENT(b
,0))
2572 if (count
> 0 || GFC_DESCRIPTOR_EXTENT(b
,0) > 0)
2573 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
2574 "in dimension 1: is %ld, should be %ld",
2575 (long int) GFC_DESCRIPTOR_EXTENT(b
,0), (long int) count
);
2578 if (GFC_DESCRIPTOR_RANK (b
) == 1)
2580 /* Treat it as a column matrix B[count,1] */
2581 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
2583 /* bystride should never be used for 1-dimensional b.
2584 The value is only used for calculation of the
2585 memory by the buffer. */
2591 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
2592 bystride
= GFC_DESCRIPTOR_STRIDE(b
,1);
2593 ycount
= GFC_DESCRIPTOR_EXTENT(b
,1);
2596 abase
= a
->base_addr
;
2597 bbase
= b
->base_addr
;
2598 dest
= retarray
->base_addr
;
2600 /* Now that everything is set up, we perform the multiplication
2603 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
2604 #define min(a,b) ((a) <= (b) ? (a) : (b))
2605 #define max(a,b) ((a) >= (b) ? (a) : (b))
2607 if (try_blas
&& rxstride
== 1 && (axstride
== 1 || aystride
== 1)
2608 && (bxstride
== 1 || bystride
== 1)
2609 && (((float) xcount
) * ((float) ycount
) * ((float) count
)
2610 > POW3(blas_limit
)))
2612 const int m
= xcount
, n
= ycount
, k
= count
, ldc
= rystride
;
2613 const GFC_REAL_8 one
= 1, zero
= 0;
2614 const int lda
= (axstride
== 1) ? aystride
: axstride
,
2615 ldb
= (bxstride
== 1) ? bystride
: bxstride
;
2617 if (lda
> 0 && ldb
> 0 && ldc
> 0 && m
> 1 && n
> 1 && k
> 1)
2619 assert (gemm
!= NULL
);
2620 const char *transa
, *transb
;
2624 transa
= axstride
== 1 ? "N" : "T";
2629 transb
= bxstride
== 1 ? "N" : "T";
2631 gemm (transa
, transb
, &m
,
2632 &n
, &k
, &one
, abase
, &lda
, bbase
, &ldb
, &zero
, dest
,
2638 if (rxstride
== 1 && axstride
== 1 && bxstride
== 1)
2640 /* This block of code implements a tuned matmul, derived from
2641 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
2643 Bo Kagstrom and Per Ling
2644 Department of Computing Science
2646 S-901 87 Umea, Sweden
2648 from netlib.org, translated to C, and modified for matmul.m4. */
2650 const GFC_REAL_8
*a
, *b
;
2652 const index_type m
= xcount
, n
= ycount
, k
= count
;
2654 /* System generated locals */
2655 index_type a_dim1
, a_offset
, b_dim1
, b_offset
, c_dim1
, c_offset
,
2656 i1
, i2
, i3
, i4
, i5
, i6
;
2658 /* Local variables */
2659 GFC_REAL_8 f11
, f12
, f21
, f22
, f31
, f32
, f41
, f42
,
2660 f13
, f14
, f23
, f24
, f33
, f34
, f43
, f44
;
2661 index_type i
, j
, l
, ii
, jj
, ll
;
2662 index_type isec
, jsec
, lsec
, uisec
, ujsec
, ulsec
;
2667 c
= retarray
->base_addr
;
2669 /* Parameter adjustments */
2671 c_offset
= 1 + c_dim1
;
2674 a_offset
= 1 + a_dim1
;
2677 b_offset
= 1 + b_dim1
;
2680 /* Empty c first. */
2681 for (j
=1; j
<=n
; j
++)
2682 for (i
=1; i
<=m
; i
++)
2683 c
[i
+ j
* c_dim1
] = (GFC_REAL_8
)0;
2685 /* Early exit if possible */
2686 if (m
== 0 || n
== 0 || k
== 0)
2689 /* Adjust size of t1 to what is needed. */
2690 index_type t1_dim
, a_sz
;
2696 t1_dim
= a_sz
* 256 + b_dim1
;
2700 t1
= malloc (t1_dim
* sizeof(GFC_REAL_8
));
2702 /* Start turning the crank. */
2704 for (jj
= 1; jj
<= i1
; jj
+= 512)
2710 ujsec
= jsec
- jsec
% 4;
2712 for (ll
= 1; ll
<= i2
; ll
+= 256)
2718 ulsec
= lsec
- lsec
% 2;
2721 for (ii
= 1; ii
<= i3
; ii
+= 256)
2727 uisec
= isec
- isec
% 2;
2728 i4
= ll
+ ulsec
- 1;
2729 for (l
= ll
; l
<= i4
; l
+= 2)
2731 i5
= ii
+ uisec
- 1;
2732 for (i
= ii
; i
<= i5
; i
+= 2)
2734 t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257] =
2736 t1
[l
- ll
+ 2 + ((i
- ii
+ 1) << 8) - 257] =
2737 a
[i
+ (l
+ 1) * a_dim1
];
2738 t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257] =
2739 a
[i
+ 1 + l
* a_dim1
];
2740 t1
[l
- ll
+ 2 + ((i
- ii
+ 2) << 8) - 257] =
2741 a
[i
+ 1 + (l
+ 1) * a_dim1
];
2745 t1
[l
- ll
+ 1 + (isec
<< 8) - 257] =
2746 a
[ii
+ isec
- 1 + l
* a_dim1
];
2747 t1
[l
- ll
+ 2 + (isec
<< 8) - 257] =
2748 a
[ii
+ isec
- 1 + (l
+ 1) * a_dim1
];
2754 for (i
= ii
; i
<= i4
; ++i
)
2756 t1
[lsec
+ ((i
- ii
+ 1) << 8) - 257] =
2757 a
[i
+ (ll
+ lsec
- 1) * a_dim1
];
2761 uisec
= isec
- isec
% 4;
2762 i4
= jj
+ ujsec
- 1;
2763 for (j
= jj
; j
<= i4
; j
+= 4)
2765 i5
= ii
+ uisec
- 1;
2766 for (i
= ii
; i
<= i5
; i
+= 4)
2768 f11
= c
[i
+ j
* c_dim1
];
2769 f21
= c
[i
+ 1 + j
* c_dim1
];
2770 f12
= c
[i
+ (j
+ 1) * c_dim1
];
2771 f22
= c
[i
+ 1 + (j
+ 1) * c_dim1
];
2772 f13
= c
[i
+ (j
+ 2) * c_dim1
];
2773 f23
= c
[i
+ 1 + (j
+ 2) * c_dim1
];
2774 f14
= c
[i
+ (j
+ 3) * c_dim1
];
2775 f24
= c
[i
+ 1 + (j
+ 3) * c_dim1
];
2776 f31
= c
[i
+ 2 + j
* c_dim1
];
2777 f41
= c
[i
+ 3 + j
* c_dim1
];
2778 f32
= c
[i
+ 2 + (j
+ 1) * c_dim1
];
2779 f42
= c
[i
+ 3 + (j
+ 1) * c_dim1
];
2780 f33
= c
[i
+ 2 + (j
+ 2) * c_dim1
];
2781 f43
= c
[i
+ 3 + (j
+ 2) * c_dim1
];
2782 f34
= c
[i
+ 2 + (j
+ 3) * c_dim1
];
2783 f44
= c
[i
+ 3 + (j
+ 3) * c_dim1
];
2785 for (l
= ll
; l
<= i6
; ++l
)
2787 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
2788 * b
[l
+ j
* b_dim1
];
2789 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
2790 * b
[l
+ j
* b_dim1
];
2791 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
2792 * b
[l
+ (j
+ 1) * b_dim1
];
2793 f22
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
2794 * b
[l
+ (j
+ 1) * b_dim1
];
2795 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
2796 * b
[l
+ (j
+ 2) * b_dim1
];
2797 f23
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
2798 * b
[l
+ (j
+ 2) * b_dim1
];
2799 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
2800 * b
[l
+ (j
+ 3) * b_dim1
];
2801 f24
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
2802 * b
[l
+ (j
+ 3) * b_dim1
];
2803 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
2804 * b
[l
+ j
* b_dim1
];
2805 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
2806 * b
[l
+ j
* b_dim1
];
2807 f32
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
2808 * b
[l
+ (j
+ 1) * b_dim1
];
2809 f42
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
2810 * b
[l
+ (j
+ 1) * b_dim1
];
2811 f33
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
2812 * b
[l
+ (j
+ 2) * b_dim1
];
2813 f43
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
2814 * b
[l
+ (j
+ 2) * b_dim1
];
2815 f34
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
2816 * b
[l
+ (j
+ 3) * b_dim1
];
2817 f44
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
2818 * b
[l
+ (j
+ 3) * b_dim1
];
2820 c
[i
+ j
* c_dim1
] = f11
;
2821 c
[i
+ 1 + j
* c_dim1
] = f21
;
2822 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
2823 c
[i
+ 1 + (j
+ 1) * c_dim1
] = f22
;
2824 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
2825 c
[i
+ 1 + (j
+ 2) * c_dim1
] = f23
;
2826 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
2827 c
[i
+ 1 + (j
+ 3) * c_dim1
] = f24
;
2828 c
[i
+ 2 + j
* c_dim1
] = f31
;
2829 c
[i
+ 3 + j
* c_dim1
] = f41
;
2830 c
[i
+ 2 + (j
+ 1) * c_dim1
] = f32
;
2831 c
[i
+ 3 + (j
+ 1) * c_dim1
] = f42
;
2832 c
[i
+ 2 + (j
+ 2) * c_dim1
] = f33
;
2833 c
[i
+ 3 + (j
+ 2) * c_dim1
] = f43
;
2834 c
[i
+ 2 + (j
+ 3) * c_dim1
] = f34
;
2835 c
[i
+ 3 + (j
+ 3) * c_dim1
] = f44
;
2840 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
2842 f11
= c
[i
+ j
* c_dim1
];
2843 f12
= c
[i
+ (j
+ 1) * c_dim1
];
2844 f13
= c
[i
+ (j
+ 2) * c_dim1
];
2845 f14
= c
[i
+ (j
+ 3) * c_dim1
];
2847 for (l
= ll
; l
<= i6
; ++l
)
2849 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2850 257] * b
[l
+ j
* b_dim1
];
2851 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2852 257] * b
[l
+ (j
+ 1) * b_dim1
];
2853 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2854 257] * b
[l
+ (j
+ 2) * b_dim1
];
2855 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2856 257] * b
[l
+ (j
+ 3) * b_dim1
];
2858 c
[i
+ j
* c_dim1
] = f11
;
2859 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
2860 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
2861 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
2868 for (j
= jj
+ ujsec
; j
<= i4
; ++j
)
2870 i5
= ii
+ uisec
- 1;
2871 for (i
= ii
; i
<= i5
; i
+= 4)
2873 f11
= c
[i
+ j
* c_dim1
];
2874 f21
= c
[i
+ 1 + j
* c_dim1
];
2875 f31
= c
[i
+ 2 + j
* c_dim1
];
2876 f41
= c
[i
+ 3 + j
* c_dim1
];
2878 for (l
= ll
; l
<= i6
; ++l
)
2880 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2881 257] * b
[l
+ j
* b_dim1
];
2882 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) -
2883 257] * b
[l
+ j
* b_dim1
];
2884 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) -
2885 257] * b
[l
+ j
* b_dim1
];
2886 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) -
2887 257] * b
[l
+ j
* b_dim1
];
2889 c
[i
+ j
* c_dim1
] = f11
;
2890 c
[i
+ 1 + j
* c_dim1
] = f21
;
2891 c
[i
+ 2 + j
* c_dim1
] = f31
;
2892 c
[i
+ 3 + j
* c_dim1
] = f41
;
2895 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
2897 f11
= c
[i
+ j
* c_dim1
];
2899 for (l
= ll
; l
<= i6
; ++l
)
2901 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2902 257] * b
[l
+ j
* b_dim1
];
2904 c
[i
+ j
* c_dim1
] = f11
;
2914 else if (rxstride
== 1 && aystride
== 1 && bxstride
== 1)
2916 if (GFC_DESCRIPTOR_RANK (a
) != 1)
2918 const GFC_REAL_8
*restrict abase_x
;
2919 const GFC_REAL_8
*restrict bbase_y
;
2920 GFC_REAL_8
*restrict dest_y
;
2923 for (y
= 0; y
< ycount
; y
++)
2925 bbase_y
= &bbase
[y
*bystride
];
2926 dest_y
= &dest
[y
*rystride
];
2927 for (x
= 0; x
< xcount
; x
++)
2929 abase_x
= &abase
[x
*axstride
];
2931 for (n
= 0; n
< count
; n
++)
2932 s
+= abase_x
[n
] * bbase_y
[n
];
2939 const GFC_REAL_8
*restrict bbase_y
;
2942 for (y
= 0; y
< ycount
; y
++)
2944 bbase_y
= &bbase
[y
*bystride
];
2946 for (n
= 0; n
< count
; n
++)
2947 s
+= abase
[n
*axstride
] * bbase_y
[n
];
2948 dest
[y
*rystride
] = s
;
2952 else if (GFC_DESCRIPTOR_RANK (a
) == 1)
2954 const GFC_REAL_8
*restrict bbase_y
;
2957 for (y
= 0; y
< ycount
; y
++)
2959 bbase_y
= &bbase
[y
*bystride
];
2961 for (n
= 0; n
< count
; n
++)
2962 s
+= abase
[n
*axstride
] * bbase_y
[n
*bxstride
];
2963 dest
[y
*rxstride
] = s
;
2966 else if (axstride
< aystride
)
2968 for (y
= 0; y
< ycount
; y
++)
2969 for (x
= 0; x
< xcount
; x
++)
2970 dest
[x
*rxstride
+ y
*rystride
] = (GFC_REAL_8
)0;
2972 for (y
= 0; y
< ycount
; y
++)
2973 for (n
= 0; n
< count
; n
++)
2974 for (x
= 0; x
< xcount
; x
++)
2975 /* dest[x,y] += a[x,n] * b[n,y] */
2976 dest
[x
*rxstride
+ y
*rystride
] +=
2977 abase
[x
*axstride
+ n
*aystride
] *
2978 bbase
[n
*bxstride
+ y
*bystride
];
2982 const GFC_REAL_8
*restrict abase_x
;
2983 const GFC_REAL_8
*restrict bbase_y
;
2984 GFC_REAL_8
*restrict dest_y
;
2987 for (y
= 0; y
< ycount
; y
++)
2989 bbase_y
= &bbase
[y
*bystride
];
2990 dest_y
= &dest
[y
*rystride
];
2991 for (x
= 0; x
< xcount
; x
++)
2993 abase_x
= &abase
[x
*axstride
];
2995 for (n
= 0; n
< count
; n
++)
2996 s
+= abase_x
[n
*aystride
] * bbase_y
[n
*bxstride
];
2997 dest_y
[x
*rxstride
] = s
;