]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
Convert Checks in zstd_decompress_block.c to RETURN_ERROR_IF
authorW. Felix Handte <w@felixhandte.com>
Thu, 6 Dec 2018 01:17:11 +0000 (17:17 -0800)
committerW. Felix Handte <w@felixhandte.com>
Mon, 28 Jan 2019 16:56:39 +0000 (11:56 -0500)
lib/decompress/zstd_decompress_block.c

index 32baad9fbb5c48fef64f7960e4af041c27b60a72..b1932e5deded8fb1a5255f84614a3afb0ffaee1c 100644 (file)
@@ -56,14 +56,15 @@ static void ZSTD_copy4(void* dst, const void* src) { memcpy(dst, src, 4); }
 size_t ZSTD_getcBlockSize(const void* src, size_t srcSize,
                           blockProperties_t* bpPtr)
 {
-    if (srcSize < ZSTD_blockHeaderSize) return ERROR(srcSize_wrong);
+    RETURN_ERROR_IF(srcSize < ZSTD_blockHeaderSize, srcSize_wrong);
+
     {   U32 const cBlockHeader = MEM_readLE24(src);
         U32 const cSize = cBlockHeader >> 3;
         bpPtr->lastBlock = cBlockHeader & 1;
         bpPtr->blockType = (blockType_e)((cBlockHeader >> 1) & 3);
         bpPtr->origSize = cSize;   /* only useful for RLE */
         if (bpPtr->blockType == bt_rle) return 1;
-        if (bpPtr->blockType == bt_reserved) return ERROR(corruption_detected);
+        RETURN_ERROR_IF(bpPtr->blockType == bt_reserved, corruption_detected);
         return cSize;
     }
 }
@@ -78,7 +79,7 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx,
 size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx,
                           const void* src, size_t srcSize)   /* note : srcSize < BLOCKSIZE */
 {
-    if (srcSize < MIN_CBLOCK_SIZE) return ERROR(corruption_detected);
+    RETURN_ERROR_IF(srcSize < MIN_CBLOCK_SIZE, corruption_detected);
 
     {   const BYTE* const istart = (const BYTE*) src;
         symbolEncodingType_e const litEncType = (symbolEncodingType_e)(istart[0] & 3);
@@ -86,11 +87,11 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx,
         switch(litEncType)
         {
         case set_repeat:
-            if (dctx->litEntropy==0) return ERROR(dictionary_corrupted);
+            RETURN_ERROR_IF(dctx->litEntropy==0, dictionary_corrupted);
             /* fall-through */
 
         case set_compressed:
-            if (srcSize < 5) return ERROR(corruption_detected);   /* srcSize >= MIN_CBLOCK_SIZE == 3; here we need up to 5 for case 3 */
+            RETURN_ERROR_IF(srcSize < 5, corruption_detected, "srcSize >= MIN_CBLOCK_SIZE == 3; here we need up to 5 for case 3");
             {   size_t lhSize, litSize, litCSize;
                 U32 singleStream=0;
                 U32 const lhlCode = (istart[0] >> 2) & 3;
@@ -118,8 +119,8 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx,
                     litCSize = (lhc >> 22) + (istart[4] << 10);
                     break;
                 }
-                if (litSize > ZSTD_BLOCKSIZE_MAX) return ERROR(corruption_detected);
-                if (litCSize + lhSize > srcSize) return ERROR(corruption_detected);
+                RETURN_ERROR_IF(litSize > ZSTD_BLOCKSIZE_MAX, corruption_detected);
+                RETURN_ERROR_IF(litCSize + lhSize > srcSize, corruption_detected);
 
                 /* prefetch huffman table if cold */
                 if (dctx->ddictIsCold && (litSize > 768 /* heuristic */)) {
@@ -157,7 +158,7 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx,
                     }
                 }
 
-                if (HUF_isError(hufSuccess)) return ERROR(corruption_detected);
+                RETURN_ERROR_IF(HUF_isError(hufSuccess), corruption_detected);
 
                 dctx->litPtr = dctx->litBuffer;
                 dctx->litSize = litSize;
@@ -187,7 +188,7 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx,
                 }
 
                 if (lhSize+litSize+WILDCOPY_OVERLENGTH > srcSize) {  /* risk reading beyond src buffer with wildcopy */
-                    if (litSize+lhSize > srcSize) return ERROR(corruption_detected);
+                    RETURN_ERROR_IF(litSize+lhSize > srcSize, corruption_detected);
                     memcpy(dctx->litBuffer, istart+lhSize, litSize);
                     dctx->litPtr = dctx->litBuffer;
                     dctx->litSize = litSize;
@@ -216,17 +217,17 @@ size_t ZSTD_decodeLiteralsBlock(ZSTD_DCtx* dctx,
                 case 3:
                     lhSize = 3;
                     litSize = MEM_readLE24(istart) >> 4;
-                    if (srcSize<4) return ERROR(corruption_detected);   /* srcSize >= MIN_CBLOCK_SIZE == 3; here we need lhSize+1 = 4 */
+                    RETURN_ERROR_IF(srcSize<4, corruption_detected, "srcSize >= MIN_CBLOCK_SIZE == 3; here we need lhSize+1 = 4");
                     break;
                 }
-                if (litSize > ZSTD_BLOCKSIZE_MAX) return ERROR(corruption_detected);
+                RETURN_ERROR_IF(litSize > ZSTD_BLOCKSIZE_MAX, corruption_detected);
                 memset(dctx->litBuffer, istart[lhSize], litSize + WILDCOPY_OVERLENGTH);
                 dctx->litPtr = dctx->litBuffer;
                 dctx->litSize = litSize;
                 return lhSize+1;
             }
         default:
-            return ERROR(corruption_detected);   /* impossible */
+            RETURN_ERROR_IF(1, corruption_detected, "impossible");
         }
     }
 }
@@ -436,8 +437,8 @@ static size_t ZSTD_buildSeqTable(ZSTD_seqSymbol* DTableSpace, const ZSTD_seqSymb
     switch(type)
     {
     case set_rle :
-        if (!srcSize) return ERROR(srcSize_wrong);
-        if ( (*(const BYTE*)src) > max) return ERROR(corruption_detected);
+        RETURN_ERROR_IF(!srcSize, srcSize_wrong);
+        RETURN_ERROR_IF((*(const BYTE*)src) > max, corruption_detected);
         {   U32 const symbol = *(const BYTE*)src;
             U32 const baseline = baseValue[symbol];
             U32 const nbBits = nbAdditionalBits[symbol];
@@ -449,7 +450,7 @@ static size_t ZSTD_buildSeqTable(ZSTD_seqSymbol* DTableSpace, const ZSTD_seqSymb
         *DTablePtr = defaultTable;
         return 0;
     case set_repeat:
-        if (!flagRepeatTable) return ERROR(corruption_detected);
+        RETURN_ERROR_IF(!flagRepeatTable, corruption_detected);
         /* prefetch FSE table if used */
         if (ddictIsCold && (nbSeq > 24 /* heuristic */)) {
             const void* const pStart = *DTablePtr;
@@ -461,15 +462,15 @@ static size_t ZSTD_buildSeqTable(ZSTD_seqSymbol* DTableSpace, const ZSTD_seqSymb
         {   unsigned tableLog;
             S16 norm[MaxSeq+1];
             size_t const headerSize = FSE_readNCount(norm, &max, &tableLog, src, srcSize);
-            if (FSE_isError(headerSize)) return ERROR(corruption_detected);
-            if (tableLog > maxLog) return ERROR(corruption_detected);
+            RETURN_ERROR_IF(FSE_isError(headerSize), corruption_detected);
+            RETURN_ERROR_IF(tableLog > maxLog, corruption_detected);
             ZSTD_buildFSETable(DTableSpace, norm, max, baseValue, nbAdditionalBits, tableLog);
             *DTablePtr = DTableSpace;
             return headerSize;
         }
-    default :   /* impossible */
+    default :
         assert(0);
-        return ERROR(GENERIC);
+        RETURN_ERROR_IF(1, GENERIC, "impossible");
     }
 }
 
@@ -483,28 +484,28 @@ size_t ZSTD_decodeSeqHeaders(ZSTD_DCtx* dctx, int* nbSeqPtr,
     DEBUGLOG(5, "ZSTD_decodeSeqHeaders");
 
     /* check */
-    if (srcSize < MIN_SEQUENCES_SIZE) return ERROR(srcSize_wrong);
+    RETURN_ERROR_IF(srcSize < MIN_SEQUENCES_SIZE, srcSize_wrong);
 
     /* SeqHead */
     nbSeq = *ip++;
     if (!nbSeq) {
         *nbSeqPtr=0;
-        if (srcSize != 1) return ERROR(srcSize_wrong);
+        RETURN_ERROR_IF(srcSize != 1, srcSize_wrong);
         return 1;
     }
     if (nbSeq > 0x7F) {
         if (nbSeq == 0xFF) {
-            if (ip+2 > iend) return ERROR(srcSize_wrong);
+            RETURN_ERROR_IF(ip+2 > iend, srcSize_wrong);
             nbSeq = MEM_readLE16(ip) + LONGNBSEQ, ip+=2;
         } else {
-            if (ip >= iend) return ERROR(srcSize_wrong);
+            RETURN_ERROR_IF(ip >= iend, srcSize_wrong);
             nbSeq = ((nbSeq-0x80)<<8) + *ip++;
         }
     }
     *nbSeqPtr = nbSeq;
 
     /* FSE table descriptors */
-    if (ip+4 > iend) return ERROR(srcSize_wrong); /* minimum possible size */
+    RETURN_ERROR_IF(ip+4 > iend, srcSize_wrong); /* minimum possible size */
     {   symbolEncodingType_e const LLtype = (symbolEncodingType_e)(*ip >> 6);
         symbolEncodingType_e const OFtype = (symbolEncodingType_e)((*ip >> 4) & 3);
         symbolEncodingType_e const MLtype = (symbolEncodingType_e)((*ip >> 2) & 3);
@@ -517,7 +518,7 @@ size_t ZSTD_decodeSeqHeaders(ZSTD_DCtx* dctx, int* nbSeqPtr,
                                                       LL_base, LL_bits,
                                                       LL_defaultDTable, dctx->fseEntropy,
                                                       dctx->ddictIsCold, nbSeq);
-            if (ZSTD_isError(llhSize)) return ERROR(corruption_detected);
+            RETURN_ERROR_IF(ZSTD_isError(llhSize), corruption_detected);
             ip += llhSize;
         }
 
@@ -527,7 +528,7 @@ size_t ZSTD_decodeSeqHeaders(ZSTD_DCtx* dctx, int* nbSeqPtr,
                                                       OF_base, OF_bits,
                                                       OF_defaultDTable, dctx->fseEntropy,
                                                       dctx->ddictIsCold, nbSeq);
-            if (ZSTD_isError(ofhSize)) return ERROR(corruption_detected);
+            RETURN_ERROR_IF(ZSTD_isError(ofhSize), corruption_detected);
             ip += ofhSize;
         }
 
@@ -537,7 +538,7 @@ size_t ZSTD_decodeSeqHeaders(ZSTD_DCtx* dctx, int* nbSeqPtr,
                                                       ML_base, ML_bits,
                                                       ML_defaultDTable, dctx->fseEntropy,
                                                       dctx->ddictIsCold, nbSeq);
-            if (ZSTD_isError(mlhSize)) return ERROR(corruption_detected);
+            RETURN_ERROR_IF(ZSTD_isError(mlhSize), corruption_detected);
             ip += mlhSize;
         }
     }
@@ -590,8 +591,8 @@ size_t ZSTD_execSequenceLast7(BYTE* op,
     const BYTE* match = oLitEnd - sequence.offset;
 
     /* check */
-    if (oMatchEnd>oend) return ERROR(dstSize_tooSmall);   /* last match must fit within dstBuffer */
-    if (iLitEnd > litLimit) return ERROR(corruption_detected);   /* try to read beyond literal buffer */
+    RETURN_ERROR_IF(oMatchEnd>oend, dstSize_tooSmall, "last match must fit within dstBuffer");
+    RETURN_ERROR_IF(iLitEnd > litLimit, corruption_detected, "try to read beyond literal buffer");
 
     /* copy literals */
     while (op < oLitEnd) *op++ = *(*litPtr)++;
@@ -599,7 +600,7 @@ size_t ZSTD_execSequenceLast7(BYTE* op,
     /* copy Match */
     if (sequence.offset > (size_t)(oLitEnd - base)) {
         /* offset beyond prefix */
-        if (sequence.offset > (size_t)(oLitEnd - vBase)) return ERROR(corruption_detected);
+        RETURN_ERROR_IF(sequence.offset > (size_t)(oLitEnd - vBase),corruption_detected);
         match = dictEnd - (base-match);
         if (match + sequence.matchLength <= dictEnd) {
             memmove(oLitEnd, match, sequence.matchLength);
@@ -631,8 +632,8 @@ size_t ZSTD_execSequence(BYTE* op,
     const BYTE* match = oLitEnd - sequence.offset;
 
     /* check */
-    if (oMatchEnd>oend) return ERROR(dstSize_tooSmall); /* last match must start at a minimum distance of WILDCOPY_OVERLENGTH from oend */
-    if (iLitEnd > litLimit) return ERROR(corruption_detected);   /* over-read beyond lit buffer */
+    RETURN_ERROR_IF(oMatchEnd>oend, dstSize_tooSmall, "last match must start at a minimum distance of WILDCOPY_OVERLENGTH from oend");
+    RETURN_ERROR_IF(iLitEnd > litLimit, corruption_detected, "over-read beyond lit buffer");
     if (oLitEnd>oend_w) return ZSTD_execSequenceLast7(op, oend, sequence, litPtr, litLimit, prefixStart, virtualStart, dictEnd);
 
     /* copy Literals */
@@ -645,8 +646,7 @@ size_t ZSTD_execSequence(BYTE* op,
     /* copy Match */
     if (sequence.offset > (size_t)(oLitEnd - prefixStart)) {
         /* offset beyond prefix -> go into extDict */
-        if (sequence.offset > (size_t)(oLitEnd - virtualStart))
-            return ERROR(corruption_detected);
+        RETURN_ERROR_IF(sequence.offset > (size_t)(oLitEnd - virtualStart), corruption_detected);
         match = dictEnd + (match - prefixStart);
         if (match + sequence.matchLength <= dictEnd) {
             memmove(oLitEnd, match, sequence.matchLength);
@@ -712,8 +712,8 @@ size_t ZSTD_execSequenceLong(BYTE* op,
     const BYTE* match = sequence.match;
 
     /* check */
-    if (oMatchEnd > oend) return ERROR(dstSize_tooSmall); /* last match must start at a minimum distance of WILDCOPY_OVERLENGTH from oend */
-    if (iLitEnd > litLimit) return ERROR(corruption_detected);   /* over-read beyond lit buffer */
+    RETURN_ERROR_IF(oMatchEnd > oend, dstSize_tooSmall, "last match must start at a minimum distance of WILDCOPY_OVERLENGTH from oend");
+    RETURN_ERROR_IF(iLitEnd > litLimit, corruption_detected, "over-read beyond lit buffer");
     if (oLitEnd > oend_w) return ZSTD_execSequenceLast7(op, oend, sequence, litPtr, litLimit, prefixStart, dictStart, dictEnd);
 
     /* copy Literals */
@@ -726,7 +726,7 @@ size_t ZSTD_execSequenceLong(BYTE* op,
     /* copy Match */
     if (sequence.offset > (size_t)(oLitEnd - prefixStart)) {
         /* offset beyond prefix */
-        if (sequence.offset > (size_t)(oLitEnd - dictStart)) return ERROR(corruption_detected);
+        RETURN_ERROR_IF(sequence.offset > (size_t)(oLitEnd - dictStart), corruption_detected);
         if (match + sequence.matchLength <= dictEnd) {
             memmove(oLitEnd, match, sequence.matchLength);
             return sequenceLength;
@@ -927,14 +927,14 @@ ZSTD_decompressSequences_body( ZSTD_DCtx* dctx,
 
         /* check if reached exact end */
         DEBUGLOG(5, "ZSTD_decompressSequences_body: after decode loop, remaining nbSeq : %i", nbSeq);
-        if (nbSeq) return ERROR(corruption_detected);
+        RETURN_ERROR_IF(nbSeq, corruption_detected);
         /* save reps for next block */
         { U32 i; for (i=0; i<ZSTD_REP_NUM; i++) dctx->entropy.rep[i] = (U32)(seqState.prevOffset[i]); }
     }
 
     /* last literal segment */
     {   size_t const lastLLSize = litEnd - litPtr;
-        if (lastLLSize > (size_t)(oend-op)) return ERROR(dstSize_tooSmall);
+        RETURN_ERROR_IF(lastLLSize > (size_t)(oend-op), dstSize_tooSmall);
         memcpy(op, litPtr, lastLLSize);
         op += lastLLSize;
     }
@@ -1076,7 +1076,7 @@ ZSTD_decompressSequencesLong_body(
             sequences[seqNb] = ZSTD_decodeSequenceLong(&seqState, isLongOffset);
             PREFETCH_L1(sequences[seqNb].match); PREFETCH_L1(sequences[seqNb].match + sequences[seqNb].matchLength - 1); /* note : it's safe to invoke PREFETCH() on any memory address, including invalid ones */
         }
-        if (seqNb<seqAdvance) return ERROR(corruption_detected);
+        RETURN_ERROR_IF(seqNb<seqAdvance, corruption_detected);
 
         /* decode and decompress */
         for ( ; (BIT_reloadDStream(&(seqState.DStream)) <= BIT_DStream_completed) && (seqNb<nbSeq) ; seqNb++) {
@@ -1087,7 +1087,7 @@ ZSTD_decompressSequencesLong_body(
             sequences[seqNb & STORED_SEQS_MASK] = sequence;
             op += oneSeqSize;
         }
-        if (seqNb<nbSeq) return ERROR(corruption_detected);
+        RETURN_ERROR_IF(seqNb<nbSeq, corruption_detected);
 
         /* finish queue */
         seqNb -= seqAdvance;
@@ -1103,7 +1103,7 @@ ZSTD_decompressSequencesLong_body(
 
     /* last literal segment */
     {   size_t const lastLLSize = litEnd - litPtr;
-        if (lastLLSize > (size_t)(oend-op)) return ERROR(dstSize_tooSmall);
+        RETURN_ERROR_IF(lastLLSize > (size_t)(oend-op), dstSize_tooSmall);
         memcpy(op, litPtr, lastLLSize);
         op += lastLLSize;
     }
@@ -1240,7 +1240,7 @@ ZSTD_decompressBlock_internal(ZSTD_DCtx* dctx,
     ZSTD_longOffset_e const isLongOffset = (ZSTD_longOffset_e)(MEM_32bits() && (!frame || (dctx->fParams.windowSize > (1ULL << STREAM_ACCUMULATOR_MIN))));
     DEBUGLOG(5, "ZSTD_decompressBlock_internal (size : %u)", (U32)srcSize);
 
-    if (srcSize >= ZSTD_BLOCKSIZE_MAX) return ERROR(srcSize_wrong);
+    RETURN_ERROR_IF(srcSize >= ZSTD_BLOCKSIZE_MAX, srcSize_wrong);
 
     /* Decode literals section */
     {   size_t const litCSize = ZSTD_decodeLiteralsBlock(dctx, src, srcSize);