#define OUT_SIZE() ERROR("Output buffer too small for output")
#define CORRUPTION() ERROR("Corruption detected while decompressing")
#define BAD_ALLOC() ERROR("Memory allocation error")
+#define IMPOSSIBLE() ERROR("An impossibility has occurred")
typedef uint8_t u8;
typedef uint16_t u16;
/// file. They implement low-level functionality needed for the higher level
/// decompression functions.
+/*** IO STREAM OPERATIONS *************/
+/// These structs are the interface for IO, and do bounds checking on all
+/// operations. They should be used opaquely to ensure safety.
+
+/// Output is always done byte-by-byte
+typedef struct {
+ u8 *ptr;
+ size_t len;
+} ostream_t;
+
+/// Input often reads a few bits at a time, so maintain an internal offset
+typedef struct {
+ const u8 *ptr;
+ int bit_offset;
+ size_t len;
+} istream_t;
+
+/// The following two functions are the only ones that allow the istream to be
+/// non-byte aligned
+
+/// Reads `num` bits from a bitstream, and updates the internal offset
+static inline u64 IO_read_bits(istream_t *const in, const int num);
+/// Rewinds the stream by `num` bits
+static inline void IO_rewind_bits(istream_t *const in, const int num);
+/// If the remaining bits in a byte will be unused, advance to the end of the
+/// byte
+static inline void IO_align_stream(istream_t *const in);
+
+/// Write the given byte into the output stream
+static inline void IO_write_byte(ostream_t *const out, u8 symb);
+
+/// Returns the number of bytes left to be read in this stream. The stream must
+/// be byte aligned.
+static inline size_t IO_istream_len(const istream_t *const in);
+
+/// Returns a pointer where `len` bytes can be read, and advances the internal
+/// state. The stream must be byte aligned.
+static inline const u8 *IO_read_bytes(istream_t *const in, size_t len);
+/// Returns a pointer where `len` bytes can be written, and advances the internal
+/// state. The stream must be byte aligned.
+static inline u8 *IO_write_bytes(ostream_t *const out, size_t len);
+
+/// Advance the inner state by `len` bytes. The stream must be byte aligned.
+static inline void IO_advance_input(istream_t *const in, size_t len);
+
+/// Returns an `ostream_t` constructed from the given pointer and length
+static inline ostream_t IO_make_ostream(u8 *out, size_t len);
+/// Returns an `istream_t` constructed from the given pointer and length
+static inline istream_t IO_make_istream(const u8 *in, size_t len);
+
+/// Returns an `istream_t` with the same base as `in`, and length `len`
+/// Then, advance `in` to account for the consumed bytes
+/// `in` must be byte aligned
+static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len);
+/*** END IO STREAM OPERATIONS *********/
+
/*** BITSTREAM OPERATIONS *************/
/// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits
static inline u64 read_bits_LE(const u8 *src, const int num,
/// Decompresses a single Huffman stream, returns the number of bytes decoded.
/// `src_len` must be the exact length of the Huffman-coded block.
-static size_t HUF_decompress_1stream(const HUF_dtable *const dtable, u8 *dst,
- const size_t dst_len, const u8 *src,
- size_t src_len);
+static size_t HUF_decompress_1stream(const HUF_dtable *const dtable,
+ ostream_t *const out, istream_t *const in);
/// Same as previous but decodes 4 streams, formatted as in the Zstandard
/// specification.
/// `src_len` must be the exact length of the Huffman-coded block.
-static size_t HUF_decompress_4stream(const HUF_dtable *const dtable, u8 *dst,
- const size_t dst_len, const u8 *const src,
- const size_t src_len);
+static size_t HUF_decompress_4stream(const HUF_dtable *const dtable,
+ ostream_t *const out, istream_t *const in);
/// Initialize a Huffman decoding table using the table of bit counts provided
static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits,
/// using an FSE decoding table. `src_len` must be the exact length of the
/// block.
static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable,
- u8 *dst, const size_t dst_len,
- const u8 *const src,
- const size_t src_len);
+ ostream_t *const out,
+ istream_t *const in);
/// Initialize a decoding table using normalized frequencies.
static void FSE_init_dtable(FSE_dtable *const dtable,
/// Decode an FSE header as defined in the Zstandard format specification and
/// use the decoded frequencies to initialize a decoding table.
-static size_t FSE_decode_header(FSE_dtable *const dtable, const u8 *const src,
- const size_t src_len,
+static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in,
const int max_accuracy_log);
/// Initialize an FSE table that will always return the same symbol and consume
/******* ZSTD HELPER STRUCTS AND PROTOTYPES ***********************************/
-/// Input and output pointers to allow them to be advanced by
-/// functions that consume input/produce output
-typedef struct {
- u8 *dst;
- size_t dst_len;
-
- const u8 *src;
- size_t src_len;
-} io_streams_t;
-
/// A small structure that can be reused in various places that need to access
/// frame header information
typedef struct {
int content_checksum_flag;
// Whether or not the output for this frame is in a single segment
int single_segment_flag;
-
- // The size in bytes of this header
- int header_size;
} frame_header_t;
/// The context needed to decode blocks in a frame
FSE_dtable ml_dtable;
FSE_dtable of_dtable;
- // The last 3 offsets for the special "repeat offsets". Array size is 4 so
- // that previous_offsets[1] corresponds to the most recent offset
- u64 previous_offsets[4];
+ // The last 3 offsets for the special "repeat offsets".
+ u64 previous_offsets[3];
} frame_context_t;
/// The decoded contents of a dictionary so that it doesn't have to be repeated
size_t content_size;
// Offset history to prepopulate the frame's history
- u64 previous_offsets[4];
+ u64 previous_offsets[3];
u32 dictionary_id;
} dictionary_t;
/// Accepts a dict argument, which may be NULL indicating no dictionary.
/// See
/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame-concatenation
-static void decode_frame(io_streams_t *const streams,
+static void decode_frame(ostream_t *const out, istream_t *const in,
const dictionary_t *const dict);
// Decode data in a compressed block
-static void decompress_block(io_streams_t *const streams,
- frame_context_t *const ctx,
- const size_t block_len);
+static void decompress_block(frame_context_t *const ctx, ostream_t *const out,
+ istream_t *const in);
// Decode the literals section of a block
-static size_t decode_literals(io_streams_t *const streams,
- frame_context_t *const ctx, u8 **const literals);
+static size_t decode_literals(frame_context_t *const ctx, istream_t *const in,
+ u8 **const literals);
// Decode the sequences part of a block
-static size_t decode_sequences(frame_context_t *const ctx, const u8 *const src,
- const size_t src_len,
+static size_t decode_sequences(frame_context_t *const ctx, istream_t *const in,
sequence_command_t **const sequences);
// Execute the decoded sequences on the literals block
-static void execute_sequences(io_streams_t *const streams,
- frame_context_t *const ctx,
+static void execute_sequences(frame_context_t *const ctx, ostream_t *const out,
+ const u8 *const literals,
+ const size_t literals_len,
const sequence_command_t *const sequences,
- const size_t num_sequences,
- const u8 *literals,
- size_t literals_len);
+ const size_t num_sequences);
// Parse a provided dictionary blob for use in decompression
-static void parse_dictionary(dictionary_t *const dict, const u8 *const src,
- const size_t src_len);
+static void parse_dictionary(dictionary_t *const dict, const u8 *src,
+ size_t src_len);
static void free_dictionary(dictionary_t *const dict);
/******* END ZSTD HELPER STRUCTS AND PROTOTYPES *******************************/
parse_dictionary(&parsed_dict, (const u8 *)dict, dict_len);
}
- io_streams_t streams = {(u8 *)dst, dst_len, (const u8 *)src, src_len};
- while (streams.src_len > 0) {
- decode_frame(&streams, &parsed_dict);
+ istream_t in = {(const u8 *)src, 0, src_len};
+ ostream_t out = {(u8 *)dst, dst_len};
+ while (IO_istream_len(&in) > 0) {
+ decode_frame(&out, &in, &parsed_dict);
}
free_dictionary(&parsed_dict);
- return streams.dst - (u8 *)dst;
+ return out.ptr - (u8 *)dst;
}
/******* FRAME DECODING ******************************************************/
-static void decode_data_frame(io_streams_t *const streams,
+static void decode_data_frame(ostream_t *const out, istream_t *const in,
const dictionary_t *const dict);
-static void init_frame_context(io_streams_t *const streams,
- frame_context_t *const context,
+static void init_frame_context(frame_context_t *const context,
+ istream_t *const in,
const dictionary_t *const dict);
static void free_frame_context(frame_context_t *const context);
static void parse_frame_header(frame_header_t *const header,
- const u8 *const src, const size_t src_len);
+ istream_t *const in);
static void frame_context_apply_dict(frame_context_t *const ctx,
const dictionary_t *const dict);
-static void decompress_data(io_streams_t *const streams,
- frame_context_t *const ctx);
+static void decompress_data(frame_context_t *const ctx, ostream_t *const out,
+ istream_t *const in);
-static void decode_frame(io_streams_t *const streams,
+static void decode_frame(ostream_t *const out, istream_t *const in,
const dictionary_t *const dict) {
- if (streams->src_len < 4) {
- INP_SIZE();
- }
- const u32 magic_number = read_bits_LE(streams->src, 32, 0);
+ const u32 magic_number = IO_read_bits(in, 32);
- streams->src += 4;
- streams->src_len -= 4;
- if (magic_number >= 0x184D2A50U && magic_number <= 0x184D2A5F) {
- // skippable frame
- if (streams->src_len < 4) {
- INP_SIZE();
- }
- const size_t frame_size = read_bits_LE(streams->src, 32, 32);
-
- if (streams->src_len < 4 + frame_size) {
- INP_SIZE();
- }
+ if ((magic_number & ~0xFU) == 0x184D2A50U) {
+ // Skippable frame
+ const size_t frame_size = IO_read_bits(in, 32);
// skip over frame
- streams->src += 4 + frame_size;
- streams->src_len -= 4 + frame_size;
+ IO_advance_input(in, frame_size);
} else if (magic_number == 0xFD2FB528U) {
// ZSTD frame
- decode_data_frame(streams, dict);
+ decode_data_frame(out, in, dict);
} else {
// not a real frame
ERROR("Invalid magic number");
/// are skippable frames.
/// See
/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#general-structure-of-zstandard-frame-format
-static void decode_data_frame(io_streams_t *const streams,
+static void decode_data_frame(ostream_t *const out, istream_t *const in,
const dictionary_t *const dict) {
frame_context_t ctx;
// Initialize the context that needs to be carried from block to block
- init_frame_context(streams, &ctx, dict);
+ init_frame_context(&ctx, in, dict);
if (ctx.header.frame_content_size != 0 &&
- ctx.header.frame_content_size > streams->dst_len) {
+ ctx.header.frame_content_size > out->len) {
OUT_SIZE();
}
- decompress_data(streams, &ctx);
+ decompress_data(&ctx, out, in);
free_frame_context(&ctx);
}
/// Takes the information provided in the header and dictionary, and initializes
/// the context for this frame
-static void init_frame_context(io_streams_t *const streams,
- frame_context_t *const context,
+static void init_frame_context(frame_context_t *const context,
+ istream_t *const in,
const dictionary_t *const dict) {
// Most fields in context are correct when initialized to 0
- memset(context, 0x00, sizeof(frame_context_t));
+ memset(context, 0, sizeof(frame_context_t));
// Parse data from the frame header
- parse_frame_header(&context->header, streams->src, streams->src_len);
- streams->src += context->header.header_size;
- streams->src_len -= context->header.header_size;
+ parse_frame_header(&context->header, in);
// Set up the offset history for the repeat offset commands
- context->previous_offsets[1] = 1;
- context->previous_offsets[2] = 4;
- context->previous_offsets[3] = 8;
+ context->previous_offsets[0] = 1;
+ context->previous_offsets[1] = 4;
+ context->previous_offsets[2] = 8;
// Apply details from the dict if it exists
frame_context_apply_dict(context, dict);
}
static void parse_frame_header(frame_header_t *const header,
- const u8 *const src, const size_t src_len) {
- if (src_len < 1) {
- INP_SIZE();
- }
-
- const u8 descriptor = read_bits_LE(src, 8, 0);
+ istream_t *const in) {
+ const u8 descriptor = IO_read_bits(in, 8);
// decode frame header descriptor into flags
const u8 frame_content_size_flag = descriptor >> 6;
CORRUPTION();
}
- int header_size = 1;
-
header->single_segment_flag = single_segment_flag;
header->content_checksum_flag = content_checksum_flag;
// decode window size
if (!single_segment_flag) {
- if (src_len < header_size + 1) {
- INP_SIZE();
- }
-
// Use the algorithm from the specification to compute window size
// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor
- u8 window_descriptor = src[header_size];
+ u8 window_descriptor = IO_read_bits(in, 8);
u8 exponent = window_descriptor >> 3;
u8 mantissa = window_descriptor & 7;
size_t window_base = (size_t)1 << (10 + exponent);
size_t window_add = (window_base / 8) * mantissa;
header->window_size = window_base + window_add;
-
- header_size++;
}
// decode dictionary id if it exists
const int bytes_array[] = {0, 1, 2, 4};
const int bytes = bytes_array[dictionary_id_flag];
- if (src_len < header_size + bytes) {
- INP_SIZE();
- }
-
- header->dictionary_id = read_bits_LE(src + header_size, bytes * 8, 0);
-
- header_size += bytes;
+ header->dictionary_id = IO_read_bits(in, bytes * 8);
} else {
header->dictionary_id = 0;
}
const int bytes_array[] = {1, 2, 4, 8};
const int bytes = bytes_array[frame_content_size_flag];
- if (src_len < header_size + bytes) {
- INP_SIZE();
- }
-
- header->frame_content_size =
- read_bits_LE(src + header_size, bytes * 8, 0);
+ header->frame_content_size = IO_read_bits(in, bytes * 8);
if (bytes == 2) {
header->frame_content_size += 256;
}
-
- header_size += bytes;
} else {
header->frame_content_size = 0;
}
// back to the dictionary or not on large offsets
header->window_size = header->frame_content_size;
}
-
- header->header_size = header_size;
}
/// A dictionary acts as initializing values for the frame context before
if (!dict || !dict->content)
return;
- if (ctx->header.dictionary_id == 0 && dict->dictionary_id != 0) {
- // The dictionary is unneeded, and shouldn't be used as it may interfere
- // with the default offset history
- return;
- }
-
- // If the dictionary id is 0, it doesn't matter if we provide the wrong raw
- // content dict, it won't change anything
+ // If the requested dictionary_id is non-zero, the correct dictionary must
+ // be present
if (ctx->header.dictionary_id != 0 &&
ctx->header.dictionary_id != dict->dictionary_id) {
- ERROR("Wrong/no dictionary provided");
+ ERROR("Wrong dictionary provided");
}
- // Copy the pointer in so we can reference it in sequence execution
+ // Copy the dict content to the context for references during sequence
+ // execution
ctx->dict_content = dict->content;
ctx->dict_content_len = dict->content_size;
}
/// Decompress the data from a frame block by block
-static void decompress_data(io_streams_t *const streams,
- frame_context_t *const ctx) {
+static void decompress_data(frame_context_t *const ctx, ostream_t *const out,
+ istream_t *const in) {
int last_block = 0;
do {
- if (streams->src_len < 3) {
- INP_SIZE();
- }
// Parse the block header
- last_block = streams->src[0] & 1;
- const int block_type = (streams->src[0] >> 1) & 3;
- const size_t block_len = read_bits_LE(streams->src, 21, 3);
-
- streams->src += 3;
- streams->src_len -= 3;
+ last_block = IO_read_bits(in, 1);
+ const int block_type = IO_read_bits(in, 2);
+ const size_t block_len = IO_read_bits(in, 21);
switch (block_type) {
case 0: {
// Raw, uncompressed block
- if (streams->src_len < block_len) {
- INP_SIZE();
- }
- if (streams->dst_len < block_len) {
- OUT_SIZE();
- }
-
+ const u8 *const read_ptr = IO_read_bytes(in, block_len);
+ u8 *const write_ptr = IO_write_bytes(out, block_len);
+ //
// Copy the raw data into the output
- memcpy(streams->dst, streams->src, block_len);
-
- streams->src += block_len;
- streams->src_len -= block_len;
-
- streams->dst += block_len;
- streams->dst_len -= block_len;
+ memcpy(write_ptr, read_ptr, block_len);
ctx->current_total_output += block_len;
break;
}
case 1: {
// RLE block, repeat the first byte N times
- if (streams->src_len < 1) {
- INP_SIZE();
- }
- if (streams->dst_len < block_len) {
- OUT_SIZE();
- }
+ const u8 *const read_ptr = IO_read_bytes(in, 1);
+ u8 *const write_ptr = IO_write_bytes(out, block_len);
// Copy `block_len` copies of `streams->src[0]` to the output
- memset(streams->dst, streams->src[0], block_len);
-
- streams->dst += block_len;
- streams->dst_len -= block_len;
-
- streams->src += 1;
- streams->src_len -= 1;
+ memset(write_ptr, read_ptr[0], block_len);
ctx->current_total_output += block_len;
break;
}
- case 2:
- // Compressed block, this is mode complex
- decompress_block(streams, ctx, block_len);
+ case 2: {
+ // Compressed block
+ // Create a sub-stream for the block
+ istream_t block_stream = IO_make_sub_istream(in, block_len);
+ decompress_block(ctx, out, &block_stream);
break;
+ }
case 3:
// Reserved block type
CORRUPTION();
break;
+ default:
+ IMPOSSIBLE();
}
} while (!last_block);
if (ctx->header.content_checksum_flag) {
// This program does not support checking the checksum, so skip over it
// if it's present
- if (streams->src_len < 4) {
- INP_SIZE();
- }
- streams->src += 4;
- streams->src_len -= 4;
+ IO_advance_input(in, 4);
}
}
/******* END FRAME DECODING ***************************************************/
/******* BLOCK DECOMPRESSION **************************************************/
-static void decompress_block(io_streams_t *const streams, frame_context_t *const ctx,
- const size_t block_len) {
- if (streams->src_len < block_len) {
- INP_SIZE();
- }
- // We need this to determine how long the compressed literals block was
- const u8 *const end_of_block = streams->src + block_len;
-
+static void decompress_block(frame_context_t *const ctx, ostream_t *const out,
+ istream_t *const in) {
// Part 1: decode the literals block
u8 *literals = NULL;
- const size_t literals_size = decode_literals(streams, ctx, &literals);
+ const size_t literals_size = decode_literals(ctx, in, &literals);
// Part 2: decode the sequences block
- if (streams->src > end_of_block) {
- INP_SIZE();
- }
- const size_t sequences_size = end_of_block - streams->src;
sequence_command_t *sequences = NULL;
const size_t num_sequences =
- decode_sequences(ctx, streams->src, sequences_size, &sequences);
-
- streams->src += sequences_size;
- streams->src_len -= sequences_size;
+ decode_sequences(ctx, in, &sequences);
// Part 3: combine literals and sequence commands to generate output
- execute_sequences(streams, ctx, sequences, num_sequences, literals,
- literals_size);
+ execute_sequences(ctx, out, literals, literals_size, sequences,
+ num_sequences);
free(literals);
free(sequences);
}
/******* END BLOCK DECOMPRESSION **********************************************/
/******* LITERALS DECODING ****************************************************/
-static size_t decode_literals_simple(io_streams_t *const streams,
- u8 **const literals, const int block_type,
+static size_t decode_literals_simple(istream_t *const in, u8 **const literals,
+ const int block_type,
const int size_format);
-static size_t decode_literals_compressed(io_streams_t *const streams,
- frame_context_t *const ctx,
+static size_t decode_literals_compressed(frame_context_t *const ctx,
+ istream_t *const in,
u8 **const literals,
const int block_type,
const int size_format);
-static size_t decode_huf_table(const u8 *src, size_t src_len,
- HUF_dtable *const dtable);
-static size_t fse_decode_hufweights(const u8 *const src, const size_t src_len,
- u8 *const weights, int *const num_symbs,
- const size_t compressed_size);
-
-static size_t decode_literals(io_streams_t *const streams,
- frame_context_t *const ctx, u8 **const literals) {
- if (streams->src_len < 1) {
- INP_SIZE();
- }
+static void decode_huf_table(istream_t *const in, HUF_dtable *const dtable);
+static void fse_decode_hufweights(ostream_t *weights, istream_t *const in,
+ int *const num_symbs);
+
+static size_t decode_literals(frame_context_t *const ctx, istream_t *const in,
+ u8 **const literals) {
// Decode literals header
- int block_type = streams->src[0] & 3;
- int size_format = (streams->src[0] >> 2) & 3;
+ int block_type = IO_read_bits(in, 2);
+ int size_format = IO_read_bits(in, 2);
if (block_type <= 1) {
// Raw or RLE literals block
- return decode_literals_simple(streams, literals, block_type,
+ return decode_literals_simple(in, literals, block_type,
size_format);
} else {
// Huffman compressed literals
- return decode_literals_compressed(streams, ctx, literals, block_type,
+ return decode_literals_compressed(ctx, in, literals, block_type,
size_format);
}
}
/// Decodes literals blocks in raw or RLE form
-static size_t decode_literals_simple(io_streams_t *const streams,
- u8 **const literals, const int block_type,
+static size_t decode_literals_simple(istream_t *const in, u8 **const literals,
+ const int block_type,
const int size_format) {
size_t size;
switch (size_format) {
- // These cases are in the form X0
- // In this case, the X bit is actually part of the size field
+ // These cases are in the form ?0
+ // In this case, the ? bit is actually part of the size field
case 0:
case 2:
- size = read_bits_LE(streams->src, 5, 3);
- streams->src += 1;
- streams->src_len -= 1;
+ // "Size_Format uses 1 bit. Regenerated_Size uses 5 bits (0-31)."
+ IO_rewind_bits(in, 1);
+ size = IO_read_bits(in, 2);
break;
case 1:
- if (streams->src_len < 2) {
- INP_SIZE();
- }
- size = read_bits_LE(streams->src, 12, 4);
- streams->src += 2;
- streams->src_len -= 2;
+ // "Size_Format uses 2 bits. Regenerated_Size uses 12 bits (0-4095)."
+ size = IO_read_bits(in, 12);
break;
case 3:
- if (streams->src_len < 2) {
- INP_SIZE();
- }
- size = read_bits_LE(streams->src, 20, 4);
- streams->src += 3;
- streams->src_len -= 3;
+ // "Size_Format uses 2 bits. Regenerated_Size uses 20 bits (0-1048575)."
+ size = IO_read_bits(in, 20);
break;
default:
- // Impossible
- size = -1;
+ // Size format is in range 0-3
+ IMPOSSIBLE();
}
if (size > MAX_LITERALS_SIZE) {
}
switch (block_type) {
- case 0:
+ case 0: {
// Raw data
- if (size > streams->src_len) {
- INP_SIZE();
- }
- memcpy(*literals, streams->src, size);
- streams->src += size;
- streams->src_len -= size;
+ const u8 *const read_ptr = IO_read_bytes(in, size);
+ memcpy(*literals, read_ptr, size);
break;
- case 1:
+ }
+ case 1: {
// Single repeated byte
- if (1 > streams->src_len) {
- INP_SIZE();
- }
- memset(*literals, streams->src[0], size);
- streams->src += 1;
- streams->src_len -= 1;
+ const u8 *const read_ptr = IO_read_bytes(in, 1);
+ memset(*literals, read_ptr[0], size);
break;
}
+ default:
+ IMPOSSIBLE();
+ }
return size;
}
/// Decodes Huffman compressed literals
-static size_t decode_literals_compressed(io_streams_t *const streams,
- frame_context_t *const ctx,
+static size_t decode_literals_compressed(frame_context_t *const ctx,
+ istream_t *const in,
u8 **const literals,
const int block_type,
const int size_format) {
int num_streams = 4;
switch (size_format) {
case 0:
+ // "A single stream. Both Compressed_Size and Regenerated_Size use 10
+ // bits (0-1023)."
num_streams = 1;
// Fall through as it has the same size format
case 1:
- if (streams->src_len < 3) {
- INP_SIZE();
- }
- regenerated_size = read_bits_LE(streams->src, 10, 4);
- compressed_size = read_bits_LE(streams->src, 10, 14);
- streams->src += 3;
- streams->src_len -= 3;
+ // "4 streams. Both Compressed_Size and Regenerated_Size use 10 bits
+ // (0-1023)."
+ regenerated_size = IO_read_bits(in, 10);
+ compressed_size = IO_read_bits(in, 10);
break;
case 2:
- if (streams->src_len < 4) {
- INP_SIZE();
- }
- regenerated_size = read_bits_LE(streams->src, 14, 4);
- compressed_size = read_bits_LE(streams->src, 14, 18);
- streams->src += 4;
- streams->src_len -= 4;
+ // "4 streams. Both Compressed_Size and Regenerated_Size use 14 bits
+ // (0-16383)."
+ regenerated_size = IO_read_bits(in, 14);
+ compressed_size = IO_read_bits(in, 14);
break;
case 3:
- if (streams->src_len < 5) {
- INP_SIZE();
- }
- regenerated_size = read_bits_LE(streams->src, 18, 4);
- compressed_size = read_bits_LE(streams->src, 18, 22);
- streams->src += 5;
- streams->src_len -= 5;
+ // "4 streams. Both Compressed_Size and Regenerated_Size use 18 bits
+ // (0-262143)."
+ regenerated_size = IO_read_bits(in, 18);
+ compressed_size = IO_read_bits(in, 18);
break;
default:
// Impossible
- compressed_size = regenerated_size = -1;
+ IMPOSSIBLE();
}
if (regenerated_size > MAX_LITERALS_SIZE ||
compressed_size > regenerated_size) {
CORRUPTION();
}
- if (compressed_size > streams->src_len) {
- INP_SIZE();
- }
-
*literals = malloc(regenerated_size);
if (!*literals) {
BAD_ALLOC();
}
+ ostream_t lit_stream = IO_make_ostream(*literals, regenerated_size);
+ istream_t huf_stream = IO_make_sub_istream(in, compressed_size);
+
if (block_type == 2) {
// Decode provided Huffman table
HUF_free_dtable(&ctx->literals_dtable);
- const size_t size = decode_huf_table(streams->src, compressed_size,
- &ctx->literals_dtable);
- streams->src += size;
- streams->src_len -= size;
- compressed_size -= size;
+ decode_huf_table(&huf_stream, &ctx->literals_dtable);
} else {
- // If we're to repeat the previous Huffman table, make sure it exists
+ // If the previous Huffman table is being repeated, ensure it exists
if (!ctx->literals_dtable.symbols) {
CORRUPTION();
}
}
+ size_t symbols_decoded;
if (num_streams == 1) {
- HUF_decompress_1stream(&ctx->literals_dtable, *literals,
- regenerated_size, streams->src, compressed_size);
+ symbols_decoded = HUF_decompress_1stream(&ctx->literals_dtable, &lit_stream, &huf_stream);
} else {
- HUF_decompress_4stream(&ctx->literals_dtable, *literals,
- regenerated_size, streams->src, compressed_size);
+ symbols_decoded = HUF_decompress_4stream(&ctx->literals_dtable, &lit_stream, &huf_stream);
+ }
+
+ if (symbols_decoded != regenerated_size) {
+ CORRUPTION();
}
- streams->src += compressed_size;
- streams->src_len -= compressed_size;
return regenerated_size;
}
// Decode the Huffman table description
-static size_t decode_huf_table(const u8 *src, size_t src_len,
- HUF_dtable *const dtable) {
- if (src_len < 1) {
- INP_SIZE();
- }
+static void decode_huf_table(istream_t *const in, HUF_dtable *const dtable) {
+ const u8 header = IO_read_bits(in, 8);
- const u8 *const osrc = src;
-
- const u8 header = src[0];
u8 weights[HUF_MAX_SYMBS];
memset(weights, 0, sizeof(weights));
- src++;
- src_len--;
-
int num_symbs;
if (header >= 128) {
num_symbs = header - 127;
const size_t bytes = (num_symbs + 1) / 2;
- if (bytes > src_len) {
- INP_SIZE();
- }
+ const u8 *const weight_src = IO_read_bytes(in, bytes);
for (int i = 0; i < num_symbs; i++) {
// read_bits_LE isn't applicable here because the weights are order
// reversed within each byte
// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#huffman-tree-header
if (i % 2 == 0) {
- weights[i] = src[i / 2] >> 4;
+ weights[i] = weight_src[i / 2] >> 4;
} else {
- weights[i] = src[i / 2] & 0xf;
+ weights[i] = weight_src[i / 2] & 0xf;
}
}
-
- src += bytes;
- src_len -= bytes;
} else {
// The weights are FSE encoded, decode them before we can construct the
// table
- const size_t size =
- fse_decode_hufweights(src, src_len, weights, &num_symbs, header);
- src += size;
- src_len -= size;
+ istream_t fse_stream = IO_make_sub_istream(in, header);
+ ostream_t weight_stream = IO_make_ostream(weights, HUF_MAX_SYMBS);
+ fse_decode_hufweights(&weight_stream, &fse_stream, &num_symbs);
}
// Construct the table using the decoded weights
HUF_init_dtable_usingweights(dtable, weights, num_symbs);
- return src - osrc;
}
-static size_t fse_decode_hufweights(const u8 *const src, const size_t src_len,
- u8 *const weights, int *const num_symbs,
- const size_t compressed_size) {
+static void fse_decode_hufweights(ostream_t *weights, istream_t *const in,
+ int *const num_symbs) {
const int MAX_ACCURACY_LOG = 7;
FSE_dtable dtable;
// Construct the FSE table
- const size_t read =
- FSE_decode_header(&dtable, src, src_len, MAX_ACCURACY_LOG);
-
- if (src_len < compressed_size) {
- INP_SIZE();
- }
+ FSE_decode_header(&dtable, in, MAX_ACCURACY_LOG);
// Decode the weights
- *num_symbs = FSE_decompress_interleaved2(
- &dtable, weights, HUF_MAX_SYMBS, src + read, compressed_size - read);
+ *num_symbs = FSE_decompress_interleaved2(&dtable, weights, in);
FSE_free_dtable(&dtable);
-
- return compressed_size;
}
/******* END LITERALS DECODING ************************************************/
/******* SEQUENCE DECODING ****************************************************/
/// The combination of FSE states needed to decode sequences
typedef struct {
- u16 ll_state, of_state, ml_state;
- FSE_dtable ll_table, of_table, ml_table;
+ FSE_dtable ll_table;
+ FSE_dtable of_table;
+ FSE_dtable ml_table;
+
+ u16 ll_state;
+ u16 of_state;
+ u16 ml_state;
} sequence_state_t;
/// Different modes to signal to decode_seq_tables what to do
/// Offset decoding is simpler so we just need a maximum code value
static const u8 SEQ_MAX_CODES[3] = {35, -1, 52};
-static void decompress_sequences(frame_context_t *const ctx, const u8 *src,
- size_t src_len,
+static void decompress_sequences(frame_context_t *const ctx,
+ istream_t *const in,
sequence_command_t *const sequences,
const size_t num_sequences);
static sequence_command_t decode_sequence(sequence_state_t *const state,
const u8 *const src,
i64 *const offset);
-static size_t decode_seq_table(const u8 *src, size_t src_len,
- FSE_dtable *const table, const seq_part_t type,
- const seq_mode_t mode);
+static void decode_seq_table(istream_t *const in, FSE_dtable *const table,
+ const seq_part_t type, const seq_mode_t mode);
-static size_t decode_sequences(frame_context_t *const ctx, const u8 *src,
- size_t src_len,
+static size_t decode_sequences(frame_context_t *const ctx, istream_t *in,
sequence_command_t **const sequences) {
size_t num_sequences;
// Decode the sequence header and allocate space for the output
- if (src_len < 1) {
- INP_SIZE();
- }
- if (src[0] == 0) {
+ u8 header = IO_read_bits(in, 8);
+ if (header == 0) {
+ // "There are no sequences. The sequence section stops there.
+ // Regenerated content is defined entirely by literals section."
*sequences = NULL;
return 0;
- } else if (src[0] < 128) {
- num_sequences = src[0];
- src++;
- src_len--;
- } else if (src[0] < 255) {
- if (src_len < 2) {
- INP_SIZE();
- }
- num_sequences = ((src[0] - 128) << 8) + src[1];
- src += 2;
- src_len -= 2;
+ } else if (header < 128) {
+ // "Number_of_Sequences = byte0 . Uses 1 byte."
+ num_sequences = header;
+ } else if (header < 255) {
+ // "Number_of_Sequences = ((byte0-128) << 8) + byte1 . Uses 2 bytes."
+ num_sequences = ((header - 128) << 8) + IO_read_bits(in, 8);
} else {
- if (src_len < 3) {
- INP_SIZE();
- }
- num_sequences = src[1] + ((u64)src[2] << 8) + 0x7F00;
- src += 3;
- src_len -= 3;
+ // "Number_of_Sequences = byte1 + (byte2<<8) + 0x7F00 . Uses 3 bytes."
+ num_sequences = IO_read_bits(in, 16) + 0x7F00;
}
*sequences = malloc(num_sequences * sizeof(sequence_command_t));
BAD_ALLOC();
}
- decompress_sequences(ctx, src, src_len, *sequences, num_sequences);
+ decompress_sequences(ctx, in, *sequences, num_sequences);
return num_sequences;
}
/// Decompress the FSE encoded sequence commands
-static void decompress_sequences(frame_context_t *const ctx, const u8 *src,
- size_t src_len,
+static void decompress_sequences(frame_context_t *const ctx, istream_t *in,
sequence_command_t *const sequences,
const size_t num_sequences) {
- if (src_len < 1) {
- INP_SIZE();
- }
- u8 compression_modes = src[0];
- src++;
- src_len--;
+ u8 compression_modes = IO_read_bits(in, 8);
if ((compression_modes & 3) != 0) {
CORRUPTION();
}
- {
- size_t read;
- // Update the tables we have stored in the context
- read = decode_seq_table(src, src_len, &ctx->ll_dtable,
- seq_literal_length,
- (compression_modes >> 6) & 3);
- src += read;
- src_len -= read;
- }
+ // Update the tables we have stored in the context
+ decode_seq_table(in, &ctx->ll_dtable, seq_literal_length,
+ (compression_modes >> 6) & 3);
- {
- const size_t read =
- decode_seq_table(src, src_len, &ctx->of_dtable, seq_offset,
- (compression_modes >> 4) & 3);
- src += read;
- src_len -= read;
- }
+ decode_seq_table(in, &ctx->of_dtable, seq_offset,
+ (compression_modes >> 4) & 3);
- {
- const size_t read = decode_seq_table(src, src_len, &ctx->ml_dtable,
- seq_match_length,
- (compression_modes >> 2) & 3);
- src += read;
- src_len -= read;
- }
+ decode_seq_table(in, &ctx->ml_dtable, seq_match_length,
+ (compression_modes >> 2) & 3);
// Check to make sure none of the tables are uninitialized
if (!ctx->ll_dtable.symbols || !ctx->of_dtable.symbols ||
memcpy(&state.of_table, &ctx->of_dtable, sizeof(FSE_dtable));
memcpy(&state.ml_table, &ctx->ml_dtable, sizeof(FSE_dtable));
- const int padding = 8 - log2inf(src[src_len - 1]);
- i64 offset = src_len * 8 - padding;
+ size_t len = IO_istream_len(in);
+ const u8 *const src = IO_read_bytes(in, len);
+
+ // "After writing the last bit containing information, the compressor writes
+ // a single 1-bit and then fills the byte with 0-7 0 bits of padding."
+ const int padding = 8 - log2inf(src[len - 1]);
+ i64 offset = len * 8 - padding;
FSE_init_state(&state.ll_table, &state.ll_state, src, &offset);
FSE_init_state(&state.of_table, &state.of_state, src, &offset);
CORRUPTION();
}
- // Don't free our tables so they can be used in the next block
+ // Don't free tables so they can be used in the next block
}
// Decode a single sequence and update the state
}
/// Given a sequence part and table mode, decode the FSE distribution
-static size_t decode_seq_table(const u8 *src, size_t src_len,
- FSE_dtable *const table, const seq_part_t type,
- const seq_mode_t mode) {
+static void decode_seq_table(istream_t *const in, FSE_dtable *const table,
+ const seq_part_t type, const seq_mode_t mode) {
// Constant arrays indexed by seq_part_t
const i16 *const default_distributions[] = {SEQ_LITERAL_LENGTH_DEFAULT_DIST,
SEQ_OFFSET_DEFAULT_DIST,
const size_t max_accuracies[] = {9, 8, 9};
if (mode != seq_repeat) {
- // ree old one before overwriting
+ // Free old one before overwriting
FSE_free_dtable(table);
}
const size_t accuracy_log = default_distribution_accuracies[type];
FSE_init_dtable(table, distribution, symbs, accuracy_log);
-
- return 0;
+ break;
}
case seq_rle: {
- if (src_len < 1) {
- INP_SIZE();
- }
- const u8 symb = src[0];
- src++;
- src_len--;
+ const u8 symb = IO_read_bits(in, 8);
FSE_init_dtable_rle(table, symb);
-
- return 1;
+ break;
}
case seq_fse: {
- size_t read =
- FSE_decode_header(table, src, src_len, max_accuracies[type]);
- src += read;
- src_len -= read;
-
- return read;
+ FSE_decode_header(table, in, max_accuracies[type]);
+ break;
}
case seq_repeat:
- // Don't have to do anything here as we're not changing the table
- return 0;
+ // Nothing to do here, table will be unchanged
+ break;
default:
// Impossible, as mode is from 0-3
- return -1;
+ IMPOSSIBLE();
+ break;
}
}
/******* END SEQUENCE DECODING ************************************************/
/******* SEQUENCE EXECUTION ***************************************************/
-static void execute_sequences(io_streams_t *const streams,
- frame_context_t *const ctx,
+static void execute_sequences(frame_context_t *const ctx, ostream_t *const out,
+ const u8 *const literals,
+ const size_t literals_len,
const sequence_command_t *const sequences,
- const size_t num_sequences,
- const u8 *literals,
- size_t literals_len) {
+ const size_t num_sequences) {
+ istream_t litstream = IO_make_istream(literals, literals_len);
+
u64 *const offset_hist = ctx->previous_offsets;
size_t total_output = ctx->current_total_output;
for (size_t i = 0; i < num_sequences; i++) {
const sequence_command_t seq = sequences[i];
- if (seq.literal_length > literals_len) {
- CORRUPTION();
- }
-
- if (streams->dst_len < seq.literal_length + seq.match_length) {
- OUT_SIZE();
- }
- // Copy literals to output
- memcpy(streams->dst, literals, seq.literal_length);
-
- literals += seq.literal_length;
- literals_len -= seq.literal_length;
+ {
+ if (seq.literal_length > IO_istream_len(&litstream)) {
+ CORRUPTION();
+ }
- streams->dst += seq.literal_length;
- streams->dst_len -= seq.literal_length;
+ u8 *const write_ptr = IO_write_bytes(out, seq.literal_length);
+ const u8 *const read_ptr =
+ IO_read_bytes(&litstream, seq.literal_length);
+ // Copy literals to output
+ memcpy(write_ptr, read_ptr, seq.literal_length);
- total_output += seq.literal_length;
+ total_output += seq.literal_length;
+ }
size_t offset;
// Offsets are special, we need to handle the repeat offsets
if (seq.offset <= 3) {
- u32 idx = seq.offset;
+ // "The first 3 values define a repeated offset and we will call
+ // them Repeated_Offset1, Repeated_Offset2, and Repeated_Offset3.
+ // They are sorted in recency order, with Repeated_Offset1 meaning
+ // 'most recent one'".
+
+ // Use 0 indexing for the array
+ u32 idx = seq.offset - 1;
if (seq.literal_length == 0) {
- // Special case when literal length is 0
+ // "There is an exception though, when current sequence's
+ // literals length is 0. In this case, repeated offsets are
+ // shifted by one, so Repeated_Offset1 becomes Repeated_Offset2,
+ // Repeated_Offset2 becomes Repeated_Offset3, and
+ // Repeated_Offset3 becomes Repeated_Offset1 - 1_byte."
idx++;
}
- if (idx == 1) {
- offset = offset_hist[1];
+ if (idx == 0) {
+ offset = offset_hist[0];
} else {
- // If idx == 4 then literal length was 0 and the offset was 3
- offset = idx < 4 ? offset_hist[idx] : offset_hist[1] - 1;
+ // If idx == 3 then literal length was 0 and the offset was 3,
+ // as per the exception listed above
+ offset = idx < 3 ? offset_hist[idx] : offset_hist[0] - 1;
- // If idx == 2 we don't need to modify offset_hist[3]
- if (idx > 2) {
- offset_hist[3] = offset_hist[2];
+ // If idx == 1 we don't need to modify offset_hist[2]
+ if (idx > 1) {
+ offset_hist[2] = offset_hist[1];
}
- offset_hist[2] = offset_hist[1];
- offset_hist[1] = offset;
+ offset_hist[1] = offset_hist[0];
+ offset_hist[0] = offset;
}
} else {
offset = seq.offset - 3;
// Shift back history
- offset_hist[3] = offset_hist[2];
offset_hist[2] = offset_hist[1];
- offset_hist[1] = offset;
+ offset_hist[1] = offset_hist[0];
+ offset_hist[0] = offset;
}
size_t match_length = seq.match_length;
+
+ u8 *write_ptr = IO_write_bytes(out, match_length);
if (total_output <= ctx->header.window_size) {
// In this case offset might go back into the dictionary
if (offset > total_output + ctx->dict_content_len) {
}
if (offset > total_output) {
+ // "The rest of the dictionary is its content. The content act
+ // as a "past" in front of data to compress or decompress, so it
+ // can be referenced in sequence commands."
const size_t dict_copy =
MIN(offset - total_output, match_length);
const size_t dict_offset =
ctx->dict_content_len - (offset - total_output);
- for (size_t i = 0; i < dict_copy; i++) {
- *streams->dst++ = ctx->dict_content[dict_offset + i];
- }
+
+ memcpy(write_ptr, ctx->dict_content + dict_offset, dict_copy);
+ write_ptr += dict_copy;
match_length -= dict_copy;
}
} else if (offset > ctx->header.window_size) {
// ex: if the output so far was "abc", a command with offset=3 and
// match_length=6 would produce "abcabcabc" as the new output
for (size_t i = 0; i < match_length; i++) {
- *streams->dst = *(streams->dst - offset);
- streams->dst++;
+ *write_ptr = *(write_ptr - offset);
+ write_ptr++;
}
- streams->dst_len -= seq.match_length;
total_output += seq.match_length;
}
- if (streams->dst_len < literals_len) {
- OUT_SIZE();
- }
- // Copy any leftover literals
- memcpy(streams->dst, literals, literals_len);
- streams->dst += literals_len;
- streams->dst_len -= literals_len;
+ {
+ size_t len = IO_istream_len(&litstream);
+ u8 *const write_ptr = IO_write_bytes(out, len);
+ const u8 *const read_ptr = IO_read_bytes(&litstream, len);
+ // Copy any leftover literals
+ memcpy(write_ptr, read_ptr, len);
- total_output += literals_len;
+ total_output += len;
+ }
ctx->current_total_output = total_output;
}
/******* END SEQUENCE EXECUTION ***********************************************/
/******* OUTPUT SIZE COUNTING *************************************************/
-size_t traverse_frame(const frame_header_t *const header, const u8 *src,
- size_t src_len);
+static void traverse_frame(const frame_header_t *const header, istream_t *const in);
/// Get the decompressed size of an input stream so memory can be allocated in
/// advance.
/// implementation, as this API allows for the decompression of multiple
/// concatenated frames.
size_t ZSTD_get_decompressed_size(const void *src, const size_t src_len) {
- const u8 *ip = (const u8 *) src;
- size_t ip_len = src_len;
- size_t dst_size = 0;
-
- // Each frame header only gives us the size of its frame, so iterate over all
- // frames
- while (ip_len > 0) {
- if (ip_len < 4) {
- INP_SIZE();
- }
-
- const u32 magic_number = read_bits_LE(ip, 32, 0);
-
- ip += 4;
- ip_len -= 4;
- if (magic_number >= 0x184D2A50U && magic_number <= 0x184D2A5F) {
- // skippable frame, this has no impact on output size
- if (ip_len < 4) {
- INP_SIZE();
- }
- const size_t frame_size = read_bits_LE(ip, 32, 32);
+ istream_t in = IO_make_istream(src, src_len);
+ size_t dst_size = 0;
+
+ // Each frame header only gives us the size of its frame, so iterate over
+ // all
+ // frames
+ while (IO_istream_len(&in) > 0) {
+ const u32 magic_number = IO_read_bits(&in, 32);
+
+ if ((magic_number & ~0xFU) == 0x184D2A50U) {
+ // skippable frame, this has no impact on output size
+ const size_t frame_size = IO_read_bits(&in, 32);
+ IO_advance_input(&in, frame_size);
+ } else if (magic_number == 0xFD2FB528U) {
+ // ZSTD frame
+ frame_header_t header;
+ parse_frame_header(&header, &in);
+
+ if (header.frame_content_size == 0 && !header.single_segment_flag) {
+ // Content size not provided, we can't tell
+ return -1;
+ }
- if (ip_len < 4 + frame_size) {
- INP_SIZE();
- }
+ dst_size += header.frame_content_size;
- // skip over frame
- ip += 4 + frame_size;
- ip_len -= 4 + frame_size;
- } else if (magic_number == 0xFD2FB528U) {
- // ZSTD frame
- frame_header_t header;
- parse_frame_header(&header, ip, ip_len);
-
- if (header.frame_content_size == 0 && !header.single_segment_flag) {
- // Content size not provided, we can't tell
- return -1;
+ // Consume the input from the frame to reach the start of the next
+ traverse_frame(&header, &in);
+ } else {
+ // not a real frame
+ ERROR("Invalid magic number");
}
-
- dst_size += header.frame_content_size;
-
- // we need to traverse the frame to find when the next one starts
- const size_t traversed = traverse_frame(&header, ip, ip_len);
- ip += traversed;
- ip_len -= traversed;
- } else {
- // not a real frame
- ERROR("Invalid magic number");
}
- }
- return dst_size;
+ return dst_size;
}
/// Iterate over each block in a frame to find the end of it, to get to the
/// start of the next frame
-size_t traverse_frame(const frame_header_t *const header, const u8 *src,
- size_t src_len) {
- const u8 *const src_beg = src;
- const u8 *const src_end = src + src_len;
- src += header->header_size;
- src_len += header->header_size;
-
+static void traverse_frame(const frame_header_t *const header, istream_t *const in) {
int last_block = 0;
do {
- if (src + 3 > src_end) {
- INP_SIZE();
- }
// Parse the block header
- last_block = src[0] & 1;
- const int block_type = (src[0] >> 1) & 3;
- const size_t block_len = read_bits_LE(src, 21, 3);
+ last_block = IO_read_bits(in, 1);
+ const int block_type = IO_read_bits(in, 2);
+ const size_t block_len = IO_read_bits(in, 21);
- src += 3;
switch (block_type) {
case 0: // Raw block, block_len bytes
- if (src + block_len > src_end) {
- INP_SIZE();
- }
- src += block_len;
+ IO_advance_input(in, block_len);
break;
case 1: // RLE block, 1 byte
- if (src + 1 > src_end) {
- INP_SIZE();
- }
- src++;
+ IO_advance_input(in, 1);
break;
case 2: // Compressed block, compressed size is block_len
- if (src + block_len > src_end) {
- INP_SIZE();
- }
- src += block_len;
+ IO_advance_input(in, block_len);
break;
case 3:
// Reserved block type
CORRUPTION();
break;
+ default:
+ IMPOSSIBLE();
}
} while (!last_block);
if (header->content_checksum_flag) {
- if (src + 4 > src_end) {
- INP_SIZE();
- }
- src += 4;
+ IO_advance_input(in, 4);
}
-
- return src - src_beg;
}
/******* END OUTPUT SIZE COUNTING *********************************************/
if (src_len < 8) {
INP_SIZE();
}
- const u32 magic_number = read_bits_LE(src, 32, 0);
+
+ istream_t in = IO_make_istream(src, src_len);
+
+ const u32 magic_number = IO_read_bits(&in, 32);
if (magic_number != 0xEC30A437) {
// raw content dict
init_raw_content_dict(dict, src, src_len);
return;
}
- dict->dictionary_id = read_bits_LE(src, 32, 32);
- src += 8;
- src_len -= 8;
+ dict->dictionary_id = IO_read_bits(&in, 32);
// Parse the provided entropy tables in order
- {
- const size_t read =
- decode_huf_table(src, src_len, &dict->literals_dtable);
- src += read;
- src_len -= read;
- }
- {
- const size_t read = decode_seq_table(src, src_len, &dict->of_dtable,
- seq_offset, seq_fse);
- src += read;
- src_len -= read;
- }
- {
- const size_t read = decode_seq_table(src, src_len, &dict->ml_dtable,
- seq_match_length, seq_fse);
- src += read;
- src_len -= read;
- }
- {
- const size_t read = decode_seq_table(src, src_len, &dict->ll_dtable,
- seq_literal_length, seq_fse);
- src += read;
- src_len -= read;
- }
+ decode_huf_table(&in, &dict->literals_dtable);
+ decode_seq_table(&in, &dict->of_dtable, seq_offset, seq_fse);
+ decode_seq_table(&in, &dict->ml_dtable, seq_match_length, seq_fse);
+ decode_seq_table(&in, &dict->ll_dtable, seq_literal_length, seq_fse);
- if (src_len < 12) {
- INP_SIZE();
- }
// Read in the previous offset history
- dict->previous_offsets[1] = read_bits_LE(src, 32, 0);
- dict->previous_offsets[2] = read_bits_LE(src, 32, 32);
- dict->previous_offsets[3] = read_bits_LE(src, 32, 64);
-
- src += 12;
- src_len -= 12;
+ dict->previous_offsets[0] = IO_read_bits(&in, 32);
+ dict->previous_offsets[1] = IO_read_bits(&in, 32);
+ dict->previous_offsets[2] = IO_read_bits(&in, 32);
// Ensure the provided offsets aren't too large
- for (int i = 1; i <= 3; i++) {
+ for (int i = 0; i < 3; i++) {
if (dict->previous_offsets[i] > src_len) {
ERROR("Dictionary corrupted");
}
}
+
// The rest is the content
- dict->content = malloc(src_len);
+ dict->content_size = IO_istream_len(&in);
+ dict->content = malloc(dict->content_size);
if (!dict->content) {
BAD_ALLOC();
}
- dict->content_size = src_len;
- memcpy(dict->content, src, src_len);
+ const u8 *const content = IO_read_bytes(&in, dict->content_size);
+
+ memcpy(dict->content, content, dict->content_size);
}
/// If parse_dictionary is given a raw content dictionary, it delegates here
}
/******* END DICTIONARY PARSING ***********************************************/
+/******* IO STREAM OPERATIONS *************************************************/
+#define UNALIGNED() ERROR("Attempting to operate on a non-byte aligned stream")
+/// Reads `num` bits from a bitstream, and updates the internal offset
+static inline u64 IO_read_bits(istream_t *const in, const int num) {
+ if (num > 64) {
+ return -1;
+ }
+
+ const size_t bytes = (num + in->bit_offset + 7) / 8;
+ const size_t full_bytes = (num + in->bit_offset) / 8;
+ if (bytes > in->len) {
+ INP_SIZE();
+ }
+
+ const u64 result = read_bits_LE(in->ptr, num, in->bit_offset);
+
+ in->bit_offset = (num + in->bit_offset) % 8;
+ in->ptr += full_bytes;
+ in->len -= full_bytes;
+
+ return result;
+}
+
+/// If a non-zero number of bits have been read from the current byte, advance
+/// the offset to the next byte
+static inline void IO_rewind_bits(istream_t *const in, int num) {
+ if (num < 0) {
+ ERROR("Attempting to rewind stream by a negative number of bits");
+ }
+
+ const int new_offset = in->bit_offset - num;
+ const i64 bytes = (new_offset - 7) / 8;
+
+ in->ptr += bytes;
+ in->len -= bytes;
+ in->bit_offset = ((new_offset % 8) + 8) % 8;
+}
+
+/// If the remaining bits in a byte will be unused, advance to the end of the
+/// byte
+static inline void IO_align_stream(istream_t *const in) {
+ if (in->bit_offset != 0) {
+ if (in->len == 0) {
+ INP_SIZE();
+ }
+ in->ptr++;
+ in->len--;
+ in->bit_offset = 0;
+ }
+}
+
+/// Write the given byte into the output stream
+static inline void IO_write_byte(ostream_t *const out, u8 symb) {
+ if (out->len == 0) {
+ OUT_SIZE();
+ }
+
+ out->ptr[0] = symb;
+ out->ptr++;
+ out->len--;
+}
+
+/// Returns the number of bytes left to be read in this stream. The stream must
+/// be byte aligned.
+static inline size_t IO_istream_len(const istream_t *const in) {
+ return in->len;
+}
+
+/// Returns a pointer where `len` bytes can be read, and advances the internal
+/// state. The stream must be byte aligned.
+static inline const u8 *IO_read_bytes(istream_t *const in, size_t len) {
+ if (len > in->len) {
+ INP_SIZE();
+ }
+ if (in->bit_offset != 0) {
+ UNALIGNED();
+ }
+ const u8 *const ptr = in->ptr;
+ in->ptr += len;
+ in->len -= len;
+
+ return ptr;
+}
+/// Returns a pointer to write `len` bytes to, and advances the internal state
+static inline u8 *IO_write_bytes(ostream_t *const out, size_t len) {
+ if (len > out->len) {
+ INP_SIZE();
+ }
+ u8 *const ptr = out->ptr;
+ out->ptr += len;
+ out->len -= len;
+
+ return ptr;
+}
+
+/// Advance the inner state by `len` bytes
+static inline void IO_advance_input(istream_t *const in, size_t len) {
+ if (len > in->len) {
+ INP_SIZE();
+ }
+ if (in->bit_offset != 0) {
+ UNALIGNED();
+ }
+
+ in->ptr += len;
+ in->len -= len;
+}
+
+/// Returns an `ostream_t` constructed from the given pointer and length
+static inline ostream_t IO_make_ostream(u8 *out, size_t len) {
+ return (ostream_t) { out, len };
+}
+
+/// Returns an `istream_t` constructed from the given pointer and length
+static inline istream_t IO_make_istream(const u8 *in, size_t len) {
+ return (istream_t) { in, 0, len };
+}
+
+/// Returns an `istream_t` with the same base as `in`, and length `len`
+/// Then, advance `in` to account for the consumed bytes
+/// `in` must be byte aligned
+static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len) {
+ if (len > in->len) {
+ INP_SIZE();
+ }
+ if (in->bit_offset != 0) {
+ UNALIGNED();
+ }
+ const istream_t sub = { in->ptr, in->bit_offset, len };
+
+ in->ptr += len;
+ in->len -= len;
+
+ return sub;
+}
+/******* END IO STREAM OPERATIONS *********************************************/
+
/******* BITSTREAM OPERATIONS *************************************************/
/// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits
static inline u64 read_bits_LE(const u8 *src, const int num,
*state = STREAM_read_bits(src, bits, offset);
}
-static size_t HUF_decompress_1stream(const HUF_dtable *const dtable, u8 *dst,
- const size_t dst_len, const u8 *src,
- size_t src_len) {
- const u8 *const dst_max = dst + dst_len;
- const u8 *const odst = dst;
+static size_t HUF_decompress_1stream(const HUF_dtable *const dtable,
+ ostream_t *const out,
+ istream_t *const in) {
+ const size_t len = IO_istream_len(in);
+ if (len == 0) {
+ INP_SIZE();
+ }
+ const u8 *const src = IO_read_bytes(in, len);
// To maintain similarity with FSE, start from the end
// Find the last 1 bit
- const int padding = 8 - log2inf(src[src_len - 1]);
+ const int padding = 8 - log2inf(src[len - 1]);
- i64 offset = src_len * 8 - padding;
+ i64 offset = len * 8 - padding;
u16 state;
HUF_init_state(dtable, &state, src, &offset);
- while (dst < dst_max && offset > -dtable->max_bits) {
+ size_t symbols_written = 0;
+ while (offset > -dtable->max_bits) {
// Iterate over the stream, decoding one symbol at a time
- *dst++ = HUF_decode_symbol(dtable, &state, src, &offset);
- }
- // If we stopped before consuming all the input, we didn't have enough space
- if (dst == dst_max && offset > -dtable->max_bits) {
- OUT_SIZE();
+ IO_write_byte(out, HUF_decode_symbol(dtable, &state, src, &offset));
+ symbols_written++;
}
// When all symbols have been decoded, the final state value shouldn't have
CORRUPTION();
}
- return dst - odst;
+ return symbols_written;
}
-static size_t HUF_decompress_4stream(const HUF_dtable *const dtable, u8 *dst,
- const size_t dst_len, const u8 *const src,
- const size_t src_len) {
- if (src_len < 6) {
- INP_SIZE();
- }
+static size_t HUF_decompress_4stream(const HUF_dtable *const dtable,
+ ostream_t *const out, istream_t *const in) {
+ const size_t csize1 = IO_read_bits(in, 16);
+ const size_t csize2 = IO_read_bits(in, 16);
+ const size_t csize3 = IO_read_bits(in, 16);
- const u8 *const src1 = src + 6;
- const u8 *const src2 = src1 + read_bits_LE(src, 16, 0);
- const u8 *const src3 = src2 + read_bits_LE(src, 16, 16);
- const u8 *const src4 = src3 + read_bits_LE(src, 16, 32);
- const u8 *const src_end = src + src_len;
-
- // We can't test with all 4 sizes because the 4th size is a function of the
- // other 3 and the provided length
- if (src4 - src >= src_len) {
- INP_SIZE();
- }
-
- const size_t segment_size = (dst_len + 3) / 4;
- u8 *const dst1 = dst;
- u8 *const dst2 = dst1 + segment_size;
- u8 *const dst3 = dst2 + segment_size;
- u8 *const dst4 = dst3 + segment_size;
- u8 *const dst_end = dst + dst_len;
-
- size_t total_out = 0;
+ istream_t in1 = IO_make_sub_istream(in, csize1);
+ istream_t in2 = IO_make_sub_istream(in, csize2);
+ istream_t in3 = IO_make_sub_istream(in, csize3);
+ istream_t in4 = IO_make_sub_istream(in, IO_istream_len(in));
+ size_t total_output = 0;
// Decode each stream independently for simplicity
// If we wanted to we could decode all 4 at the same time for speed,
// utilizing more execution units
- total_out += HUF_decompress_1stream(dtable, dst1, segment_size, src1,
- src2 - src1);
- total_out += HUF_decompress_1stream(dtable, dst2, segment_size, src2,
- src3 - src2);
- total_out += HUF_decompress_1stream(dtable, dst3, segment_size, src3,
- src4 - src3);
- total_out += HUF_decompress_1stream(dtable, dst4, dst_end - dst4, src4,
- src_end - src4);
-
- return total_out;
+ total_output += HUF_decompress_1stream(dtable, out, &in1);
+ total_output += HUF_decompress_1stream(dtable, out, &in2);
+ total_output += HUF_decompress_1stream(dtable, out, &in3);
+ total_output += HUF_decompress_1stream(dtable, out, &in4);
+
+ return total_output;
}
static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits,
u64 weight_sum = 0;
for (int i = 0; i < num_symbs; i++) {
+ // Weights are in the same range as bit count
+ if (weights[i] > HUF_MAX_BITS) {
+ CORRUPTION();
+ }
weight_sum += weights[i] > 0 ? (u64)1 << (weights[i] - 1) : 0;
}
}
static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable,
- u8 *dst, const size_t dst_len,
- const u8 *const src,
- const size_t src_len) {
- if (src_len == 0) {
+ ostream_t *const out,
+ istream_t *const in) {
+ const size_t len = IO_istream_len(in);
+ if (len == 0) {
INP_SIZE();
}
-
- const u8 *const dst_max = dst + dst_len;
- const u8 *const odst = dst;
+ const u8 *const src = IO_read_bytes(in, len);
// Find the last 1 bit
- const int padding = 8 - log2inf(src[src_len - 1]);
-
- i64 offset = src_len * 8 - padding;
+ const int padding = 8 - log2inf(src[len - 1]);
+ i64 offset = len * 8 - padding;
// The end of the stream contains the 2 states, in this order
u16 state1, state2;
// Decode until we overflow the stream
// Since we decode in reverse order, overflowing the stream is offset going
// negative
+ size_t symbols_written = 0;
while (1) {
- if (dst > dst_max - 2) {
- OUT_SIZE();
- }
- *dst++ = FSE_decode_symbol(dtable, &state1, src, &offset);
+ IO_write_byte(out, FSE_decode_symbol(dtable, &state1, src, &offset));
+ symbols_written++;
if (offset < 0) {
// There's still a symbol to decode in state2
- *dst++ = FSE_peek_symbol(dtable, state2);
+ IO_write_byte(out, FSE_peek_symbol(dtable, state2));
+ symbols_written++;
break;
}
- if (dst > dst_max - 2) {
- OUT_SIZE();
- }
- *dst++ = FSE_decode_symbol(dtable, &state2, src, &offset);
+ IO_write_byte(out, FSE_decode_symbol(dtable, &state2, src, &offset));
+ symbols_written++;
if (offset < 0) {
// There's still a symbol to decode in state1
- *dst++ = FSE_peek_symbol(dtable, state1);
+ IO_write_byte(out, FSE_peek_symbol(dtable, state1));
+ symbols_written++;
break;
}
}
- // Number of symbols read
- return dst - odst;
+ return symbols_written;
}
static void FSE_init_dtable(FSE_dtable *const dtable,
/// Decode an FSE header as defined in the Zstandard format specification and
/// use the decoded frequencies to initialize a decoding table.
-static size_t FSE_decode_header(FSE_dtable *const dtable, const u8 *const src,
- const size_t src_len,
+static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in,
const int max_accuracy_log) {
if (max_accuracy_log > FSE_MAX_ACCURACY_LOG) {
ERROR("FSE accuracy too large");
}
- if (src_len < 1) {
- INP_SIZE();
- }
- const int accuracy_log = 5 + read_bits_LE(src, 4, 0);
+ const int accuracy_log = 5 + IO_read_bits(in, 4);
if (accuracy_log > max_accuracy_log) {
ERROR("FSE accuracy too large");
}
i16 frequencies[FSE_MAX_SYMBS];
int symb = 0;
- // Offset of 4 because 4 bits were already read in for accuracy
- size_t offset = 4;
while (remaining > 1 && symb < FSE_MAX_SYMBS) {
// Log of the number of possible values we could read
int bits = log2inf(remaining) + 1;
- u16 val = read_bits_LE(src, bits, offset);
- offset += bits;
+ u16 val = IO_read_bits(in, bits);
// Try to mask out the lower bits to see if it qualifies for the "small
// value" threshold
const u16 threshold = ((u16)1 << bits) - 1 - remaining;
if ((val & lower_mask) < threshold) {
- offset--;
+ IO_rewind_bits(in, 1);
val = val & lower_mask;
} else if (val > lower_mask) {
val = val - threshold;
// Handle the special probability = 0 case
if (proba == 0) {
// Read the next two bits to see how many more 0s
- int repeat = read_bits_LE(src, 2, offset);
- offset += 2;
+ int repeat = IO_read_bits(in, 2);
while (1) {
for (int i = 0; i < repeat && symb < FSE_MAX_SYMBS; i++) {
frequencies[symb++] = 0;
}
if (repeat == 3) {
- repeat = read_bits_LE(src, 2, offset);
- offset += 2;
+ repeat = IO_read_bits(in, 2);
} else {
break;
}
}
}
}
+ IO_align_stream(in);
if (remaining != 1 || symb >= FSE_MAX_SYMBS) {
CORRUPTION();
// Initialize the decoding table using the determined weights
FSE_init_dtable(dtable, frequencies, symb, accuracy_log);
-
- return (offset + 7) / 8;
}
static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb) {