]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
fixed Huff0 quad-symbols decoder (#173)
authorYann Collet <yann.collet.73@gmail.com>
Thu, 5 May 2016 10:41:36 +0000 (12:41 +0200)
committerYann Collet <yann.collet.73@gmail.com>
Thu, 5 May 2016 10:41:36 +0000 (12:41 +0200)
lib/common/huf_static.h
lib/decompress/huf_decompress.c

index d2d29a42bdbd4eee9cdf65df753fbff24253b9ed..e68ec33ddc122672b76e14c41f2e23eb229d6d54 100644 (file)
@@ -76,7 +76,7 @@ extern "C" {
 ******************************************/
 size_t HUF_decompress4X2 (void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize);   /* single-symbol decoder */
 size_t HUF_decompress4X4 (void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize);   /* double-symbols decoder */
-size_t HUF_decompress4X6 (void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize);   /* quad-symbols decoder */
+size_t HUF_decompress4X6 (void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize);   /* quad-symbols decoder, only works for dstSize >= 64 */
 
 
 /* ****************************************
@@ -122,7 +122,7 @@ size_t HUF_compress1X_usingCTable(void* dst, size_t dstSize, const void* src, si
 
 size_t HUF_decompress1X2 (void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize);   /* single-symbol decoder */
 size_t HUF_decompress1X4 (void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize);   /* double-symbol decoder */
-size_t HUF_decompress1X6 (void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize);   /* quad-symbol decoder */
+size_t HUF_decompress1X6 (void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize);   /* quad-symbols decoder, only works for dstSize >= 64 */
 
 size_t HUF_decompress1X2_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const unsigned short* DTable);
 size_t HUF_decompress1X4_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const unsigned* DTable);
@@ -157,7 +157,6 @@ MEM_STATIC size_t HUF_readStats(BYTE* huffWeight, size_t hwSize, U32* rankStats,
                             const void* src, size_t srcSize)
 {
     U32 weightTotal;
-    U32 tableLog;
     const BYTE* ip = (const BYTE*) src;
     size_t iSize = ip[0];
     size_t oSize;
@@ -191,31 +190,31 @@ MEM_STATIC size_t HUF_readStats(BYTE* huffWeight, size_t hwSize, U32* rankStats,
     /* collect weight stats */
     memset(rankStats, 0, (HUF_ABSOLUTEMAX_TABLELOG + 1) * sizeof(U32));
     weightTotal = 0;
-    { U32 n; for (n=0; n<oSize; n++) {
-        if (huffWeight[n] >= HUF_ABSOLUTEMAX_TABLELOG) return ERROR(corruption_detected);
-        rankStats[huffWeight[n]]++;
-        weightTotal += (1 << huffWeight[n]) >> 1;
-    }}
+    {   U32 n; for (n=0; n<oSize; n++) {
+            if (huffWeight[n] >= HUF_ABSOLUTEMAX_TABLELOG) return ERROR(corruption_detected);
+            rankStats[huffWeight[n]]++;
+            weightTotal += (1 << huffWeight[n]) >> 1;
+    }   }
 
     /* get last non-null symbol weight (implied, total must be 2^n) */
-    tableLog = BIT_highbit32(weightTotal) + 1;
-    if (tableLog > HUF_ABSOLUTEMAX_TABLELOG) return ERROR(corruption_detected);
-    /* determine last weight */
-    {   U32 const total = 1 << tableLog;
-        U32 const rest = total - weightTotal;
-        U32 const verif = 1 << BIT_highbit32(rest);
-        U32 const lastWeight = BIT_highbit32(rest) + 1;
-        if (verif != rest) return ERROR(corruption_detected);    /* last value must be a clean power of 2 */
-        huffWeight[oSize] = (BYTE)lastWeight;
-        rankStats[lastWeight]++;
-    }
+    {   U32 const tableLog = BIT_highbit32(weightTotal) + 1;
+        if (tableLog > HUF_ABSOLUTEMAX_TABLELOG) return ERROR(corruption_detected);
+        *tableLogPtr = tableLog;
+        /* determine last weight */
+        {   U32 const total = 1 << tableLog;
+            U32 const rest = total - weightTotal;
+            U32 const verif = 1 << BIT_highbit32(rest);
+            U32 const lastWeight = BIT_highbit32(rest) + 1;
+            if (verif != rest) return ERROR(corruption_detected);    /* last value must be a clean power of 2 */
+            huffWeight[oSize] = (BYTE)lastWeight;
+            rankStats[lastWeight]++;
+    }   }
 
     /* check tree construction validity */
     if ((rankStats[1] < 2) || (rankStats[1] & 1)) return ERROR(corruption_detected);   /* by construction : at least 2 elts of rank 1, must be even */
 
     /* results */
     *nbSymbolsPtr = (U32)(oSize+1);
-    *tableLogPtr = tableLog;
     return iSize+1;
 }
 
index 7637bd9413f1b55b0e472ce7c19f6aca54e09526..d64808f65a1831917e1da0c55a8e03f3b44b1514 100644 (file)
@@ -847,17 +847,17 @@ size_t HUF_readDTableX6 (U32* DTable, const void* src, size_t srcSize)
 
 static U32 HUF_decodeSymbolX6(void* op, BIT_DStream_t* DStream, const HUF_DDescX6* dd, const HUF_DSeqX6* ds, const U32 dtLog)
 {
-    const size_t val = BIT_lookBitsFast(DStream, dtLog);   /* note : dtLog >= 1 */
+    size_t const val = BIT_lookBitsFast(DStream, dtLog);   /* note : dtLog >= 1 */
     memcpy(op, ds+val, sizeof(HUF_DSeqX6));
     BIT_skipBits(DStream, dd[val].nbBits);
     return dd[val].nbBytes;
 }
 
-static U32 HUF_decodeLastSymbolsX6(void* op, const U32 maxL, BIT_DStream_t* DStream,
+static U32 HUF_decodeLastSymbolsX6(void* op, U32 const maxL, BIT_DStream_t* DStream,
                                   const HUF_DDescX6* dd, const HUF_DSeqX6* ds, const U32 dtLog)
 {
-    const size_t val = BIT_lookBitsFast(DStream, dtLog);   /* note : dtLog >= 1 */
-    U32 length = dd[val].nbBytes;
+    size_t const val = BIT_lookBitsFast(DStream, dtLog);   /* note : dtLog >= 1 */
+    U32 const length = dd[val].nbBytes;
     if (length <= maxL) {
         memcpy(op, ds+val, length);
         BIT_skipBits(DStream, dd[val].nbBits);
@@ -910,7 +910,6 @@ static inline size_t HUF_decodeStreamX6(BYTE* p, BIT_DStream_t* bitDPtr, BYTE* c
     return p-pStart;
 }
 
-
 size_t HUF_decompress1X6_usingDTable(
           void* dst,  size_t dstSize,
     const void* cSrc, size_t cSrcSize,
@@ -919,17 +918,15 @@ size_t HUF_decompress1X6_usingDTable(
     const BYTE* const istart = (const BYTE*) cSrc;
     BYTE* const ostart = (BYTE*) dst;
     BYTE* const oend = ostart + dstSize;
-
-    const U32 dtLog = DTable[0];
-    size_t errorCode;
+    BIT_DStream_t bitD;
 
     /* Init */
-    BIT_DStream_t bitD;
-    errorCode = BIT_initDStream(&bitD, istart, cSrcSize);
-    if (HUF_isError(errorCode)) return errorCode;
+    { size_t const errorCode = BIT_initDStream(&bitD, istart, cSrcSize);
+      if (HUF_isError(errorCode)) return errorCode; }
 
     /* finish bitStreams one by one */
-    HUF_decodeStreamX6(ostart, &bitD, oend, DTable, dtLog);
+    { U32 const dtLog = DTable[0];
+      HUF_decodeStreamX6(ostart, &bitD, oend, DTable, dtLog); }
 
     /* check */
     if (!BIT_endOfDStream(&bitD)) return ERROR(corruption_detected);
@@ -943,7 +940,7 @@ size_t HUF_decompress1X6 (void* dst, size_t dstSize, const void* cSrc, size_t cS
     HUF_CREATE_STATIC_DTABLEX6(DTable, HUF_MAX_TABLELOG);
     const BYTE* ip = (const BYTE*) cSrc;
 
-    size_t hSize = HUF_readDTableX6 (DTable, cSrc, cSrcSize);
+    size_t const hSize = HUF_readDTableX6 (DTable, cSrc, cSrcSize);
     if (HUF_isError(hSize)) return hSize;
     if (hSize >= cSrcSize) return ERROR(srcSize_wrong);
     ip += hSize;
@@ -953,6 +950,24 @@ size_t HUF_decompress1X6 (void* dst, size_t dstSize, const void* cSrc, size_t cS
 }
 
 
+#define HUF_DECODE_ROUNDX6 \
+            HUF_DECODE_SYMBOLX6_2(op1, &bitD1); \
+            HUF_DECODE_SYMBOLX6_2(op2, &bitD2); \
+            HUF_DECODE_SYMBOLX6_2(op3, &bitD3); \
+            HUF_DECODE_SYMBOLX6_2(op4, &bitD4); \
+            HUF_DECODE_SYMBOLX6_1(op1, &bitD1); \
+            HUF_DECODE_SYMBOLX6_1(op2, &bitD2); \
+            HUF_DECODE_SYMBOLX6_1(op3, &bitD3); \
+            HUF_DECODE_SYMBOLX6_1(op4, &bitD4); \
+            HUF_DECODE_SYMBOLX6_2(op1, &bitD1); \
+            HUF_DECODE_SYMBOLX6_2(op2, &bitD2); \
+            HUF_DECODE_SYMBOLX6_2(op3, &bitD3); \
+            HUF_DECODE_SYMBOLX6_2(op4, &bitD4); \
+            HUF_DECODE_SYMBOLX6_0(op1, &bitD1); \
+            HUF_DECODE_SYMBOLX6_0(op2, &bitD2); \
+            HUF_DECODE_SYMBOLX6_0(op3, &bitD3); \
+            HUF_DECODE_SYMBOLX6_0(op4, &bitD4);
+
 size_t HUF_decompress4X6_usingDTable(
           void* dst,  size_t dstSize,
     const void* cSrc, size_t cSrcSize,
@@ -960,6 +975,7 @@ size_t HUF_decompress4X6_usingDTable(
 {
     /* Check */
     if (cSrcSize < 10) return ERROR(corruption_detected);   /* strict minimum : jump table + 1 byte per stream */
+    if (dstSize  < 64) return ERROR(dstSize_tooSmall);      /* only work for dstSize >= 64 */
 
     {   const BYTE* const istart = (const BYTE*) cSrc;
         BYTE* const ostart = (BYTE*) dst;
@@ -970,7 +986,6 @@ size_t HUF_decompress4X6_usingDTable(
         const HUF_DDescX6* dd = (const HUF_DDescX6*)ddPtr;
         const void* const dsPtr = DTable + 1 + ((size_t)1<<(dtLog-1));
         const HUF_DSeqX6* ds = (const HUF_DSeqX6*)dsPtr;
-        size_t errorCode;
 
         /* Init */
         BIT_DStream_t bitD1;
@@ -997,43 +1012,41 @@ size_t HUF_decompress4X6_usingDTable(
 
         length4 = cSrcSize - (length1 + length2 + length3 + 6);
         if (length4 > cSrcSize) return ERROR(corruption_detected);   /* overflow */
-        errorCode = BIT_initDStream(&bitD1, istart1, length1);
-        if (HUF_isError(errorCode)) return errorCode;
-        errorCode = BIT_initDStream(&bitD2, istart2, length2);
-        if (HUF_isError(errorCode)) return errorCode;
-        errorCode = BIT_initDStream(&bitD3, istart3, length3);
-        if (HUF_isError(errorCode)) return errorCode;
-        errorCode = BIT_initDStream(&bitD4, istart4, length4);
-        if (HUF_isError(errorCode)) return errorCode;
-
-        /* 16-64 symbols per loop (4-16 symbols per stream) */
+        { size_t const errorCode = BIT_initDStream(&bitD1, istart1, length1);
+          if (HUF_isError(errorCode)) return errorCode; }
+        { size_t const errorCode = BIT_initDStream(&bitD2, istart2, length2);
+          if (HUF_isError(errorCode)) return errorCode; }
+        { size_t const errorCode = BIT_initDStream(&bitD3, istart3, length3);
+          if (HUF_isError(errorCode)) return errorCode; }
+        { size_t const errorCode = BIT_initDStream(&bitD4, istart4, length4);
+          if (HUF_isError(errorCode)) return errorCode; }
+
+        /* 4-64 symbols per loop (1-16 symbols per stream) */
         endSignal = BIT_reloadDStream(&bitD1) | BIT_reloadDStream(&bitD2) | BIT_reloadDStream(&bitD3) | BIT_reloadDStream(&bitD4);
-        for ( ; (op3 <= opStart4) && (endSignal==BIT_DStream_unfinished) && (op4<=(oend-16)) ; ) {
-            HUF_DECODE_SYMBOLX6_2(op1, &bitD1);
-            HUF_DECODE_SYMBOLX6_2(op2, &bitD2);
-            HUF_DECODE_SYMBOLX6_2(op3, &bitD3);
-            HUF_DECODE_SYMBOLX6_2(op4, &bitD4);
-            HUF_DECODE_SYMBOLX6_1(op1, &bitD1);
-            HUF_DECODE_SYMBOLX6_1(op2, &bitD2);
-            HUF_DECODE_SYMBOLX6_1(op3, &bitD3);
-            HUF_DECODE_SYMBOLX6_1(op4, &bitD4);
-            HUF_DECODE_SYMBOLX6_2(op1, &bitD1);
-            HUF_DECODE_SYMBOLX6_2(op2, &bitD2);
-            HUF_DECODE_SYMBOLX6_2(op3, &bitD3);
-            HUF_DECODE_SYMBOLX6_2(op4, &bitD4);
-            HUF_DECODE_SYMBOLX6_0(op1, &bitD1);
-            HUF_DECODE_SYMBOLX6_0(op2, &bitD2);
-            HUF_DECODE_SYMBOLX6_0(op3, &bitD3);
-            HUF_DECODE_SYMBOLX6_0(op4, &bitD4);
-
-            endSignal = BIT_reloadDStream(&bitD1) | BIT_reloadDStream(&bitD2) | BIT_reloadDStream(&bitD3) | BIT_reloadDStream(&bitD4);
-        }
+        if (endSignal==BIT_DStream_unfinished) {
+            HUF_DECODE_ROUNDX6;
+            if (sizeof(bitD1.bitContainer)==4) {   /* need to decode at least 4 bytes per stream */
+                    endSignal = BIT_reloadDStream(&bitD1) | BIT_reloadDStream(&bitD2) | BIT_reloadDStream(&bitD3) | BIT_reloadDStream(&bitD4);
+                    HUF_DECODE_ROUNDX6;
+            }
+            {   U32 const saved2 = MEM_read32(opStart2);   /* saved from overwrite */
+                U32 const saved3 = MEM_read32(opStart3);
+                U32 const saved4 = MEM_read32(opStart4);
+                endSignal = BIT_reloadDStream(&bitD1) | BIT_reloadDStream(&bitD2) | BIT_reloadDStream(&bitD3) | BIT_reloadDStream(&bitD4);
+                for ( ; (op3 <= opStart4) && (endSignal==BIT_DStream_unfinished) && (op4<=(oend-16)) ; ) {
+                    HUF_DECODE_ROUNDX6;
+                    endSignal = BIT_reloadDStream(&bitD1) | BIT_reloadDStream(&bitD2) | BIT_reloadDStream(&bitD3) | BIT_reloadDStream(&bitD4);
+                }
+                MEM_write32(opStart2, saved2);
+                MEM_write32(opStart3, saved3);
+                MEM_write32(opStart4, saved4);
+        }   }
 
         /* check corruption */
         if (op1 > opStart2) return ERROR(corruption_detected);
         if (op2 > opStart3) return ERROR(corruption_detected);
         if (op3 > opStart4) return ERROR(corruption_detected);
-        /* note : op4 supposed already verified within main loop */
+        /* note : op4 already verified within main loop */
 
         /* finish bitStreams one by one */
         HUF_decodeStreamX6(op1, &bitD1, opStart2, DTable, dtLog);
@@ -1097,12 +1110,7 @@ typedef size_t (*decompressionAlgo)(void* dst, size_t dstSize, const void* cSrc,
 size_t HUF_decompress (void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize)
 {
     static const decompressionAlgo decompress[3] = { HUF_decompress4X2, HUF_decompress4X4, HUF_decompress4X6 };
-    /* estimate decompression time */
-    U32 Q;
-    const U32 D256 = (U32)(dstSize >> 8);
-    U32 Dtime[3];
-    U32 algoNb = 0;
-    int n;
+    U32 Dtime[3];   /* decompression time estimation */
 
     /* validation checks */
     if (dstSize == 0) return ERROR(dstSize_tooSmall);
@@ -1111,16 +1119,19 @@ size_t HUF_decompress (void* dst, size_t dstSize, const void* cSrc, size_t cSrcS
     if (cSrcSize == 1) { memset(dst, *(const BYTE*)cSrc, dstSize); return dstSize; }   /* RLE */
 
     /* decoder timing evaluation */
-    Q = (U32)(cSrcSize * 16 / dstSize);   /* Q < 16 since dstSize > cSrcSize */
-    for (n=0; n<3; n++)
-        Dtime[n] = algoTime[Q][n].tableTime + (algoTime[Q][n].decode256Time * D256);
+    {   U32 const Q = (U32)(cSrcSize * 16 / dstSize);   /* Q < 16 since dstSize > cSrcSize */
+        U32 const D256 = (U32)(dstSize >> 8);
+        U32 n; for (n=0; n<3; n++)
+            Dtime[n] = algoTime[Q][n].tableTime + (algoTime[Q][n].decode256Time * D256);
+    }
 
     Dtime[1] += Dtime[1] >> 4; Dtime[2] += Dtime[2] >> 3; /* advantage to algorithms using less memory, for cache eviction */
 
-    if (Dtime[1] < Dtime[0]) algoNb = 1;
-    //if (Dtime[2] < Dtime[algoNb]) algoNb = 2;
-
-    return decompress[algoNb](dst, dstSize, cSrc, cSrcSize);
+    {   U32 algoNb = 0;
+        if (Dtime[1] < Dtime[0]) algoNb = 1;
+        if (Dtime[2] < Dtime[algoNb]) algoNb = 2;
+        return decompress[algoNb](dst, dstSize, cSrc, cSrcSize);
+    }
 
     //return HUF_decompress4X2(dst, dstSize, cSrc, cSrcSize);   /* multi-streams single-symbol decoding */
     //return HUF_decompress4X4(dst, dstSize, cSrc, cSrcSize);   /* multi-streams double-symbols decoding */