]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
Streaming decompression can detect incorrect header ID sooner 3175/head
authorYann Collet <cyan@fb.com>
Wed, 22 Jun 2022 01:14:11 +0000 (18:14 -0700)
committerYann Collet <cyan@fb.com>
Wed, 22 Jun 2022 06:09:03 +0000 (23:09 -0700)
Streaming decompression used to wait for a minimum of 5 bytes before attempting decoding.
This meant that, in the case that only a few bytes (<5) were provided,
and assuming these bytes are incorrect,
there would be no error reported.
The streaming API would simply request more data, waiting for at least 5 bytes.

This PR makes it possible to detect incorrect Frame IDs as soon as the first byte is provided.

Fix #3169

lib/decompress/zstd_decompress.c
tests/zstreamtest.c

index 85f4d2202e9f14f6799c1757268caf0fdeeeb182..5bd412df436664ce43ab716ad1cbc1793fe25015 100644 (file)
  *************************************/
 
 #define DDICT_HASHSET_MAX_LOAD_FACTOR_COUNT_MULT 4
-#define DDICT_HASHSET_MAX_LOAD_FACTOR_SIZE_MULT 3   /* These two constants represent SIZE_MULT/COUNT_MULT load factor without using a float.
-                                                     * Currently, that means a 0.75 load factor.
-                                                     * So, if count * COUNT_MULT / size * SIZE_MULT != 0, then we've exceeded
-                                                     * the load factor of the ddict hash set.
-                                                     */
+#define DDICT_HASHSET_MAX_LOAD_FACTOR_SIZE_MULT 3  /* These two constants represent SIZE_MULT/COUNT_MULT load factor without using a float.
+                                                    * Currently, that means a 0.75 load factor.
+                                                    * So, if count * COUNT_MULT / size * SIZE_MULT != 0, then we've exceeded
+                                                    * the load factor of the ddict hash set.
+                                                    */
 
 #define DDICT_HASHSET_TABLE_BASE_SIZE 64
 #define DDICT_HASHSET_RESIZE_FACTOR 2
@@ -439,16 +439,40 @@ size_t ZSTD_frameHeaderSize(const void* src, size_t srcSize)
  *  note : only works for formats ZSTD_f_zstd1 and ZSTD_f_zstd1_magicless
  * @return : 0, `zfhPtr` is correctly filled,
  *          >0, `srcSize` is too small, value is wanted `srcSize` amount,
- *           or an error code, which can be tested using ZSTD_isError() */
+**           or an error code, which can be tested using ZSTD_isError() */
 size_t ZSTD_getFrameHeader_advanced(ZSTD_frameHeader* zfhPtr, const void* src, size_t srcSize, ZSTD_format_e format)
 {
     const BYTE* ip = (const BYTE*)src;
     size_t const minInputSize = ZSTD_startingInputLength(format);
 
-    ZSTD_memset(zfhPtr, 0, sizeof(*zfhPtr));   /* not strictly necessary, but static analyzer do not understand that zfhPtr is only going to be read only if return value is zero, since they are 2 different signals */
-    if (srcSize < minInputSize) return minInputSize;
-    RETURN_ERROR_IF(src==NULL, GENERIC, "invalid parameter");
+    DEBUGLOG(5, "ZSTD_getFrameHeader_advanced: minInputSize = %zu, srcSize = %zu", minInputSize, srcSize);
+
+    if (srcSize > 0) {
+        /* note : technically could be considered an assert(), since it's an invalid entry */
+        RETURN_ERROR_IF(src==NULL, GENERIC, "invalid parameter : src==NULL, but srcSize>0");
+    }
+    if (srcSize < minInputSize) {
+        if (srcSize > 0 && format != ZSTD_f_zstd1_magicless) {
+            /* when receiving less than @minInputSize bytes,
+             * control these bytes at least correspond to a supported magic number
+             * in order to error out early if they don't.
+            **/
+            size_t const toCopy = MIN(4, srcSize);
+            unsigned char hbuf[4]; MEM_writeLE32(hbuf, ZSTD_MAGICNUMBER);
+            assert(src != NULL);
+            ZSTD_memcpy(hbuf, src, toCopy);
+            if ( MEM_readLE32(hbuf) != ZSTD_MAGICNUMBER ) {
+                /* not a zstd frame : let's check if it's a skippable frame */
+                MEM_writeLE32(hbuf, ZSTD_MAGIC_SKIPPABLE_START);
+                ZSTD_memcpy(hbuf, src, toCopy);
+                if ((MEM_readLE32(hbuf) & ZSTD_MAGIC_SKIPPABLE_MASK) != ZSTD_MAGIC_SKIPPABLE_START) {
+                    RETURN_ERROR(prefix_unknown,
+                                "first bytes don't correspond to any supported magic number");
+        }   }   }
+        return minInputSize;
+    }
 
+    ZSTD_memset(zfhPtr, 0, sizeof(*zfhPtr));   /* not strictly necessary, but static analyzers may not understand that zfhPtr will be read only if return value is zero, since they are 2 different signals */
     if ( (format != ZSTD_f_zstd1_magicless)
       && (MEM_readLE32(src) != ZSTD_MAGICNUMBER) ) {
         if ((MEM_readLE32(src) & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) {
@@ -1981,7 +2005,6 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB
                 if (zds->refMultipleDDicts && zds->ddictSet) {
                     ZSTD_DCtx_selectFrameDDict(zds);
                 }
-                DEBUGLOG(5, "header size : %u", (U32)hSize);
                 if (ZSTD_isError(hSize)) {
 #if defined(ZSTD_LEGACY_SUPPORT) && (ZSTD_LEGACY_SUPPORT>=1)
                     U32 const legacyVersion = ZSTD_isLegacy(istart, iend-istart);
@@ -2013,6 +2036,11 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB
                             zds->lhSize += remainingInput;
                         }
                         input->pos = input->size;
+                        /* check first few bytes */
+                        FORWARD_IF_ERROR(
+                            ZSTD_getFrameHeader_advanced(&zds->fParams, zds->headerBuffer, zds->lhSize, zds->format),
+                            "First few bytes detected incorrect" );
+                        /* return hint input size */
                         return (MAX((size_t)ZSTD_FRAMEHEADERSIZE_MIN(zds->format), hSize) - zds->lhSize) + ZSTD_blockHeaderSize;   /* remaining header bytes + next block header */
                     }
                     assert(ip != NULL);
index 20a05a75c9f049715f574a3f2029789a4dd2ecf1..3fcdd5399a4ee071ca198b535614527e752a15a4 100644 (file)
@@ -424,6 +424,15 @@ static int basicUnitTests(U32 seed, double compressibility)
     }   }
     DISPLAYLEVEL(3, "OK \n");
 
+    /* check decompression fails early if first bytes are wrong */
+    DISPLAYLEVEL(3, "test%3i : early decompression error if first bytes are incorrect : ", testNb++);
+    {   const char buf[3] = { 0 };  /* too short, not enough to start decoding header */
+        ZSTD_inBuffer inb = { buf, sizeof(buf), 0 };
+        size_t const remaining = ZSTD_decompressStream(zd, &outBuff, &inb);
+        if (!ZSTD_isError(remaining)) goto _output_error; /* should have errored out immediately (note: this does not test the exact error code) */
+    }
+    DISPLAYLEVEL(3, "OK \n");
+
     /* context size functions */
     DISPLAYLEVEL(3, "test%3i : estimate DStream size : ", testNb++);
     {   ZSTD_frameHeader fhi;