]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
refactored literal segment
authorYann Collet <yann.collet.73@gmail.com>
Mon, 19 Oct 2015 18:25:44 +0000 (19:25 +0100)
committerYann Collet <yann.collet.73@gmail.com>
Mon, 19 Oct 2015 18:25:44 +0000 (19:25 +0100)
lib/zstd.c

index e4e34c49aabda62bc719d00eadd49d0f2797df1e..716a467a006826e132d8fae2373e416f8a43f8de 100644 (file)
 #define BIT6  64
 #define BIT5  32
 #define BIT4  16
+#define BIT1   2
+#define BIT0   1
 
 #define KB *(1 <<10)
 #define MB *(1 <<20)
 #define GB *(1U<<30)
 
 #define BLOCKSIZE (128 KB)                 /* define, for static allocation */
+#define MIN_SEQUENCES_SIZE (2 /*seqNb*/ + 2 /*dumps*/ + 3 /*seqTables*/ + 1 /*bitStream*/)
+#define MIN_CBLOCK_SIZE (3 /*litCSize*/ + MIN_SEQUENCES_SIZE)
+#define IS_RAW BIT0
+#define IS_RLE BIT1
+
 static const U32 g_maxDistance = 4 * BLOCKSIZE;
 static const U32 g_maxLimit = 1 GB;
 static const U32 g_searchStrength = 8;
@@ -364,74 +371,72 @@ size_t ZSTD_compressBound(size_t srcSize)   /* maximum compressed size */
 }
 
 
-static size_t ZSTD_compressRle (void* dst, size_t maxDstSize, const void* src, size_t srcSize)
+static size_t ZSTD_noCompressBlock (void* dst, size_t maxDstSize, const void* src, size_t srcSize)
 {
     BYTE* const ostart = (BYTE* const)dst;
 
-    /* at this stage : dstSize >= FSE_compressBound(srcSize) > (ZSTD_blockHeaderSize+1) (checked by ZSTD_compressLiterals()) */
-    (void)maxDstSize;
-
-    ostart[ZSTD_blockHeaderSize] = *(const BYTE*)src;
+    if (srcSize + ZSTD_blockHeaderSize > maxDstSize) return ERROR(dstSize_tooSmall);
+    memcpy(ostart + ZSTD_blockHeaderSize, src, srcSize);
 
     /* Build header */
     ostart[0]  = (BYTE)(srcSize>>16);
     ostart[1]  = (BYTE)(srcSize>>8);
     ostart[2]  = (BYTE) srcSize;
-    ostart[0] += (BYTE)(bt_rle<<6);
+    ostart[0] += (BYTE)(bt_raw<<6);   /* is a raw (uncompressed) block */
 
-    return ZSTD_blockHeaderSize+1;
+    return ZSTD_blockHeaderSize+srcSize;
 }
 
 
-static size_t ZSTD_noCompressBlock (void* dst, size_t maxDstSize, const void* src, size_t srcSize)
+static size_t ZSTD_compressRawLiteralsBlock (void* dst, size_t maxDstSize, const void* src, size_t srcSize)
 {
     BYTE* const ostart = (BYTE* const)dst;
 
-    if (srcSize + ZSTD_blockHeaderSize > maxDstSize) return ERROR(dstSize_tooSmall);
-    memcpy(ostart + ZSTD_blockHeaderSize, src, srcSize);
-
-    /* Build header */
-    ostart[0]  = (BYTE)(srcSize>>16);
-    ostart[1]  = (BYTE)(srcSize>>8);
-    ostart[2]  = (BYTE) srcSize;
-    ostart[0] += (BYTE)(bt_raw<<6);   /* is a raw (uncompressed) block */
+    if (srcSize + 3 > maxDstSize) return ERROR(dstSize_tooSmall);
 
-    return ZSTD_blockHeaderSize+srcSize;
+    MEM_writeLE32(dst, ((U32)srcSize << 2) | IS_RAW);
+    memcpy(ostart + 3, src, srcSize);
+    return srcSize + 3;
 }
 
+static size_t ZSTD_compressRleLiteralsBlock (void* dst, size_t maxDstSize, const void* src, size_t srcSize)
+{
+    BYTE* const ostart = (BYTE* const)dst;
 
-size_t ZSTD_minGain(size_t srcSize) { return (srcSize >> 6) + 1; }
+    (void)maxDstSize;
+    MEM_writeLE32(dst, ((U32)srcSize << 2) | IS_RLE);  /* note : maxDstSize > litHeaderSize > 4 */
+    ostart[3] = *(const BYTE*)src;
+    return 4;
+}
 
+size_t ZSTD_minGain(size_t srcSize) { return (srcSize >> 6) + 1; }
 
-static size_t ZSTD_compressLiterals (void* dst, size_t dstSize,
+static size_t ZSTD_compressLiterals (void* dst, size_t maxDstSize,
                                      const void* src, size_t srcSize)
 {
     const size_t minGain = ZSTD_minGain(srcSize);
     BYTE* const ostart = (BYTE*)dst;
     size_t hsize;
-    static const size_t LHSIZE = 5;
+    static const size_t litHeaderSize = 5;
 
-    if (dstSize < LHSIZE+1) return ERROR(dstSize_tooSmall);   /* not enough space for compression */
+    if (maxDstSize < litHeaderSize+1) return ERROR(dstSize_tooSmall);   /* not enough space for compression */
 
-    hsize = HUF_compress(ostart+LHSIZE, dstSize-LHSIZE, src, srcSize);
-    if (hsize<2) return hsize;   /* special cases */
-    if (hsize >= srcSize - minGain) return 0;
+    hsize = HUF_compress(ostart+litHeaderSize, maxDstSize-litHeaderSize, src, srcSize);
 
-    hsize += 2;  /* work around vs fixed 3-bytes header */
+    if ((hsize==0) || (hsize >= srcSize - minGain)) return ZSTD_compressRawLiteralsBlock(dst, maxDstSize, src, srcSize);
+    if (hsize==1) return ZSTD_compressRleLiteralsBlock(dst, maxDstSize, src, srcSize);
 
     /* Build header */
     {
-        ostart[0]  = (BYTE)(bt_compressed<<6); /* is a block, is compressed */
-        ostart[0] += (BYTE)(hsize>>16);
-        ostart[1]  = (BYTE)(hsize>>8);
-        ostart[2]  = (BYTE)(hsize>>0);
-        ostart[0] += (BYTE)((srcSize>>16)<<3);
-        ostart[3]  = (BYTE)(srcSize>>8);
-        ostart[4]  = (BYTE)(srcSize>>0);
+        ostart[0]  = (BYTE)(srcSize << 2); /* is a block, is compressed */
+        ostart[1]  = (BYTE)(srcSize >> 6);
+        ostart[2]  = (BYTE)(srcSize >>14);
+        ostart[2] += (BYTE)(hsize << 5);
+        ostart[3]  = (BYTE)(hsize >> 3);
+        ostart[4]  = (BYTE)(hsize >>11);
     }
 
-    hsize -= 2;
-    return hsize+LHSIZE;
+    return hsize+litHeaderSize;
 }
 
 
@@ -451,17 +456,15 @@ static size_t ZSTD_compressSequences(BYTE* dst, size_t maxDstSize,
     const BYTE* const op_lit_start = seqStorePtr->litStart;
     const BYTE* op_lit = seqStorePtr->lit;
     const BYTE* const llTable = seqStorePtr->litLengthStart;
-    const BYTE* op_litLength = seqStorePtr->litLength;
+    const BYTE* const llPtr = seqStorePtr->litLength;
     const BYTE* const mlTable = seqStorePtr->matchLengthStart;
     const U32*  const offsetTable = seqStorePtr->offsetStart;
     BYTE* const offCodeTable = seqStorePtr->offCodeStart;
     BYTE* op = dst;
     BYTE* const oend = dst + maxDstSize;
-    const size_t nbSeq = op_litLength - llTable;
+    const size_t nbSeq = llPtr - llTable;
     const size_t minGain = ZSTD_minGain(srcSize);
     const size_t maxCSize = srcSize - minGain;
-    const size_t minSeqSize = 1 /*lastL*/ + 2 /*dHead*/ + 2 /*dumpsIn*/ + 5 /*SeqHead*/ + 3 /*SeqIn*/ + 1 /*margin*/ + ZSTD_blockHeaderSize;
-    const size_t maxLSize = maxCSize > minSeqSize ? maxCSize - minSeqSize : 0;
     BYTE* seqHead;
 
 
@@ -470,23 +473,16 @@ static size_t ZSTD_compressSequences(BYTE* dst, size_t maxDstSize,
         size_t cSize;
         size_t litSize = op_lit - op_lit_start;
 
-        if (litSize <= LITERAL_NOENTROPY) cSize = ZSTD_noCompressBlock (op, maxDstSize, op_lit_start, litSize);
+        if (litSize <= LITERAL_NOENTROPY)
+            cSize = ZSTD_compressRawLiteralsBlock(op, maxDstSize, op_lit_start, litSize);
         else
-        {
             cSize = ZSTD_compressLiterals(op, maxDstSize, op_lit_start, litSize);
-            if (cSize == 1) cSize = ZSTD_compressRle (op, maxDstSize, op_lit_start, litSize);
-            else if (cSize == 0)
-            {
-                if (litSize >= maxLSize) return 0;   /* block not compressible enough */
-                cSize = ZSTD_noCompressBlock (op, maxDstSize, op_lit_start, litSize);
-            }
-        }
         if (ZSTD_isError(cSize)) return cSize;
         op += cSize;
     }
 
     /* Sequences Header */
-    if ((oend-op) < 2+3+6)  /* nbSeq + dumpsLength + 3*rleCTable*/
+    if ((oend-op) < MIN_SEQUENCES_SIZE)
         return ERROR(dstSize_tooSmall);
     MEM_writeLE16(op, (U16)nbSeq); op+=2;
     seqHead = op;
@@ -1018,6 +1014,7 @@ struct ZSTD_Dctx_s
     BYTE litBuffer[BLOCKSIZE];
 };   /* typedef'd to ZSTD_Dctx within "zstd_static.h" */
 
+
 size_t ZSTD_getcBlockSize(const void* src, size_t srcSize, blockProperties_t* bpPtr)
 {
     const BYTE* const in = (const BYTE* const)src;
@@ -1037,7 +1034,6 @@ size_t ZSTD_getcBlockSize(const void* src, size_t srcSize, blockProperties_t* bp
     return cSize;
 }
 
-
 static size_t ZSTD_copyUncompressedBlock(void* dst, size_t maxDstSize, const void* src, size_t srcSize)
 {
     if (srcSize > maxDstSize) return ERROR(dstSize_tooSmall);
@@ -1046,77 +1042,69 @@ static size_t ZSTD_copyUncompressedBlock(void* dst, size_t maxDstSize, const voi
 }
 
 
-static size_t ZSTD_decompressLiterals(void* ctx,
-                                      void* dst, size_t maxDstSize,
+/** ZSTD_decompressLiterals
+    @return : nb of bytes read from src, or an error code*/
+static size_t ZSTD_decompressLiterals(void* dst, size_t* maxDstSizePtr,
                                 const void* src, size_t srcSize)
 {
-    BYTE* op = (BYTE*)dst;
     const BYTE* ip = (const BYTE*)src;
-    size_t errorCode;
-    size_t litSize;
 
-    /* check : minimum 2, for litSize, +1, for content */
-    if (srcSize <= 3) return ERROR(corruption_detected);
+    const size_t litSize = (MEM_readLE32(src) & 0x1FFFFF) >> 2;   /* no buffer issue : srcSize >= MIN_CBLOCK_SIZE */
+    const size_t litCSize = (MEM_readLE32(ip+2) & 0xFFFFFF) >> 5;   /* no buffer issue : srcSize >= MIN_CBLOCK_SIZE */
 
-    litSize = ip[1] + (ip[0]<<8);
-    litSize += ((ip[-3] >> 3) & 7) << 16;   // mmmmh....
+    if (litSize > *maxDstSizePtr) return ERROR(corruption_detected);
+    if (litCSize + 5 > srcSize) return ERROR(corruption_detected);
 
-    (void)ctx;
-    if (litSize > maxDstSize) return ERROR(dstSize_tooSmall);
-    errorCode = HUF_decompress(op, litSize, ip+2, srcSize-2);
-    if (HUF_isError(errorCode)) return ERROR(GENERIC);
-    return litSize;
+    if (HUF_isError(HUF_decompress(dst, litSize, ip+5, litCSize))) return ERROR(corruption_detected);
+
+    *maxDstSizePtr = litSize;
+    return litCSize + 5;
 }
 
 
+/** ZSTD_decodeLiteralsBlock
+    @return : nb of bytes read from src (< srcSize )*/
 size_t ZSTD_decodeLiteralsBlock(void* ctx,
                           const void* src, size_t srcSize)
 {
     ZSTD_Dctx* dctx = (ZSTD_Dctx*)ctx;
     const BYTE* const istart = (const BYTE* const)src;
-    const BYTE* ip = istart;
-    blockProperties_t litbp;
 
-    size_t litcSize = ZSTD_getcBlockSize(src, srcSize, &litbp);
-    if (ZSTD_isError(litcSize)) return litcSize;
-    if (litcSize > srcSize - ZSTD_blockHeaderSize) return ERROR(srcSize_wrong);
-    ip += ZSTD_blockHeaderSize;
+    /* any compressed block with literals segment must be at least this size */
+    if (srcSize < MIN_CBLOCK_SIZE) return ERROR(corruption_detected);
 
-    switch(litbp.blockType)
+    switch(*istart & 3)
     {
-    case bt_raw:
-        dctx->litPtr = ip;
-        dctx->litBufSize = srcSize - ZSTD_blockHeaderSize;
-        dctx->litSize = litcSize;
-        ip += litcSize;
-        break;
-    case bt_rle:
+    default:
+    case 0:
         {
-            size_t rleSize = litbp.origSize;
-            if (rleSize>BLOCKSIZE) return ERROR(dstSize_tooSmall);
-            memset(dctx->litBuffer, *ip, rleSize);
+            size_t nbLiterals = BLOCKSIZE;
+            const size_t readSize = ZSTD_decompressLiterals(dctx->litBuffer, &nbLiterals, src, srcSize);
             dctx->litPtr = dctx->litBuffer;
             dctx->litBufSize = BLOCKSIZE;
-            dctx->litSize = rleSize;
-            ip++;
-            break;
+            dctx->litSize = nbLiterals;
+            return readSize;   /* works if it's an error too */
+        }
+    case IS_RAW:
+        {
+            const size_t litSize = (MEM_readLE32(istart) & 0xFFFFFF) >> 2;   /* no buffer issue : srcSize >= MIN_CBLOCK_SIZE */
+            if (litSize > srcSize-3) return ERROR(corruption_detected);
+            dctx->litPtr = istart+3;
+            dctx->litBufSize = srcSize-3;
+            dctx->litSize = litSize;
+            return litSize+3;
         }
-    case bt_compressed:
+    case IS_RLE:
         {
-            size_t decodedLitSize = ZSTD_decompressLiterals(ctx, dctx->litBuffer, BLOCKSIZE, ip, litcSize);
-            if (ZSTD_isError(decodedLitSize)) return decodedLitSize;
+            const size_t litSize = (MEM_readLE32(istart) & 0xFFFFFF) >> 2;   /* no buffer issue : srcSize >= MIN_CBLOCK_SIZE */
+            if (litSize > BLOCKSIZE) return ERROR(corruption_detected);
+            memset(dctx->litBuffer, istart[3], litSize);
             dctx->litPtr = dctx->litBuffer;
             dctx->litBufSize = BLOCKSIZE;
-            dctx->litSize = decodedLitSize;
-            ip += litcSize;
-            break;
+            dctx->litSize = litSize;
+            return 4;
         }
-    default:
-        /* unknown blockType (impossible) */
-        return ERROR(GENERIC);
     }
-
-    return ip-istart;
 }
 
 
@@ -1449,15 +1437,14 @@ static size_t ZSTD_decompressBlock(
                             void* dst, size_t maxDstSize,
                       const void* src, size_t srcSize)
 {
-    /* blockType == blockCompressed; srcSize is trusted */
+    /* blockType == blockCompressed */
     const BYTE* ip = (const BYTE*)src;
-    size_t errorCode;
 
     /* Decode literals sub-block */
-    errorCode = ZSTD_decodeLiteralsBlock(ctx, src, srcSize);
-    if (ZSTD_isError(errorCode)) return errorCode;
-    ip += errorCode;
-    srcSize -= errorCode;
+    size_t litCSize = ZSTD_decodeLiteralsBlock(ctx, src, srcSize);
+    if (ZSTD_isError(litCSize)) return litCSize;
+    ip += litCSize;
+    srcSize -= litCSize;
 
     return ZSTD_decompressSequences(ctx, dst, maxDstSize, ip, srcSize);
 }
@@ -1472,7 +1459,6 @@ static size_t ZSTD_decompressDCtx(void* ctx, void* dst, size_t maxDstSize, const
     BYTE* const oend = ostart + maxDstSize;
     size_t remainingSize = srcSize;
     U32 magicNumber;
-    size_t errorCode=0;
     blockProperties_t blockProperties;
 
     /* Frame Header */
@@ -1487,20 +1473,21 @@ static size_t ZSTD_decompressDCtx(void* ctx, void* dst, size_t maxDstSize, const
     /* Loop on each block */
     while (1)
     {
-        size_t blockSize = ZSTD_getcBlockSize(ip, iend-ip, &blockProperties);
-        if (ZSTD_isError(blockSize)) return blockSize;
+        size_t decodedSize=0;
+        size_t cBlockSize = ZSTD_getcBlockSize(ip, iend-ip, &blockProperties);
+        if (ZSTD_isError(cBlockSize)) return cBlockSize;
 
         ip += ZSTD_blockHeaderSize;
         remainingSize -= ZSTD_blockHeaderSize;
-        if (blockSize > remainingSize) return ERROR(srcSize_wrong);
+        if (cBlockSize > remainingSize) return ERROR(srcSize_wrong);
 
         switch(blockProperties.blockType)
         {
         case bt_compressed:
-            errorCode = ZSTD_decompressBlock(ctx, op, oend-op, ip, blockSize);
+            decodedSize = ZSTD_decompressBlock(ctx, op, oend-op, ip, cBlockSize);
             break;
         case bt_raw :
-            errorCode = ZSTD_copyUncompressedBlock(op, oend-op, ip, blockSize);
+            decodedSize = ZSTD_copyUncompressedBlock(op, oend-op, ip, cBlockSize);
             break;
         case bt_rle :
             return ERROR(GENERIC);   /* not yet supported */
@@ -1510,14 +1497,14 @@ static size_t ZSTD_decompressDCtx(void* ctx, void* dst, size_t maxDstSize, const
             if (remainingSize) return ERROR(srcSize_wrong);
             break;
         default:
-            return ERROR(GENERIC);
+            return ERROR(GENERIC);   /* impossible */
         }
-        if (blockSize == 0) break;   /* bt_end */
+        if (cBlockSize == 0) break;   /* bt_end */
 
-        if (ZSTD_isError(errorCode)) return errorCode;
-        op += errorCode;
-        ip += blockSize;
-        remainingSize -= blockSize;
+        if (ZSTD_isError(decodedSize)) return decodedSize;
+        op += decodedSize;
+        ip += cBlockSize;
+        remainingSize -= cBlockSize;
     }
 
     return op-ostart;