From: Sean Purcell Date: Mon, 30 Jan 2017 22:42:21 +0000 (-0800) Subject: Added ZSTD_get_decompressed_size X-Git-Tag: v1.1.4~1^2~77^2~7 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5657e0e07d4ecb23300390c0e95d4f9ec4ca1d66;p=thirdparty%2Fzstd.git Added ZSTD_get_decompressed_size Since this implementation handles multiple concatenated frames, to determine decompressed size we must traverse the entire input, checking each frame's frame_content_size field --- diff --git a/contrib/educational_decoder/harness.c b/contrib/educational_decoder/harness.c index 6f4765d9d..c44100fff 100644 --- a/contrib/educational_decoder/harness.c +++ b/contrib/educational_decoder/harness.c @@ -5,8 +5,8 @@ typedef unsigned char u8; -// There's no good way to determine output size without decompressing -// For this example assume we'll never decompress at a ratio larger than 16 +// If the data doesn't have decompressed size with it, fallback on assuming the +// compression ratio is at most 16 #define MAX_COMPRESSION_RATIO (16) u8 *input; @@ -14,80 +14,89 @@ u8 *output; u8 *dict; size_t read_file(const char *path, u8 **ptr) { - FILE *f = fopen(path, "rb"); - if (!f) { - fprintf(stderr, "failed to open file %s\n", path); - exit(1); - } - - fseek(f, 0L, SEEK_END); - size_t size = ftell(f); - rewind(f); - - *ptr = malloc(size); - if (!ptr) { - fprintf(stderr, "failed to allocate memory to hold %s\n", path); - exit(1); - } - - size_t pos = 0; - while (!feof(f)) { - size_t read = fread(&(*ptr)[pos], 1, size, f); - if (ferror(f)) { - fprintf(stderr, "error while reading file %s\n", path); - exit(1); + FILE *f = fopen(path, "rb"); + if (!f) { + fprintf(stderr, "failed to open file %s\n", path); + exit(1); } - pos += read; - } - fclose(f); + fseek(f, 0L, SEEK_END); + size_t size = ftell(f); + rewind(f); - return pos; + *ptr = malloc(size); + if (!ptr) { + fprintf(stderr, "failed to allocate memory to hold %s\n", path); + exit(1); + } + + size_t pos = 0; + while (!feof(f)) { + size_t read = fread(&(*ptr)[pos], 1, size, f); + if (ferror(f)) { + fprintf(stderr, "error while reading file %s\n", path); + exit(1); + } + pos += read; + } + + fclose(f); + + return pos; } void write_file(const char *path, const u8 *ptr, size_t size) { - FILE *f = fopen(path, "wb"); - - size_t written = 0; - while (written < size) { - written += fwrite(&ptr[written], 1, size, f); - if (ferror(f)) { - fprintf(stderr, "error while writing file %s\n", path); - exit(1); + FILE *f = fopen(path, "wb"); + + size_t written = 0; + while (written < size) { + written += fwrite(&ptr[written], 1, size, f); + if (ferror(f)) { + fprintf(stderr, "error while writing file %s\n", path); + exit(1); + } } - } - fclose(f); + fclose(f); } int main(int argc, char **argv) { - if (argc < 3) { - fprintf(stderr, "usage: %s [dictionary]\n", argv[0]); - - return 1; - } - - size_t input_size = read_file(argv[1], &input); - size_t dict_size = 0; - if (argc >= 4) { - dict_size = read_file(argv[3], &dict); - } - - output = malloc(MAX_COMPRESSION_RATIO * input_size); - if (!output) { - fprintf(stderr, "failed to allocate memory\n"); - return 1; - } - - size_t decompressed = - ZSTD_decompress_with_dict(output, input_size * MAX_COMPRESSION_RATIO, - input, input_size, dict, dict_size); - - write_file(argv[2], output, decompressed); - - free(input); - free(output); - free(dict); - input = output = dict = NULL; + if (argc < 3) { + fprintf(stderr, "usage: %s [dictionary]\n", + argv[0]); + + return 1; + } + + size_t input_size = read_file(argv[1], &input); + size_t dict_size = 0; + if (argc >= 4) { + dict_size = read_file(argv[3], &dict); + } + + size_t decompressed_size = ZSTD_get_decompressed_size(input, input_size); + if (decompressed_size == -1) { + decompressed_size = MAX_COMPRESSION_RATIO * input_size; + fprintf(stderr, "WARNING: Compressed data does contain decompressed " + "size, going to assume the compression ratio is at " + "most %d (decompressed size of at most %lld\n", + MAX_COMPRESSION_RATIO, decompressed_size); + } + output = malloc(decompressed_size); + if (!output) { + fprintf(stderr, "failed to allocate memory\n"); + return 1; + } + + size_t decompressed = + ZSTD_decompress_with_dict(output, input_size * MAX_COMPRESSION_RATIO, + input, input_size, dict, dict_size); + + write_file(argv[2], output, decompressed); + + free(input); + free(output); + free(dict); + input = output = dict = NULL; } diff --git a/contrib/educational_decoder/zstd_decompress.c b/contrib/educational_decoder/zstd_decompress.c index 8dc159008..7b04c4b29 100644 --- a/contrib/educational_decoder/zstd_decompress.c +++ b/contrib/educational_decoder/zstd_decompress.c @@ -16,6 +16,10 @@ size_t ZSTD_decompress_with_dict(void *dst, size_t dst_len, const void *src, size_t src_len, const void *dict, size_t dict_len); +/// Get the decompressed size of an input stream so memory can be allocated in +/// advance +size_t ZSTD_get_decompressed_size(const void *src, size_t src_len); + /******* UTILITY MACROS AND TYPES *********************************************/ #define MAX_WINDOW_SIZE ((size_t)512 << 20) // Max block size decompressed size is 128 KB and literal blocks must be smaller @@ -232,11 +236,31 @@ typedef struct { size_t src_len; } io_streams_t; -/// The context needed to decode blocks in a frame +/// A small structure that can be reused in various places that need to access +/// frame header information typedef struct { + // The size of window that we need to be able to contiguously store for + // references size_t window_size; + // The total output size of this compressed frame size_t frame_content_size; + // The dictionary id if this frame uses one + u32 dictionary_id; + + // Whether or not the content of this frame has a checksum + int content_checksum_flag; + // Whether or not the output for this frame is in a single segment + int single_segment_flag; + + // The size in bytes of this header + int header_size; +} frame_header_t; + +/// The context needed to decode blocks in a frame +typedef struct { + frame_header_t header; + // The total amount of data available for backreferences, to determine if an // offset too large to be correct size_t current_total_output; @@ -255,12 +279,6 @@ typedef struct { // The last 3 offsets for the special "repeat offsets". Array size is 4 so // that previous_offsets[1] corresponds to the most recent offset u64 previous_offsets[4]; - - // The dictionary id for this frame if one exists - u32 dictionary_id; - - int single_segment_flag; - int content_checksum_flag; } frame_context_t; /// The decoded contents of a dictionary so that it doesn't have to be repeated @@ -364,10 +382,11 @@ size_t ZSTD_decompress_with_dict(void *dst, size_t dst_len, const void *src, /******* FRAME DECODING ******************************************************/ static void decode_data_frame(io_streams_t *streams, dictionary_t *dict); -static void init_frame_context(frame_context_t *context); -static void free_frame_context(frame_context_t *context); -static void parse_frame_header(io_streams_t *streams, frame_context_t *ctx, +static void init_frame_context(io_streams_t *streams, frame_context_t *context, dictionary_t *dict); +static void free_frame_context(frame_context_t *context); +static void parse_frame_header(frame_header_t *header, const u8 *src, + size_t src_len); static void frame_context_apply_dict(frame_context_t *ctx, dictionary_t *dict); static void decompress_data(io_streams_t *streams, frame_context_t *ctx); @@ -411,12 +430,10 @@ static void decode_data_frame(io_streams_t *streams, dictionary_t *dict) { frame_context_t ctx; // Initialize the context that needs to be carried from block to block - init_frame_context(&ctx); - parse_frame_header(streams, &ctx, dict); - frame_context_apply_dict(&ctx, dict); + init_frame_context(streams, &ctx, dict); - if (ctx.frame_content_size != 0 && - ctx.frame_content_size > streams->dst_len) { + if (ctx.header.frame_content_size != 0 && + ctx.header.frame_content_size > streams->dst_len) { OUT_SIZE(); } @@ -425,13 +442,40 @@ static void decode_data_frame(io_streams_t *streams, dictionary_t *dict) { free_frame_context(&ctx); } -static void init_frame_context(frame_context_t *context) { +/// Takes the information provided in the header and dictionary, and initializes +/// the context for this frame +static void init_frame_context(io_streams_t *streams, frame_context_t *context, + dictionary_t *dict) { memset(context, 0x00, sizeof(frame_context_t)); + // Parse data from the frame header + parse_frame_header(&context->header, streams->src, streams->src_len); + streams->src += context->header.header_size; + streams->src_len -= context->header.header_size; + // Set up the offset history for the repeat offset commands context->previous_offsets[1] = 1; context->previous_offsets[2] = 4; context->previous_offsets[3] = 8; + + { + // Allocate the window buffer + size_t buffer_size; + if (context->header.single_segment_flag) { + buffer_size = context->header.frame_content_size + + (dict ? dict->content_size : 0); + } else { + buffer_size = context->header.window_size; + } + + if (buffer_size > MAX_WINDOW_SIZE) { + ERROR("Requested window size too large"); + } + cbuf_init(&context->window, buffer_size); + } + + // Apply details from the dict if it exists + frame_context_apply_dict(context, dict); } static void free_frame_context(frame_context_t *context) { @@ -446,13 +490,13 @@ static void free_frame_context(frame_context_t *context) { memset(context, 0, sizeof(frame_context_t)); } -static void parse_frame_header(io_streams_t *streams, frame_context_t *ctx, - dictionary_t *dict) { - if (streams->src_len < 1) { +static void parse_frame_header(frame_header_t *header, const u8 *src, + size_t src_len) { + if (src_len < 1) { INP_SIZE(); } - u8 descriptor = read_bits_LE(streams->src, 8, 0); + u8 descriptor = read_bits_LE(src, 8, 0); // decode frame header descriptor into flags u8 frame_content_size_flag = descriptor >> 6; @@ -465,30 +509,28 @@ static void parse_frame_header(io_streams_t *streams, frame_context_t *ctx, CORRUPTION(); } - streams->src++; - streams->src_len--; + int header_size = 1; - ctx->single_segment_flag = single_segment_flag; - ctx->content_checksum_flag = content_checksum_flag; + header->single_segment_flag = single_segment_flag; + header->content_checksum_flag = content_checksum_flag; // decode window size if (!single_segment_flag) { - if (streams->src_len < 1) { + if (src_len < header_size + 1) { INP_SIZE(); } // Use the algorithm from the specification to compute window size // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor - u8 window_descriptor = read_bits_LE(streams->src, 8, 0); + u8 window_descriptor = src[header_size]; u8 exponent = window_descriptor >> 3; u8 mantissa = window_descriptor & 7; size_t window_base = (size_t)1 << (10 + exponent); size_t window_add = (window_base / 8) * mantissa; - ctx->window_size = window_base + window_add; + header->window_size = window_base + window_add; - streams->src++; - streams->src_len--; + header_size++; } // decode dictionary id if it exists @@ -496,52 +538,40 @@ static void parse_frame_header(io_streams_t *streams, frame_context_t *ctx, const int bytes_array[] = {0, 1, 2, 4}; const int bytes = bytes_array[dictionary_id_flag]; - if (streams->src_len < bytes) { + if (src_len < header_size + bytes) { INP_SIZE(); } - ctx->dictionary_id = read_bits_LE(streams->src, bytes * 8, 0); - streams->src += bytes; - streams->src_len -= bytes; + header->dictionary_id = read_bits_LE(src + header_size, bytes * 8, 0); + + header_size += bytes; } else { - ctx->dictionary_id = 0; + header->dictionary_id = 0; } // decode frame content size if it exists if (single_segment_flag || frame_content_size_flag) { // if frame_content_size_flag == 0 but single_segment_flag is set, we - // still - // have a 1 byte field + // still have a 1 byte field const int bytes_array[] = {1, 2, 4, 8}; const int bytes = bytes_array[frame_content_size_flag]; - if (streams->src_len < bytes) { + if (src_len < header_size + bytes) { INP_SIZE(); } - ctx->frame_content_size = read_bits_LE(streams->src, bytes * 8, 0); + header->frame_content_size = + read_bits_LE(src + header_size, bytes * 8, 0); if (bytes == 2) { - ctx->frame_content_size += 256; + header->frame_content_size += 256; } - streams->src += bytes; - streams->src_len -= bytes; - } - - if (single_segment_flag) { - ctx->window_size = - ctx->frame_content_size + (dict ? dict->content_size : 0); - // We need to allocate a buffer to write to of size at least output + - // dict - // size - size_t size = ctx->frame_content_size + (dict ? dict->content_size : 0); + header_size += bytes; + } else { + header->frame_content_size = 0; } - // Allocate the window - if (ctx->window_size > MAX_WINDOW_SIZE) { - ERROR("Requested window size too large"); - } - cbuf_init(&ctx->window, ctx->window_size); + header->header_size = header_size; } /// A dictionary acts as initializing values for the frame context before @@ -552,7 +582,7 @@ static void frame_context_apply_dict(frame_context_t *ctx, dictionary_t *dict) { if (!dict || !dict->content) return; - if (ctx->dictionary_id == 0 && dict->dictionary_id != 0) { + if (ctx->header.dictionary_id == 0 && dict->dictionary_id != 0) { // The dictionary is unneeded, and shouldn't be used as it may interfere // with the default offset history return; @@ -560,7 +590,8 @@ static void frame_context_apply_dict(frame_context_t *ctx, dictionary_t *dict) { // If the dictionary id is 0, it doesn't matter if we provide the wrong raw // content dict, it won't change anything - if (ctx->dictionary_id != 0 && ctx->dictionary_id != dict->dictionary_id) { + if (ctx->header.dictionary_id != 0 && + ctx->header.dictionary_id != dict->dictionary_id) { ERROR("Wrong/no dictionary provided"); } @@ -575,8 +606,7 @@ static void frame_context_apply_dict(frame_context_t *ctx, dictionary_t *dict) { // be used in the table repeat modes if (dict->dictionary_id != 0) { // Deep copy the entropy tables so they can be freed independently of - // the - // dictionary struct + // the dictionary struct HUF_copy_dtable(&ctx->literals_dtable, &dict->literals_dtable); FSE_copy_dtable(&ctx->ll_dtable, &dict->ll_dtable); FSE_copy_dtable(&ctx->of_dtable, &dict->of_dtable); @@ -590,14 +620,14 @@ static void frame_context_apply_dict(frame_context_t *ctx, dictionary_t *dict) { /// Decompress the data from a frame block by block static void decompress_data(io_streams_t *streams, frame_context_t *ctx) { - u8 last_block = 0; + int last_block = 0; do { if (streams->src_len < 3) { INP_SIZE(); } // Parse the block header last_block = streams->src[0] & 1; - u8 block_type = (streams->src[0] >> 1) & 3; + int block_type = (streams->src[0] >> 1) & 3; size_t block_len = read_bits_LE(streams->src, 21, 3); streams->src += 3; @@ -648,6 +678,10 @@ static void decompress_data(io_streams_t *streams, frame_context_t *ctx) { // Compressed block, this is mode complex decompress_block(streams, ctx, block_len); break; + case 3: + // Reserved block type + CORRUPTION(); + break; } } while (!last_block); @@ -656,10 +690,9 @@ static void decompress_data(io_streams_t *streams, frame_context_t *ctx) { streams->dst += written; streams->dst_len -= written; - if (ctx->content_checksum_flag) { + if (ctx->header.content_checksum_flag) { // This program does not support checking the checksum, so skip over it - // if - // it's present + // if it's present if (streams->src_len < 4) { INP_SIZE(); } @@ -1312,6 +1345,126 @@ static size_t execute_sequences(io_streams_t *streams, frame_context_t *ctx, } /******* END SEQUENCE EXECUTION ***********************************************/ +/******* OUTPUT SIZE COUNTING *************************************************/ +size_t traverse_frame(frame_header_t *header, const u8 *src, size_t src_len); + +/// Get the decompressed size of an input stream so memory can be allocated in +/// advance. +/// This is more complex than the implementation in the reference +/// implementation, as this API allows for the decompression of multiple +/// concatenated frames. +size_t ZSTD_get_decompressed_size(const void *src, size_t src_len) { + const u8 *ip = (const u8 *) src; + size_t dst_size = 0; + + // Each frame header only gives us the size of its frame, so iterate over all + // frames + while (src_len > 0) { + if (src_len < 4) { + INP_SIZE(); + } + + u32 magic_number = read_bits_LE(ip, 32, 0); + + ip += 4; + src_len -= 4; + if (magic_number >= 0x184D2A50U && magic_number <= 0x184D2A5F) { + // skippable frame, this has no impact on output size + if (src_len < 4) { + INP_SIZE(); + } + size_t frame_size = read_bits_LE(ip, 32, 32); + + if (src_len < 4 + frame_size) { + INP_SIZE(); + } + + // skip over frame + ip += 4 + frame_size; + src_len -= 4 + frame_size; + } else if (magic_number == 0xFD2FB528U) { + // ZSTD frame + frame_header_t header; + parse_frame_header(&header, ip, src_len); + + if (header.frame_content_size == 0 && !header.single_segment_flag) { + // Content size not provided, we can't tell + return -1; + } + + dst_size += header.frame_content_size; + + // we need to traverse the frame to find when the next one starts + size_t traversed = traverse_frame(&header, ip, src_len); + ip += traversed; + src_len -= traversed; + } else { + // not a real frame + ERROR("Invalid magic number"); + } + } + + return dst_size; +} + +/// Iterate over each block in a frame to find the end of it, to get to the +/// start of the next frame +size_t traverse_frame(frame_header_t *header, const u8 *src, size_t src_len) { + const u8 *const src_beg = src; + const u8 *const src_end = src + src_len; + src += header->header_size; + src_len += header->header_size; + + int last_block = 0; + + do { + if (src + 3 > src_end) { + INP_SIZE(); + } + // Parse the block header + last_block = src[0] & 1; + int block_type = (src[0] >> 1) & 3; + size_t block_len = read_bits_LE(src, 21, 3); + + src += 3; + switch (block_type) { + case 0: // Raw block, block_len bytes + if (src + block_len > src_end) { + INP_SIZE(); + } + src += block_len; + break; + case 1: // RLE block, 1 byte + if (src + 1 > src_end) { + INP_SIZE(); + } + src++; + break; + case 2: // Compressed block, compressed size is block_len + if (src + block_len > src_end) { + INP_SIZE(); + } + src += block_len; + break; + case 3: + // Reserved block type + CORRUPTION(); + break; + } + } while (!last_block); + + if (header->content_checksum_flag) { + if (src + 4 > src_end) { + INP_SIZE(); + } + src += 4; + } + + return src - src_beg; +} + +/******* END OUTPUT SIZE COUNTING *********************************************/ + /******* DICTIONARY PARSING ***************************************************/ static void init_raw_content_dict(dictionary_t *dict, const u8 *src, size_t src_len); @@ -1952,8 +2105,8 @@ static void FSE_init_dtable(FSE_dtable *dtable, const i16 *norm_freqs, high_threshold); // Make sure we don't occupy a spot taken // by the low prob symbols // Note: no other collision checking is necessary as `step` is - // coprime to - // `size`, so the cycle will visit each position exactly once + // coprime to `size`, so the cycle will visit each position exactly + // once } } if (pos != 0) { @@ -1964,13 +2117,11 @@ static void FSE_init_dtable(FSE_dtable *dtable, const i16 *norm_freqs, for (int i = 0; i < size; i++) { u8 symbol = dtable->symbols[i]; u16 next_state_desc = state_desc[symbol]++; - // Fills in the table appropriately - // next_state_desc increases by symbol over time, decreasing number of - // bits + // Fills in the table appropriately next_state_desc increases by symbol + // over time, decreasing number of bits dtable->num_bits[i] = (u8)(accuracy_log - log2inf(next_state_desc)); // baseline increases until the bit threshold is passed, at which point - // it - // resets to 0 + // it resets to 0 dtable->new_state_base[i] = ((u16)next_state_desc << dtable->num_bits[i]) - size; } @@ -2057,8 +2208,7 @@ static void FSE_init_dtable_rle(FSE_dtable *dtable, u8 symb) { dtable->new_state_base = malloc(sizeof(u16)); // This setup will always have a state of 0, always return symbol `symb`, - // and - // never consume any bits + // and never consume any bits dtable->symbols[0] = symb; dtable->num_bits[0] = 0; dtable->new_state_base[0] = 0; diff --git a/contrib/educational_decoder/zstd_decompress.h b/contrib/educational_decoder/zstd_decompress.h index 3671678b1..3e1bc568f 100644 --- a/contrib/educational_decoder/zstd_decompress.h +++ b/contrib/educational_decoder/zstd_decompress.h @@ -3,4 +3,5 @@ size_t ZSTD_decompress(void *dst, size_t dst_len, const void *src, size_t ZSTD_decompress_with_dict(void *dst, size_t dst_len, const void *src, size_t src_len, const void *dict, size_t dict_len); +size_t ZSTD_get_decompressed_size(const void *src, size_t src_len);