]> git.ipfire.org Git - thirdparty/vectorscan.git/commitdiff
refactor Noodle Single/Double to use masked loads
authorKonstantinos Margaritis <konstantinos@vectorcamp.gr>
Mon, 18 Dec 2023 18:08:51 +0000 (20:08 +0200)
committerKonstantinos Margaritis <konma@vectorcamp.gr>
Thu, 21 Dec 2023 23:24:45 +0000 (23:24 +0000)
src/hwlm/noodle_engine_simd.hpp

index 91c72840d56a9ed60be0d7811e1df48319b88596..9e16c2f370b78a90d2174ca3bd3cf3ea08439fc5 100644 (file)
 #include "util/supervector/supervector.hpp"
 #include "util/supervector/casemask.hpp"
 
+template <uint16_t S>
 static really_really_inline
 hwlm_error_t single_zscan(const struct noodTable *n,const u8 *d, const u8 *buf,
-                          Z_TYPE z, size_t len, const struct cb_info *cbi) {
+                          typename SuperVector<S>::comparemask_type z, size_t len, const struct cb_info *cbi) {
     while (unlikely(z)) {
-        Z_TYPE pos = JOIN(findAndClearLSB_, Z_BITS)(&z) >> Z_POSSHIFT;
+        typename SuperVector<S>::comparemask_type pos = SuperVector<S>::findLSB(z) >> Z_POSSHIFT;
         size_t matchPos = d - buf + pos;
         DEBUG_PRINTF("match pos %zu\n", matchPos);
         hwlmcb_rv_t rv = final(n, buf, len, n->msk_len != 1, cbi, matchPos);
@@ -45,12 +46,12 @@ hwlm_error_t single_zscan(const struct noodTable *n,const u8 *d, const u8 *buf,
     return HWLM_SUCCESS;
 }
 
+template <uint16_t S>
 static really_really_inline
 hwlm_error_t double_zscan(const struct noodTable *n,const u8 *d, const u8 *buf,
-                          Z_TYPE z, size_t len, const struct cb_info *cbi) {
+                          typename SuperVector<S>::comparemask_type z, size_t len, const struct cb_info *cbi) {
     while (unlikely(z)) {
-        Z_TYPE pos = JOIN(findAndClearLSB_, Z_BITS)(&z) >> Z_POSSHIFT;
-        DEBUG_PRINTF("pos %u\n", pos);
+        typename SuperVector<S>::comparemask_type pos = SuperVector<S>::findLSB(z) >> Z_POSSHIFT;
         size_t matchPos = d - buf + pos - 1;
         DEBUG_PRINTF("match pos %zu\n", matchPos);
         hwlmcb_rv_t rv = final(n, buf, len, true, cbi, matchPos);
@@ -79,18 +80,28 @@ hwlm_error_t scanSingleMain(const struct noodTable *n, const u8 *buf,
     assert(d < buf_end);
     if (d + S <= buf_end) {
         // Reach vector aligned boundaries
-        DEBUG_PRINTF("until aligned %p \n", ROUNDUP_PTR(d, S));
+        DEBUG_PRINTF("until aligned %p, S: %d \n", ROUNDUP_PTR(d, S), S);
         if (!ISALIGNED_N(d, S)) {
-            const u8 *d1 = ROUNDUP_PTR(d, S);
-            DEBUG_PRINTF("d1 - d: %ld \n", d1 - d);
-            size_t l = d1 - d;
-            SuperVector<S> chars = SuperVector<S>::loadu(d) & caseMask;
-            typename SuperVector<S>::comparemask_type mask = SINGLE_LOAD_MASK(l * SuperVector<S>::mask_width());
-            typename SuperVector<S>::comparemask_type z = mask & mask1.eqmask(chars);
-
-            hwlm_error_t rv = single_zscan(n, d, buf, z, len, cbi);
+            const u8 *d0 = ROUNDDOWN_PTR(d, S);
+            DEBUG_PRINTF("d - d0: %ld \n", d - d0);
+#if defined(HAVE_MASKED_LOADS)
+            uint8_t l = d - d0;
+            typename SuperVector<S>::comparemask_type mask = ~SuperVector<S>::single_load_mask(l);
+            SuperVector<S> chars = SuperVector<S>::loadu_maskz(d0, mask) & caseMask;
+            typename SuperVector<S>::comparemask_type z = mask1.eqmask(chars);
+            DEBUG_PRINTF("mask: %08llx\n", mask);
+            hwlm_error_t rv = single_zscan<S>(n, d0, buf, z, len, cbi);
+#else
+            uint8_t l = d0 + S - d;
+            SuperVector<S> chars = SuperVector<S>::loadu_maskz(d, l) & caseMask;
+            typename SuperVector<S>::comparemask_type z = mask1.eqmask(chars);
+            hwlm_error_t rv = single_zscan<S>(n, d, buf, z, len, cbi);
+#endif
+            chars.print32("chars");
+            DEBUG_PRINTF("z: %08llx\n", (u64a) z);
+
             RETURN_IF_TERMINATED(rv);
-            d = d1;
+            d = d0 + S;
         }
 
         while(d + S <= buf_end) {
@@ -101,7 +112,7 @@ hwlm_error_t scanSingleMain(const struct noodTable *n, const u8 *buf,
             typename SuperVector<S>::comparemask_type z = mask1.eqmask(v);
             z = SuperVector<S>::iteration_mask(z);
 
-            hwlm_error_t rv = single_zscan(n, d, buf, z, len, cbi);
+            hwlm_error_t rv = single_zscan<S>(n, d, buf, z, len, cbi);
             RETURN_IF_TERMINATED(rv);
             d += S;
         }
@@ -111,11 +122,10 @@ hwlm_error_t scanSingleMain(const struct noodTable *n, const u8 *buf,
     // finish off tail
 
     if (d != buf_end) {
-        SuperVector<S> chars = SuperVector<S>::loadu(d) & caseMask;
-        size_t l = buf_end - d;
-        typename SuperVector<S>::comparemask_type mask = SINGLE_LOAD_MASK(l * SuperVector<S>::mask_width());
-        typename SuperVector<S>::comparemask_type z = mask & mask1.eqmask(chars);
-        hwlm_error_t rv = single_zscan(n, d, buf, z, len, cbi);
+        uint8_t l = buf_end - d;
+        SuperVector<S> chars = SuperVector<S>::loadu_maskz(d, l) & caseMask;
+        typename SuperVector<S>::comparemask_type z = mask1.eqmask(chars);
+        hwlm_error_t rv = single_zscan<S>(n, d, buf, z, len, cbi);
         RETURN_IF_TERMINATED(rv);
     }
 
@@ -145,21 +155,34 @@ hwlm_error_t scanDoubleMain(const struct noodTable *n, const u8 *buf,
     assert(d < buf_end);
     if (d + S <= buf_end) {
         // Reach vector aligned boundaries
-        DEBUG_PRINTF("until aligned %p \n", ROUNDUP_PTR(d, S));
+        DEBUG_PRINTF("until aligned %p, S: %d \n", ROUNDUP_PTR(d, S), S);
         if (!ISALIGNED_N(d, S)) {
-            const u8 *d1 = ROUNDUP_PTR(d, S);
-            size_t l = d1 - d;
-            SuperVector<S> chars = SuperVector<S>::loadu(d) & caseMask;
-            typename SuperVector<S>::comparemask_type mask = DOUBLE_LOAD_MASK(l * SuperVector<S>::mask_width());
+            const u8 *d0 = ROUNDDOWN_PTR(d, S);
+#if defined(HAVE_MASKED_LOADS)
+            uint8_t l = d - d0;
+            typename SuperVector<S>::comparemask_type mask = ~SuperVector<S>::double_load_mask(l);
+            SuperVector<S> chars = SuperVector<S>::loadu_maskz(d0, mask) & caseMask;
+            typename SuperVector<S>::comparemask_type z1 = mask1.eqmask(chars);
+            typename SuperVector<S>::comparemask_type z2 = mask2.eqmask(chars);
+            typename SuperVector<S>::comparemask_type z = (z1 << SuperVector<S>::mask_width()) & z2;
+            DEBUG_PRINTF("z: %0llx\n", z);
+            lastz1 = z1 >> (S - 1);
+
+            DEBUG_PRINTF("mask: %08llx\n", mask);
+            hwlm_error_t rv = double_zscan<S>(n, d0, buf, z, len, cbi);
+#else
+            uint8_t l = d0 + S - d;
+            SuperVector<S> chars = SuperVector<S>::loadu_maskz(d, l) & caseMask;
+            chars.print8("chars");
             typename SuperVector<S>::comparemask_type z1 = mask1.eqmask(chars);
             typename SuperVector<S>::comparemask_type z2 = mask2.eqmask(chars);
-            typename SuperVector<S>::comparemask_type z = mask & (z1 << SuperVector<S>::mask_width()) & z2;
-            lastz1 = z1 >> (Z_SHIFT * SuperVector<S>::mask_width());
-            z = SuperVector<S>::iteration_mask(z);
 
-            hwlm_error_t rv = double_zscan(n, d, buf, z, len, cbi);
+            typename SuperVector<S>::comparemask_type z = (z1 << SuperVector<S>::mask_width()) & z2;
+            hwlm_error_t rv = double_zscan<S>(n, d, buf, z, len, cbi);
+            lastz1 = z1 >> (l - 1);
+#endif
             RETURN_IF_TERMINATED(rv);
-            d = d1;
+            d = d0 + S;
         }
 
         while(d + S <= buf_end) {
@@ -170,10 +193,10 @@ hwlm_error_t scanDoubleMain(const struct noodTable *n, const u8 *buf,
             typename SuperVector<S>::comparemask_type z1 = mask1.eqmask(chars);
             typename SuperVector<S>::comparemask_type z2 = mask2.eqmask(chars);
             typename SuperVector<S>::comparemask_type z = (z1 << SuperVector<S>::mask_width() | lastz1) & z2;
-            lastz1 = z1 >> (Z_SHIFT * SuperVector<S>::mask_width());
+            lastz1 = z1 >> (S - 1);
             z = SuperVector<S>::iteration_mask(z);
 
-            hwlm_error_t rv = double_zscan(n, d, buf, z, len, cbi);
+            hwlm_error_t rv = double_zscan<S>(n, d, buf, z, len, cbi);
             RETURN_IF_TERMINATED(rv);
             d += S;
         }
@@ -181,17 +204,15 @@ hwlm_error_t scanDoubleMain(const struct noodTable *n, const u8 *buf,
 
     DEBUG_PRINTF("d %p e %p \n", d, buf_end);
     // finish off tail
-
     if (d != buf_end) {
-        size_t l = buf_end - d;
-        SuperVector<S> chars = SuperVector<S>::loadu(d) & caseMask;
-        typename SuperVector<S>::comparemask_type mask = DOUBLE_LOAD_MASK(l * SuperVector<S>::mask_width());
+        uint8_t l = buf_end - d;
+        SuperVector<S> chars = SuperVector<S>::loadu_maskz(d, l) & caseMask;
         typename SuperVector<S>::comparemask_type z1 = mask1.eqmask(chars);
         typename SuperVector<S>::comparemask_type z2 = mask2.eqmask(chars);
-        typename SuperVector<S>::comparemask_type z = mask & (z1 << SuperVector<S>::mask_width() | lastz1) & z2;
+        typename SuperVector<S>::comparemask_type z = (z1 << SuperVector<S>::mask_width() | lastz1) & z2;
         z = SuperVector<S>::iteration_mask(z);
 
-        hwlm_error_t rv = double_zscan(n, d, buf, z, len, cbi);
+        hwlm_error_t rv = double_zscan<S>(n, d, buf, z, len, cbi);
         RETURN_IF_TERMINATED(rv);
     }
 
@@ -202,7 +223,9 @@ hwlm_error_t scanDoubleMain(const struct noodTable *n, const u8 *buf,
 static really_inline
 hwlm_error_t scanSingle(const struct noodTable *n, const u8 *buf, size_t len,
                         size_t start, bool noCase, const struct cb_info *cbi) {
-/*    if (len < VECTORSIZE) {
+    /*
+     * TODO: Investigate adding scalar case for smaller sizes
+    if (len < VECTORSIZE) {
       return scanSingleSlow(n, buf, len, start, noCase, n->key0, cbi);
     }*/