]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
[lazy] Optimize ZSTD_row_getMatchMask for level 8-10
authorDanila Kutenin <kutdanila@yandex.ru>
Sun, 22 May 2022 10:34:33 +0000 (10:34 +0000)
committerDanila Kutenin <kutdanila@yandex.ru>
Sun, 22 May 2022 10:44:24 +0000 (10:44 +0000)
We found that movemask is not used properly or consumes too much CPU.
This effort helps to optimize the movemask emulation on ARM.

For level 8-9 we saw 3-5% improvements. For level 10 we say 1.5%
improvement.

The key idea is not to use pure movemasks but to have groups of bits.
For rowEntries == 16, 32 we are going to have groups of size 4 and 2
respectively. It means that each bit will be duplicated within the group

Then we do AND to have only one bit set in the group so that iteration
with lowering bit `a &= (a - 1)` works as well.

Also, aarch64 does not have rotate instructions for 16 bit, only for 32
and 64, that's why we see more improvements for level 8-9.

vshrn_n_u16 instruction is used to achieve that: vshrn_n_u16 shifts by
4 every u16 and narrows to 8 lower bits. See the picture below. It's
also used in
[Folly](https://github.com/facebook/folly/blob/c5702590080aa5d0e8d666d91861d64634065132/folly/container/detail/F14Table.h#L446).
It also uses 2 cycles according to Neoverse-N{1,2} guidelines.

64 bit movemask is already well optimized. We have ongoing experiments
but were not able to validate other implementations work reliably faster.

lib/compress/zstd_lazy.c

index 912b59b9c606f0ea589130bc231d725bc7d57ee2..d897832bae098495025e74378103f0cdb8922a13 100644 (file)
@@ -974,20 +974,45 @@ ZSTD_row_getSSEMask(int nbChunks, const BYTE* const src, const BYTE tag, const U
 }
 #endif
 
-/* Returns a ZSTD_VecMask (U32) that has the nth bit set to 1 if the newly-computed "tag" matches
- * the hash at the nth position in a row of the tagTable.
- * Each row is a circular buffer beginning at the value of "head". So we must rotate the "matches" bitfield
- * to match up with the actual layout of the entries within the hashTable */
+/* Returns the mask width of bits group of which will be set to 1. Given not all
+ * architectures have easy movemask instruction, this helps to iterate over
+ * groups of bits easier and faster.
+ */
+FORCE_INLINE_TEMPLATE U32
+ZSTD_row_matchMaskGroupWidth(const U32 rowEntries) {
+  assert((rowEntries == 16) || (rowEntries == 32) || rowEntries == 64);
+  assert(rowEntries <= ZSTD_ROW_HASH_MAX_ENTRIES);
+#if defined(ZSTD_ARCH_ARM_NEON)
+  if (rowEntries == 16) {
+    return 4;
+  }
+  if (rowEntries == 32) {
+    return 2;
+  }
+  if (rowEntries == 64) {
+    return 1;
+  }
+#endif
+  return 1;
+}
+
+/* Returns a ZSTD_VecMask (U64) that has the nth group (determined by
+ * ZSTD_row_matchMaskGroupWidth) of bits set to 1 if the newly-computed "tag"
+ * matches the hash at the nth position in a row of the tagTable.
+ * Each row is a circular buffer beginning at the value of "headGrouped". So we
+ * must rotate the "matches" bitfield to match up with the actual layout of the
+ * entries within the hashTable */
 FORCE_INLINE_TEMPLATE ZSTD_VecMask
-ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 head, const U32 rowEntries)
+ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 headGrouped, const U32 rowEntries)
 {
     const BYTE* const src = tagRow + ZSTD_ROW_HASH_TAG_OFFSET;
     assert((rowEntries == 16) || (rowEntries == 32) || rowEntries == 64);
     assert(rowEntries <= ZSTD_ROW_HASH_MAX_ENTRIES);
+    assert(ZSTD_row_matchMaskGroupWidth(rowEntries) * rowEntries <= sizeof(ZSTD_VecMask) * 8);
 
 #if defined(ZSTD_ARCH_X86_SSE2)
 
-    return ZSTD_row_getSSEMask(rowEntries / 16, src, tag, head);
+    return ZSTD_row_getSSEMask(rowEntries / 16, src, tag, headGrouped);
 
 #else /* SW or NEON-LE */
 
@@ -995,30 +1020,29 @@ ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 head,
   /* This NEON path only works for little endian - otherwise use SWAR below */
     if (MEM_isLittleEndian()) {
         if (rowEntries == 16) {
+            /* vshrn_n_u16 shifts by 4 every u16 and narrows to 8 lower bits.
+             * After that groups of 4 bits represent the equalMask. We lower
+             * all bits except the highest in these groups by doing AND with
+             * 0x88 = 0b10001000.
+             */
             const uint8x16_t chunk = vld1q_u8(src);
             const uint16x8_t equalMask = vreinterpretq_u16_u8(vceqq_u8(chunk, vdupq_n_u8(tag)));
-            const uint16x8_t t0 = vshlq_n_u16(equalMask, 7);
-            const uint32x4_t t1 = vreinterpretq_u32_u16(vsriq_n_u16(t0, t0, 14));
-            const uint64x2_t t2 = vreinterpretq_u64_u32(vshrq_n_u32(t1, 14));
-            const uint8x16_t t3 = vreinterpretq_u8_u64(vsraq_n_u64(t2, t2, 28));
-            const U16 hi = (U16)vgetq_lane_u8(t3, 8);
-            const U16 lo = (U16)vgetq_lane_u8(t3, 0);
-            return ZSTD_rotateRight_U16((hi << 8) | lo, head);
+            const uint8x8_t res = vshrn_n_u16(equalMask, 4);
+            const U64 matches = vget_lane_u64(vreinterpret_u64_u8(res), 0);
+            return ZSTD_rotateRight_U64(matches, headGrouped) & 0x8888888888888888ull;
         } else if (rowEntries == 32) {
-            const uint16x8x2_t chunk = vld2q_u16((const U16*)(const void*)src);
+            /* Same idea as with rowEntries == 16 but doing AND with
+             * 0x55 = 0b01010101.
+             */
+            const uint16x8x2_t chunk = vld2q_u16((const uint16_t*)(const void*)src);
             const uint8x16_t chunk0 = vreinterpretq_u8_u16(chunk.val[0]);
             const uint8x16_t chunk1 = vreinterpretq_u8_u16(chunk.val[1]);
-            const uint8x16_t equalMask0 = vceqq_u8(chunk0, vdupq_n_u8(tag));
-            const uint8x16_t equalMask1 = vceqq_u8(chunk1, vdupq_n_u8(tag));
-            const int8x8_t pack0 = vqmovn_s16(vreinterpretq_s16_u8(equalMask0));
-            const int8x8_t pack1 = vqmovn_s16(vreinterpretq_s16_u8(equalMask1));
-            const uint8x8_t t0 = vreinterpret_u8_s8(pack0);
-            const uint8x8_t t1 = vreinterpret_u8_s8(pack1);
-            const uint8x8_t t2 = vsri_n_u8(t1, t0, 2);
-            const uint8x8x2_t t3 = vuzp_u8(t2, t0);
-            const uint8x8_t t4 = vsri_n_u8(t3.val[1], t3.val[0], 4);
-            const U32 matches = vget_lane_u32(vreinterpret_u32_u8(t4), 0);
-            return ZSTD_rotateRight_U32(matches, head);
+            const uint8x16_t dup = vdupq_n_u8(tag);
+            const uint8x8_t t0 = vshrn_n_u16(vreinterpretq_u16_u8(vceqq_u8(chunk0, dup)), 6);
+            const uint8x8_t t1 = vshrn_n_u16(vreinterpretq_u16_u8(vceqq_u8(chunk1, dup)), 6);
+            const uint8x8_t res = vsli_n_u8(t0, t1, 4);
+            const U64 matches = vget_lane_u64(vreinterpret_u64_u8(res), 0) ;
+            return ZSTD_rotateRight_U64(matches, headGrouped) & 0x5555555555555555ull;
         } else { /* rowEntries == 64 */
             const uint8x16x4_t chunk = vld4q_u8(src);
             const uint8x16_t dup = vdupq_n_u8(tag);
@@ -1033,7 +1057,7 @@ ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 head,
             const uint8x16_t t3 = vsriq_n_u8(t2, t2, 4);
             const uint8x8_t t4 = vshrn_n_u16(vreinterpretq_u16_u8(t3), 4);
             const U64 matches = vget_lane_u64(vreinterpret_u64_u8(t4), 0);
-            return ZSTD_rotateRight_U64(matches, head);
+            return ZSTD_rotateRight_U64(matches, headGrouped);
         }
     }
 # endif /* ZSTD_ARCH_ARM_NEON */
@@ -1071,11 +1095,11 @@ ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 head,
         }
         matches = ~matches;
         if (rowEntries == 16) {
-            return ZSTD_rotateRight_U16((U16)matches, head);
+            return ZSTD_rotateRight_U16((U16)matches, headGrouped);
         } else if (rowEntries == 32) {
-            return ZSTD_rotateRight_U32((U32)matches, head);
+            return ZSTD_rotateRight_U32((U32)matches, headGrouped);
         } else {
-            return ZSTD_rotateRight_U64((U64)matches, head);
+            return ZSTD_rotateRight_U64((U64)matches, headGrouped);
         }
     }
 #endif
@@ -1123,6 +1147,7 @@ size_t ZSTD_RowFindBestMatch(
     const U32 rowEntries = (1U << rowLog);
     const U32 rowMask = rowEntries - 1;
     const U32 cappedSearchLog = MIN(cParams->searchLog, rowLog); /* nb of searches is capped at nb entries per row */
+    const U32 groupWidth = ZSTD_row_matchMaskGroupWidth(rowEntries);
     U32 nbAttempts = 1U << cappedSearchLog;
     size_t ml=4-1;
 
@@ -1165,15 +1190,15 @@ size_t ZSTD_RowFindBestMatch(
         U32 const tag = hash & ZSTD_ROW_HASH_TAG_MASK;
         U32* const row = hashTable + relRow;
         BYTE* tagRow = (BYTE*)(tagTable + relRow);
-        U32 const head = *tagRow & rowMask;
+        U32 const headGrouped = (*tagRow & rowMask) * groupWidth;
         U32 matchBuffer[ZSTD_ROW_HASH_MAX_ENTRIES];
         size_t numMatches = 0;
         size_t currMatch = 0;
-        ZSTD_VecMask matches = ZSTD_row_getMatchMask(tagRow, (BYTE)tag, head, rowEntries);
+        ZSTD_VecMask matches = ZSTD_row_getMatchMask(tagRow, (BYTE)tag, headGrouped, rowEntries);
 
         /* Cycle through the matches and prefetch */
         for (; (matches > 0) && (nbAttempts > 0); --nbAttempts, matches &= (matches - 1)) {
-            U32 const matchPos = (head + ZSTD_VecMask_next(matches)) & rowMask;
+            U32 const matchPos = ((headGrouped + ZSTD_VecMask_next(matches)) / groupWidth) & rowMask;
             U32 const matchIndex = row[matchPos];
             assert(numMatches < rowEntries);
             if (matchIndex < lowLimit)
@@ -1234,14 +1259,14 @@ size_t ZSTD_RowFindBestMatch(
         const U32 dmsSize              = (U32)(dmsEnd - dmsBase);
         const U32 dmsIndexDelta        = dictLimit - dmsSize;
 
-        {   U32 const head = *dmsTagRow & rowMask;
+        {   U32 const headGrouped = (*dmsTagRow & rowMask) * groupWidth;
             U32 matchBuffer[ZSTD_ROW_HASH_MAX_ENTRIES];
             size_t numMatches = 0;
             size_t currMatch = 0;
-            ZSTD_VecMask matches = ZSTD_row_getMatchMask(dmsTagRow, (BYTE)dmsTag, head, rowEntries);
+            ZSTD_VecMask matches = ZSTD_row_getMatchMask(dmsTagRow, (BYTE)dmsTag, headGrouped, rowEntries);
 
             for (; (matches > 0) && (nbAttempts > 0); --nbAttempts, matches &= (matches - 1)) {
-                U32 const matchPos = (head + ZSTD_VecMask_next(matches)) & rowMask;
+                U32 const matchPos = ((headGrouped + ZSTD_VecMask_next(matches)) / groupWidth) & rowMask;
                 U32 const matchIndex = dmsRow[matchPos];
                 if (matchIndex < dmsLowestIndex)
                     break;