]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
control long length within AVX2 implementation
authorYann Collet <yann.collet.73@gmail.com>
Wed, 8 Jan 2025 00:42:36 +0000 (16:42 -0800)
committerYann Collet <cyan@fb.com>
Thu, 16 Jan 2025 01:11:27 +0000 (17:11 -0800)
lib/compress/zstd_compress.c

index c8dc86ccf9468c7becdc17557c4f2baa0cf32222..a5298031303c182a6e98e299c5f64687c2c7c449 100644 (file)
@@ -7121,8 +7121,12 @@ size_t ZSTD_compressSequences(ZSTD_CCtx* cctx,
  * At the end, instead of extracting two __m128i,
  * we use _mm256_permute4x64_epi64(..., 0xE8) to move lane2 into lane1,
  * then store the lower 16 bytes in one go.
+ *
+ * @returns 0 on succes, with no long length detected
+ * @returns > 0 if there is one long length (> 65535),
+ * indicating the position, and type.
  */
-void convertSequences_noRepcodes(
+size_t convertSequences_noRepcodes(
     SeqDef* dstSeqs,
     const ZSTD_Sequence* inSeqs,
     size_t nbSequences)
@@ -7136,6 +7140,9 @@ void convertSequences_noRepcodes(
         ZSTD_REP_NUM, 0, -MINMATCH, 0     /* for sequence i+1 */
     );
 
+    /* limit: check if there is a long length */
+    const __m256i limit = _mm256_set1_epi32(65535);
+
     /*
      * shuffle mask for byte-level rearrangement in each 128-bit half:
      *
@@ -7170,16 +7177,20 @@ void convertSequences_noRepcodes(
      */
 #define PERM_LANE_0X_E8 0xE8  /* [0,2,2,3] in lane indices */
 
-    size_t i = 0;
+    size_t longLen = 0, i = 0;
     /* Process 2 sequences per loop iteration */
     for (; i + 1 < nbSequences; i += 2) {
-        /* 1) Load 2 ZSTD_Sequence (32 bytes) */
+        /* Load 2 ZSTD_Sequence (32 bytes) */
         __m256i vin  = _mm256_loadu_si256((__m256i const*)&inSeqs[i]);
 
-        /* 2) Add {2, 0, -3, 0} in each 128-bit half */
+        /* Add {2, 0, -3, 0} in each 128-bit half */
         __m256i vadd = _mm256_add_epi32(vin, addition);
 
-        /* 3) Shuffle bytes so each half gives us the 8 bytes we need */
+        /* Check for long length */
+        __m256i cmp  = _mm256_cmpgt_epi32(vadd, limit);  // 0xFFFFFFFF for element > 65535
+        int cmp_res  = _mm256_movemask_epi8(cmp);
+
+        /* Shuffle bytes so each half gives us the 8 bytes we need */
         __m256i vshf = _mm256_shuffle_epi8(vadd, mask);
         /*
          * Now:
@@ -7189,105 +7200,47 @@ void convertSequences_noRepcodes(
          *   Lane3 = 0
          */
 
-        /* 4) Permute 64-bit lanes => move Lane2 down into Lane1. */
+        /* Permute 64-bit lanes => move Lane2 down into Lane1. */
         __m256i vperm = _mm256_permute4x64_epi64(vshf, PERM_LANE_0X_E8);
         /*
          * Now the lower 16 bytes (Lane0+Lane1) = [seq0, seq1].
          * The upper 16 bytes are [Lane2, Lane3] = [seq1, 0], but we won't use them.
          */
 
-        /* 5) Store only the lower 16 bytes => 2 SeqDef (8 bytes each) */
+        /* Store only the lower 16 bytes => 2 SeqDef (8 bytes each) */
         _mm_storeu_si128((__m128i *)&dstSeqs[i], _mm256_castsi256_si128(vperm));
         /*
          * This writes out 16 bytes total:
          *   - offset 0..7  => seq0 (offBase, litLength, mlBase)
          *   - offset 8..15 => seq1 (offBase, litLength, mlBase)
          */
-    }
 
-    /* Handle leftover if @nbSequences is odd */
-    if (i < nbSequences) {
-        /* Fallback: process last sequence */
-        assert(i == nbSequences - 1);
-        dstSeqs[i].offBase = OFFSET_TO_OFFBASE(inSeqs[i].offset);
-        /* note: doesn't work if one length is > 65535 */
-        dstSeqs[i].litLength = (U16)inSeqs[i].litLength;
-        dstSeqs[i].mlBase = (U16)(inSeqs[i].matchLength - MINMATCH);
+        /* check (unlikely) long lengths > 65535
+         * indices for lengths correspond to bits [4..7], [8..11], [20..23], [24..27]
+         * => combined mask = 0x0FF00FF0
+         */
+        if (UNLIKELY((cmp_res & 0x0FF00FF0) != 0)) {
+            /* long length detected: let's figure out which one*/
+            if (inSeqs[i].matchLength > 65535+MINMATCH) {
+                assert(longLen == 0);
+                longLen = i + 1;
+            }
+            if (inSeqs[i].litLength > 65535) {
+                assert(longLen == 0);
+                longLen = i + nbSequences + 1;
+            }
+            if (inSeqs[i+1].matchLength > 65535+MINMATCH) {
+                assert(longLen == 0);
+                longLen = i + 1 + 1;
+            }
+            if (inSeqs[i+1].litLength > 65535) {
+                assert(longLen == 0);
+                longLen = i + 1 + nbSequences + 1;
+            }
+        }
     }
-}
-
-#elif defined(__SSSE3__)
-
-#include <tmmintrin.h>  /* SSSE3 intrinsics: _mm_shuffle_epi8 */
-#include <emmintrin.h>  /* SSE2 intrinsics:  _mm_add_epi32, etc. */
 
-/*
- * Convert sequences with SSE.
- * - offset   -> offBase = offset + 2
- * - litLength (32-bit) -> (U16) litLength
- * - matchLength (32-bit) -> (U16)(matchLength - 3)
- * - rep is discarded.
- *
- * We shuffle so that only the first 8 bytes in the final 128-bit
- * register are used. We still store 16 bytes (low 8 are good, high 8 are "don't care").
- */
-static void convertSequences_noRepcodes(SeqDef* dstSeqs,
-                                const ZSTD_Sequence* inSeqs,
-                                size_t nbSequences)
-{
-    /*
-       addition = { offset+2, litLength+0, matchLength-3, rep+0 }
-       setr means the first argument is placed in the lowest 32 bits,
-       second in next-lower 32 bits, etc.
-    */
-    const __m128i addition = _mm_setr_epi32(2, 0, -3, 0);
-
-    /*
-       Shuffle mask: we reorder bytes after the addition.
-
-       Input layout in 128-bit register (after addition):
-         Bytes:   [ 0..3 | 4..7 | 8..11 | 12..15 ]
-         Fields:   offset+2   litLength  matchLength   rep
-
-       We want in output:
-         Bytes:   [ 0..3 | 4..5 | 6..7 | 8..15 ignore ]
-         Fields:   offset+2   (U16)litLength (U16)(matchLength)
-
-       _mm_shuffle_epi8 picks bytes from the source. A byte of 0x80 means “zero out”.
-       So we want:
-         out[0] = in[0], out[1] = in[1], out[2] = in[2], out[3] = in[3],     // offset+2 (4 bytes)
-         out[4] = in[4], out[5] = in[5],                                   // (U16) litLength
-         out[6] = in[8], out[7] = in[9],                                   // (U16) matchLength
-         out[8..15] = 0x80 => won't matter if we only care about first 8 bytes
-     */
-    const __m128i mask = _mm_setr_epi8(
-        0, 1, 2, 3,       /* offset (4 bytes)       */
-        4, 5,             /* litLength (2 bytes)    */
-        8, 9,             /* matchLength (2 bytes)  */
-        (char)0x80, (char)0x80, (char)0x80, (char)0x80,
-        (char)0x80, (char)0x80, (char)0x80, (char)0x80
-    );
-    size_t i;
-
-    for (i = 0; i + 1 < nbSequences; i += 2) {
-        /*-------------------------*/
-        /* Process inSeqs[i]      */
-        /*-------------------------*/
-        __m128i vin0  = _mm_loadu_si128((const __m128i  *)(const void*)&inSeqs[i]);
-        __m128i vadd0 = _mm_add_epi32(vin0, addition);
-        __m128i vshf0 = _mm_shuffle_epi8(vadd0, mask);
-        _mm_storel_epi64((__m128i *)(void*)&dstSeqs[i], vshf0);
-
-        /*-------------------------*/
-        /* Process inSeqs[i + 1]  */
-        /*-------------------------*/
-        __m128i vin1  = _mm_loadu_si128((__m128i const *)(const void*)&inSeqs[i + 1]);
-        __m128i vadd1 = _mm_add_epi32(vin1, addition);
-        __m128i vshf1 = _mm_shuffle_epi8(vadd1, mask);
-        _mm_storel_epi64((__m128i *)(void*)&dstSeqs[i + 1], vshf1);
-    }
-
-    /* Handle leftover if nbSequences is odd */
+    /* Handle leftover if @nbSequences is odd */
     if (i < nbSequences) {
         /* Fallback: process last sequence */
         assert(i == nbSequences - 1);
@@ -7295,11 +7248,24 @@ static void convertSequences_noRepcodes(SeqDef* dstSeqs,
         /* note: doesn't work if one length is > 65535 */
         dstSeqs[i].litLength = (U16)inSeqs[i].litLength;
         dstSeqs[i].mlBase = (U16)(inSeqs[i].matchLength - MINMATCH);
+        if (UNLIKELY(inSeqs[i].matchLength > 65535+MINMATCH)) {
+            assert(longLen == 0);
+            longLen = i + 1;
+        }
+        if (UNLIKELY(inSeqs[i].litLength > 65535)) {
+            assert(longLen == 0);
+            longLen = i + nbSequences + 1;
+        }
     }
 
+    return longLen;
 }
 
-#else /* no SSE */
+/* the vector implementation could also be ported to SSSE3,
+ * but since this implementation is targeting modern systems >= Sapphire Rapid,
+ * it's not useful to develop and maintain code for older platforms (before AVX2) */
+
+#else /* no AVX2 */
 
 static size_t
 convertSequences_noRepcodes(SeqDef* dstSeqs,
@@ -7312,6 +7278,10 @@ convertSequences_noRepcodes(SeqDef* dstSeqs,
         /* note: doesn't work if one length is > 65535 */
         dstSeqs[n].litLength = (U16)inSeqs[n].litLength;
         dstSeqs[n].mlBase = (U16)(inSeqs[n].matchLength - MINMATCH);
+        if (UNLIKELY(inSeqs[n].matchLength > 65535+MINMATCH)) {
+            assert(longLen == 0);
+            longLen = n + 1;
+        }
         if (UNLIKELY(inSeqs[n].litLength > 65535)) {
             assert(longLen == 0);
             longLen = n + nbSequences + 1;