]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
Switch IO to go through streams
authorSean Purcell <me@seanp.xyz>
Thu, 2 Feb 2017 01:05:45 +0000 (17:05 -0800)
committerSean Purcell <me@seanp.xyz>
Fri, 3 Feb 2017 23:22:52 +0000 (15:22 -0800)
contrib/educational_decoder/harness.c
contrib/educational_decoder/zstd_decompress.c

index cff8239d6f0884179597dca00ceba6f399c9dc89..683278dfcd0147272acaa05102c4dc1233771982 100644 (file)
@@ -18,6 +18,9 @@ typedef unsigned char u8;
 // compression ratio is at most 16
 #define MAX_COMPRESSION_RATIO (16)
 
+// Protect against allocating too much memory for output
+#define MAX_OUTPUT_SIZE ((size_t)1024 * 1024 * 1024)
+
 u8 *input;
 u8 *output;
 u8 *dict;
@@ -86,11 +89,17 @@ int main(int argc, char **argv) {
     size_t decompressed_size = ZSTD_get_decompressed_size(input, input_size);
     if (decompressed_size == -1) {
         decompressed_size = MAX_COMPRESSION_RATIO * input_size;
-        fprintf(stderr, "WARNING: Compressed data does contain decompressed "
-                        "size, going to assume the compression ratio is at "
-                        "most %d (decompressed size of at most %zu)\n",
+        fprintf(stderr, "WARNING: Compressed data does not contain "
+                        "decompressed size, going to assume the compression "
+                        "ratio is at most %d (decompressed size of at most "
+                        "%zu)\n",
                 MAX_COMPRESSION_RATIO, decompressed_size);
     }
+    if (decompressed_size > MAX_OUTPUT_SIZE) {
+        fprintf(stderr,
+                "Required output size too large for this implementation\n");
+        return 1;
+    }
     output = malloc(decompressed_size);
     if (!output) {
         fprintf(stderr, "failed to allocate memory\n");
index e2fbcf2cf06014b76fa73521f0bc8185fecf2e6c..8f28313e408305022bd56a2ae8cdf77570bda627 100644 (file)
@@ -48,6 +48,7 @@ size_t ZSTD_get_decompressed_size(const void *const src, const size_t src_len);
 #define OUT_SIZE() ERROR("Output buffer too small for output")
 #define CORRUPTION() ERROR("Corruption detected while decompressing")
 #define BAD_ALLOC() ERROR("Memory allocation error")
+#define IMPOSSIBLE() ERROR("An impossibility has occurred")
 
 typedef uint8_t u8;
 typedef uint16_t u16;
@@ -65,6 +66,62 @@ typedef int64_t i64;
 /// file.  They implement low-level functionality needed for the higher level
 /// decompression functions.
 
+/*** IO STREAM OPERATIONS *************/
+/// These structs are the interface for IO, and do bounds checking on all
+/// operations.  They should be used opaquely to ensure safety.
+
+/// Output is always done byte-by-byte
+typedef struct {
+    u8 *ptr;
+    size_t len;
+} ostream_t;
+
+/// Input often reads a few bits at a time, so maintain an internal offset
+typedef struct {
+    const u8 *ptr;
+    int bit_offset;
+    size_t len;
+} istream_t;
+
+/// The following two functions are the only ones that allow the istream to be
+/// non-byte aligned
+
+/// Reads `num` bits from a bitstream, and updates the internal offset
+static inline u64 IO_read_bits(istream_t *const in, const int num);
+/// Rewinds the stream by `num` bits
+static inline void IO_rewind_bits(istream_t *const in, const int num);
+/// If the remaining bits in a byte will be unused, advance to the end of the
+/// byte
+static inline void IO_align_stream(istream_t *const in);
+
+/// Write the given byte into the output stream
+static inline void IO_write_byte(ostream_t *const out, u8 symb);
+
+/// Returns the number of bytes left to be read in this stream.  The stream must
+/// be byte aligned.
+static inline size_t IO_istream_len(const istream_t *const in);
+
+/// Returns a pointer where `len` bytes can be read, and advances the internal
+/// state.  The stream must be byte aligned.
+static inline const u8 *IO_read_bytes(istream_t *const in, size_t len);
+/// Returns a pointer where `len` bytes can be written, and advances the internal
+/// state.  The stream must be byte aligned.
+static inline u8 *IO_write_bytes(ostream_t *const out, size_t len);
+
+/// Advance the inner state by `len` bytes.  The stream must be byte aligned.
+static inline void IO_advance_input(istream_t *const in, size_t len);
+
+/// Returns an `ostream_t` constructed from the given pointer and length
+static inline ostream_t IO_make_ostream(u8 *out, size_t len);
+/// Returns an `istream_t` constructed from the given pointer and length
+static inline istream_t IO_make_istream(const u8 *in, size_t len);
+
+/// Returns an `istream_t` with the same base as `in`, and length `len`
+/// Then, advance `in` to account for the consumed bytes
+/// `in` must be byte aligned
+static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len);
+/*** END IO STREAM OPERATIONS *********/
+
 /*** BITSTREAM OPERATIONS *************/
 /// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits
 static inline u64 read_bits_LE(const u8 *src, const int num,
@@ -109,15 +166,13 @@ static inline void HUF_init_state(const HUF_dtable *const dtable,
 
 /// Decompresses a single Huffman stream, returns the number of bytes decoded.
 /// `src_len` must be the exact length of the Huffman-coded block.
-static size_t HUF_decompress_1stream(const HUF_dtable *const dtable, u8 *dst,
-                                     const size_t dst_len, const u8 *src,
-                                     size_t src_len);
+static size_t HUF_decompress_1stream(const HUF_dtable *const dtable,
+                                     ostream_t *const out, istream_t *const in);
 /// Same as previous but decodes 4 streams, formatted as in the Zstandard
 /// specification.
 /// `src_len` must be the exact length of the Huffman-coded block.
-static size_t HUF_decompress_4stream(const HUF_dtable *const dtable, u8 *dst,
-                                     const size_t dst_len, const u8 *const src,
-                                     const size_t src_len);
+static size_t HUF_decompress_4stream(const HUF_dtable *const dtable,
+                                     ostream_t *const out, istream_t *const in);
 
 /// Initialize a Huffman decoding table using the table of bit counts provided
 static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits,
@@ -176,9 +231,8 @@ static inline void FSE_init_state(const FSE_dtable *const dtable,
 /// using an FSE decoding table.  `src_len` must be the exact length of the
 /// block.
 static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable,
-                                          u8 *dst, const size_t dst_len,
-                                          const u8 *const src,
-                                          const size_t src_len);
+                                          ostream_t *const out,
+                                          istream_t *const in);
 
 /// Initialize a decoding table using normalized frequencies.
 static void FSE_init_dtable(FSE_dtable *const dtable,
@@ -187,8 +241,7 @@ static void FSE_init_dtable(FSE_dtable *const dtable,
 
 /// Decode an FSE header as defined in the Zstandard format specification and
 /// use the decoded frequencies to initialize a decoding table.
-static size_t FSE_decode_header(FSE_dtable *const dtable, const u8 *const src,
-                                const size_t src_len,
+static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in,
                                 const int max_accuracy_log);
 
 /// Initialize an FSE table that will always return the same symbol and consume
@@ -207,16 +260,6 @@ static void FSE_copy_dtable(FSE_dtable *const dst, const FSE_dtable *const src);
 
 /******* ZSTD HELPER STRUCTS AND PROTOTYPES ***********************************/
 
-/// Input and output pointers to allow them to be advanced by
-/// functions that consume input/produce output
-typedef struct {
-    u8 *dst;
-    size_t dst_len;
-
-    const u8 *src;
-    size_t src_len;
-} io_streams_t;
-
 /// A small structure that can be reused in various places that need to access
 /// frame header information
 typedef struct {
@@ -233,9 +276,6 @@ typedef struct {
     int content_checksum_flag;
     // Whether or not the output for this frame is in a single segment
     int single_segment_flag;
-
-    // The size in bytes of this header
-    int header_size;
 } frame_header_t;
 
 /// The context needed to decode blocks in a frame
@@ -256,9 +296,8 @@ typedef struct {
     FSE_dtable ml_dtable;
     FSE_dtable of_dtable;
 
-    // The last 3 offsets for the special "repeat offsets".  Array size is 4 so
-    // that previous_offsets[1] corresponds to the most recent offset
-    u64 previous_offsets[4];
+    // The last 3 offsets for the special "repeat offsets".
+    u64 previous_offsets[3];
 } frame_context_t;
 
 /// The decoded contents of a dictionary so that it doesn't have to be repeated
@@ -275,7 +314,7 @@ typedef struct {
     size_t content_size;
 
     // Offset history to prepopulate the frame's history
-    u64 previous_offsets[4];
+    u64 previous_offsets[3];
 
     u32 dictionary_id;
 } dictionary_t;
@@ -301,34 +340,31 @@ typedef struct {
 /// Accepts a dict argument, which may be NULL indicating no dictionary.
 /// See
 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame-concatenation
-static void decode_frame(io_streams_t *const streams,
+static void decode_frame(ostream_t *const out, istream_t *const in,
                          const dictionary_t *const dict);
 
 // Decode data in a compressed block
-static void decompress_block(io_streams_t *const streams,
-                             frame_context_t *const ctx,
-                             const size_t block_len);
+static void decompress_block(frame_context_t *const ctx, ostream_t *const out,
+                             istream_t *const in);
 
 // Decode the literals section of a block
-static size_t decode_literals(io_streams_t *const streams,
-                              frame_context_t *const ctx, u8 **const literals);
+static size_t decode_literals(frame_context_t *const ctx, istream_t *const in,
+                              u8 **const literals);
 
 // Decode the sequences part of a block
-static size_t decode_sequences(frame_context_t *const ctx, const u8 *const src,
-                               const size_t src_len,
+static size_t decode_sequences(frame_context_t *const ctx, istream_t *const in,
                                sequence_command_t **const sequences);
 
 // Execute the decoded sequences on the literals block
-static void execute_sequences(io_streams_t *const streams,
-                              frame_context_t *const ctx,
+static void execute_sequences(frame_context_t *const ctx, ostream_t *const out,
+                              const u8 *const literals,
+                              const size_t literals_len,
                               const sequence_command_t *const sequences,
-                              const size_t num_sequences,
-                              const u8 *literals,
-                              size_t literals_len);
+                              const size_t num_sequences);
 
 // Parse a provided dictionary blob for use in decompression
-static void parse_dictionary(dictionary_t *const dict, const u8 *const src,
-                             const size_t src_len);
+static void parse_dictionary(dictionary_t *const dict, const u8 *src,
+                             size_t src_len);
 static void free_dictionary(dictionary_t *const dict);
 /******* END ZSTD HELPER STRUCTS AND PROTOTYPES *******************************/
 
@@ -348,58 +384,46 @@ size_t ZSTD_decompress_with_dict(void *const dst, const size_t dst_len,
         parse_dictionary(&parsed_dict, (const u8 *)dict, dict_len);
     }
 
-    io_streams_t streams = {(u8 *)dst, dst_len, (const u8 *)src, src_len};
-    while (streams.src_len > 0) {
-        decode_frame(&streams, &parsed_dict);
+    istream_t in = {(const u8 *)src, 0, src_len};
+    ostream_t out = {(u8 *)dst, dst_len};
+    while (IO_istream_len(&in) > 0) {
+        decode_frame(&out, &in, &parsed_dict);
     }
 
     free_dictionary(&parsed_dict);
 
-    return streams.dst - (u8 *)dst;
+    return out.ptr - (u8 *)dst;
 }
 
 /******* FRAME DECODING ******************************************************/
 
-static void decode_data_frame(io_streams_t *const streams,
+static void decode_data_frame(ostream_t *const out, istream_t *const in,
                               const dictionary_t *const dict);
-static void init_frame_context(io_streams_t *const streams,
-                               frame_context_t *const context,
+static void init_frame_context(frame_context_t *const context,
+                               istream_t *const in,
                                const dictionary_t *const dict);
 static void free_frame_context(frame_context_t *const context);
 static void parse_frame_header(frame_header_t *const header,
-                               const u8 *const src, const size_t src_len);
+                               istream_t *const in);
 static void frame_context_apply_dict(frame_context_t *const ctx,
                                      const dictionary_t *const dict);
 
-static void decompress_data(io_streams_t *const streams,
-                            frame_context_t *const ctx);
+static void decompress_data(frame_context_t *const ctx, ostream_t *const out,
+                            istream_t *const in);
 
-static void decode_frame(io_streams_t *const streams,
+static void decode_frame(ostream_t *const out, istream_t *const in,
                          const dictionary_t *const dict) {
-    if (streams->src_len < 4) {
-        INP_SIZE();
-    }
-    const u32 magic_number = read_bits_LE(streams->src, 32, 0);
+    const u32 magic_number = IO_read_bits(in, 32);
 
-    streams->src += 4;
-    streams->src_len -= 4;
-    if (magic_number >= 0x184D2A50U && magic_number <= 0x184D2A5F) {
-        // skippable frame
-        if (streams->src_len < 4) {
-            INP_SIZE();
-        }
-        const size_t frame_size = read_bits_LE(streams->src, 32, 32);
-
-        if (streams->src_len < 4 + frame_size) {
-            INP_SIZE();
-        }
+    if ((magic_number & ~0xFU) == 0x184D2A50U) {
+        // Skippable frame
+        const size_t frame_size = IO_read_bits(in, 32);
 
         // skip over frame
-        streams->src += 4 + frame_size;
-        streams->src_len -= 4 + frame_size;
+        IO_advance_input(in, frame_size);
     } else if (magic_number == 0xFD2FB528U) {
         // ZSTD frame
-        decode_data_frame(streams, dict);
+        decode_data_frame(out, in, dict);
     } else {
         // not a real frame
         ERROR("Invalid magic number");
@@ -410,40 +434,38 @@ static void decode_frame(io_streams_t *const streams,
 /// are skippable frames.
 /// See
 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#general-structure-of-zstandard-frame-format
-static void decode_data_frame(io_streams_t *const streams,
+static void decode_data_frame(ostream_t *const out, istream_t *const in,
                               const dictionary_t *const dict) {
     frame_context_t ctx;
 
     // Initialize the context that needs to be carried from block to block
-    init_frame_context(streams, &ctx, dict);
+    init_frame_context(&ctx, in, dict);
 
     if (ctx.header.frame_content_size != 0 &&
-        ctx.header.frame_content_size > streams->dst_len) {
+        ctx.header.frame_content_size > out->len) {
         OUT_SIZE();
     }
 
-    decompress_data(streams, &ctx);
+    decompress_data(&ctx, out, in);
 
     free_frame_context(&ctx);
 }
 
 /// Takes the information provided in the header and dictionary, and initializes
 /// the context for this frame
-static void init_frame_context(io_streams_t *const streams,
-                               frame_context_t *const context,
+static void init_frame_context(frame_context_t *const context,
+                               istream_t *const in,
                                const dictionary_t *const dict) {
     // Most fields in context are correct when initialized to 0
-    memset(context, 0x00, sizeof(frame_context_t));
+    memset(context, 0, sizeof(frame_context_t));
 
     // Parse data from the frame header
-    parse_frame_header(&context->header, streams->src, streams->src_len);
-    streams->src += context->header.header_size;
-    streams->src_len -= context->header.header_size;
+    parse_frame_header(&context->header, in);
 
     // Set up the offset history for the repeat offset commands
-    context->previous_offsets[1] = 1;
-    context->previous_offsets[2] = 4;
-    context->previous_offsets[3] = 8;
+    context->previous_offsets[0] = 1;
+    context->previous_offsets[1] = 4;
+    context->previous_offsets[2] = 8;
 
     // Apply details from the dict if it exists
     frame_context_apply_dict(context, dict);
@@ -460,12 +482,8 @@ static void free_frame_context(frame_context_t *const context) {
 }
 
 static void parse_frame_header(frame_header_t *const header,
-                               const u8 *const src, const size_t src_len) {
-    if (src_len < 1) {
-        INP_SIZE();
-    }
-
-    const u8 descriptor = read_bits_LE(src, 8, 0);
+                               istream_t *const in) {
+    const u8 descriptor = IO_read_bits(in, 8);
 
     // decode frame header descriptor into flags
     const u8 frame_content_size_flag = descriptor >> 6;
@@ -478,28 +496,20 @@ static void parse_frame_header(frame_header_t *const header,
         CORRUPTION();
     }
 
-    int header_size = 1;
-
     header->single_segment_flag = single_segment_flag;
     header->content_checksum_flag = content_checksum_flag;
 
     // decode window size
     if (!single_segment_flag) {
-        if (src_len < header_size + 1) {
-            INP_SIZE();
-        }
-
         // Use the algorithm from the specification to compute window size
         // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor
-        u8 window_descriptor = src[header_size];
+        u8 window_descriptor = IO_read_bits(in, 8);
         u8 exponent = window_descriptor >> 3;
         u8 mantissa = window_descriptor & 7;
 
         size_t window_base = (size_t)1 << (10 + exponent);
         size_t window_add = (window_base / 8) * mantissa;
         header->window_size = window_base + window_add;
-
-        header_size++;
     }
 
     // decode dictionary id if it exists
@@ -507,13 +517,7 @@ static void parse_frame_header(frame_header_t *const header,
         const int bytes_array[] = {0, 1, 2, 4};
         const int bytes = bytes_array[dictionary_id_flag];
 
-        if (src_len < header_size + bytes) {
-            INP_SIZE();
-        }
-
-        header->dictionary_id = read_bits_LE(src + header_size, bytes * 8, 0);
-
-        header_size += bytes;
+        header->dictionary_id = IO_read_bits(in, bytes * 8);
     } else {
         header->dictionary_id = 0;
     }
@@ -525,17 +529,10 @@ static void parse_frame_header(frame_header_t *const header,
         const int bytes_array[] = {1, 2, 4, 8};
         const int bytes = bytes_array[frame_content_size_flag];
 
-        if (src_len < header_size + bytes) {
-            INP_SIZE();
-        }
-
-        header->frame_content_size =
-            read_bits_LE(src + header_size, bytes * 8, 0);
+        header->frame_content_size = IO_read_bits(in, bytes * 8);
         if (bytes == 2) {
             header->frame_content_size += 256;
         }
-
-        header_size += bytes;
     } else {
         header->frame_content_size = 0;
     }
@@ -546,8 +543,6 @@ static void parse_frame_header(frame_header_t *const header,
         // back to the dictionary or not on large offsets
         header->window_size = header->frame_content_size;
     }
-
-    header->header_size = header_size;
 }
 
 /// A dictionary acts as initializing values for the frame context before
@@ -559,20 +554,15 @@ static void frame_context_apply_dict(frame_context_t *const ctx,
     if (!dict || !dict->content)
         return;
 
-    if (ctx->header.dictionary_id == 0 && dict->dictionary_id != 0) {
-        // The dictionary is unneeded, and shouldn't be used as it may interfere
-        // with the default offset history
-        return;
-    }
-
-    // If the dictionary id is 0, it doesn't matter if we provide the wrong raw
-    // content dict, it won't change anything
+    // If the requested dictionary_id is non-zero, the correct dictionary must
+    // be present
     if (ctx->header.dictionary_id != 0 &&
         ctx->header.dictionary_id != dict->dictionary_id) {
-        ERROR("Wrong/no dictionary provided");
+        ERROR("Wrong dictionary provided");
     }
 
-    // Copy the pointer in so we can reference it in sequence execution
+    // Copy the dict content to the context for references during sequence
+    // execution
     ctx->dict_content = dict->content;
     ctx->dict_content_len = dict->content_size;
 
@@ -592,188 +582,137 @@ static void frame_context_apply_dict(frame_context_t *const ctx,
 }
 
 /// Decompress the data from a frame block by block
-static void decompress_data(io_streams_t *const streams,
-                            frame_context_t *const ctx) {
+static void decompress_data(frame_context_t *const ctx, ostream_t *const out,
+                            istream_t *const in) {
     int last_block = 0;
     do {
-        if (streams->src_len < 3) {
-            INP_SIZE();
-        }
         // Parse the block header
-        last_block = streams->src[0] & 1;
-        const int block_type = (streams->src[0] >> 1) & 3;
-        const size_t block_len = read_bits_LE(streams->src, 21, 3);
-
-        streams->src += 3;
-        streams->src_len -= 3;
+        last_block = IO_read_bits(in, 1);
+        const int block_type = IO_read_bits(in, 2);
+        const size_t block_len = IO_read_bits(in, 21);
 
         switch (block_type) {
         case 0: {
             // Raw, uncompressed block
-            if (streams->src_len < block_len) {
-                INP_SIZE();
-            }
-            if (streams->dst_len < block_len) {
-                OUT_SIZE();
-            }
-
+            const u8 *const read_ptr = IO_read_bytes(in, block_len);
+            u8 *const write_ptr = IO_write_bytes(out, block_len);
+            //
             // Copy the raw data into the output
-            memcpy(streams->dst, streams->src, block_len);
-
-            streams->src += block_len;
-            streams->src_len -= block_len;
-
-            streams->dst += block_len;
-            streams->dst_len -= block_len;
+            memcpy(write_ptr, read_ptr, block_len);
 
             ctx->current_total_output += block_len;
             break;
         }
         case 1: {
             // RLE block, repeat the first byte N times
-            if (streams->src_len < 1) {
-                INP_SIZE();
-            }
-            if (streams->dst_len < block_len) {
-                OUT_SIZE();
-            }
+            const u8 *const read_ptr = IO_read_bytes(in, 1);
+            u8 *const write_ptr = IO_write_bytes(out, block_len);
 
             // Copy `block_len` copies of `streams->src[0]` to the output
-            memset(streams->dst, streams->src[0], block_len);
-
-            streams->dst += block_len;
-            streams->dst_len -= block_len;
-
-            streams->src += 1;
-            streams->src_len -= 1;
+            memset(write_ptr, read_ptr[0], block_len);
 
             ctx->current_total_output += block_len;
             break;
         }
-        case 2:
-            // Compressed block, this is mode complex
-            decompress_block(streams, ctx, block_len);
+        case 2: {
+            // Compressed block
+            // Create a sub-stream for the block
+            istream_t block_stream = IO_make_sub_istream(in, block_len);
+            decompress_block(ctx, out, &block_stream);
             break;
+        }
         case 3:
             // Reserved block type
             CORRUPTION();
             break;
+        default:
+            IMPOSSIBLE();
         }
     } while (!last_block);
 
     if (ctx->header.content_checksum_flag) {
         // This program does not support checking the checksum, so skip over it
         // if it's present
-        if (streams->src_len < 4) {
-            INP_SIZE();
-        }
-        streams->src += 4;
-        streams->src_len -= 4;
+        IO_advance_input(in, 4);
     }
 }
 /******* END FRAME DECODING ***************************************************/
 
 /******* BLOCK DECOMPRESSION **************************************************/
-static void decompress_block(io_streams_t *const streams, frame_context_t *const ctx,
-                             const size_t block_len) {
-    if (streams->src_len < block_len) {
-        INP_SIZE();
-    }
-    // We need this to determine how long the compressed literals block was
-    const u8 *const end_of_block = streams->src + block_len;
-
+static void decompress_block(frame_context_t *const ctx, ostream_t *const out,
+                             istream_t *const in) {
     // Part 1: decode the literals block
     u8 *literals = NULL;
-    const size_t literals_size = decode_literals(streams, ctx, &literals);
+    const size_t literals_size = decode_literals(ctx, in, &literals);
 
     // Part 2: decode the sequences block
-    if (streams->src > end_of_block) {
-        INP_SIZE();
-    }
-    const size_t sequences_size = end_of_block - streams->src;
     sequence_command_t *sequences = NULL;
     const size_t num_sequences =
-        decode_sequences(ctx, streams->src, sequences_size, &sequences);
-
-    streams->src += sequences_size;
-    streams->src_len -= sequences_size;
+        decode_sequences(ctx, in, &sequences);
 
     // Part 3: combine literals and sequence commands to generate output
-    execute_sequences(streams, ctx, sequences, num_sequences, literals,
-                      literals_size);
+    execute_sequences(ctx, out, literals, literals_size, sequences,
+                      num_sequences);
     free(literals);
     free(sequences);
 }
 /******* END BLOCK DECOMPRESSION **********************************************/
 
 /******* LITERALS DECODING ****************************************************/
-static size_t decode_literals_simple(io_streams_t *const streams,
-                                     u8 **const literals, const int block_type,
+static size_t decode_literals_simple(istream_t *const in, u8 **const literals,
+                                     const int block_type,
                                      const int size_format);
-static size_t decode_literals_compressed(io_streams_t *const streams,
-                                         frame_context_t *const ctx,
+static size_t decode_literals_compressed(frame_context_t *const ctx,
+                                         istream_t *const in,
                                          u8 **const literals,
                                          const int block_type,
                                          const int size_format);
-static size_t decode_huf_table(const u8 *src, size_t src_len,
-                               HUF_dtable *const dtable);
-static size_t fse_decode_hufweights(const u8 *const src, const size_t src_len,
-                                    u8 *const weights, int *const num_symbs,
-                                    const size_t compressed_size);
-
-static size_t decode_literals(io_streams_t *const streams,
-                              frame_context_t *const ctx, u8 **const literals) {
-    if (streams->src_len < 1) {
-        INP_SIZE();
-    }
+static void decode_huf_table(istream_t *const in, HUF_dtable *const dtable);
+static void fse_decode_hufweights(ostream_t *weights, istream_t *const in,
+                                    int *const num_symbs);
+
+static size_t decode_literals(frame_context_t *const ctx, istream_t *const in,
+                              u8 **const literals) {
     // Decode literals header
-    int block_type = streams->src[0] & 3;
-    int size_format = (streams->src[0] >> 2) & 3;
+    int block_type = IO_read_bits(in, 2);
+    int size_format = IO_read_bits(in, 2);
 
     if (block_type <= 1) {
         // Raw or RLE literals block
-        return decode_literals_simple(streams, literals, block_type,
+        return decode_literals_simple(in, literals, block_type,
                                       size_format);
     } else {
         // Huffman compressed literals
-        return decode_literals_compressed(streams, ctx, literals, block_type,
+        return decode_literals_compressed(ctx, in, literals, block_type,
                                           size_format);
     }
 }
 
 /// Decodes literals blocks in raw or RLE form
-static size_t decode_literals_simple(io_streams_t *const streams,
-                                     u8 **const literals, const int block_type,
+static size_t decode_literals_simple(istream_t *const in, u8 **const literals,
+                                     const int block_type,
                                      const int size_format) {
     size_t size;
     switch (size_format) {
-    // These cases are in the form X0
-    // In this case, the X bit is actually part of the size field
+    // These cases are in the form ?0
+    // In this case, the ? bit is actually part of the size field
     case 0:
     case 2:
-        size = read_bits_LE(streams->src, 5, 3);
-        streams->src += 1;
-        streams->src_len -= 1;
+        // "Size_Format uses 1 bit. Regenerated_Size uses 5 bits (0-31)."
+        IO_rewind_bits(in, 1);
+        size = IO_read_bits(in, 2);
         break;
     case 1:
-        if (streams->src_len < 2) {
-            INP_SIZE();
-        }
-        size = read_bits_LE(streams->src, 12, 4);
-        streams->src += 2;
-        streams->src_len -= 2;
+        // "Size_Format uses 2 bits. Regenerated_Size uses 12 bits (0-4095)."
+        size = IO_read_bits(in, 12);
         break;
     case 3:
-        if (streams->src_len < 2) {
-            INP_SIZE();
-        }
-        size = read_bits_LE(streams->src, 20, 4);
-        streams->src += 3;
-        streams->src_len -= 3;
+        // "Size_Format uses 2 bits. Regenerated_Size uses 20 bits (0-1048575)."
+        size = IO_read_bits(in, 20);
         break;
     default:
-        // Impossible
-        size = -1;
+        // Size format is in range 0-3
+        IMPOSSIBLE();
     }
 
     if (size > MAX_LITERALS_SIZE) {
@@ -786,32 +725,28 @@ static size_t decode_literals_simple(io_streams_t *const streams,
     }
 
     switch (block_type) {
-    case 0:
+    case 0: {
         // Raw data
-        if (size > streams->src_len) {
-            INP_SIZE();
-        }
-        memcpy(*literals, streams->src, size);
-        streams->src += size;
-        streams->src_len -= size;
+        const u8 *const read_ptr = IO_read_bytes(in, size);
+        memcpy(*literals, read_ptr, size);
         break;
-    case 1:
+    }
+    case 1: {
         // Single repeated byte
-        if (1 > streams->src_len) {
-            INP_SIZE();
-        }
-        memset(*literals, streams->src[0], size);
-        streams->src += 1;
-        streams->src_len -= 1;
+        const u8 *const read_ptr = IO_read_bytes(in, 1);
+        memset(*literals, read_ptr[0], size);
         break;
     }
+    default:
+        IMPOSSIBLE();
+    }
 
     return size;
 }
 
 /// Decodes Huffman compressed literals
-static size_t decode_literals_compressed(io_streams_t *const streams,
-                                         frame_context_t *const ctx,
+static size_t decode_literals_compressed(frame_context_t *const ctx,
+                                         istream_t *const in,
                                          u8 **const literals,
                                          const int block_type,
                                          const int size_format) {
@@ -820,98 +755,78 @@ static size_t decode_literals_compressed(io_streams_t *const streams,
     int num_streams = 4;
     switch (size_format) {
     case 0:
+        // "A single stream. Both Compressed_Size and Regenerated_Size use 10
+        // bits (0-1023)."
         num_streams = 1;
     // Fall through as it has the same size format
     case 1:
-        if (streams->src_len < 3) {
-            INP_SIZE();
-        }
-        regenerated_size = read_bits_LE(streams->src, 10, 4);
-        compressed_size = read_bits_LE(streams->src, 10, 14);
-        streams->src += 3;
-        streams->src_len -= 3;
+        // "4 streams. Both Compressed_Size and Regenerated_Size use 10 bits
+        // (0-1023)."
+        regenerated_size = IO_read_bits(in, 10);
+        compressed_size = IO_read_bits(in, 10);
         break;
     case 2:
-        if (streams->src_len < 4) {
-            INP_SIZE();
-        }
-        regenerated_size = read_bits_LE(streams->src, 14, 4);
-        compressed_size = read_bits_LE(streams->src, 14, 18);
-        streams->src += 4;
-        streams->src_len -= 4;
+        // "4 streams. Both Compressed_Size and Regenerated_Size use 14 bits
+        // (0-16383)."
+        regenerated_size = IO_read_bits(in, 14);
+        compressed_size = IO_read_bits(in, 14);
         break;
     case 3:
-        if (streams->src_len < 5) {
-            INP_SIZE();
-        }
-        regenerated_size = read_bits_LE(streams->src, 18, 4);
-        compressed_size = read_bits_LE(streams->src, 18, 22);
-        streams->src += 5;
-        streams->src_len -= 5;
+        // "4 streams. Both Compressed_Size and Regenerated_Size use 18 bits
+        // (0-262143)."
+        regenerated_size = IO_read_bits(in, 18);
+        compressed_size = IO_read_bits(in, 18);
         break;
     default:
         // Impossible
-        compressed_size = regenerated_size = -1;
+        IMPOSSIBLE();
     }
     if (regenerated_size > MAX_LITERALS_SIZE ||
         compressed_size > regenerated_size) {
         CORRUPTION();
     }
 
-    if (compressed_size > streams->src_len) {
-        INP_SIZE();
-    }
-
     *literals = malloc(regenerated_size);
     if (!*literals) {
         BAD_ALLOC();
     }
 
+    ostream_t lit_stream = IO_make_ostream(*literals, regenerated_size);
+    istream_t huf_stream = IO_make_sub_istream(in, compressed_size);
+
     if (block_type == 2) {
         // Decode provided Huffman table
 
         HUF_free_dtable(&ctx->literals_dtable);
-        const size_t size = decode_huf_table(streams->src, compressed_size,
-                                             &ctx->literals_dtable);
-        streams->src += size;
-        streams->src_len -= size;
-        compressed_size -= size;
+        decode_huf_table(&huf_stream, &ctx->literals_dtable);
     } else {
-        // If we're to repeat the previous Huffman table, make sure it exists
+        // If the previous Huffman table is being repeated, ensure it exists
         if (!ctx->literals_dtable.symbols) {
             CORRUPTION();
         }
     }
 
+    size_t symbols_decoded;
     if (num_streams == 1) {
-        HUF_decompress_1stream(&ctx->literals_dtable, *literals,
-                               regenerated_size, streams->src, compressed_size);
+        symbols_decoded = HUF_decompress_1stream(&ctx->literals_dtable, &lit_stream, &huf_stream);
     } else {
-        HUF_decompress_4stream(&ctx->literals_dtable, *literals,
-                               regenerated_size, streams->src, compressed_size);
+        symbols_decoded = HUF_decompress_4stream(&ctx->literals_dtable, &lit_stream, &huf_stream);
+    }
+
+    if (symbols_decoded != regenerated_size) {
+        CORRUPTION();
     }
-    streams->src += compressed_size;
-    streams->src_len -= compressed_size;
 
     return regenerated_size;
 }
 
 // Decode the Huffman table description
-static size_t decode_huf_table(const u8 *src, size_t src_len,
-                               HUF_dtable *const dtable) {
-    if (src_len < 1) {
-        INP_SIZE();
-    }
+static void decode_huf_table(istream_t *const in, HUF_dtable *const dtable) {
+    const u8 header = IO_read_bits(in, 8);
 
-    const u8 *const osrc = src;
-
-    const u8 header = src[0];
     u8 weights[HUF_MAX_SYMBS];
     memset(weights, 0, sizeof(weights));
 
-    src++;
-    src_len--;
-
     int num_symbs;
 
     if (header >= 128) {
@@ -919,67 +834,56 @@ static size_t decode_huf_table(const u8 *src, size_t src_len,
         num_symbs = header - 127;
         const size_t bytes = (num_symbs + 1) / 2;
 
-        if (bytes > src_len) {
-            INP_SIZE();
-        }
+        const u8 *const weight_src = IO_read_bytes(in, bytes);
 
         for (int i = 0; i < num_symbs; i++) {
             // read_bits_LE isn't applicable here because the weights are order
             // reversed within each byte
             // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#huffman-tree-header
             if (i % 2 == 0) {
-                weights[i] = src[i / 2] >> 4;
+                weights[i] = weight_src[i / 2] >> 4;
             } else {
-                weights[i] = src[i / 2] & 0xf;
+                weights[i] = weight_src[i / 2] & 0xf;
             }
         }
-
-        src += bytes;
-        src_len -= bytes;
     } else {
         // The weights are FSE encoded, decode them before we can construct the
         // table
-        const size_t size =
-            fse_decode_hufweights(src, src_len, weights, &num_symbs, header);
-        src += size;
-        src_len -= size;
+        istream_t fse_stream = IO_make_sub_istream(in, header);
+        ostream_t weight_stream = IO_make_ostream(weights, HUF_MAX_SYMBS);
+        fse_decode_hufweights(&weight_stream, &fse_stream, &num_symbs);
     }
 
     // Construct the table using the decoded weights
     HUF_init_dtable_usingweights(dtable, weights, num_symbs);
-    return src - osrc;
 }
 
-static size_t fse_decode_hufweights(const u8 *const src, const size_t src_len,
-                                    u8 *const weights, int *const num_symbs,
-                                    const size_t compressed_size) {
+static void fse_decode_hufweights(ostream_t *weights, istream_t *const in,
+                                    int *const num_symbs) {
     const int MAX_ACCURACY_LOG = 7;
 
     FSE_dtable dtable;
 
     // Construct the FSE table
-    const size_t read =
-            FSE_decode_header(&dtable, src, src_len, MAX_ACCURACY_LOG);
-
-    if (src_len < compressed_size) {
-        INP_SIZE();
-    }
+    FSE_decode_header(&dtable, in, MAX_ACCURACY_LOG);
 
     // Decode the weights
-    *num_symbs = FSE_decompress_interleaved2(
-        &dtable, weights, HUF_MAX_SYMBS, src + read, compressed_size - read);
+    *num_symbs = FSE_decompress_interleaved2(&dtable, weights, in);
 
     FSE_free_dtable(&dtable);
-
-    return compressed_size;
 }
 /******* END LITERALS DECODING ************************************************/
 
 /******* SEQUENCE DECODING ****************************************************/
 /// The combination of FSE states needed to decode sequences
 typedef struct {
-    u16 ll_state, of_state, ml_state;
-    FSE_dtable ll_table, of_table, ml_table;
+    FSE_dtable ll_table;
+    FSE_dtable of_table;
+    FSE_dtable ml_table;
+
+    u16 ll_state;
+    u16 of_state;
+    u16 ml_state;
 } sequence_state_t;
 
 /// Different modes to signal to decode_seq_tables what to do
@@ -1031,47 +935,36 @@ static const u8 SEQ_MATCH_LENGTH_EXTRA_BITS[53] = {
 /// Offset decoding is simpler so we just need a maximum code value
 static const u8 SEQ_MAX_CODES[3] = {35, -1, 52};
 
-static void decompress_sequences(frame_context_t *const ctx, const u8 *src,
-                                 size_t src_len,
+static void decompress_sequences(frame_context_t *const ctx,
+                                 istream_t *const in,
                                  sequence_command_t *const sequences,
                                  const size_t num_sequences);
 static sequence_command_t decode_sequence(sequence_state_t *const state,
                                           const u8 *const src,
                                           i64 *const offset);
-static size_t decode_seq_table(const u8 *src, size_t src_len,
-                               FSE_dtable *const table, const seq_part_t type,
-                               const seq_mode_t mode);
+static void decode_seq_table(istream_t *const in, FSE_dtable *const table,
+                               const seq_part_t type, const seq_mode_t mode);
 
-static size_t decode_sequences(frame_context_t *const ctx, const u8 *src,
-                               size_t src_len,
+static size_t decode_sequences(frame_context_t *const ctx, istream_t *in,
                                sequence_command_t **const sequences) {
     size_t num_sequences;
 
     // Decode the sequence header and allocate space for the output
-    if (src_len < 1) {
-        INP_SIZE();
-    }
-    if (src[0] == 0) {
+    u8 header = IO_read_bits(in, 8);
+    if (header == 0) {
+        // "There are no sequences. The sequence section stops there.
+        // Regenerated content is defined entirely by literals section."
         *sequences = NULL;
         return 0;
-    } else if (src[0] < 128) {
-        num_sequences = src[0];
-        src++;
-        src_len--;
-    } else if (src[0] < 255) {
-        if (src_len < 2) {
-            INP_SIZE();
-        }
-        num_sequences = ((src[0] - 128) << 8) + src[1];
-        src += 2;
-        src_len -= 2;
+    } else if (header < 128) {
+        // "Number_of_Sequences = byte0 . Uses 1 byte."
+        num_sequences = header;
+    } else if (header < 255) {
+        // "Number_of_Sequences = ((byte0-128) << 8) + byte1 . Uses 2 bytes."
+        num_sequences = ((header - 128) << 8) + IO_read_bits(in, 8);
     } else {
-        if (src_len < 3) {
-            INP_SIZE();
-        }
-        num_sequences = src[1] + ((u64)src[2] << 8) + 0x7F00;
-        src += 3;
-        src_len -= 3;
+        // "Number_of_Sequences = byte1 + (byte2<<8) + 0x7F00 . Uses 3 bytes."
+        num_sequences = IO_read_bits(in, 16) + 0x7F00;
     }
 
     *sequences = malloc(num_sequences * sizeof(sequence_command_t));
@@ -1079,51 +972,29 @@ static size_t decode_sequences(frame_context_t *const ctx, const u8 *src,
         BAD_ALLOC();
     }
 
-    decompress_sequences(ctx, src, src_len, *sequences, num_sequences);
+    decompress_sequences(ctx, in, *sequences, num_sequences);
     return num_sequences;
 }
 
 /// Decompress the FSE encoded sequence commands
-static void decompress_sequences(frame_context_t *const ctx, const u8 *src,
-                                 size_t src_len,
+static void decompress_sequences(frame_context_t *const ctx, istream_t *in,
                                  sequence_command_t *const sequences,
                                  const size_t num_sequences) {
-    if (src_len < 1) {
-        INP_SIZE();
-    }
-    u8 compression_modes = src[0];
-    src++;
-    src_len--;
+    u8 compression_modes = IO_read_bits(in, 8);
 
     if ((compression_modes & 3) != 0) {
         CORRUPTION();
     }
 
-    {
-        size_t read;
-        // Update the tables we have stored in the context
-        read = decode_seq_table(src, src_len, &ctx->ll_dtable,
-                                seq_literal_length,
-                                (compression_modes >> 6) & 3);
-        src += read;
-        src_len -= read;
-    }
+    // Update the tables we have stored in the context
+    decode_seq_table(in, &ctx->ll_dtable, seq_literal_length,
+                     (compression_modes >> 6) & 3);
 
-    {
-        const size_t read =
-                decode_seq_table(src, src_len, &ctx->of_dtable, seq_offset,
-                                 (compression_modes >> 4) & 3);
-        src += read;
-        src_len -= read;
-    }
+    decode_seq_table(in, &ctx->of_dtable, seq_offset,
+                     (compression_modes >> 4) & 3);
 
-    {
-        const size_t read = decode_seq_table(src, src_len, &ctx->ml_dtable,
-                                             seq_match_length,
-                                             (compression_modes >> 2) & 3);
-        src += read;
-        src_len -= read;
-    }
+    decode_seq_table(in, &ctx->ml_dtable, seq_match_length,
+                     (compression_modes >> 2) & 3);
 
     // Check to make sure none of the tables are uninitialized
     if (!ctx->ll_dtable.symbols || !ctx->of_dtable.symbols ||
@@ -1137,8 +1008,13 @@ static void decompress_sequences(frame_context_t *const ctx, const u8 *src,
     memcpy(&state.of_table, &ctx->of_dtable, sizeof(FSE_dtable));
     memcpy(&state.ml_table, &ctx->ml_dtable, sizeof(FSE_dtable));
 
-    const int padding = 8 - log2inf(src[src_len - 1]);
-    i64 offset = src_len * 8 - padding;
+    size_t len = IO_istream_len(in);
+    const u8 *const src = IO_read_bytes(in, len);
+
+    // "After writing the last bit containing information, the compressor writes
+    // a single 1-bit and then fills the byte with 0-7 0 bits of padding."
+    const int padding = 8 - log2inf(src[len - 1]);
+    i64 offset = len * 8 - padding;
 
     FSE_init_state(&state.ll_table, &state.ll_state, src, &offset);
     FSE_init_state(&state.of_table, &state.of_state, src, &offset);
@@ -1153,7 +1029,7 @@ static void decompress_sequences(frame_context_t *const ctx, const u8 *src,
         CORRUPTION();
     }
 
-    // Don't free our tables so they can be used in the next block
+    // Don't free tables so they can be used in the next block
 }
 
 // Decode a single sequence and update the state
@@ -1194,9 +1070,8 @@ static sequence_command_t decode_sequence(sequence_state_t *const state,
 }
 
 /// Given a sequence part and table mode, decode the FSE distribution
-static size_t decode_seq_table(const u8 *src, size_t src_len,
-                               FSE_dtable *const table, const seq_part_t type,
-                               const seq_mode_t mode) {
+static void decode_seq_table(istream_t *const in, FSE_dtable *const table,
+                               const seq_part_t type, const seq_mode_t mode) {
     // Constant arrays indexed by seq_part_t
     const i16 *const default_distributions[] = {SEQ_LITERAL_LENGTH_DEFAULT_DIST,
                                                 SEQ_OFFSET_DEFAULT_DIST,
@@ -1207,7 +1082,7 @@ static size_t decode_seq_table(const u8 *src, size_t src_len,
     const size_t max_accuracies[] = {9, 8, 9};
 
     if (mode != seq_repeat) {
-        // ree old one before overwriting
+        // Free old one before overwriting
         FSE_free_dtable(table);
     }
 
@@ -1218,102 +1093,102 @@ static size_t decode_seq_table(const u8 *src, size_t src_len,
         const size_t accuracy_log = default_distribution_accuracies[type];
 
         FSE_init_dtable(table, distribution, symbs, accuracy_log);
-
-        return 0;
+        break;
     }
     case seq_rle: {
-        if (src_len < 1) {
-            INP_SIZE();
-        }
-        const u8 symb = src[0];
-        src++;
-        src_len--;
+        const u8 symb = IO_read_bits(in, 8);
         FSE_init_dtable_rle(table, symb);
-
-        return 1;
+        break;
     }
     case seq_fse: {
-        size_t read =
-            FSE_decode_header(table, src, src_len, max_accuracies[type]);
-        src += read;
-        src_len -= read;
-
-        return read;
+        FSE_decode_header(table, in, max_accuracies[type]);
+        break;
     }
     case seq_repeat:
-        // Don't have to do anything here as we're not changing the table
-        return 0;
+        // Nothing to do here, table will be unchanged
+        break;
     default:
         // Impossible, as mode is from 0-3
-        return -1;
+        IMPOSSIBLE();
+        break;
     }
 }
 /******* END SEQUENCE DECODING ************************************************/
 
 /******* SEQUENCE EXECUTION ***************************************************/
-static void execute_sequences(io_streams_t *const streams,
-                              frame_context_t *const ctx,
+static void execute_sequences(frame_context_t *const ctx, ostream_t *const out,
+                              const u8 *const literals,
+                              const size_t literals_len,
                               const sequence_command_t *const sequences,
-                              const size_t num_sequences,
-                              const u8 *literals,
-                              size_t literals_len) {
+                              const size_t num_sequences) {
+    istream_t litstream = IO_make_istream(literals, literals_len);
+
     u64 *const offset_hist = ctx->previous_offsets;
     size_t total_output = ctx->current_total_output;
 
     for (size_t i = 0; i < num_sequences; i++) {
         const sequence_command_t seq = sequences[i];
 
-        if (seq.literal_length > literals_len) {
-            CORRUPTION();
-        }
-
-        if (streams->dst_len < seq.literal_length + seq.match_length) {
-            OUT_SIZE();
-        }
-        // Copy literals to output
-        memcpy(streams->dst, literals, seq.literal_length);
-
-        literals += seq.literal_length;
-        literals_len -= seq.literal_length;
+        {
+            if (seq.literal_length > IO_istream_len(&litstream)) {
+                CORRUPTION();
+            }
 
-        streams->dst += seq.literal_length;
-        streams->dst_len -= seq.literal_length;
+            u8 *const write_ptr = IO_write_bytes(out, seq.literal_length);
+            const u8 *const read_ptr =
+                    IO_read_bytes(&litstream, seq.literal_length);
+            // Copy literals to output
+            memcpy(write_ptr, read_ptr, seq.literal_length);
 
-        total_output += seq.literal_length;
+            total_output += seq.literal_length;
+        }
 
         size_t offset;
 
         // Offsets are special, we need to handle the repeat offsets
         if (seq.offset <= 3) {
-            u32 idx = seq.offset;
+            // "The first 3 values define a repeated offset and we will call
+            // them Repeated_Offset1, Repeated_Offset2, and Repeated_Offset3.
+            // They are sorted in recency order, with Repeated_Offset1 meaning
+            // 'most recent one'".
+
+            // Use 0 indexing for the array
+            u32 idx = seq.offset - 1;
             if (seq.literal_length == 0) {
-                // Special case when literal length is 0
+                // "There is an exception though, when current sequence's
+                // literals length is 0. In this case, repeated offsets are
+                // shifted by one, so Repeated_Offset1 becomes Repeated_Offset2,
+                // Repeated_Offset2 becomes Repeated_Offset3, and
+                // Repeated_Offset3 becomes Repeated_Offset1 - 1_byte."
                 idx++;
             }
 
-            if (idx == 1) {
-                offset = offset_hist[1];
+            if (idx == 0) {
+                offset = offset_hist[0];
             } else {
-                // If idx == 4 then literal length was 0 and the offset was 3
-                offset = idx < 4 ? offset_hist[idx] : offset_hist[1] - 1;
+                // If idx == 3 then literal length was 0 and the offset was 3,
+                // as per the exception listed above
+                offset = idx < 3 ? offset_hist[idx] : offset_hist[0] - 1;
 
-                // If idx == 2 we don't need to modify offset_hist[3]
-                if (idx > 2) {
-                    offset_hist[3] = offset_hist[2];
+                // If idx == 1 we don't need to modify offset_hist[2]
+                if (idx > 1) {
+                    offset_hist[2] = offset_hist[1];
                 }
-                offset_hist[2] = offset_hist[1];
-                offset_hist[1] = offset;
+                offset_hist[1] = offset_hist[0];
+                offset_hist[0] = offset;
             }
         } else {
             offset = seq.offset - 3;
 
             // Shift back history
-            offset_hist[3] = offset_hist[2];
             offset_hist[2] = offset_hist[1];
-            offset_hist[1] = offset;
+            offset_hist[1] = offset_hist[0];
+            offset_hist[0] = offset;
         }
 
         size_t match_length = seq.match_length;
+
+        u8 *write_ptr = IO_write_bytes(out, match_length);
         if (total_output <= ctx->header.window_size) {
             // In this case offset might go back into the dictionary
             if (offset > total_output + ctx->dict_content_len) {
@@ -1322,13 +1197,16 @@ static void execute_sequences(io_streams_t *const streams,
             }
 
             if (offset > total_output) {
+                // "The rest of the dictionary is its content. The content act
+                // as a "past" in front of data to compress or decompress, so it
+                // can be referenced in sequence commands."
                 const size_t dict_copy =
                     MIN(offset - total_output, match_length);
                 const size_t dict_offset =
                     ctx->dict_content_len - (offset - total_output);
-                for (size_t i = 0; i < dict_copy; i++) {
-                    *streams->dst++ = ctx->dict_content[dict_offset + i];
-                }
+
+                memcpy(write_ptr, ctx->dict_content + dict_offset, dict_copy);
+                write_ptr += dict_copy;
                 match_length -= dict_copy;
             }
         } else if (offset > ctx->header.window_size) {
@@ -1340,31 +1218,29 @@ static void execute_sequences(io_streams_t *const streams,
         // ex: if the output so far was "abc", a command with offset=3 and
         // match_length=6 would produce "abcabcabc" as the new output
         for (size_t i = 0; i < match_length; i++) {
-            *streams->dst = *(streams->dst - offset);
-            streams->dst++;
+            *write_ptr = *(write_ptr - offset);
+            write_ptr++;
         }
 
-        streams->dst_len -= seq.match_length;
         total_output += seq.match_length;
     }
 
-    if (streams->dst_len < literals_len) {
-        OUT_SIZE();
-    }
-    // Copy any leftover literals
-    memcpy(streams->dst, literals, literals_len);
-    streams->dst += literals_len;
-    streams->dst_len -= literals_len;
+    {
+        size_t len = IO_istream_len(&litstream);
+        u8 *const write_ptr = IO_write_bytes(out, len);
+        const u8 *const read_ptr = IO_read_bytes(&litstream, len);
+        // Copy any leftover literals
+        memcpy(write_ptr, read_ptr, len);
 
-    total_output += literals_len;
+        total_output += len;
+    }
 
     ctx->current_total_output = total_output;
 }
 /******* END SEQUENCE EXECUTION ***********************************************/
 
 /******* OUTPUT SIZE COUNTING *************************************************/
-size_t traverse_frame(const frame_header_t *const header, const u8 *src,
-                      size_t src_len);
+static void traverse_frame(const frame_header_t *const header, istream_t *const in);
 
 /// Get the decompressed size of an input stream so memory can be allocated in
 /// advance.
@@ -1372,115 +1248,75 @@ size_t traverse_frame(const frame_header_t *const header, const u8 *src,
 /// implementation, as this API allows for the decompression of multiple
 /// concatenated frames.
 size_t ZSTD_get_decompressed_size(const void *src, const size_t src_len) {
-  const u8 *ip = (const u8 *) src;
-  size_t ip_len = src_len;
-  size_t dst_size = 0;
-
-  // Each frame header only gives us the size of its frame, so iterate over all
-  // frames
-  while (ip_len > 0) {
-    if (ip_len < 4) {
-      INP_SIZE();
-    }
-
-    const u32 magic_number = read_bits_LE(ip, 32, 0);
-
-    ip += 4;
-    ip_len -= 4;
-    if (magic_number >= 0x184D2A50U && magic_number <= 0x184D2A5F) {
-        // skippable frame, this has no impact on output size
-        if (ip_len < 4) {
-            INP_SIZE();
-        }
-        const size_t frame_size = read_bits_LE(ip, 32, 32);
+    istream_t in = IO_make_istream(src, src_len);
+    size_t dst_size = 0;
+
+    // Each frame header only gives us the size of its frame, so iterate over
+    // all
+    // frames
+    while (IO_istream_len(&in) > 0) {
+        const u32 magic_number = IO_read_bits(&in, 32);
+
+        if ((magic_number & ~0xFU) == 0x184D2A50U) {
+            // skippable frame, this has no impact on output size
+            const size_t frame_size = IO_read_bits(&in, 32);
+            IO_advance_input(&in, frame_size);
+        } else if (magic_number == 0xFD2FB528U) {
+            // ZSTD frame
+            frame_header_t header;
+            parse_frame_header(&header, &in);
+
+            if (header.frame_content_size == 0 && !header.single_segment_flag) {
+                // Content size not provided, we can't tell
+                return -1;
+            }
 
-        if (ip_len < 4 + frame_size) {
-            INP_SIZE();
-        }
+            dst_size += header.frame_content_size;
 
-        // skip over frame
-        ip += 4 + frame_size;
-        ip_len -= 4 + frame_size;
-    } else if (magic_number == 0xFD2FB528U) {
-        // ZSTD frame
-        frame_header_t header;
-        parse_frame_header(&header, ip, ip_len);
-
-        if (header.frame_content_size == 0 && !header.single_segment_flag) {
-            // Content size not provided, we can't tell
-            return -1;
+            // Consume the input from the frame to reach the start of the next
+            traverse_frame(&header, &in);
+        } else {
+            // not a real frame
+            ERROR("Invalid magic number");
         }
-
-        dst_size += header.frame_content_size;
-
-        // we need to traverse the frame to find when the next one starts
-        const size_t traversed = traverse_frame(&header, ip, ip_len);
-        ip += traversed;
-        ip_len -= traversed;
-    } else {
-        // not a real frame
-        ERROR("Invalid magic number");
     }
-  }
 
-  return dst_size;
+    return dst_size;
 }
 
 /// Iterate over each block in a frame to find the end of it, to get to the
 /// start of the next frame
-size_t traverse_frame(const frame_header_t *const header, const u8 *src,
-                      size_t src_len) {
-    const u8 *const src_beg = src;
-    const u8 *const src_end = src + src_len;
-    src += header->header_size;
-    src_len += header->header_size;
-
+static void traverse_frame(const frame_header_t *const header, istream_t *const in) {
     int last_block = 0;
 
     do {
-        if (src + 3 > src_end) {
-            INP_SIZE();
-        }
         // Parse the block header
-        last_block = src[0] & 1;
-        const int block_type = (src[0] >> 1) & 3;
-        const size_t block_len = read_bits_LE(src, 21, 3);
+        last_block = IO_read_bits(in, 1);
+        const int block_type = IO_read_bits(in, 2);
+        const size_t block_len = IO_read_bits(in, 21);
 
-        src += 3;
         switch (block_type) {
         case 0: // Raw block, block_len bytes
-            if (src + block_len > src_end) {
-                INP_SIZE();
-            }
-            src += block_len;
+            IO_advance_input(in, block_len);
             break;
         case 1: // RLE block, 1 byte
-            if (src + 1 > src_end) {
-                INP_SIZE();
-            }
-            src++;
+            IO_advance_input(in, 1);
             break;
         case 2: // Compressed block, compressed size is block_len
-            if (src + block_len > src_end) {
-                INP_SIZE();
-            }
-            src += block_len;
+            IO_advance_input(in, block_len);
             break;
         case 3:
             // Reserved block type
             CORRUPTION();
             break;
+        default:
+            IMPOSSIBLE();
         }
     } while (!last_block);
 
     if (header->content_checksum_flag) {
-        if (src + 4 > src_end) {
-            INP_SIZE();
-        }
-        src += 4;
+        IO_advance_input(in, 4);
     }
-
-    return src - src_beg;
 }
 
 /******* END OUTPUT SIZE COUNTING *********************************************/
@@ -1495,68 +1331,46 @@ static void parse_dictionary(dictionary_t *const dict, const u8 *src,
     if (src_len < 8) {
         INP_SIZE();
     }
-    const u32 magic_number = read_bits_LE(src, 32, 0);
+
+    istream_t in = IO_make_istream(src, src_len);
+
+    const u32 magic_number = IO_read_bits(&in, 32);
     if (magic_number != 0xEC30A437) {
         // raw content dict
         init_raw_content_dict(dict, src, src_len);
         return;
     }
-    dict->dictionary_id = read_bits_LE(src, 32, 32);
 
-    src += 8;
-    src_len -= 8;
+    dict->dictionary_id = IO_read_bits(&in, 32);
 
     // Parse the provided entropy tables in order
-    {
-        const size_t read =
-                decode_huf_table(src, src_len, &dict->literals_dtable);
-        src += read;
-        src_len -= read;
-    }
-    {
-        const size_t read = decode_seq_table(src, src_len, &dict->of_dtable,
-                                             seq_offset, seq_fse);
-        src += read;
-        src_len -= read;
-    }
-    {
-        const size_t read = decode_seq_table(src, src_len, &dict->ml_dtable,
-                                             seq_match_length, seq_fse);
-        src += read;
-        src_len -= read;
-    }
-    {
-        const size_t read = decode_seq_table(src, src_len, &dict->ll_dtable,
-                                             seq_literal_length, seq_fse);
-        src += read;
-        src_len -= read;
-    }
+    decode_huf_table(&in, &dict->literals_dtable);
+    decode_seq_table(&in, &dict->of_dtable, seq_offset, seq_fse);
+    decode_seq_table(&in, &dict->ml_dtable, seq_match_length, seq_fse);
+    decode_seq_table(&in, &dict->ll_dtable, seq_literal_length, seq_fse);
 
-    if (src_len < 12) {
-        INP_SIZE();
-    }
     // Read in the previous offset history
-    dict->previous_offsets[1] = read_bits_LE(src, 32, 0);
-    dict->previous_offsets[2] = read_bits_LE(src, 32, 32);
-    dict->previous_offsets[3] = read_bits_LE(src, 32, 64);
-
-    src += 12;
-    src_len -= 12;
+    dict->previous_offsets[0] = IO_read_bits(&in, 32);
+    dict->previous_offsets[1] = IO_read_bits(&in, 32);
+    dict->previous_offsets[2] = IO_read_bits(&in, 32);
 
     // Ensure the provided offsets aren't too large
-    for (int i = 1; i <= 3; i++) {
+    for (int i = 0; i < 3; i++) {
         if (dict->previous_offsets[i] > src_len) {
             ERROR("Dictionary corrupted");
         }
     }
+
     // The rest is the content
-    dict->content = malloc(src_len);
+    dict->content_size = IO_istream_len(&in);
+    dict->content = malloc(dict->content_size);
     if (!dict->content) {
         BAD_ALLOC();
     }
 
-    dict->content_size = src_len;
-    memcpy(dict->content, src, src_len);
+    const u8 *const content = IO_read_bytes(&in, dict->content_size);
+
+    memcpy(dict->content, content, dict->content_size);
 }
 
 /// If parse_dictionary is given a raw content dictionary, it delegates here
@@ -1586,6 +1400,143 @@ static void free_dictionary(dictionary_t *const dict) {
 }
 /******* END DICTIONARY PARSING ***********************************************/
 
+/******* IO STREAM OPERATIONS *************************************************/
+#define UNALIGNED() ERROR("Attempting to operate on a non-byte aligned stream")
+/// Reads `num` bits from a bitstream, and updates the internal offset
+static inline u64 IO_read_bits(istream_t *const in, const int num) {
+    if (num > 64) {
+        return -1;
+    }
+
+    const size_t bytes = (num + in->bit_offset + 7) / 8;
+    const size_t full_bytes = (num + in->bit_offset) / 8;
+    if (bytes > in->len) {
+        INP_SIZE();
+    }
+
+    const u64 result = read_bits_LE(in->ptr, num, in->bit_offset);
+
+    in->bit_offset = (num + in->bit_offset) % 8;
+    in->ptr += full_bytes;
+    in->len -= full_bytes;
+
+    return result;
+}
+
+/// If a non-zero number of bits have been read from the current byte, advance
+/// the offset to the next byte
+static inline void IO_rewind_bits(istream_t *const in, int num) {
+    if (num < 0) {
+        ERROR("Attempting to rewind stream by a negative number of bits");
+    }
+
+    const int new_offset = in->bit_offset - num;
+    const i64 bytes = (new_offset - 7) / 8;
+
+    in->ptr += bytes;
+    in->len -= bytes;
+    in->bit_offset = ((new_offset % 8) + 8) % 8;
+}
+
+/// If the remaining bits in a byte will be unused, advance to the end of the
+/// byte
+static inline void IO_align_stream(istream_t *const in) {
+    if (in->bit_offset != 0) {
+        if (in->len == 0) {
+            INP_SIZE();
+        }
+        in->ptr++;
+        in->len--;
+        in->bit_offset = 0;
+    }
+}
+
+/// Write the given byte into the output stream
+static inline void IO_write_byte(ostream_t *const out, u8 symb) {
+    if (out->len == 0) {
+        OUT_SIZE();
+    }
+
+    out->ptr[0] = symb;
+    out->ptr++;
+    out->len--;
+}
+
+/// Returns the number of bytes left to be read in this stream.  The stream must
+/// be byte aligned.
+static inline size_t IO_istream_len(const istream_t *const in) {
+    return in->len;
+}
+
+/// Returns a pointer where `len` bytes can be read, and advances the internal
+/// state.  The stream must be byte aligned.
+static inline const u8 *IO_read_bytes(istream_t *const in, size_t len) {
+    if (len > in->len) {
+        INP_SIZE();
+    }
+    if (in->bit_offset != 0) {
+        UNALIGNED();
+    }
+    const u8 *const ptr = in->ptr;
+    in->ptr += len;
+    in->len -= len;
+
+    return ptr;
+}
+/// Returns a pointer to write `len` bytes to, and advances the internal state
+static inline u8 *IO_write_bytes(ostream_t *const out, size_t len) {
+    if (len > out->len) {
+        INP_SIZE();
+    }
+    u8 *const ptr = out->ptr;
+    out->ptr += len;
+    out->len -= len;
+
+    return ptr;
+}
+
+/// Advance the inner state by `len` bytes
+static inline void IO_advance_input(istream_t *const in, size_t len) {
+    if (len > in->len) {
+         INP_SIZE();
+    }
+    if (in->bit_offset != 0) {
+        UNALIGNED();
+    }
+
+    in->ptr += len;
+    in->len -= len;
+}
+
+/// Returns an `ostream_t` constructed from the given pointer and length
+static inline ostream_t IO_make_ostream(u8 *out, size_t len) {
+    return (ostream_t) { out, len };
+}
+
+/// Returns an `istream_t` constructed from the given pointer and length
+static inline istream_t IO_make_istream(const u8 *in, size_t len) {
+    return (istream_t) { in, 0, len };
+}
+
+/// Returns an `istream_t` with the same base as `in`, and length `len`
+/// Then, advance `in` to account for the consumed bytes
+/// `in` must be byte aligned
+static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len) {
+    if (len > in->len) {
+        INP_SIZE();
+    }
+    if (in->bit_offset != 0) {
+        UNALIGNED();
+    }
+    const istream_t sub = { in->ptr, in->bit_offset, len };
+
+    in->ptr += len;
+    in->len -= len;
+
+    return sub;
+}
+/******* END IO STREAM OPERATIONS *********************************************/
+
 /******* BITSTREAM OPERATIONS *************************************************/
 /// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits
 static inline u64 read_bits_LE(const u8 *src, const int num,
@@ -1676,28 +1627,29 @@ static inline void HUF_init_state(const HUF_dtable *const dtable,
     *state = STREAM_read_bits(src, bits, offset);
 }
 
-static size_t HUF_decompress_1stream(const HUF_dtable *const dtable, u8 *dst,
-                                     const size_t dst_len, const u8 *src,
-                                     size_t src_len) {
-    const u8 *const dst_max = dst + dst_len;
-    const u8 *const odst = dst;
+static size_t HUF_decompress_1stream(const HUF_dtable *const dtable,
+                                     ostream_t *const out,
+                                     istream_t *const in) {
+    const size_t len = IO_istream_len(in);
+    if (len == 0) {
+        INP_SIZE();
+    }
+    const u8 *const src = IO_read_bytes(in, len);
 
     // To maintain similarity with FSE, start from the end
     // Find the last 1 bit
-    const int padding = 8 - log2inf(src[src_len - 1]);
+    const int padding = 8 - log2inf(src[len - 1]);
 
-    i64 offset = src_len * 8 - padding;
+    i64 offset = len * 8 - padding;
     u16 state;
 
     HUF_init_state(dtable, &state, src, &offset);
 
-    while (dst < dst_max && offset > -dtable->max_bits) {
+    size_t symbols_written = 0;
+    while (offset > -dtable->max_bits) {
         // Iterate over the stream, decoding one symbol at a time
-        *dst++ = HUF_decode_symbol(dtable, &state, src, &offset);
-    }
-    // If we stopped before consuming all the input, we didn't have enough space
-    if (dst == dst_max && offset > -dtable->max_bits) {
-        OUT_SIZE();
+        IO_write_byte(out, HUF_decode_symbol(dtable, &state, src, &offset));
+        symbols_written++;
     }
 
     // When all symbols have been decoded, the final state value shouldn't have
@@ -1709,50 +1661,30 @@ static size_t HUF_decompress_1stream(const HUF_dtable *const dtable, u8 *dst,
         CORRUPTION();
     }
 
-    return dst - odst;
+    return symbols_written;
 }
 
-static size_t HUF_decompress_4stream(const HUF_dtable *const dtable, u8 *dst,
-                                     const size_t dst_len, const u8 *const src,
-                                     const size_t src_len) {
-    if (src_len < 6) {
-        INP_SIZE();
-    }
+static size_t HUF_decompress_4stream(const HUF_dtable *const dtable,
+                                     ostream_t *const out, istream_t *const in) {
+    const size_t csize1 = IO_read_bits(in, 16);
+    const size_t csize2 = IO_read_bits(in, 16);
+    const size_t csize3 = IO_read_bits(in, 16);
 
-    const u8 *const src1 = src + 6;
-    const u8 *const src2 = src1 + read_bits_LE(src, 16, 0);
-    const u8 *const src3 = src2 + read_bits_LE(src, 16, 16);
-    const u8 *const src4 = src3 + read_bits_LE(src, 16, 32);
-    const u8 *const src_end = src + src_len;
-
-    // We can't test with all 4 sizes because the 4th size is a function of the
-    // other 3 and the provided length
-    if (src4 - src >= src_len) {
-        INP_SIZE();
-    }
-
-    const size_t segment_size = (dst_len + 3) / 4;
-    u8 *const dst1 = dst;
-    u8 *const dst2 = dst1 + segment_size;
-    u8 *const dst3 = dst2 + segment_size;
-    u8 *const dst4 = dst3 + segment_size;
-    u8 *const dst_end = dst + dst_len;
-
-    size_t total_out = 0;
+    istream_t in1 = IO_make_sub_istream(in, csize1);
+    istream_t in2 = IO_make_sub_istream(in, csize2);
+    istream_t in3 = IO_make_sub_istream(in, csize3);
+    istream_t in4 = IO_make_sub_istream(in, IO_istream_len(in));
 
+    size_t total_output = 0;
     // Decode each stream independently for simplicity
     // If we wanted to we could decode all 4 at the same time for speed,
     // utilizing more execution units
-    total_out += HUF_decompress_1stream(dtable, dst1, segment_size, src1,
-                                        src2 - src1);
-    total_out += HUF_decompress_1stream(dtable, dst2, segment_size, src2,
-                                        src3 - src2);
-    total_out += HUF_decompress_1stream(dtable, dst3, segment_size, src3,
-                                        src4 - src3);
-    total_out += HUF_decompress_1stream(dtable, dst4, dst_end - dst4, src4,
-                                        src_end - src4);
-
-    return total_out;
+    total_output += HUF_decompress_1stream(dtable, out, &in1);
+    total_output += HUF_decompress_1stream(dtable, out, &in2);
+    total_output += HUF_decompress_1stream(dtable, out, &in3);
+    total_output += HUF_decompress_1stream(dtable, out, &in4);
+
+    return total_output;
 }
 
 static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits,
@@ -1827,6 +1759,10 @@ static void HUF_init_dtable_usingweights(HUF_dtable *const table,
 
     u64 weight_sum = 0;
     for (int i = 0; i < num_symbs; i++) {
+        // Weights are in the same range as bit count
+        if (weights[i] > HUF_MAX_BITS) {
+            CORRUPTION();
+        }
         weight_sum += weights[i] > 0 ? (u64)1 << (weights[i] - 1) : 0;
     }
 
@@ -1913,20 +1849,17 @@ static inline void FSE_init_state(const FSE_dtable *const dtable,
 }
 
 static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable,
-                                          u8 *dst, const size_t dst_len,
-                                          const u8 *const src,
-                                          const size_t src_len) {
-    if (src_len == 0) {
+                                          ostream_t *const out,
+                                          istream_t *const in) {
+    const size_t len = IO_istream_len(in);
+    if (len == 0) {
         INP_SIZE();
     }
-
-    const u8 *const dst_max = dst + dst_len;
-    const u8 *const odst = dst;
+    const u8 *const src = IO_read_bytes(in, len);
 
     // Find the last 1 bit
-    const int padding = 8 - log2inf(src[src_len - 1]);
-
-    i64 offset = src_len * 8 - padding;
+    const int padding = 8 - log2inf(src[len - 1]);
+    i64 offset = len * 8 - padding;
 
     // The end of the stream contains the 2 states, in this order
     u16 state1, state2;
@@ -1936,30 +1869,28 @@ static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable,
     // Decode until we overflow the stream
     // Since we decode in reverse order, overflowing the stream is offset going
     // negative
+    size_t symbols_written = 0;
     while (1) {
-        if (dst > dst_max - 2) {
-            OUT_SIZE();
-        }
-        *dst++ = FSE_decode_symbol(dtable, &state1, src, &offset);
+        IO_write_byte(out, FSE_decode_symbol(dtable, &state1, src, &offset));
+        symbols_written++;
         if (offset < 0) {
             // There's still a symbol to decode in state2
-            *dst++ = FSE_peek_symbol(dtable, state2);
+            IO_write_byte(out, FSE_peek_symbol(dtable, state2));
+            symbols_written++;
             break;
         }
 
-        if (dst > dst_max - 2) {
-            OUT_SIZE();
-        }
-        *dst++ = FSE_decode_symbol(dtable, &state2, src, &offset);
+        IO_write_byte(out, FSE_decode_symbol(dtable, &state2, src, &offset));
+        symbols_written++;
         if (offset < 0) {
             // There's still a symbol to decode in state1
-            *dst++ = FSE_peek_symbol(dtable, state1);
+            IO_write_byte(out, FSE_peek_symbol(dtable, state1));
+            symbols_written++;
             break;
         }
     }
 
-    // Number of symbols read
-    return dst - odst;
+    return symbols_written;
 }
 
 static void FSE_init_dtable(FSE_dtable *const dtable,
@@ -2042,17 +1973,13 @@ static void FSE_init_dtable(FSE_dtable *const dtable,
 
 /// Decode an FSE header as defined in the Zstandard format specification and
 /// use the decoded frequencies to initialize a decoding table.
-static size_t FSE_decode_header(FSE_dtable *const dtable, const u8 *const src,
-                                const size_t src_len,
+static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in,
                                 const int max_accuracy_log) {
     if (max_accuracy_log > FSE_MAX_ACCURACY_LOG) {
         ERROR("FSE accuracy too large");
     }
-    if (src_len < 1) {
-        INP_SIZE();
-    }
 
-    const int accuracy_log = 5 + read_bits_LE(src, 4, 0);
+    const int accuracy_log = 5 + IO_read_bits(in, 4);
     if (accuracy_log > max_accuracy_log) {
         ERROR("FSE accuracy too large");
     }
@@ -2062,14 +1989,11 @@ static size_t FSE_decode_header(FSE_dtable *const dtable, const u8 *const src,
     i16 frequencies[FSE_MAX_SYMBS];
 
     int symb = 0;
-    // Offset of 4 because 4 bits were already read in for accuracy
-    size_t offset = 4;
     while (remaining > 1 && symb < FSE_MAX_SYMBS) {
         // Log of the number of possible values we could read
         int bits = log2inf(remaining) + 1;
 
-        u16 val = read_bits_LE(src, bits, offset);
-        offset += bits;
+        u16 val = IO_read_bits(in, bits);
 
         // Try to mask out the lower bits to see if it qualifies for the "small
         // value" threshold
@@ -2077,7 +2001,7 @@ static size_t FSE_decode_header(FSE_dtable *const dtable, const u8 *const src,
         const u16 threshold = ((u16)1 << bits) - 1 - remaining;
 
         if ((val & lower_mask) < threshold) {
-            offset--;
+            IO_rewind_bits(in, 1);
             val = val & lower_mask;
         } else if (val > lower_mask) {
             val = val - threshold;
@@ -2093,22 +2017,21 @@ static size_t FSE_decode_header(FSE_dtable *const dtable, const u8 *const src,
         // Handle the special probability = 0 case
         if (proba == 0) {
             // Read the next two bits to see how many more 0s
-            int repeat = read_bits_LE(src, 2, offset);
-            offset += 2;
+            int repeat = IO_read_bits(in, 2);
 
             while (1) {
                 for (int i = 0; i < repeat && symb < FSE_MAX_SYMBS; i++) {
                     frequencies[symb++] = 0;
                 }
                 if (repeat == 3) {
-                    repeat = read_bits_LE(src, 2, offset);
-                    offset += 2;
+                    repeat = IO_read_bits(in, 2);
                 } else {
                     break;
                 }
             }
         }
     }
+    IO_align_stream(in);
 
     if (remaining != 1 || symb >= FSE_MAX_SYMBS) {
         CORRUPTION();
@@ -2116,8 +2039,6 @@ static size_t FSE_decode_header(FSE_dtable *const dtable, const u8 *const src,
 
     // Initialize the decoding table using the determined weights
     FSE_init_dtable(dtable, frequencies, symb, accuracy_log);
-
-    return (offset + 7) / 8;
 }
 
 static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb) {