]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
Move NEON version to a separate function and fix indentation
authorDanila Kutenin <kutdanila@yandex.ru>
Mon, 23 May 2022 14:49:35 +0000 (14:49 +0000)
committerDanila Kutenin <kutdanila@yandex.ru>
Mon, 23 May 2022 14:49:35 +0000 (14:49 +0000)
lib/compress/zstd_lazy.c

index 5e9a06b9eee3ddf1a94bb15b725f987130e34ede..6404d29f998ea08ebc5acf06263ccdaa070ffc88 100644 (file)
@@ -954,6 +954,29 @@ void ZSTD_row_update(ZSTD_matchState_t* const ms, const BYTE* ip) {
     ZSTD_row_update_internal(ms, ip, mls, rowLog, rowMask, 0 /* don't use cache */);
 }
 
+/* Returns the mask width of bits group of which will be set to 1. Given not all
+ * architectures have easy movemask instruction, this helps to iterate over
+ * groups of bits easier and faster.
+ */
+FORCE_INLINE_TEMPLATE U32
+ZSTD_row_matchMaskGroupWidth(const U32 rowEntries)
+{
+    assert((rowEntries == 16) || (rowEntries == 32) || rowEntries == 64);
+    assert(rowEntries <= ZSTD_ROW_HASH_MAX_ENTRIES);
+#if defined(ZSTD_ARCH_ARM_NEON)
+    if (rowEntries == 16) {
+        return 4;
+    }
+    if (rowEntries == 32) {
+        return 2;
+    }
+    if (rowEntries == 64) {
+        return 1;
+    }
+#endif
+    return 1;
+}
+
 #if defined(ZSTD_ARCH_X86_SSE2)
 FORCE_INLINE_TEMPLATE ZSTD_VecMask
 ZSTD_row_getSSEMask(int nbChunks, const BYTE* const src, const BYTE tag, const U32 head)
@@ -974,28 +997,53 @@ ZSTD_row_getSSEMask(int nbChunks, const BYTE* const src, const BYTE tag, const U
 }
 #endif
 
-/* Returns the mask width of bits group of which will be set to 1. Given not all
- * architectures have easy movemask instruction, this helps to iterate over
- * groups of bits easier and faster.
- */
-FORCE_INLINE_TEMPLATE U32
-ZSTD_row_matchMaskGroupWidth(const U32 rowEntries) {
-  assert((rowEntries == 16) || (rowEntries == 32) || rowEntries == 64);
-  assert(rowEntries <= ZSTD_ROW_HASH_MAX_ENTRIES);
-  (void)rowEntries;
 #if defined(ZSTD_ARCH_ARM_NEON)
-  if (rowEntries == 16) {
-    return 4;
-  }
-  if (rowEntries == 32) {
-    return 2;
-  }
-  if (rowEntries == 64) {
-    return 1;
-  }
-#endif
-  return 1;
+FORCE_INLINE_TEMPLATE ZSTD_VecMask
+ZSTD_row_getNEONMask(const U32 rowEntries, const BYTE* const src, const BYTE tag, const U32 headGrouped)
+{
+    assert((rowEntries == 16) || (rowEntries == 32) || rowEntries == 64);
+    if (rowEntries == 16) {
+        /* vshrn_n_u16 shifts by 4 every u16 and narrows to 8 lower bits.
+         * After that groups of 4 bits represent the equalMask. We lower
+         * all bits except the highest in these groups by doing AND with
+         * 0x88 = 0b10001000.
+         */
+        const uint8x16_t chunk = vld1q_u8(src);
+        const uint16x8_t equalMask = vreinterpretq_u16_u8(vceqq_u8(chunk, vdupq_n_u8(tag)));
+        const uint8x8_t res = vshrn_n_u16(equalMask, 4);
+        const U64 matches = vget_lane_u64(vreinterpret_u64_u8(res), 0);
+        return ZSTD_rotateRight_U64(matches, headGrouped) & 0x8888888888888888ull;
+    } else if (rowEntries == 32) {
+        /* Same idea as with rowEntries == 16 but doing AND with
+         * 0x55 = 0b01010101.
+         */
+        const uint16x8x2_t chunk = vld2q_u16((const uint16_t*)(const void*)src);
+        const uint8x16_t chunk0 = vreinterpretq_u8_u16(chunk.val[0]);
+        const uint8x16_t chunk1 = vreinterpretq_u8_u16(chunk.val[1]);
+        const uint8x16_t dup = vdupq_n_u8(tag);
+        const uint8x8_t t0 = vshrn_n_u16(vreinterpretq_u16_u8(vceqq_u8(chunk0, dup)), 6);
+        const uint8x8_t t1 = vshrn_n_u16(vreinterpretq_u16_u8(vceqq_u8(chunk1, dup)), 6);
+        const uint8x8_t res = vsli_n_u8(t0, t1, 4);
+        const U64 matches = vget_lane_u64(vreinterpret_u64_u8(res), 0) ;
+        return ZSTD_rotateRight_U64(matches, headGrouped) & 0x5555555555555555ull;
+    } else { /* rowEntries == 64 */
+        const uint8x16x4_t chunk = vld4q_u8(src);
+        const uint8x16_t dup = vdupq_n_u8(tag);
+        const uint8x16_t cmp0 = vceqq_u8(chunk.val[0], dup);
+        const uint8x16_t cmp1 = vceqq_u8(chunk.val[1], dup);
+        const uint8x16_t cmp2 = vceqq_u8(chunk.val[2], dup);
+        const uint8x16_t cmp3 = vceqq_u8(chunk.val[3], dup);
+
+        const uint8x16_t t0 = vsriq_n_u8(cmp1, cmp0, 1);
+        const uint8x16_t t1 = vsriq_n_u8(cmp3, cmp2, 1);
+        const uint8x16_t t2 = vsriq_n_u8(t1, t0, 2);
+        const uint8x16_t t3 = vsriq_n_u8(t2, t2, 4);
+        const uint8x8_t t4 = vshrn_n_u16(vreinterpretq_u16_u8(t3), 4);
+        const U64 matches = vget_lane_u64(vreinterpret_u64_u8(t4), 0);
+        return ZSTD_rotateRight_U64(matches, headGrouped);
+    }
 }
+#endif
 
 /* Returns a ZSTD_VecMask (U64) that has the nth group (determined by
  * ZSTD_row_matchMaskGroupWidth) of bits set to 1 if the newly-computed "tag"
@@ -1020,46 +1068,7 @@ ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 headGr
 # if defined(ZSTD_ARCH_ARM_NEON)
   /* This NEON path only works for little endian - otherwise use SWAR below */
     if (MEM_isLittleEndian()) {
-        if (rowEntries == 16) {
-            /* vshrn_n_u16 shifts by 4 every u16 and narrows to 8 lower bits.
-             * After that groups of 4 bits represent the equalMask. We lower
-             * all bits except the highest in these groups by doing AND with
-             * 0x88 = 0b10001000.
-             */
-            const uint8x16_t chunk = vld1q_u8(src);
-            const uint16x8_t equalMask = vreinterpretq_u16_u8(vceqq_u8(chunk, vdupq_n_u8(tag)));
-            const uint8x8_t res = vshrn_n_u16(equalMask, 4);
-            const U64 matches = vget_lane_u64(vreinterpret_u64_u8(res), 0);
-            return ZSTD_rotateRight_U64(matches, headGrouped) & 0x8888888888888888ull;
-        } else if (rowEntries == 32) {
-            /* Same idea as with rowEntries == 16 but doing AND with
-             * 0x55 = 0b01010101.
-             */
-            const uint16x8x2_t chunk = vld2q_u16((const uint16_t*)(const void*)src);
-            const uint8x16_t chunk0 = vreinterpretq_u8_u16(chunk.val[0]);
-            const uint8x16_t chunk1 = vreinterpretq_u8_u16(chunk.val[1]);
-            const uint8x16_t dup = vdupq_n_u8(tag);
-            const uint8x8_t t0 = vshrn_n_u16(vreinterpretq_u16_u8(vceqq_u8(chunk0, dup)), 6);
-            const uint8x8_t t1 = vshrn_n_u16(vreinterpretq_u16_u8(vceqq_u8(chunk1, dup)), 6);
-            const uint8x8_t res = vsli_n_u8(t0, t1, 4);
-            const U64 matches = vget_lane_u64(vreinterpret_u64_u8(res), 0) ;
-            return ZSTD_rotateRight_U64(matches, headGrouped) & 0x5555555555555555ull;
-        } else { /* rowEntries == 64 */
-            const uint8x16x4_t chunk = vld4q_u8(src);
-            const uint8x16_t dup = vdupq_n_u8(tag);
-            const uint8x16_t cmp0 = vceqq_u8(chunk.val[0], dup);
-            const uint8x16_t cmp1 = vceqq_u8(chunk.val[1], dup);
-            const uint8x16_t cmp2 = vceqq_u8(chunk.val[2], dup);
-            const uint8x16_t cmp3 = vceqq_u8(chunk.val[3], dup);
-
-            const uint8x16_t t0 = vsriq_n_u8(cmp1, cmp0, 1);
-            const uint8x16_t t1 = vsriq_n_u8(cmp3, cmp2, 1);
-            const uint8x16_t t2 = vsriq_n_u8(t1, t0, 2);
-            const uint8x16_t t3 = vsriq_n_u8(t2, t2, 4);
-            const uint8x8_t t4 = vshrn_n_u16(vreinterpretq_u16_u8(t3), 4);
-            const U64 matches = vget_lane_u64(vreinterpret_u64_u8(t4), 0);
-            return ZSTD_rotateRight_U64(matches, headGrouped);
-        }
+        return ZSTD_row_getNEONMask(rowEntries, src, tag, headGrouped);
     }
 # endif /* ZSTD_ARCH_ARM_NEON */
     /* SWAR */