typedef unsigned char u8;
-// There's no good way to determine output size without decompressing
-// For this example assume we'll never decompress at a ratio larger than 16
+// If the data doesn't have decompressed size with it, fallback on assuming the
+// compression ratio is at most 16
#define MAX_COMPRESSION_RATIO (16)
u8 *input;
u8 *dict;
size_t read_file(const char *path, u8 **ptr) {
- FILE *f = fopen(path, "rb");
- if (!f) {
- fprintf(stderr, "failed to open file %s\n", path);
- exit(1);
- }
-
- fseek(f, 0L, SEEK_END);
- size_t size = ftell(f);
- rewind(f);
-
- *ptr = malloc(size);
- if (!ptr) {
- fprintf(stderr, "failed to allocate memory to hold %s\n", path);
- exit(1);
- }
-
- size_t pos = 0;
- while (!feof(f)) {
- size_t read = fread(&(*ptr)[pos], 1, size, f);
- if (ferror(f)) {
- fprintf(stderr, "error while reading file %s\n", path);
- exit(1);
+ FILE *f = fopen(path, "rb");
+ if (!f) {
+ fprintf(stderr, "failed to open file %s\n", path);
+ exit(1);
}
- pos += read;
- }
- fclose(f);
+ fseek(f, 0L, SEEK_END);
+ size_t size = ftell(f);
+ rewind(f);
- return pos;
+ *ptr = malloc(size);
+ if (!ptr) {
+ fprintf(stderr, "failed to allocate memory to hold %s\n", path);
+ exit(1);
+ }
+
+ size_t pos = 0;
+ while (!feof(f)) {
+ size_t read = fread(&(*ptr)[pos], 1, size, f);
+ if (ferror(f)) {
+ fprintf(stderr, "error while reading file %s\n", path);
+ exit(1);
+ }
+ pos += read;
+ }
+
+ fclose(f);
+
+ return pos;
}
void write_file(const char *path, const u8 *ptr, size_t size) {
- FILE *f = fopen(path, "wb");
-
- size_t written = 0;
- while (written < size) {
- written += fwrite(&ptr[written], 1, size, f);
- if (ferror(f)) {
- fprintf(stderr, "error while writing file %s\n", path);
- exit(1);
+ FILE *f = fopen(path, "wb");
+
+ size_t written = 0;
+ while (written < size) {
+ written += fwrite(&ptr[written], 1, size, f);
+ if (ferror(f)) {
+ fprintf(stderr, "error while writing file %s\n", path);
+ exit(1);
+ }
}
- }
- fclose(f);
+ fclose(f);
}
int main(int argc, char **argv) {
- if (argc < 3) {
- fprintf(stderr, "usage: %s <file.zst> <out_path> [dictionary]\n", argv[0]);
-
- return 1;
- }
-
- size_t input_size = read_file(argv[1], &input);
- size_t dict_size = 0;
- if (argc >= 4) {
- dict_size = read_file(argv[3], &dict);
- }
-
- output = malloc(MAX_COMPRESSION_RATIO * input_size);
- if (!output) {
- fprintf(stderr, "failed to allocate memory\n");
- return 1;
- }
-
- size_t decompressed =
- ZSTD_decompress_with_dict(output, input_size * MAX_COMPRESSION_RATIO,
- input, input_size, dict, dict_size);
-
- write_file(argv[2], output, decompressed);
-
- free(input);
- free(output);
- free(dict);
- input = output = dict = NULL;
+ if (argc < 3) {
+ fprintf(stderr, "usage: %s <file.zst> <out_path> [dictionary]\n",
+ argv[0]);
+
+ return 1;
+ }
+
+ size_t input_size = read_file(argv[1], &input);
+ size_t dict_size = 0;
+ if (argc >= 4) {
+ dict_size = read_file(argv[3], &dict);
+ }
+
+ size_t decompressed_size = ZSTD_get_decompressed_size(input, input_size);
+ if (decompressed_size == -1) {
+ decompressed_size = MAX_COMPRESSION_RATIO * input_size;
+ fprintf(stderr, "WARNING: Compressed data does contain decompressed "
+ "size, going to assume the compression ratio is at "
+ "most %d (decompressed size of at most %lld\n",
+ MAX_COMPRESSION_RATIO, decompressed_size);
+ }
+ output = malloc(decompressed_size);
+ if (!output) {
+ fprintf(stderr, "failed to allocate memory\n");
+ return 1;
+ }
+
+ size_t decompressed =
+ ZSTD_decompress_with_dict(output, input_size * MAX_COMPRESSION_RATIO,
+ input, input_size, dict, dict_size);
+
+ write_file(argv[2], output, decompressed);
+
+ free(input);
+ free(output);
+ free(dict);
+ input = output = dict = NULL;
}
size_t src_len, const void *dict,
size_t dict_len);
+/// Get the decompressed size of an input stream so memory can be allocated in
+/// advance
+size_t ZSTD_get_decompressed_size(const void *src, size_t src_len);
+
/******* UTILITY MACROS AND TYPES *********************************************/
#define MAX_WINDOW_SIZE ((size_t)512 << 20)
// Max block size decompressed size is 128 KB and literal blocks must be smaller
size_t src_len;
} io_streams_t;
-/// The context needed to decode blocks in a frame
+/// A small structure that can be reused in various places that need to access
+/// frame header information
typedef struct {
+ // The size of window that we need to be able to contiguously store for
+ // references
size_t window_size;
+ // The total output size of this compressed frame
size_t frame_content_size;
+ // The dictionary id if this frame uses one
+ u32 dictionary_id;
+
+ // Whether or not the content of this frame has a checksum
+ int content_checksum_flag;
+ // Whether or not the output for this frame is in a single segment
+ int single_segment_flag;
+
+ // The size in bytes of this header
+ int header_size;
+} frame_header_t;
+
+/// The context needed to decode blocks in a frame
+typedef struct {
+ frame_header_t header;
+
// The total amount of data available for backreferences, to determine if an
// offset too large to be correct
size_t current_total_output;
// The last 3 offsets for the special "repeat offsets". Array size is 4 so
// that previous_offsets[1] corresponds to the most recent offset
u64 previous_offsets[4];
-
- // The dictionary id for this frame if one exists
- u32 dictionary_id;
-
- int single_segment_flag;
- int content_checksum_flag;
} frame_context_t;
/// The decoded contents of a dictionary so that it doesn't have to be repeated
/******* FRAME DECODING ******************************************************/
static void decode_data_frame(io_streams_t *streams, dictionary_t *dict);
-static void init_frame_context(frame_context_t *context);
-static void free_frame_context(frame_context_t *context);
-static void parse_frame_header(io_streams_t *streams, frame_context_t *ctx,
+static void init_frame_context(io_streams_t *streams, frame_context_t *context,
dictionary_t *dict);
+static void free_frame_context(frame_context_t *context);
+static void parse_frame_header(frame_header_t *header, const u8 *src,
+ size_t src_len);
static void frame_context_apply_dict(frame_context_t *ctx, dictionary_t *dict);
static void decompress_data(io_streams_t *streams, frame_context_t *ctx);
frame_context_t ctx;
// Initialize the context that needs to be carried from block to block
- init_frame_context(&ctx);
- parse_frame_header(streams, &ctx, dict);
- frame_context_apply_dict(&ctx, dict);
+ init_frame_context(streams, &ctx, dict);
- if (ctx.frame_content_size != 0 &&
- ctx.frame_content_size > streams->dst_len) {
+ if (ctx.header.frame_content_size != 0 &&
+ ctx.header.frame_content_size > streams->dst_len) {
OUT_SIZE();
}
free_frame_context(&ctx);
}
-static void init_frame_context(frame_context_t *context) {
+/// Takes the information provided in the header and dictionary, and initializes
+/// the context for this frame
+static void init_frame_context(io_streams_t *streams, frame_context_t *context,
+ dictionary_t *dict) {
memset(context, 0x00, sizeof(frame_context_t));
+ // Parse data from the frame header
+ parse_frame_header(&context->header, streams->src, streams->src_len);
+ streams->src += context->header.header_size;
+ streams->src_len -= context->header.header_size;
+
// Set up the offset history for the repeat offset commands
context->previous_offsets[1] = 1;
context->previous_offsets[2] = 4;
context->previous_offsets[3] = 8;
+
+ {
+ // Allocate the window buffer
+ size_t buffer_size;
+ if (context->header.single_segment_flag) {
+ buffer_size = context->header.frame_content_size +
+ (dict ? dict->content_size : 0);
+ } else {
+ buffer_size = context->header.window_size;
+ }
+
+ if (buffer_size > MAX_WINDOW_SIZE) {
+ ERROR("Requested window size too large");
+ }
+ cbuf_init(&context->window, buffer_size);
+ }
+
+ // Apply details from the dict if it exists
+ frame_context_apply_dict(context, dict);
}
static void free_frame_context(frame_context_t *context) {
memset(context, 0, sizeof(frame_context_t));
}
-static void parse_frame_header(io_streams_t *streams, frame_context_t *ctx,
- dictionary_t *dict) {
- if (streams->src_len < 1) {
+static void parse_frame_header(frame_header_t *header, const u8 *src,
+ size_t src_len) {
+ if (src_len < 1) {
INP_SIZE();
}
- u8 descriptor = read_bits_LE(streams->src, 8, 0);
+ u8 descriptor = read_bits_LE(src, 8, 0);
// decode frame header descriptor into flags
u8 frame_content_size_flag = descriptor >> 6;
CORRUPTION();
}
- streams->src++;
- streams->src_len--;
+ int header_size = 1;
- ctx->single_segment_flag = single_segment_flag;
- ctx->content_checksum_flag = content_checksum_flag;
+ header->single_segment_flag = single_segment_flag;
+ header->content_checksum_flag = content_checksum_flag;
// decode window size
if (!single_segment_flag) {
- if (streams->src_len < 1) {
+ if (src_len < header_size + 1) {
INP_SIZE();
}
// Use the algorithm from the specification to compute window size
// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor
- u8 window_descriptor = read_bits_LE(streams->src, 8, 0);
+ u8 window_descriptor = src[header_size];
u8 exponent = window_descriptor >> 3;
u8 mantissa = window_descriptor & 7;
size_t window_base = (size_t)1 << (10 + exponent);
size_t window_add = (window_base / 8) * mantissa;
- ctx->window_size = window_base + window_add;
+ header->window_size = window_base + window_add;
- streams->src++;
- streams->src_len--;
+ header_size++;
}
// decode dictionary id if it exists
const int bytes_array[] = {0, 1, 2, 4};
const int bytes = bytes_array[dictionary_id_flag];
- if (streams->src_len < bytes) {
+ if (src_len < header_size + bytes) {
INP_SIZE();
}
- ctx->dictionary_id = read_bits_LE(streams->src, bytes * 8, 0);
- streams->src += bytes;
- streams->src_len -= bytes;
+ header->dictionary_id = read_bits_LE(src + header_size, bytes * 8, 0);
+
+ header_size += bytes;
} else {
- ctx->dictionary_id = 0;
+ header->dictionary_id = 0;
}
// decode frame content size if it exists
if (single_segment_flag || frame_content_size_flag) {
// if frame_content_size_flag == 0 but single_segment_flag is set, we
- // still
- // have a 1 byte field
+ // still have a 1 byte field
const int bytes_array[] = {1, 2, 4, 8};
const int bytes = bytes_array[frame_content_size_flag];
- if (streams->src_len < bytes) {
+ if (src_len < header_size + bytes) {
INP_SIZE();
}
- ctx->frame_content_size = read_bits_LE(streams->src, bytes * 8, 0);
+ header->frame_content_size =
+ read_bits_LE(src + header_size, bytes * 8, 0);
if (bytes == 2) {
- ctx->frame_content_size += 256;
+ header->frame_content_size += 256;
}
- streams->src += bytes;
- streams->src_len -= bytes;
- }
-
- if (single_segment_flag) {
- ctx->window_size =
- ctx->frame_content_size + (dict ? dict->content_size : 0);
- // We need to allocate a buffer to write to of size at least output +
- // dict
- // size
- size_t size = ctx->frame_content_size + (dict ? dict->content_size : 0);
+ header_size += bytes;
+ } else {
+ header->frame_content_size = 0;
}
- // Allocate the window
- if (ctx->window_size > MAX_WINDOW_SIZE) {
- ERROR("Requested window size too large");
- }
- cbuf_init(&ctx->window, ctx->window_size);
+ header->header_size = header_size;
}
/// A dictionary acts as initializing values for the frame context before
if (!dict || !dict->content)
return;
- if (ctx->dictionary_id == 0 && dict->dictionary_id != 0) {
+ if (ctx->header.dictionary_id == 0 && dict->dictionary_id != 0) {
// The dictionary is unneeded, and shouldn't be used as it may interfere
// with the default offset history
return;
// If the dictionary id is 0, it doesn't matter if we provide the wrong raw
// content dict, it won't change anything
- if (ctx->dictionary_id != 0 && ctx->dictionary_id != dict->dictionary_id) {
+ if (ctx->header.dictionary_id != 0 &&
+ ctx->header.dictionary_id != dict->dictionary_id) {
ERROR("Wrong/no dictionary provided");
}
// be used in the table repeat modes
if (dict->dictionary_id != 0) {
// Deep copy the entropy tables so they can be freed independently of
- // the
- // dictionary struct
+ // the dictionary struct
HUF_copy_dtable(&ctx->literals_dtable, &dict->literals_dtable);
FSE_copy_dtable(&ctx->ll_dtable, &dict->ll_dtable);
FSE_copy_dtable(&ctx->of_dtable, &dict->of_dtable);
/// Decompress the data from a frame block by block
static void decompress_data(io_streams_t *streams, frame_context_t *ctx) {
- u8 last_block = 0;
+ int last_block = 0;
do {
if (streams->src_len < 3) {
INP_SIZE();
}
// Parse the block header
last_block = streams->src[0] & 1;
- u8 block_type = (streams->src[0] >> 1) & 3;
+ int block_type = (streams->src[0] >> 1) & 3;
size_t block_len = read_bits_LE(streams->src, 21, 3);
streams->src += 3;
// Compressed block, this is mode complex
decompress_block(streams, ctx, block_len);
break;
+ case 3:
+ // Reserved block type
+ CORRUPTION();
+ break;
}
} while (!last_block);
streams->dst += written;
streams->dst_len -= written;
- if (ctx->content_checksum_flag) {
+ if (ctx->header.content_checksum_flag) {
// This program does not support checking the checksum, so skip over it
- // if
- // it's present
+ // if it's present
if (streams->src_len < 4) {
INP_SIZE();
}
}
/******* END SEQUENCE EXECUTION ***********************************************/
+/******* OUTPUT SIZE COUNTING *************************************************/
+size_t traverse_frame(frame_header_t *header, const u8 *src, size_t src_len);
+
+/// Get the decompressed size of an input stream so memory can be allocated in
+/// advance.
+/// This is more complex than the implementation in the reference
+/// implementation, as this API allows for the decompression of multiple
+/// concatenated frames.
+size_t ZSTD_get_decompressed_size(const void *src, size_t src_len) {
+ const u8 *ip = (const u8 *) src;
+ size_t dst_size = 0;
+
+ // Each frame header only gives us the size of its frame, so iterate over all
+ // frames
+ while (src_len > 0) {
+ if (src_len < 4) {
+ INP_SIZE();
+ }
+
+ u32 magic_number = read_bits_LE(ip, 32, 0);
+
+ ip += 4;
+ src_len -= 4;
+ if (magic_number >= 0x184D2A50U && magic_number <= 0x184D2A5F) {
+ // skippable frame, this has no impact on output size
+ if (src_len < 4) {
+ INP_SIZE();
+ }
+ size_t frame_size = read_bits_LE(ip, 32, 32);
+
+ if (src_len < 4 + frame_size) {
+ INP_SIZE();
+ }
+
+ // skip over frame
+ ip += 4 + frame_size;
+ src_len -= 4 + frame_size;
+ } else if (magic_number == 0xFD2FB528U) {
+ // ZSTD frame
+ frame_header_t header;
+ parse_frame_header(&header, ip, src_len);
+
+ if (header.frame_content_size == 0 && !header.single_segment_flag) {
+ // Content size not provided, we can't tell
+ return -1;
+ }
+
+ dst_size += header.frame_content_size;
+
+ // we need to traverse the frame to find when the next one starts
+ size_t traversed = traverse_frame(&header, ip, src_len);
+ ip += traversed;
+ src_len -= traversed;
+ } else {
+ // not a real frame
+ ERROR("Invalid magic number");
+ }
+ }
+
+ return dst_size;
+}
+
+/// Iterate over each block in a frame to find the end of it, to get to the
+/// start of the next frame
+size_t traverse_frame(frame_header_t *header, const u8 *src, size_t src_len) {
+ const u8 *const src_beg = src;
+ const u8 *const src_end = src + src_len;
+ src += header->header_size;
+ src_len += header->header_size;
+
+ int last_block = 0;
+
+ do {
+ if (src + 3 > src_end) {
+ INP_SIZE();
+ }
+ // Parse the block header
+ last_block = src[0] & 1;
+ int block_type = (src[0] >> 1) & 3;
+ size_t block_len = read_bits_LE(src, 21, 3);
+
+ src += 3;
+ switch (block_type) {
+ case 0: // Raw block, block_len bytes
+ if (src + block_len > src_end) {
+ INP_SIZE();
+ }
+ src += block_len;
+ break;
+ case 1: // RLE block, 1 byte
+ if (src + 1 > src_end) {
+ INP_SIZE();
+ }
+ src++;
+ break;
+ case 2: // Compressed block, compressed size is block_len
+ if (src + block_len > src_end) {
+ INP_SIZE();
+ }
+ src += block_len;
+ break;
+ case 3:
+ // Reserved block type
+ CORRUPTION();
+ break;
+ }
+ } while (!last_block);
+
+ if (header->content_checksum_flag) {
+ if (src + 4 > src_end) {
+ INP_SIZE();
+ }
+ src += 4;
+ }
+
+ return src - src_beg;
+}
+
+/******* END OUTPUT SIZE COUNTING *********************************************/
+
/******* DICTIONARY PARSING ***************************************************/
static void init_raw_content_dict(dictionary_t *dict, const u8 *src,
size_t src_len);
high_threshold); // Make sure we don't occupy a spot taken
// by the low prob symbols
// Note: no other collision checking is necessary as `step` is
- // coprime to
- // `size`, so the cycle will visit each position exactly once
+ // coprime to `size`, so the cycle will visit each position exactly
+ // once
}
}
if (pos != 0) {
for (int i = 0; i < size; i++) {
u8 symbol = dtable->symbols[i];
u16 next_state_desc = state_desc[symbol]++;
- // Fills in the table appropriately
- // next_state_desc increases by symbol over time, decreasing number of
- // bits
+ // Fills in the table appropriately next_state_desc increases by symbol
+ // over time, decreasing number of bits
dtable->num_bits[i] = (u8)(accuracy_log - log2inf(next_state_desc));
// baseline increases until the bit threshold is passed, at which point
- // it
- // resets to 0
+ // it resets to 0
dtable->new_state_base[i] =
((u16)next_state_desc << dtable->num_bits[i]) - size;
}
dtable->new_state_base = malloc(sizeof(u16));
// This setup will always have a state of 0, always return symbol `symb`,
- // and
- // never consume any bits
+ // and never consume any bits
dtable->symbols[0] = symb;
dtable->num_bits[0] = 0;
dtable->new_state_base[0] = 0;