]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
Added ZSTD_get_decompressed_size
authorSean Purcell <me@seanp.xyz>
Mon, 30 Jan 2017 22:42:21 +0000 (14:42 -0800)
committerSean Purcell <me@seanp.xyz>
Mon, 30 Jan 2017 22:56:29 +0000 (14:56 -0800)
Since this implementation handles multiple concatenated frames,
to determine decompressed size we must traverse the entire input,
checking each frame's frame_content_size field

contrib/educational_decoder/harness.c
contrib/educational_decoder/zstd_decompress.c
contrib/educational_decoder/zstd_decompress.h

index 6f4765d9d76c3e2756de04b449571b3ddf05c51f..c44100fff7a1c49df90c9086b8d1673072159b3c 100644 (file)
@@ -5,8 +5,8 @@
 
 typedef unsigned char u8;
 
-// There's no good way to determine output size without decompressing
-// For this example assume we'll never decompress at a ratio larger than 16
+// If the data doesn't have decompressed size with it, fallback on assuming the
+// compression ratio is at most 16
 #define MAX_COMPRESSION_RATIO (16)
 
 u8 *input;
@@ -14,80 +14,89 @@ u8 *output;
 u8 *dict;
 
 size_t read_file(const char *path, u8 **ptr) {
-  FILE *f = fopen(path, "rb");
-  if (!f) {
-    fprintf(stderr, "failed to open file %s\n", path);
-    exit(1);
-  }
-
-  fseek(f, 0L, SEEK_END);
-  size_t size = ftell(f);
-  rewind(f);
-
-  *ptr = malloc(size);
-  if (!ptr) {
-    fprintf(stderr, "failed to allocate memory to hold %s\n", path);
-    exit(1);
-  }
-
-  size_t pos = 0;
-  while (!feof(f)) {
-    size_t read = fread(&(*ptr)[pos], 1, size, f);
-    if (ferror(f)) {
-      fprintf(stderr, "error while reading file %s\n", path);
-      exit(1);
+    FILE *f = fopen(path, "rb");
+    if (!f) {
+        fprintf(stderr, "failed to open file %s\n", path);
+        exit(1);
     }
-    pos += read;
-  }
 
-  fclose(f);
+    fseek(f, 0L, SEEK_END);
+    size_t size = ftell(f);
+    rewind(f);
 
-  return pos;
+    *ptr = malloc(size);
+    if (!ptr) {
+        fprintf(stderr, "failed to allocate memory to hold %s\n", path);
+        exit(1);
+    }
+
+    size_t pos = 0;
+    while (!feof(f)) {
+        size_t read = fread(&(*ptr)[pos], 1, size, f);
+        if (ferror(f)) {
+            fprintf(stderr, "error while reading file %s\n", path);
+            exit(1);
+        }
+        pos += read;
+    }
+
+    fclose(f);
+
+    return pos;
 }
 
 void write_file(const char *path, const u8 *ptr, size_t size) {
-  FILE *f = fopen(path, "wb");
-
-  size_t written = 0;
-  while (written < size) {
-    written += fwrite(&ptr[written], 1, size, f);
-    if (ferror(f)) {
-      fprintf(stderr, "error while writing file %s\n", path);
-      exit(1);
+    FILE *f = fopen(path, "wb");
+
+    size_t written = 0;
+    while (written < size) {
+        written += fwrite(&ptr[written], 1, size, f);
+        if (ferror(f)) {
+            fprintf(stderr, "error while writing file %s\n", path);
+            exit(1);
+        }
     }
-  }
 
-  fclose(f);
+    fclose(f);
 }
 
 int main(int argc, char **argv) {
-  if (argc < 3) {
-    fprintf(stderr, "usage: %s <file.zst> <out_path> [dictionary]\n", argv[0]);
-
-    return 1;
-  }
-
-  size_t input_size = read_file(argv[1], &input);
-  size_t dict_size = 0;
-  if (argc >= 4) {
-    dict_size = read_file(argv[3], &dict);
-  }
-
-  output = malloc(MAX_COMPRESSION_RATIO * input_size);
-  if (!output) {
-    fprintf(stderr, "failed to allocate memory\n");
-    return 1;
-  }
-
-  size_t decompressed =
-      ZSTD_decompress_with_dict(output, input_size * MAX_COMPRESSION_RATIO,
-                                input, input_size, dict, dict_size);
-
-  write_file(argv[2], output, decompressed);
-
-  free(input);
-  free(output);
-  free(dict);
-  input = output = dict = NULL;
+    if (argc < 3) {
+        fprintf(stderr, "usage: %s <file.zst> <out_path> [dictionary]\n",
+                argv[0]);
+
+        return 1;
+    }
+
+    size_t input_size = read_file(argv[1], &input);
+    size_t dict_size = 0;
+    if (argc >= 4) {
+        dict_size = read_file(argv[3], &dict);
+    }
+
+    size_t decompressed_size = ZSTD_get_decompressed_size(input, input_size);
+    if (decompressed_size == -1) {
+        decompressed_size = MAX_COMPRESSION_RATIO * input_size;
+        fprintf(stderr, "WARNING: Compressed data does contain decompressed "
+                        "size, going to assume the compression ratio is at "
+                        "most %d (decompressed size of at most %lld\n",
+                MAX_COMPRESSION_RATIO, decompressed_size);
+    }
+    output = malloc(decompressed_size);
+    if (!output) {
+        fprintf(stderr, "failed to allocate memory\n");
+        return 1;
+    }
+
+    size_t decompressed =
+        ZSTD_decompress_with_dict(output, input_size * MAX_COMPRESSION_RATIO,
+                                  input, input_size, dict, dict_size);
+
+    write_file(argv[2], output, decompressed);
+
+    free(input);
+    free(output);
+    free(dict);
+    input = output = dict = NULL;
 }
 
index 8dc159008324fbd362d4895bd97cfbed75fa21b0..7b04c4b294e91594cf4212660b7e51d4e5eb2bfe 100644 (file)
@@ -16,6 +16,10 @@ size_t ZSTD_decompress_with_dict(void *dst, size_t dst_len, const void *src,
                                  size_t src_len, const void *dict,
                                  size_t dict_len);
 
+/// Get the decompressed size of an input stream so memory can be allocated in
+/// advance
+size_t ZSTD_get_decompressed_size(const void *src, size_t src_len);
+
 /******* UTILITY MACROS AND TYPES *********************************************/
 #define MAX_WINDOW_SIZE ((size_t)512 << 20)
 // Max block size decompressed size is 128 KB and literal blocks must be smaller
@@ -232,11 +236,31 @@ typedef struct {
     size_t src_len;
 } io_streams_t;
 
-/// The context needed to decode blocks in a frame
+/// A small structure that can be reused in various places that need to access
+/// frame header information
 typedef struct {
+    // The size of window that we need to be able to contiguously store for
+    // references
     size_t window_size;
+    // The total output size of this compressed frame
     size_t frame_content_size;
 
+    // The dictionary id if this frame uses one
+    u32 dictionary_id;
+
+    // Whether or not the content of this frame has a checksum
+    int content_checksum_flag;
+    // Whether or not the output for this frame is in a single segment
+    int single_segment_flag;
+
+    // The size in bytes of this header
+    int header_size;
+} frame_header_t;
+
+/// The context needed to decode blocks in a frame
+typedef struct {
+    frame_header_t header;
+
     // The total amount of data available for backreferences, to determine if an
     // offset too large to be correct
     size_t current_total_output;
@@ -255,12 +279,6 @@ typedef struct {
     // The last 3 offsets for the special "repeat offsets".  Array size is 4 so
     // that previous_offsets[1] corresponds to the most recent offset
     u64 previous_offsets[4];
-
-    // The dictionary id for this frame if one exists
-    u32 dictionary_id;
-
-    int single_segment_flag;
-    int content_checksum_flag;
 } frame_context_t;
 
 /// The decoded contents of a dictionary so that it doesn't have to be repeated
@@ -364,10 +382,11 @@ size_t ZSTD_decompress_with_dict(void *dst, size_t dst_len, const void *src,
 /******* FRAME DECODING ******************************************************/
 
 static void decode_data_frame(io_streams_t *streams, dictionary_t *dict);
-static void init_frame_context(frame_context_t *context);
-static void free_frame_context(frame_context_t *context);
-static void parse_frame_header(io_streams_t *streams, frame_context_t *ctx,
+static void init_frame_context(io_streams_t *streams, frame_context_t *context,
                                dictionary_t *dict);
+static void free_frame_context(frame_context_t *context);
+static void parse_frame_header(frame_header_t *header, const u8 *src,
+                               size_t src_len);
 static void frame_context_apply_dict(frame_context_t *ctx, dictionary_t *dict);
 
 static void decompress_data(io_streams_t *streams, frame_context_t *ctx);
@@ -411,12 +430,10 @@ static void decode_data_frame(io_streams_t *streams, dictionary_t *dict) {
     frame_context_t ctx;
 
     // Initialize the context that needs to be carried from block to block
-    init_frame_context(&ctx);
-    parse_frame_header(streams, &ctx, dict);
-    frame_context_apply_dict(&ctx, dict);
+    init_frame_context(streams, &ctx, dict);
 
-    if (ctx.frame_content_size != 0 &&
-        ctx.frame_content_size > streams->dst_len) {
+    if (ctx.header.frame_content_size != 0 &&
+        ctx.header.frame_content_size > streams->dst_len) {
         OUT_SIZE();
     }
 
@@ -425,13 +442,40 @@ static void decode_data_frame(io_streams_t *streams, dictionary_t *dict) {
     free_frame_context(&ctx);
 }
 
-static void init_frame_context(frame_context_t *context) {
+/// Takes the information provided in the header and dictionary, and initializes
+/// the context for this frame
+static void init_frame_context(io_streams_t *streams, frame_context_t *context,
+                               dictionary_t *dict) {
     memset(context, 0x00, sizeof(frame_context_t));
 
+    // Parse data from the frame header
+    parse_frame_header(&context->header, streams->src, streams->src_len);
+    streams->src += context->header.header_size;
+    streams->src_len -= context->header.header_size;
+
     // Set up the offset history for the repeat offset commands
     context->previous_offsets[1] = 1;
     context->previous_offsets[2] = 4;
     context->previous_offsets[3] = 8;
+
+    {
+        // Allocate the window buffer
+        size_t buffer_size;
+        if (context->header.single_segment_flag) {
+            buffer_size = context->header.frame_content_size +
+                          (dict ? dict->content_size : 0);
+        } else {
+            buffer_size = context->header.window_size;
+        }
+
+        if (buffer_size > MAX_WINDOW_SIZE) {
+            ERROR("Requested window size too large");
+        }
+        cbuf_init(&context->window, buffer_size);
+    }
+
+    // Apply details from the dict if it exists
+    frame_context_apply_dict(context, dict);
 }
 
 static void free_frame_context(frame_context_t *context) {
@@ -446,13 +490,13 @@ static void free_frame_context(frame_context_t *context) {
     memset(context, 0, sizeof(frame_context_t));
 }
 
-static void parse_frame_header(io_streams_t *streams, frame_context_t *ctx,
-                               dictionary_t *dict) {
-    if (streams->src_len < 1) {
+static void parse_frame_header(frame_header_t *header, const u8 *src,
+                               size_t src_len) {
+    if (src_len < 1) {
         INP_SIZE();
     }
 
-    u8 descriptor = read_bits_LE(streams->src, 8, 0);
+    u8 descriptor = read_bits_LE(src, 8, 0);
 
     // decode frame header descriptor into flags
     u8 frame_content_size_flag = descriptor >> 6;
@@ -465,30 +509,28 @@ static void parse_frame_header(io_streams_t *streams, frame_context_t *ctx,
         CORRUPTION();
     }
 
-    streams->src++;
-    streams->src_len--;
+    int header_size = 1;
 
-    ctx->single_segment_flag = single_segment_flag;
-    ctx->content_checksum_flag = content_checksum_flag;
+    header->single_segment_flag = single_segment_flag;
+    header->content_checksum_flag = content_checksum_flag;
 
     // decode window size
     if (!single_segment_flag) {
-        if (streams->src_len < 1) {
+        if (src_len < header_size + 1) {
             INP_SIZE();
         }
 
         // Use the algorithm from the specification to compute window size
         // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor
-        u8 window_descriptor = read_bits_LE(streams->src, 8, 0);
+        u8 window_descriptor = src[header_size];
         u8 exponent = window_descriptor >> 3;
         u8 mantissa = window_descriptor & 7;
 
         size_t window_base = (size_t)1 << (10 + exponent);
         size_t window_add = (window_base / 8) * mantissa;
-        ctx->window_size = window_base + window_add;
+        header->window_size = window_base + window_add;
 
-        streams->src++;
-        streams->src_len--;
+        header_size++;
     }
 
     // decode dictionary id if it exists
@@ -496,52 +538,40 @@ static void parse_frame_header(io_streams_t *streams, frame_context_t *ctx,
         const int bytes_array[] = {0, 1, 2, 4};
         const int bytes = bytes_array[dictionary_id_flag];
 
-        if (streams->src_len < bytes) {
+        if (src_len < header_size + bytes) {
             INP_SIZE();
         }
 
-        ctx->dictionary_id = read_bits_LE(streams->src, bytes * 8, 0);
-        streams->src += bytes;
-        streams->src_len -= bytes;
+        header->dictionary_id = read_bits_LE(src + header_size, bytes * 8, 0);
+
+        header_size += bytes;
     } else {
-        ctx->dictionary_id = 0;
+        header->dictionary_id = 0;
     }
 
     // decode frame content size if it exists
     if (single_segment_flag || frame_content_size_flag) {
         // if frame_content_size_flag == 0 but single_segment_flag is set, we
-        // still
-        // have a 1 byte field
+        // still have a 1 byte field
         const int bytes_array[] = {1, 2, 4, 8};
         const int bytes = bytes_array[frame_content_size_flag];
 
-        if (streams->src_len < bytes) {
+        if (src_len < header_size + bytes) {
             INP_SIZE();
         }
 
-        ctx->frame_content_size = read_bits_LE(streams->src, bytes * 8, 0);
+        header->frame_content_size =
+            read_bits_LE(src + header_size, bytes * 8, 0);
         if (bytes == 2) {
-            ctx->frame_content_size += 256;
+            header->frame_content_size += 256;
         }
 
-        streams->src += bytes;
-        streams->src_len -= bytes;
-    }
-
-    if (single_segment_flag) {
-        ctx->window_size =
-            ctx->frame_content_size + (dict ? dict->content_size : 0);
-        // We need to allocate a buffer to write to of size at least output +
-        // dict
-        // size
-        size_t size = ctx->frame_content_size + (dict ? dict->content_size : 0);
+        header_size += bytes;
+    } else {
+        header->frame_content_size = 0;
     }
 
-    // Allocate the window
-    if (ctx->window_size > MAX_WINDOW_SIZE) {
-        ERROR("Requested window size too large");
-    }
-    cbuf_init(&ctx->window, ctx->window_size);
+    header->header_size = header_size;
 }
 
 /// A dictionary acts as initializing values for the frame context before
@@ -552,7 +582,7 @@ static void frame_context_apply_dict(frame_context_t *ctx, dictionary_t *dict) {
     if (!dict || !dict->content)
         return;
 
-    if (ctx->dictionary_id == 0 && dict->dictionary_id != 0) {
+    if (ctx->header.dictionary_id == 0 && dict->dictionary_id != 0) {
         // The dictionary is unneeded, and shouldn't be used as it may interfere
         // with the default offset history
         return;
@@ -560,7 +590,8 @@ static void frame_context_apply_dict(frame_context_t *ctx, dictionary_t *dict) {
 
     // If the dictionary id is 0, it doesn't matter if we provide the wrong raw
     // content dict, it won't change anything
-    if (ctx->dictionary_id != 0 && ctx->dictionary_id != dict->dictionary_id) {
+    if (ctx->header.dictionary_id != 0 &&
+        ctx->header.dictionary_id != dict->dictionary_id) {
         ERROR("Wrong/no dictionary provided");
     }
 
@@ -575,8 +606,7 @@ static void frame_context_apply_dict(frame_context_t *ctx, dictionary_t *dict) {
     // be used in the table repeat modes
     if (dict->dictionary_id != 0) {
         // Deep copy the entropy tables so they can be freed independently of
-        // the
-        // dictionary struct
+        // the dictionary struct
         HUF_copy_dtable(&ctx->literals_dtable, &dict->literals_dtable);
         FSE_copy_dtable(&ctx->ll_dtable, &dict->ll_dtable);
         FSE_copy_dtable(&ctx->of_dtable, &dict->of_dtable);
@@ -590,14 +620,14 @@ static void frame_context_apply_dict(frame_context_t *ctx, dictionary_t *dict) {
 /// Decompress the data from a frame block by block
 static void decompress_data(io_streams_t *streams, frame_context_t *ctx) {
 
-    u8 last_block = 0;
+    int last_block = 0;
     do {
         if (streams->src_len < 3) {
             INP_SIZE();
         }
         // Parse the block header
         last_block = streams->src[0] & 1;
-        u8 block_type = (streams->src[0] >> 1) & 3;
+        int block_type = (streams->src[0] >> 1) & 3;
         size_t block_len = read_bits_LE(streams->src, 21, 3);
 
         streams->src += 3;
@@ -648,6 +678,10 @@ static void decompress_data(io_streams_t *streams, frame_context_t *ctx) {
             // Compressed block, this is mode complex
             decompress_block(streams, ctx, block_len);
             break;
+        case 3:
+            // Reserved block type
+            CORRUPTION();
+            break;
         }
     } while (!last_block);
 
@@ -656,10 +690,9 @@ static void decompress_data(io_streams_t *streams, frame_context_t *ctx) {
     streams->dst += written;
     streams->dst_len -= written;
 
-    if (ctx->content_checksum_flag) {
+    if (ctx->header.content_checksum_flag) {
         // This program does not support checking the checksum, so skip over it
-        // if
-        // it's present
+        // if it's present
         if (streams->src_len < 4) {
             INP_SIZE();
         }
@@ -1312,6 +1345,126 @@ static size_t execute_sequences(io_streams_t *streams, frame_context_t *ctx,
 }
 /******* END SEQUENCE EXECUTION ***********************************************/
 
+/******* OUTPUT SIZE COUNTING *************************************************/
+size_t traverse_frame(frame_header_t *header, const u8 *src, size_t src_len);
+
+/// Get the decompressed size of an input stream so memory can be allocated in
+/// advance.
+/// This is more complex than the implementation in the reference
+/// implementation, as this API allows for the decompression of multiple
+/// concatenated frames.
+size_t ZSTD_get_decompressed_size(const void *src, size_t src_len) {
+  const u8 *ip = (const u8 *) src;
+  size_t dst_size = 0;
+
+  // Each frame header only gives us the size of its frame, so iterate over all
+  // frames
+  while (src_len > 0) {
+    if (src_len < 4) {
+      INP_SIZE();
+    }
+
+    u32 magic_number = read_bits_LE(ip, 32, 0);
+
+    ip += 4;
+    src_len -= 4;
+    if (magic_number >= 0x184D2A50U && magic_number <= 0x184D2A5F) {
+        // skippable frame, this has no impact on output size
+        if (src_len < 4) {
+            INP_SIZE();
+        }
+        size_t frame_size = read_bits_LE(ip, 32, 32);
+
+        if (src_len < 4 + frame_size) {
+            INP_SIZE();
+        }
+
+        // skip over frame
+        ip += 4 + frame_size;
+        src_len -= 4 + frame_size;
+    } else if (magic_number == 0xFD2FB528U) {
+        // ZSTD frame
+        frame_header_t header;
+        parse_frame_header(&header, ip, src_len);
+
+        if (header.frame_content_size == 0 && !header.single_segment_flag) {
+            // Content size not provided, we can't tell
+            return -1;
+        }
+
+        dst_size += header.frame_content_size;
+
+        // we need to traverse the frame to find when the next one starts
+        size_t traversed = traverse_frame(&header, ip, src_len);
+        ip += traversed;
+        src_len -= traversed;
+    } else {
+        // not a real frame
+        ERROR("Invalid magic number");
+    }
+  }
+
+  return dst_size;
+}
+
+/// Iterate over each block in a frame to find the end of it, to get to the
+/// start of the next frame
+size_t traverse_frame(frame_header_t *header, const u8 *src, size_t src_len) {
+    const u8 *const src_beg = src;
+    const u8 *const src_end = src + src_len;
+    src += header->header_size;
+    src_len += header->header_size;
+
+    int last_block = 0;
+
+    do {
+        if (src + 3 > src_end) {
+            INP_SIZE();
+        }
+        // Parse the block header
+        last_block = src[0] & 1;
+        int block_type = (src[0] >> 1) & 3;
+        size_t block_len = read_bits_LE(src, 21, 3);
+
+        src += 3;
+        switch (block_type) {
+        case 0: // Raw block, block_len bytes
+            if (src + block_len > src_end) {
+                INP_SIZE();
+            }
+            src += block_len;
+            break;
+        case 1: // RLE block, 1 byte
+            if (src + 1 > src_end) {
+                INP_SIZE();
+            }
+            src++;
+            break;
+        case 2: // Compressed block, compressed size is block_len
+            if (src + block_len > src_end) {
+                INP_SIZE();
+            }
+            src += block_len;
+            break;
+        case 3:
+            // Reserved block type
+            CORRUPTION();
+            break;
+        }
+    } while (!last_block);
+
+    if (header->content_checksum_flag) {
+        if (src + 4 > src_end) {
+            INP_SIZE();
+        }
+        src += 4;
+    }
+
+    return src - src_beg;
+}
+
+/******* END OUTPUT SIZE COUNTING *********************************************/
+
 /******* DICTIONARY PARSING ***************************************************/
 static void init_raw_content_dict(dictionary_t *dict, const u8 *src,
                                   size_t src_len);
@@ -1952,8 +2105,8 @@ static void FSE_init_dtable(FSE_dtable *dtable, const i16 *norm_freqs,
                      high_threshold); // Make sure we don't occupy a spot taken
                                       // by the low prob symbols
             // Note: no other collision checking is necessary as `step` is
-            // coprime to
-            // `size`, so the cycle will visit each position exactly once
+            // coprime to `size`, so the cycle will visit each position exactly
+            // once
         }
     }
     if (pos != 0) {
@@ -1964,13 +2117,11 @@ static void FSE_init_dtable(FSE_dtable *dtable, const i16 *norm_freqs,
     for (int i = 0; i < size; i++) {
         u8 symbol = dtable->symbols[i];
         u16 next_state_desc = state_desc[symbol]++;
-        // Fills in the table appropriately
-        // next_state_desc increases by symbol over time, decreasing number of
-        // bits
+        // Fills in the table appropriately next_state_desc increases by symbol
+        // over time, decreasing number of bits
         dtable->num_bits[i] = (u8)(accuracy_log - log2inf(next_state_desc));
         // baseline increases until the bit threshold is passed, at which point
-        // it
-        // resets to 0
+        // it resets to 0
         dtable->new_state_base[i] =
             ((u16)next_state_desc << dtable->num_bits[i]) - size;
     }
@@ -2057,8 +2208,7 @@ static void FSE_init_dtable_rle(FSE_dtable *dtable, u8 symb) {
     dtable->new_state_base = malloc(sizeof(u16));
 
     // This setup will always have a state of 0, always return symbol `symb`,
-    // and
-    // never consume any bits
+    // and never consume any bits
     dtable->symbols[0] = symb;
     dtable->num_bits[0] = 0;
     dtable->new_state_base[0] = 0;
index 3671678b170d9a9337f6afce7870dba7b1ef3cfa..3e1bc568fe985b7d15d351358d0d46dd16cd3870 100644 (file)
@@ -3,4 +3,5 @@ size_t ZSTD_decompress(void *dst, size_t dst_len, const void *src,
 size_t ZSTD_decompress_with_dict(void *dst, size_t dst_len, const void *src,
                                  size_t src_len, const void *dict,
                                  size_t dict_len);
+size_t ZSTD_get_decompressed_size(const void *src, size_t src_len);