]> git.ipfire.org Git - thirdparty/vectorscan.git/commitdiff
comparemask_type is u64a on Arm, use single load_mask
authorKonstantinos Margaritis <konma@vectorcamp.gr>
Mon, 18 Dec 2023 20:23:07 +0000 (20:23 +0000)
committerKonstantinos Margaritis <konma@vectorcamp.gr>
Thu, 21 Dec 2023 23:25:20 +0000 (23:25 +0000)
src/hwlm/noodle_engine_simd.hpp
src/util/supervector/arch/arm/impl.cpp
src/util/supervector/supervector.hpp

index 9af76768cf6f25967b7586d65475711da704a147..23827873fdb9c8a2a7988122a05779f7a99b2a2c 100644 (file)
@@ -86,15 +86,21 @@ hwlm_error_t scanSingleMain(const struct noodTable *n, const u8 *buf,
             DEBUG_PRINTF("d - d0: %ld \n", d - d0);
 #if defined(HAVE_MASKED_LOADS)
             uint8_t l = d - d0;
-            typename SuperVector<S>::comparemask_type mask = ~SuperVector<S>::single_load_mask(l);
+            typename SuperVector<S>::comparemask_type mask = ~SuperVector<S>::load_mask(l);
             SuperVector<S> chars = SuperVector<S>::loadu_maskz(d0, mask) & caseMask;
             typename SuperVector<S>::comparemask_type z = mask1.eqmask(chars);
             DEBUG_PRINTF("mask: %08llx\n", mask);
             hwlm_error_t rv = single_zscan<S>(n, d0, buf, z, len, cbi);
 #else
             uint8_t l = d0 + S - d;
+            DEBUG_PRINTF("l: %d \n", l);
             SuperVector<S> chars = SuperVector<S>::loadu_maskz(d, l) & caseMask;
+            chars.print8("chars");
             typename SuperVector<S>::comparemask_type z = mask1.eqmask(chars);
+            DEBUG_PRINTF("z: %08llx\n", (u64a) z);
+            z = SuperVector<S>::iteration_mask(z);
+            DEBUG_PRINTF("z: %08llx\n", (u64a) z);
+
             hwlm_error_t rv = single_zscan<S>(n, d, buf, z, len, cbi);
 #endif
             chars.print32("chars");
@@ -125,6 +131,8 @@ hwlm_error_t scanSingleMain(const struct noodTable *n, const u8 *buf,
         uint8_t l = buf_end - d;
         SuperVector<S> chars = SuperVector<S>::loadu_maskz(d, l) & caseMask;
         typename SuperVector<S>::comparemask_type z = mask1.eqmask(chars);
+        z = SuperVector<S>::iteration_mask(z);
+
         hwlm_error_t rv = single_zscan<S>(n, d, buf, z, len, cbi);
         RETURN_IF_TERMINATED(rv);
     }
@@ -160,12 +168,12 @@ hwlm_error_t scanDoubleMain(const struct noodTable *n, const u8 *buf,
             const u8 *d0 = ROUNDDOWN_PTR(d, S);
 #if defined(HAVE_MASKED_LOADS)
             uint8_t l = d - d0;
-            typename SuperVector<S>::comparemask_type mask = ~SuperVector<S>::double_load_mask(l);
+            typename SuperVector<S>::comparemask_type mask = ~SuperVector<S>::load_mask(l);
             SuperVector<S> chars = SuperVector<S>::loadu_maskz(d0, mask) & caseMask;
             typename SuperVector<S>::comparemask_type z1 = mask1.eqmask(chars);
             typename SuperVector<S>::comparemask_type z2 = mask2.eqmask(chars);
             typename SuperVector<S>::comparemask_type z = (z1 << SuperVector<S>::mask_width()) & z2;
-            DEBUG_PRINTF("z: %0llx\n", z);
+            z = SuperVector<S>::iteration_mask(z);
             lastz1 = z1 >> (S - 1);
 
             DEBUG_PRINTF("mask: %08llx\n", mask);
@@ -176,8 +184,9 @@ hwlm_error_t scanDoubleMain(const struct noodTable *n, const u8 *buf,
             chars.print8("chars");
             typename SuperVector<S>::comparemask_type z1 = mask1.eqmask(chars);
             typename SuperVector<S>::comparemask_type z2 = mask2.eqmask(chars);
-
             typename SuperVector<S>::comparemask_type z = (z1 << SuperVector<S>::mask_width()) & z2;
+            z = SuperVector<S>::iteration_mask(z);
+
             hwlm_error_t rv = double_zscan<S>(n, d, buf, z, len, cbi);
             lastz1 = z1 >> (l - 1);
 #endif
index 55f6c55c1660737c7b0c73aaf87d7b5063adff45..bd866223b7e1d6013db1ecdf7b8658d8965324b2 100644 (file)
@@ -525,11 +525,26 @@ really_inline SuperVector<16> SuperVector<16>::load(void const *ptr)
 template <>
 really_inline SuperVector<16> SuperVector<16>::loadu_maskz(void const *ptr, uint8_t const len)
 {
-    SuperVector mask = Ones_vshr(16 -len);
-    SuperVector<16> v = loadu(ptr);
+    SuperVector mask = Ones_vshr(16 - len);
+    SuperVector v = loadu(ptr);
     return mask & v;
 }
 
+template <>
+really_inline SuperVector<16> SuperVector<16>::loadu_maskz(void const *ptr, typename base_type::comparemask_type const mask)
+{
+    DEBUG_PRINTF("mask = %08llx\n", mask);
+    SuperVector v = loadu(ptr);
+    (void)mask;
+    return v; // FIXME: & mask
+}
+
+template<>
+really_inline typename SuperVector<16>::comparemask_type SuperVector<16>::findLSB(typename SuperVector<16>::comparemask_type &z)
+{
+  return findAndClearLSB_64(&z) >> 2;
+}
+
 template<>
 really_inline SuperVector<16> SuperVector<16>::alignr(SuperVector<16> &other, int8_t offset)
 {
index 3c4b1eea04cf146951cc41f65b6f58294027b395..6d2bc809203c307a7b3901b87ce6190b9eb81f67 100644 (file)
@@ -130,7 +130,11 @@ struct BaseVector<16>
   static constexpr bool      is_valid = true;
   static constexpr u16           size = 16;
   using                          type = m128;
+#if defined(ARCH_ARM32) || defined(ARCH_AARCH64)
+  using              comparemask_type = u64a;
+#else
   using              comparemask_type = u32;
+#endif
   static constexpr bool  has_previous = false;
   using                 previous_type = u64a;
   static constexpr u16  previous_size = 8;
@@ -229,8 +233,7 @@ public:
   static typename base_type::comparemask_type
   iteration_mask(typename base_type::comparemask_type mask);
 
-  static typename base_type::comparemask_type single_load_mask(uint8_t const len) { return (((1ULL) << (len)) - 1ULL); }
-  static typename base_type::comparemask_type double_load_mask(uint8_t const len) { return (((1ULL) << (len)) - 1ULL); }
+  static typename base_type::comparemask_type load_mask(uint8_t const len) { return (((1ULL) << (len)) - 1ULL); }
   static typename base_type::comparemask_type findLSB(typename base_type::comparemask_type &z);
   static SuperVector loadu(void const *ptr);
   static SuperVector load(void const *ptr);