]> git.ipfire.org Git - thirdparty/vectorscan.git/commitdiff
fix loadu_maskz, remove old defines
authorKonstantinos Margaritis <konstantinos@vectorcamp.gr>
Mon, 18 Dec 2023 18:07:35 +0000 (20:07 +0200)
committerKonstantinos Margaritis <konma@vectorcamp.gr>
Thu, 21 Dec 2023 23:24:31 +0000 (23:24 +0000)
src/util/supervector/arch/x86/impl.cpp
src/util/supervector/supervector.hpp

index b8a75c95c3f155f1bfe7f5823c06907136ba4505..77ffc038cca3d96b6ebccfa20ad8a2c57e9b63fd 100644 (file)
@@ -524,7 +524,28 @@ really_inline SuperVector<16> SuperVector<16>::loadu_maskz(void const *ptr, uint
 {
     SuperVector mask = Ones_vshr(16 -len);
     SuperVector v = _mm_loadu_si128((const m128 *)ptr);
-    return mask & v;
+    return v & mask;
+}
+
+template <>
+really_inline SuperVector<16> SuperVector<16>::loadu_maskz(void const *ptr, typename base_type::comparemask_type const mask)
+{
+#ifdef HAVE_AVX512
+    SuperVector<16> v = _mm_maskz_loadu_epi8(mask, (const m128 *)ptr);
+    v.print8("v");
+    return v;
+#else
+    DEBUG_PRINTF("mask = %08x\n", mask);
+    SuperVector v = _mm_loadu_si128((const m128 *)ptr);
+    (void)mask;
+    return v; // FIXME: & mask
+#endif
+}
+
+template<>
+really_inline typename SuperVector<16>::comparemask_type SuperVector<16>::findLSB(typename SuperVector<16>::comparemask_type &z)
+{
+  return findAndClearLSB_32(&z);
 }
 
 template<>
@@ -1126,22 +1147,35 @@ really_inline SuperVector<32> SuperVector<32>::load(void const *ptr)
 template <>
 really_inline SuperVector<32> SuperVector<32>::loadu_maskz(void const *ptr, uint8_t const len)
 {
+    SuperVector mask = Ones_vshr(32 -len);
+    mask.print8("mask");
+    SuperVector<32> v = _mm256_loadu_si256((const m256 *)ptr);
+    v.print8("v");
+    return v & mask;
+}
+
+template <>
+really_inline SuperVector<32> SuperVector<32>::loadu_maskz(void const *ptr, typename base_type::comparemask_type const mask)
+{
+    DEBUG_PRINTF("mask = %08llx\n", mask);
 #ifdef HAVE_AVX512
-    u32 mask = (~0ULL) >> (32 - len);
-    SuperVector<32> v = _mm256_mask_loadu_epi8(Zeroes().u.v256[0], mask, (const m256 *)ptr);
+    SuperVector<32> v = _mm256_maskz_loadu_epi8(mask, (const m256 *)ptr);
     v.print8("v");
     return v;
 #else
-    DEBUG_PRINTF("len = %d", len);
-    SuperVector<32> mask = Ones_vshr(32 -len);
-    mask.print8("mask");
-    (Ones() >> (32 - len)).print8("mask");
     SuperVector<32> v = _mm256_loadu_si256((const m256 *)ptr);
     v.print8("v");
-    return mask & v;
+    (void)mask;
+    return v; // FIXME: & mask
 #endif
 }
 
+template<>
+really_inline typename SuperVector<32>::comparemask_type SuperVector<32>::findLSB(typename SuperVector<32>::comparemask_type &z)
+{
+  return findAndClearLSB_64(&z);
+}
+
 template<>
 really_inline SuperVector<32> SuperVector<32>::alignr(SuperVector<32> &other, int8_t offset)
 {
@@ -1778,11 +1812,26 @@ really_inline SuperVector<64> SuperVector<64>::loadu_maskz(void const *ptr, uint
 {
     u64a mask = (~0ULL) >> (64 - len);
     DEBUG_PRINTF("mask = %016llx\n", mask);
-    SuperVector<64> v = _mm512_mask_loadu_epi8(Zeroes().u.v512[0], mask, (const m512 *)ptr);
+    SuperVector<64> v = _mm512_maskz_loadu_epi8(mask, (const m512 *)ptr);
     v.print8("v");
     return v;
 }
 
+template <>
+really_inline SuperVector<64> SuperVector<64>::loadu_maskz(void const *ptr, typename base_type::comparemask_type const mask)
+{
+    DEBUG_PRINTF("mask = %016llx\n", mask);
+    SuperVector<64> v = _mm512_maskz_loadu_epi8(mask, (const m512 *)ptr);
+    v.print8("v");
+    return v;
+}
+
+template<>
+really_inline typename SuperVector<64>::comparemask_type SuperVector<64>::findLSB(typename SuperVector<64>::comparemask_type &z)
+{
+  return findAndClearLSB_64(&z);
+}
+
 template<>
 template<>
 really_inline SuperVector<64> SuperVector<64>::pshufb<true>(SuperVector<64> b)
index 253907fa3e9e3467c6c466ab94ddac06b83788aa..1d72ee81f4c79374e6f93cd99abe64d78c164838 100644 (file)
 #endif
 #endif // VS_SIMDE_BACKEND
 
+#include <util/bitutils.h>
+
 #if defined(HAVE_SIMD_512_BITS)
-using Z_TYPE = u64a;
-#define Z_BITS 64
-#define Z_SHIFT 63
 #define Z_POSSHIFT 0
-#define DOUBLE_LOAD_MASK(l)        ((~0ULL) >> (Z_BITS -(l)))
-#define SINGLE_LOAD_MASK(l)        (((1ULL) << (l)) - 1ULL)
 #elif defined(HAVE_SIMD_256_BITS)
-using Z_TYPE = u32;
-#define Z_BITS 32
-#define Z_SHIFT 31
 #define Z_POSSHIFT 0
-#define DOUBLE_LOAD_MASK(l)        (((1ULL) << (l)) - 1ULL)
-#define SINGLE_LOAD_MASK(l)        (((1ULL) << (l)) - 1ULL)
 #elif defined(HAVE_SIMD_128_BITS)
 #if !defined(VS_SIMDE_BACKEND) && (defined(ARCH_ARM32) || defined(ARCH_AARCH64))
-using Z_TYPE = u64a;
-#define Z_BITS 64
 #define Z_POSSHIFT 2
-#define DOUBLE_LOAD_MASK(l) ((~0ULL) >> (Z_BITS - (l)))
 #else
-using Z_TYPE = u32;
-#define Z_BITS 32
 #define Z_POSSHIFT 0
-#define DOUBLE_LOAD_MASK(l) (((1ULL) << (l)) - 1ULL)
 #endif
-#define Z_SHIFT 15
-#define SINGLE_LOAD_MASK(l)        (((1ULL) << (l)) - 1ULL)
 #endif
 
 // Define a common assume_aligned using an appropriate compiler built-in, if
@@ -138,7 +122,7 @@ struct BaseVector<64>
   static constexpr u16  previous_size = 32;
 };
 
-// 128 bit implementation
+// 256 bit implementation
 template <>
 struct BaseVector<32>
 {
@@ -158,7 +142,7 @@ struct BaseVector<16>
   static constexpr bool      is_valid = true;
   static constexpr u16           size = 16;
   using                          type = m128;
-  using              comparemask_type = u64a;
+  using              comparemask_type = u32;
   static constexpr bool  has_previous = false;
   using                 previous_type = u64a;
   static constexpr u16  previous_size = 8;
@@ -257,9 +241,13 @@ public:
   static typename base_type::comparemask_type
   iteration_mask(typename base_type::comparemask_type mask);
 
+  static typename base_type::comparemask_type single_load_mask(uint8_t const len) { return (((1ULL) << (len)) - 1ULL); }
+  static typename base_type::comparemask_type double_load_mask(uint8_t const len) { return (((1ULL) << (len)) - 1ULL); }
+  static typename base_type::comparemask_type findLSB(typename base_type::comparemask_type &z);
   static SuperVector loadu(void const *ptr);
   static SuperVector load(void const *ptr);
   static SuperVector loadu_maskz(void const *ptr, uint8_t const len);
+  static SuperVector loadu_maskz(void const *ptr, typename base_type::comparemask_type const len);
   SuperVector alignr(SuperVector &other, int8_t offset);
 
   template<bool emulateIntel=true>