]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
Simplify 32-bit long offsets decoding logic
authorNick Terrell <terrelln@fb.com>
Mon, 30 Jan 2023 19:15:15 +0000 (11:15 -0800)
committerNick Terrell <nickrterrell@gmail.com>
Mon, 30 Jan 2023 20:21:42 +0000 (12:21 -0800)
The previous code had an issue when `bitsConsumed == 32` it would read 0
bits for the `ofBits` read, which violates the precondition of
`BIT_readBitsFast()`. This can happen when the stream is corrupted.

Fix thie issue by always reading the maximum possible number of extra
bits. I've measured neutral decoding performance, likely because this
branch is unlikely, but this should be faster anyways. And if not, it is
only 32-bit decoding, so performance isn't as critical.

Credit to OSS-Fuzz

lib/common/zstd_internal.h
lib/decompress/zstd_decompress_block.c

index 8878aa1007384cbded68be2275cbd720ca118008..37836dc703f8b09de79d8dbaa69c39ec311e02c2 100644 (file)
@@ -113,6 +113,8 @@ typedef enum { set_basic, set_rle, set_compressed, set_repeat } symbolEncodingTy
 #define LLFSELog    9
 #define OffFSELog   8
 #define MaxFSELog  MAX(MAX(MLFSELog, LLFSELog), OffFSELog)
+#define MaxMLBits 16
+#define MaxLLBits 16
 
 #define ZSTD_MAX_HUF_HEADER_SIZE 128 /* header + <= 127 byte tree description */
 /* Each table cannot take more than #symbols * FSELog bits */
index 0d934043be64d503558a4ae41c677c67162b3095..a2a6eb3eddf510d5f58207e127001cd5b1750141 100644 (file)
@@ -1220,6 +1220,10 @@ ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets)
         U32 const llnbBits = llDInfo->nbBits;
         U32 const mlnbBits = mlDInfo->nbBits;
         U32 const ofnbBits = ofDInfo->nbBits;
+
+        assert(llBits <= MaxLLBits);
+        assert(mlBits <= MaxMLBits);
+        assert(ofBits <= MaxOff);
         /*
          * As gcc has better branch and block analyzers, sometimes it is only
          * valuable to mark likeliness for clang, it gives around 3-4% of
@@ -1235,19 +1239,16 @@ ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets)
     #endif
                 ZSTD_STATIC_ASSERT(ZSTD_lo_isLongOffset == 1);
                 ZSTD_STATIC_ASSERT(LONG_OFFSETS_MAX_EXTRA_BITS_32 == 5);
-                assert(ofBits <= MaxOff);
+                ZSTD_STATIC_ASSERT(STREAM_ACCUMULATOR_MIN_32 > LONG_OFFSETS_MAX_EXTRA_BITS_32);
+                ZSTD_STATIC_ASSERT(STREAM_ACCUMULATOR_MIN_32 - LONG_OFFSETS_MAX_EXTRA_BITS_32 >= MaxMLBits);
                 if (MEM_32bits() && longOffsets && (ofBits >= STREAM_ACCUMULATOR_MIN_32)) {
-                    U32 const extraBits = ofBits - MIN(ofBits, 32 - seqState->DStream.bitsConsumed);
+                    /* Always read extra bits, this keeps the logic simple,
+                     * avoids branches, and avoids accidentally reading 0 bits.
+                     */
+                    U32 const extraBits = LONG_OFFSETS_MAX_EXTRA_BITS_32;
                     offset = ofBase + (BIT_readBitsFast(&seqState->DStream, ofBits - extraBits) << extraBits);
                     BIT_reloadDStream(&seqState->DStream);
-                    if (extraBits) offset += BIT_readBitsFast(&seqState->DStream, extraBits);
-#if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE)
-                    /* This assert is only valid when decoding valid sequences.
-                     * It cal fail when we consume more bits than are in the bitstream,
-                     * which can happen on corruption.
-                     */
-                    assert(extraBits <= LONG_OFFSETS_MAX_EXTRA_BITS_32);   /* to avoid another reload */
-#endif
+                    offset += BIT_readBitsFast(&seqState->DStream, extraBits);
                 } else {
                     offset = ofBase + BIT_readBitsFast(&seqState->DStream, ofBits/*>0*/);   /* <=  (ZSTD_WINDOWLOG_MAX-1) bits */
                     if (MEM_32bits()) BIT_reloadDStream(&seqState->DStream);