]> git.ipfire.org Git - thirdparty/gcc.git/blobdiff - libgfortran/m4/matmul.m4
Update copyright years.
[thirdparty/gcc.git] / libgfortran / m4 / matmul.m4
index 7a54b05595cacec5136b1982219ed3500665ab8e..7fc1f5fa75fb1ce4f4607e1d2e2e2310dfe37f28 100644 (file)
 `/* Implementation of the MATMUL intrinsic
-   Copyright 2002 Free Software Foundation, Inc.
+   Copyright (C) 2002-2024 Free Software Foundation, Inc.
    Contributed by Paul Brook <paul@nowt.org>
 
-This file is part of the GNU Fortran 95 runtime library (libgfortran).
+This file is part of the GNU Fortran runtime library (libgfortran).
 
 Libgfortran is free software; you can redistribute it and/or
-modify it under the terms of the GNU Lesser General Public
+modify it under the terms of the GNU General Public
 License as published by the Free Software Foundation; either
-version 2.1 of the License, or (at your option) any later version.
+version 3 of the License, or (at your option) any later version.
 
 Libgfortran is distributed in the hope that it will be useful,
 but WITHOUT ANY WARRANTY; without even the implied warranty of
 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-GNU Lesser General Public License for more details.
+GNU General Public License for more details.
 
-You should have received a copy of the GNU Lesser General Public
-License along with libgfor; see the file COPYING.LIB.  If not,
-write to the Free Software Foundation, Inc., 59 Temple Place - Suite 330,
-Boston, MA 02111-1307, USA.  */
+Under Section 7 of GPL version 3, you are granted additional
+permissions described in the GCC Runtime Library Exception, version
+3.1, as published by the Free Software Foundation.
 
-#include "config.h"
-#include <stdlib.h>
-#include <assert.h>
-#include "libgfortran.h"'
-include(iparm.m4)dnl
+You should have received a copy of the GNU General Public License and
+a copy of the GCC Runtime Library Exception along with this program;
+see the files COPYING3 and COPYING.RUNTIME respectively.  If not, see
+<http://www.gnu.org/licenses/>.  */
 
-/* Dimensions: retarray(x,y) a(x, count) b(count,y).
-   Either a or b can be rank 1.  In this case x or y is 1.  */
-void
-`__matmul_'rtype_code (rtype * retarray, rtype * a, rtype * b)
-{
-  rtype_name *abase;
-  rtype_name *bbase;
-  rtype_name *dest;
-  rtype_name res;
-  index_type rxstride;
-  index_type rystride;
-  index_type xcount;
-  index_type ycount;
-  index_type xstride;
-  index_type ystride;
-  index_type x;
-  index_type y;
-
-  rtype_name *pa;
-  rtype_name *pb;
-  index_type astride;
-  index_type bstride;
-  index_type count;
-  index_type n;
-
-  assert (GFC_DESCRIPTOR_RANK (a) == 2
-          || GFC_DESCRIPTOR_RANK (b) == 2);
-
-  if (retarray->data == NULL)
-    {
-      if (GFC_DESCRIPTOR_RANK (a) == 1)
-        {
-          retarray->dim[0].lbound = 0;
-          retarray->dim[0].ubound = b->dim[1].ubound - b->dim[1].lbound;
-          retarray->dim[0].stride = 1;
-        }
-      else if (GFC_DESCRIPTOR_RANK (b) == 1)
-        {
-          retarray->dim[0].lbound = 0;
-          retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
-          retarray->dim[0].stride = 1;
-        }
-      else
-        {
-          retarray->dim[0].lbound = 0;
-          retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
-          retarray->dim[0].stride = 1;
-          
-          retarray->dim[1].lbound = 0;
-          retarray->dim[1].ubound = b->dim[1].ubound - b->dim[1].lbound;
-          retarray->dim[1].stride = retarray->dim[0].ubound+1;
-        }
-          
-      retarray->data = internal_malloc (sizeof (rtype_name) * size0 (retarray));
-      retarray->base = 0;
-    }
-
-  abase = a->data;
-  bbase = b->data;
-  dest = retarray->data;
+#include "libgfortran.h"
+#include <string.h>
+#include <assert.h>'
 
-  if (retarray->dim[0].stride == 0)
-    retarray->dim[0].stride = 1;
-  if (a->dim[0].stride == 0)
-    a->dim[0].stride = 1;
-  if (b->dim[0].stride == 0)
-    b->dim[0].stride = 1;
-
-sinclude(`matmul_asm_'rtype_code`.m4')dnl
+include(iparm.m4)dnl
 
-  if (GFC_DESCRIPTOR_RANK (retarray) == 1)
-    {
-      rxstride = retarray->dim[0].stride;
-      rystride = rxstride;
-    }
-  else
-    {
-      rxstride = retarray->dim[0].stride;
-      rystride = retarray->dim[1].stride;
-    }
+`#if defined (HAVE_'rtype_name`)
+
+/* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
+   passed to us by the front-end, in which case we call it for large
+   matrices.  */
+
+typedef void (*blas_call)(const char *, const char *, const int *, const int *,
+                          const int *, const 'rtype_name` *, const 'rtype_name` *,
+                          const int *, const 'rtype_name` *, const int *,
+                          const 'rtype_name` *, 'rtype_name` *, const int *,
+                          int, int);
+
+/* The order of loops is different in the case of plain matrix
+   multiplication C=MATMUL(A,B), and in the frequent special case where
+   the argument A is the temporary result of a TRANSPOSE intrinsic:
+   C=MATMUL(TRANSPOSE(A),B).  Transposed temporaries are detected by
+   looking at their strides.
+
+   The equivalent Fortran pseudo-code is:
+
+   DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
+   IF (.NOT.IS_TRANSPOSED(A)) THEN
+     C = 0
+     DO J=1,N
+       DO K=1,COUNT
+         DO I=1,M
+           C(I,J) = C(I,J)+A(I,K)*B(K,J)
+   ELSE
+     DO J=1,N
+       DO I=1,M
+         S = 0
+         DO K=1,COUNT
+           S = S+A(I,K)*B(K,J)
+         C(I,J) = S
+   ENDIF
+*/
+
+/* If try_blas is set to a nonzero value, then the matmul function will
+   see if there is a way to perform the matrix multiplication by a call
+   to the BLAS gemm function.  */
+
+extern void matmul_'rtype_code` ('rtype` * const restrict retarray, 
+       'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
+       int blas_limit, blas_call gemm);
+export_proto(matmul_'rtype_code`);
+
+/* Put exhaustive list of possible architectures here here, ORed together.  */
+
+#if defined(HAVE_AVX) || defined(HAVE_AVX2) || defined(HAVE_AVX512F)
+
+#ifdef HAVE_AVX
+'define(`matmul_name',`matmul_'rtype_code`_avx')dnl
+`static void
+'matmul_name` ('rtype` * const restrict retarray, 
+       'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
+       int blas_limit, blas_call gemm) __attribute__((__target__("avx")));
+static' include(matmul_internal.m4)dnl
+`#endif /* HAVE_AVX */
+
+#ifdef HAVE_AVX2
+'define(`matmul_name',`matmul_'rtype_code`_avx2')dnl
+`static void
+'matmul_name` ('rtype` * const restrict retarray, 
+       'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
+       int blas_limit, blas_call gemm) __attribute__((__target__("avx2,fma")));
+static' include(matmul_internal.m4)dnl
+`#endif /* HAVE_AVX2 */
+
+#ifdef HAVE_AVX512F
+'define(`matmul_name',`matmul_'rtype_code`_avx512f')dnl
+`static void
+'matmul_name` ('rtype` * const restrict retarray, 
+       'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
+       int blas_limit, blas_call gemm) __attribute__((__target__("avx512f")));
+static' include(matmul_internal.m4)dnl
+`#endif  /* HAVE_AVX512F */
+
+/* AMD-specifix funtions with AVX128 and FMA3/FMA4.  */
+
+#if defined(HAVE_AVX) && defined(HAVE_FMA3) && defined(HAVE_AVX128)
+'define(`matmul_name',`matmul_'rtype_code`_avx128_fma3')dnl
+`void
+'matmul_name` ('rtype` * const restrict retarray, 
+       'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
+       int blas_limit, blas_call gemm) __attribute__((__target__("avx,fma")));
+internal_proto('matmul_name`);
+#endif
+
+#if defined(HAVE_AVX) && defined(HAVE_FMA4) && defined(HAVE_AVX128)
+'define(`matmul_name',`matmul_'rtype_code`_avx128_fma4')dnl
+`void
+'matmul_name` ('rtype` * const restrict retarray, 
+       'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
+       int blas_limit, blas_call gemm) __attribute__((__target__("avx,fma4")));
+internal_proto('matmul_name`);
+#endif
+
+/* Function to fall back to if there is no special processor-specific version.  */
+'define(`matmul_name',`matmul_'rtype_code`_vanilla')dnl
+`static' include(matmul_internal.m4)dnl
+
+`/* Compiling main function, with selection code for the processor.  */
+
+/* Currently, this is i386 only.  Adjust for other architectures.  */
+
+void matmul_'rtype_code` ('rtype` * const restrict retarray, 
+       'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
+       int blas_limit, blas_call gemm)
+{
+  static void (*matmul_p) ('rtype` * const restrict retarray, 
+       'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
+       int blas_limit, blas_call gemm);
 
-  /* If we have rank 1 parameters, zero the absent stride, and set the size to
-     one.  */
-  if (GFC_DESCRIPTOR_RANK (a) == 1)
-    {
-      astride = a->dim[0].stride;
-      count = a->dim[0].ubound + 1 - a->dim[0].lbound;
-      xstride = 0;
-      rxstride = 0;
-      xcount = 1;
-    }
-  else
-    {
-      astride = a->dim[1].stride;
-      count = a->dim[1].ubound + 1 - a->dim[1].lbound;
-      xstride = a->dim[0].stride;
-      xcount = a->dim[0].ubound + 1 - a->dim[0].lbound;
-    }
-  if (GFC_DESCRIPTOR_RANK (b) == 1)
-    {
-      bstride = b->dim[0].stride;
-      assert(count == b->dim[0].ubound + 1 - b->dim[0].lbound);
-      ystride = 0;
-      rystride = 0;
-      ycount = 1;
-    }
-  else
-    {
-      bstride = b->dim[0].stride;
-      assert(count == b->dim[0].ubound + 1 - b->dim[0].lbound);
-      ystride = b->dim[1].stride;
-      ycount = b->dim[1].ubound + 1 - b->dim[1].lbound;
-    }
+  void (*matmul_fn) ('rtype` * const restrict retarray, 
+       'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
+       int blas_limit, blas_call gemm);
 
-  for (y = 0; y < ycount; y++)
+  matmul_fn = __atomic_load_n (&matmul_p, __ATOMIC_RELAXED);
+  if (matmul_fn == NULL)
     {
-      for (x = 0; x < xcount; x++)
-        {
-          /* Do the summation for this element.  For real and integer types
-             this is the same as DOT_PRODUCT.  For complex types we use do
-             a*b, not conjg(a)*b.  */
-          pa = abase;
-          pb = bbase;
-          res = 0;
-
-          for (n = 0; n < count; n++)
-            {
-              res += *pa * *pb;
-              pa += astride;
-              pb += bstride;
-            }
-
-          *dest = res;
-
-          dest += rxstride;
-          abase += xstride;
+      matmul_fn = matmul_'rtype_code`_vanilla;
+      if (__builtin_cpu_is ("intel"))
+       {
+          /* Run down the available processors in order of preference.  */
+#ifdef HAVE_AVX512F
+         if (__builtin_cpu_supports ("avx512f"))
+           {
+             matmul_fn = matmul_'rtype_code`_avx512f;
+             goto store;
+           }
+
+#endif  /* HAVE_AVX512F */
+
+#ifdef HAVE_AVX2
+         if (__builtin_cpu_supports ("avx2")
+             && __builtin_cpu_supports ("fma"))
+           {
+             matmul_fn = matmul_'rtype_code`_avx2;
+             goto store;
+           }
+
+#endif
+
+#ifdef HAVE_AVX
+         if (__builtin_cpu_supports ("avx"))
+           {
+              matmul_fn = matmul_'rtype_code`_avx;
+             goto store;
+           }
+#endif  /* HAVE_AVX */
         }
-      abase -= xstride * xcount;
-      bbase += ystride;
-      dest += rystride - (rxstride * xcount);
-    }
+    else if (__builtin_cpu_is ("amd"))
+      {
+#if defined(HAVE_AVX) && defined(HAVE_FMA3) && defined(HAVE_AVX128)
+       if (__builtin_cpu_supports ("avx")
+           && __builtin_cpu_supports ("fma"))
+         {
+            matmul_fn = matmul_'rtype_code`_avx128_fma3;
+           goto store;
+         }
+#endif
+#if defined(HAVE_AVX) && defined(HAVE_FMA4) && defined(HAVE_AVX128)
+       if (__builtin_cpu_supports ("avx")
+           && __builtin_cpu_supports ("fma4"))
+         {
+            matmul_fn = matmul_'rtype_code`_avx128_fma4;
+           goto store;
+         }
+#endif
+
+      }
+   store:
+      __atomic_store_n (&matmul_p, matmul_fn, __ATOMIC_RELAXED);
+   }
+
+   (*matmul_fn) (retarray, a, b, try_blas, blas_limit, gemm);
 }
 
+#else  /* Just the vanilla function.  */
+
+'define(`matmul_name',`matmul_'rtype_code)dnl
+define(`target_attribute',`')dnl
+include(matmul_internal.m4)dnl
+`#endif
+#endif
+'