]> git.ipfire.org Git - thirdparty/samba.git/commitdiff
lib/compression: LZ77 + Huffman compression
authorDouglas Bagnall <douglas.bagnall@catalyst.net.nz>
Thu, 17 Nov 2022 10:14:58 +0000 (23:14 +1300)
committerJoseph Sutton <jsutton@samba.org>
Thu, 1 Dec 2022 22:56:39 +0000 (22:56 +0000)
This compresses files as described in MS-XCA 2.2, and as decompressed
by the decompressor in the previous commit.

As with the decompressor, there are two public functions -- one that
uses a talloc context, and one that uses pre-allocated memory. The
compressor requires a tightly bound amount of auxillary memory
(>220kB) in a few different buffers, which is all gathered together in
the public struct lzxhuff_compressor_mem. An instantiated but not
initialised copy of this struct is required by the non-talloc
function; it can be used over and over again.

Our compression speed is about the same as the decompression speed
(between 20 and 500 MB/s on this laptop, depending on the data), and
our compression ratio is very similar to that of Windows.

Signed-off-by: Douglas Bagnall <douglas.bagnall@catalyst.net.nz>
Reviewed-by: Joseph Sutton <josephsutton@catalyst.net.nz>
lib/compression/lzxpress_huffman.c
lib/compression/lzxpress_huffman.h
lib/compression/tests/test_lzx_huffman.c

index 105b1bda0d4523ab380abc9354d823d169078130..7a48ca73425f77ebcfa489f8885963476677bef9 100644 (file)
 
 #define LZXPRESS_ERROR -1LL
 
+/*
+ * We won't encode a match length longer than MAX_MATCH_LENGTH.
+ *
+ * Reports are that Windows has a limit at 64M.
+ */
+#define MAX_MATCH_LENGTH (64 * 1024 * 1024)
+
 
 struct bitstream {
        const uint8_t *bytes;
@@ -45,6 +52,1106 @@ struct bitstream {
 };
 
 
+#if ! defined __has_builtin
+#define __has_builtin(x) 0
+#endif
+
+/*
+ * bitlen_nonzero_16() returns the bit number of the most significant bit, or
+ * put another way, the integer log base 2. Log(0) is undefined; the argument
+ * has to be non-zero!
+ * 1     -> 0
+ * 2,3   -> 1
+ * 4-7   -> 2
+ * 1024  -> 10, etc
+ *
+ * Probably this is handled by a compiler intrinsic function that maps to a
+ * dedicated machine instruction.
+ */
+
+static inline int bitlen_nonzero_16(uint16_t x)
+{
+#if  __has_builtin(__builtin_clz)
+
+       /* __builtin_clz returns the number of leading zeros */
+       return (sizeof(unsigned int) * CHAR_BIT) - 1
+               - __builtin_clz((unsigned int) x);
+
+#else
+
+       int count = -1;
+       while(x) {
+               x >>= 1;
+               count++;
+       }
+       return count;
+
+#endif
+}
+
+
+struct lzxhuff_compressor_context {
+       const uint8_t *input_bytes;
+       size_t input_size;
+       size_t input_pos;
+       size_t prev_block_pos;
+       uint8_t *output;
+       size_t available_size;
+       size_t output_pos;
+};
+
+static int compare_huffman_node_count(struct huffman_node *a,
+                                     struct huffman_node *b)
+{
+       return a->count - b->count;
+}
+
+static int compare_huffman_node_depth(struct huffman_node *a,
+                                     struct huffman_node *b)
+{
+       int c = a->depth - b->depth;
+       if (c != 0) {
+               return c;
+       }
+       return (int)a->symbol - (int)b->symbol;
+}
+
+
+#define HASH_MASK ((1 << LZX_HUFF_COMP_HASH_BITS) - 1)
+
+static inline uint16_t three_byte_hash(const uint8_t *bytes)
+{
+       /*
+        * MS-XCA says "three byte hash", but does not specify it.
+        *
+        * This one is just cobbled together, but has quite good distribution
+        * in the 12-14 bit forms, which is what we care about most.
+        * e.g: 13 bit: median 2048, min 2022, max 2074, stddev 6.0
+        */
+       uint16_t a = bytes[0];
+       uint16_t b = bytes[1] ^ 0x2e;
+       uint16_t c = bytes[2] ^ 0x55;
+       uint16_t ca = c - a;
+       uint16_t d = ((a + b) << 8) ^ (ca << 5) ^ (c + b) ^ (0xcab + a);
+       return d & HASH_MASK;
+}
+
+
+static inline uint16_t encode_match(size_t len, size_t offset)
+{
+       uint16_t code = 256;
+       code |= MIN(len - 3, 15);
+       code |= bitlen_nonzero_16(offset) << 4;
+       return code;
+}
+
+
+static bool depth_walk(struct huffman_node *n, uint32_t depth)
+{
+       bool ok;
+       if (n->left == NULL) {
+               /* this is a leaf, record the depth */
+               n->depth = depth;
+               return true;
+       }
+       if (depth > 14) {
+               return false;
+       }
+       ok = (depth_walk(n->left, depth + 1) &&
+             depth_walk(n->right, depth + 1));
+
+       return ok;
+}
+
+
+static bool check_and_record_depths(struct huffman_node *root)
+{
+       return depth_walk(root, 0);
+}
+
+
+static bool encode_values(struct huffman_node *leaves,
+                         size_t n_leaves,
+                         uint16_t symbol_values[512])
+{
+       size_t i;
+       /*
+        * See, we have a leading 1 in our internal code representation, which
+        * indicates the code length.
+        */
+       uint32_t code = 1;
+       uint32_t code_len = 0;
+       memset(symbol_values, 0, sizeof(uint16_t) * 512);
+       for (i = 0; i < n_leaves; i++) {
+               code <<= leaves[i].depth - code_len;
+               code_len = leaves[i].depth;
+
+               symbol_values[leaves[i].symbol] = code;
+               code++;
+       }
+       /*
+        * The last code should be 11111... with code_len + 1 ones. The final
+        * code++ will wrap this round to 1000... with code_len + 1 zeroes.
+        */
+
+       if (code != 2 << code_len) {
+               return false;
+       }
+       return true;
+}
+
+
+static int generate_huffman_codes(struct huffman_node *leaf_nodes,
+                                 struct huffman_node *internal_nodes,
+                                 uint16_t symbol_values[512])
+{
+       size_t head_leaf = 0;
+       size_t head_branch = 0;
+       size_t tail_branch = 0;
+       struct huffman_node *huffman_root = NULL;
+       size_t i, j;
+       size_t n_leaves = 0;
+
+       /*
+        * Before we sort the nodes, we can eliminate the unused ones.
+        */
+       for (i = 0; i < 512; i++) {
+               if (leaf_nodes[i].count) {
+                       leaf_nodes[n_leaves] = leaf_nodes[i];
+                       n_leaves++;
+               }
+       }
+       if (n_leaves == 0) {
+               return LZXPRESS_ERROR;
+       }
+       if (n_leaves == 1) {
+               /*
+                * There is *almost* no way this should happen, and it would
+                * ruin the tree (because the shortest possible codes are 1
+                * bit long, and there are two of them).
+                *
+                * The only way to get here is in an internal block in a
+                * 3-or-more block message (i.e. > 128k), which consists
+                * entirely of a match starting in the previous block (if it
+                * was the end block, it would have the EOF symbol).
+                *
+                * What we do is add a dummy symbol which is this one XOR 256.
+                * It won't be used in the stream but will balance the tree.
+                */
+               leaf_nodes[1] = leaf_nodes[0];
+               leaf_nodes[1].symbol ^= 0x100;
+               n_leaves = 2;
+       }
+
+       /* note, in sort we're using internal_nodes as auxillary space */
+       stable_sort(leaf_nodes,
+                   internal_nodes,
+                   n_leaves,
+                   sizeof(struct huffman_node),
+                   (samba_compare_fn_t)compare_huffman_node_count);
+
+       /*
+        * This outer loop is for re-quantizing the counts if the tree is too
+        * tall (>15), which we need to do because the final encoding can't
+        * express a tree that deep.
+        *
+        * In theory, this should be a 'while (true)' loop, but we chicken
+        * out with 10 iterations, just in case.
+        *
+        * In practice it will almost always resolve in the first round; if
+        * not then, in the second or third. Remember we'll looking at 64k or
+        * less, so the rarest we can have is 1 in 64k; each round of
+        * quantization effecively doubles its frequency to 1 in 32k, 1 in
+        * 16k, etc, until we're treating the rare symbol as actually quite
+        * common.
+        */
+       for (j = 0; j < 10; j++) {
+               bool less_than_15_bits;
+               while (true) {
+                       struct huffman_node *a = NULL;
+                       struct huffman_node *b = NULL;
+                       size_t leaf_len = n_leaves - head_leaf;
+                       size_t internal_len = tail_branch - head_branch;
+
+                       if (leaf_len + internal_len == 1) {
+                               /*
+                                * We have the complete tree. The root will be
+                                * an internal node unless there is just one
+                                * symbol, which is already impossible.
+                                */
+                               if (unlikely(leaf_len == 1)) {
+                                       return LZXPRESS_ERROR;
+                               } else {
+                                       huffman_root = \
+                                               &internal_nodes[head_branch];
+                               }
+                               break;
+                       }
+                       /*
+                        * We know here we have at least two nodes, and we
+                        * want to select the two lowest scoring ones. Those
+                        * have to be either a) the head of each queue, or b)
+                        * the first two nodes of either queue.
+                        *
+                        * The complicating factors are: a) we need to check
+                        * the length of each queue, and b) in the case of
+                        * ties, we prefer to pair leaves with leaves.
+                        *
+                        * Note a complication we don't have: the leaf node
+                        * queue never grows, and the subtree queue starts
+                        * empty and cannot grow beyond n - 1. It feeds on
+                        * itself. We don't need to think about overflow.
+                        */
+                       if (leaf_len == 0) {
+                               /* two from subtrees */
+                               a = &internal_nodes[head_branch];
+                               b = &internal_nodes[head_branch + 1];
+                               head_branch += 2;
+                       } else if (internal_len == 0) {
+                               /* two from nodes */
+                               a = &leaf_nodes[head_leaf];
+                               b = &leaf_nodes[head_leaf + 1];
+                               head_leaf += 2;
+                       } else if (leaf_len == 1 && internal_len == 1) {
+                               /* one of each */
+                               a = &leaf_nodes[head_leaf];
+                               b = &internal_nodes[head_branch];
+                               head_branch++;
+                               head_leaf++;
+                       } else {
+                               /*
+                                * Take the lowest head, twice, checking for
+                                * length after taking the first one.
+                                */
+                               if (leaf_nodes[head_leaf].count >
+                                   internal_nodes[head_branch].count) {
+                                       a = &internal_nodes[head_branch];
+                                       head_branch++;
+                                       if (internal_len == 1) {
+                                               b = &leaf_nodes[head_leaf];
+                                               head_leaf++;
+                                               goto done;
+                                       }
+                               } else {
+                                       a = &leaf_nodes[head_leaf];
+                                       head_leaf++;
+                                       if (leaf_len == 1) {
+                                               b = &internal_nodes[head_branch];
+                                               head_branch++;
+                                               goto done;
+                                       }
+                               }
+                               /* the other node */
+                               if (leaf_nodes[head_leaf].count >
+                                   internal_nodes[head_branch].count) {
+                                       b = &internal_nodes[head_branch];
+                                       head_branch++;
+                               } else {
+                                       b = &leaf_nodes[head_leaf];
+                                       head_leaf++;
+                               }
+                       }
+               done:
+                       /*
+                        * Now we add a new node to the subtrees list that
+                        * combines the score of node_a and node_b, and points
+                        * to them as children.
+                        */
+                       internal_nodes[tail_branch].count = a->count + b->count;
+                       internal_nodes[tail_branch].left = a;
+                       internal_nodes[tail_branch].right = b;
+                       tail_branch++;
+                       if (tail_branch == n_leaves) {
+                               /*
+                                * We're not getting here, no way, never ever.
+                                * Unless we made a terible mistake.
+                                *
+                                * That is, in a binary tree with n leaves,
+                                * there are ALWAYS n-1 internal nodes.
+                                */
+                               return LZXPRESS_ERROR;
+                       }
+               }
+               /*
+                * We have a tree, and need to turn it into a lookup table,
+                * and see if it is shallow enough (<= 15).
+                */
+               less_than_15_bits = check_and_record_depths(huffman_root);
+               if (less_than_15_bits) {
+                       /*
+                        * Now the leaf nodes know how deep they are, and we
+                        * no longer need the internal nodes.
+                        *
+                        * We need to sort the nodes of equal depth, so that
+                        * they are sorted by depth first, and symbol value
+                        * second. The internal_nodes can again be auxillary
+                        * memory.
+                        */
+                       stable_sort(
+                               leaf_nodes,
+                               internal_nodes,
+                               n_leaves,
+                               sizeof(struct huffman_node),
+                               (samba_compare_fn_t)compare_huffman_node_depth);
+
+                       encode_values(leaf_nodes, n_leaves, symbol_values);
+
+                       return n_leaves;
+               }
+
+               /*
+                * requantize by halfing and rounding up, so that small counts
+                * become relatively bigger. This will lead to a flatter tree.
+                */
+               for (i = 0; i < n_leaves; i++) {
+                       leaf_nodes[i].count >>= 1;
+                       leaf_nodes[i].count += 1;
+               }
+               head_leaf = 0;
+               head_branch = 0;
+               tail_branch = 0;
+       }
+       return LZXPRESS_ERROR;
+}
+
+/*
+ * LZX_HUFF_COMP_HASH_SEARCH_ATTEMPTS is how far ahead to search in the
+ * circular hash table for a match, before we give up. A bigger number will
+ * generally lead to better but slower compression, but a stupidly big number
+ * will just be worse.
+ *
+ * If you're fiddling with this, consider also fiddling with
+ * LZX_HUFF_COMP_HASH_BITS.
+ */
+#define LZX_HUFF_COMP_HASH_SEARCH_ATTEMPTS 5
+
+static inline void store_match(uint16_t *hash_table,
+                              uint16_t h,
+                              uint16_t offset)
+{
+       int i;
+       uint16_t o = hash_table[h];
+       uint16_t h2;
+       uint16_t worst_h;
+       int worst_score;
+
+       if (o == 0xffff) {
+               /* there is nothing there yet */
+               hash_table[h] = offset;
+               return;
+       }
+       for (i = 1; i < LZX_HUFF_COMP_HASH_SEARCH_ATTEMPTS; i++) {
+               h2 = (h + i) & HASH_MASK;
+               if (hash_table[h2] == 0xffff) {
+                       hash_table[h2] = offset;
+                       return;
+               }
+       }
+       /*
+        * There are no slots, but we really want to store this, so we'll kick
+        * out the one with the longest distance.
+        */
+       worst_h = h;
+       worst_score = offset - o;
+       for (i = 1; i < LZX_HUFF_COMP_HASH_SEARCH_ATTEMPTS; i++) {
+               int score;
+               h2 = (h + i) & HASH_MASK;
+               o = hash_table[h2];
+               score = offset - o;
+               if (score > worst_score) {
+                       worst_score = score;
+                       worst_h = h2;
+               }
+       }
+       hash_table[worst_h] = offset;
+}
+
+
+/*
+ * Yes, struct match looks a lot like a DATA_BLOB.
+ */
+struct match {
+       const uint8_t *there;
+       size_t length;
+};
+
+
+static inline struct match lookup_match(uint16_t *hash_table,
+                                       uint16_t h,
+                                       const uint8_t *data,
+                                       const uint8_t *here,
+                                       size_t max_len)
+{
+       int i;
+       uint16_t o = hash_table[h];
+       uint16_t h2;
+       size_t len;
+       const uint8_t *there = NULL;
+       struct match best = {0};
+
+       for (i = 0; i < LZX_HUFF_COMP_HASH_SEARCH_ATTEMPTS; i++) {
+               h2 = (h + i) & HASH_MASK;
+               o = hash_table[h2];
+               if (o == 0xffff) {
+                       /*
+                        * in setting this, we would never have stepped over
+                        * an 0xffff, so we won't now.
+                        */
+                       break;
+               }
+               there = data + o;
+               if (here - there > 65534 || there > here) {
+                       continue;
+               }
+
+               /*
+                * When we already have a long match, we can try to avoid
+                * measuring out another long, but shorter match.
+                */
+               if (best.length > 1000 &&
+                   there[best.length - 1] != best.there[best.length - 1]) {
+                       continue;
+               }
+
+               for (len = 0;
+                    len < max_len && here[len] == there[len];
+                    len++) {
+                       /* counting */
+               }
+               if (len > 2) {
+                       /*
+                        * As a tiebreaker, we prefer the closer match which
+                        * is likely to encode smaller (and certainly no worse).
+                        */
+                       if (len > best.length ||
+                           (len == best.length && there > best.there)) {
+                               best.length = len;
+                               best.there = there;
+                       }
+               }
+       }
+       return best;
+}
+
+
+
+static ssize_t lz77_encode_block(struct lzxhuff_compressor_context *cmp_ctx,
+                                struct lzxhuff_compressor_mem *cmp_mem,
+                                uint16_t *hash_table,
+                                uint16_t *prev_hash_table)
+{
+       uint16_t *intermediate = cmp_mem->intermediate;
+       struct huffman_node *leaf_nodes = cmp_mem->leaf_nodes;
+       uint16_t *symbol_values = cmp_mem->symbol_values;
+       size_t i, j, intermediate_len;
+       const uint8_t *data = cmp_ctx->input_bytes + cmp_ctx->input_pos;
+       const uint8_t *prev_block = NULL;
+       size_t remaining_size = cmp_ctx->input_size - cmp_ctx->input_pos;
+       size_t block_end = MIN(65536, remaining_size);
+       struct match match;
+       int n_symbols;
+
+       if (cmp_ctx->input_size < cmp_ctx->input_pos) {
+               return LZXPRESS_ERROR;
+       }
+
+       if (cmp_ctx->prev_block_pos != cmp_ctx->input_pos) {
+               prev_block = cmp_ctx->input_bytes + cmp_ctx->prev_block_pos;
+       } else if (prev_hash_table != NULL) {
+               /* we've got confused! hash and block should go together */
+               return LZXPRESS_ERROR;
+       }
+
+       /*
+        * leaf_nodes is used to count the symbols seen, for later Huffman
+        * encoding.
+        */
+       for (i = 0; i < 512; i++) {
+               leaf_nodes[i] = (struct huffman_node) {
+                       .symbol = i
+               };
+       }
+
+       j = 0;
+
+       if (remaining_size < 41) {
+               /*
+                * There is no point doing a hash table and looking for
+                * matches in this tiny block (remembering we are committed to
+                * using 32 bits, so there's a good chance we wouldn't even
+                * save a byte). The threshold of 41 matches Windows.
+                * If remaining_size < 3, we *can't* do the hash.
+                */
+               i = 0;
+       } else {
+               /*
+                * We use 0xffff as the unset value for table, because it is
+                * not a valid match offset (and 0x0 is).
+                */
+               memset(hash_table, 0xff, sizeof(cmp_mem->hash_table1));
+
+               for (i = 0; i <= block_end - 3; i++) {
+                       uint16_t code;
+                       const uint8_t *here = data + i;
+                       uint16_t h = three_byte_hash(here);
+                       size_t max_len = MIN(remaining_size - i, MAX_MATCH_LENGTH);
+                       match = lookup_match(hash_table,
+                                            h,
+                                            data,
+                                            here,
+                                            max_len);
+
+                       if (match.there == NULL && prev_hash_table != NULL) {
+                               /*
+                                * If this is not the first block,
+                                * backreferences can look into the previous
+                                * block (but only as far as 65535 bytes, so
+                                * the end of this block cannot see the start
+                                * of the last one).
+                                */
+                               match = lookup_match(prev_hash_table,
+                                                    h,
+                                                    prev_block,
+                                                    here,
+                                                    remaining_size - i);
+                       }
+
+                       store_match(hash_table, h, i);
+
+                       if (match.there == NULL) {
+                               /* add a literal and move on. */
+                               uint8_t c = data[i];
+                               leaf_nodes[c].count++;
+                               intermediate[j] = c;
+                               j++;
+                               continue;
+                       }
+
+                       /* a real match */
+                       if (match.length <= 65538) {
+                               intermediate[j] = 0xffff;
+                               intermediate[j + 1] = match.length - 3;
+                               intermediate[j + 2] = here - match.there;
+                               j += 3;
+                       } else {
+                               size_t m = match.length - 3;
+                               intermediate[j] = 0xfffe;
+                               intermediate[j + 1] = m & 0xffff;
+                               intermediate[j + 2] = m >> 16;
+                               intermediate[j + 3] = here - match.there;
+                               j += 4;
+                       }
+                       code = encode_match(match.length, here - match.there);
+                       leaf_nodes[code].count++;
+                       i += match.length - 1; /* `- 1` for the loop i++ */
+                       /*
+                        * A match can take us past the intended block length,
+                        * extending the block. We don't need to do anything
+                        * special for this case -- the loops will naturally
+                        * do the right thing.
+                        */
+               }
+       }
+
+       /*
+        * There might be some bytes at the end.
+        */
+       for (; i < block_end; i++) {
+               leaf_nodes[data[i]].count++;
+               intermediate[j] = data[i];
+               j++;
+       }
+
+       if (i == remaining_size) {
+               /* add a trailing EOF marker (256) */
+               intermediate[j] = 0xffff;
+               intermediate[j + 1] = 0;
+               intermediate[j + 2] = 1;
+               j += 3;
+               leaf_nodes[256].count++;
+       }
+
+       intermediate_len = j;
+
+       cmp_ctx->prev_block_pos = cmp_ctx->input_pos;
+       cmp_ctx->input_pos += i;
+
+       /* fill in the symbols table */
+       n_symbols = generate_huffman_codes(leaf_nodes,
+                                          cmp_mem->internal_nodes,
+                                          symbol_values);
+       if (n_symbols < 0) {
+               return n_symbols;
+       }
+
+       return intermediate_len;
+}
+
+
+
+static ssize_t write_huffman_table(uint16_t symbol_values[512],
+                                  uint8_t *output,
+                                  size_t available_size)
+{
+       size_t i;
+
+       if (available_size < 256) {
+               return LZXPRESS_ERROR;
+       }
+
+       for (i = 0; i < 256; i++) {
+               uint8_t b = 0;
+               uint16_t even = symbol_values[i * 2];
+               uint16_t odd = symbol_values[i * 2 + 1];
+               if (even != 0) {
+                       b = bitlen_nonzero_16(even);
+               }
+               if (odd != 0) {
+                       b |= bitlen_nonzero_16(odd) << 4;
+               }
+               output[i] = b;
+       }
+       return i;
+}
+
+
+struct write_context {
+       uint8_t *dest;
+       size_t dest_len;
+       size_t head;                 /* where lengths go */
+       size_t next_code;            /* where symbol stream goes */
+       size_t pending_next_code;    /* will be next_code */
+       int bit_len;
+       uint32_t bits;
+};
+
+/*
+ * Write out 16 bits, little-endian, for write_huffman_codes()
+ *
+ * As you'll notice, there's a bit to do.
+ *
+ * We are collecting up bits in a uint32_t, then when there are 16 of them we
+ * write out a word into the stream, using a trio of offsets (wc->next_code,
+ * wc->pending_next_code, and wc->head) which dance around ensuring that the
+ * bitstream and the interspersed lengths are in the right places relative to
+ * each other.
+ */
+
+static inline bool write_bits(struct write_context *wc,
+                             uint16_t code, uint16_t length)
+{
+       wc->bits <<= length;
+       wc->bits |= code;
+       wc->bit_len += length;
+       if (wc->bit_len > 16) {
+               uint32_t w = wc->bits >> (wc->bit_len - 16);
+               wc->bit_len -= 16;
+               if (wc->next_code + 2 > wc->dest_len) {
+                       return false;
+               }
+               wc->dest[wc->next_code] = w & 0xff;
+               wc->dest[wc->next_code + 1] = (w >> 8) & 0xff;
+               wc->next_code = wc->pending_next_code;
+               wc->pending_next_code = wc->head;
+               wc->head += 2;
+       }
+       return true;
+}
+
+
+static inline bool write_code(struct write_context *wc, uint16_t code)
+{
+       int code_bit_len = bitlen_nonzero_16(code);
+       code &= (1 << code_bit_len) - 1;
+       return  write_bits(wc, code, code_bit_len);
+}
+
+static inline bool write_byte(struct write_context *wc, uint8_t byte)
+{
+       if (wc->head + 1 > wc->dest_len) {
+               return false;
+       }
+       wc->dest[wc->head] = byte;
+       wc->head++;
+       return true;
+}
+
+
+static inline bool write_long_len(struct write_context *wc, size_t len)
+{
+       if (len < 65535) {
+               if (wc->head + 3 > wc->dest_len) {
+                       return false;
+               }
+               wc->dest[wc->head] = 255;
+               wc->dest[wc->head + 1] = len & 255;
+               wc->dest[wc->head + 2] = len >> 8;
+               wc->head += 3;
+       } else {
+               if (wc->head + 7 > wc->dest_len) {
+                       return false;
+               }
+               wc->dest[wc->head] = 255;
+               wc->dest[wc->head + 1] = 0;
+               wc->dest[wc->head + 2] = 0;
+               wc->dest[wc->head + 3] = len & 255;
+               wc->dest[wc->head + 4] = (len >> 8) & 255;
+               wc->dest[wc->head + 5] = (len >> 16) & 255;
+               wc->dest[wc->head + 6] = (len >> 24) & 255;
+               wc->head += 7;
+       }
+       return true;
+}
+
+static ssize_t write_compressed_bytes(uint16_t symbol_values[512],
+                                     uint16_t *intermediate,
+                                     size_t intermediate_len,
+                                     uint8_t *dest,
+                                     size_t dest_len)
+{
+       bool ok;
+       size_t i;
+       size_t end;
+       struct write_context wc = {
+               .head = 4,
+               .pending_next_code = 2,
+               .dest = dest,
+               .dest_len = dest_len
+       };
+       for (i = 0; i < intermediate_len; i++) {
+               uint16_t c = intermediate[i];
+               size_t len;
+               uint16_t distance;
+               uint16_t code_len = 0;
+               uint16_t code_dist = 0;
+               if (c < 256) {
+                       ok = write_code(&wc, symbol_values[c]);
+                       if (!ok) {
+                               return LZXPRESS_ERROR;
+                       }
+                       continue;
+               }
+
+               if (c == 0xfffe) {
+                       if (i > intermediate_len - 4) {
+                               return LZXPRESS_ERROR;
+                       }
+
+                       len = intermediate[i + 1];
+                       len |= intermediate[i + 2] << 16;
+                       distance = intermediate[i + 3];
+                       i += 3;
+               } else if (c == 0xffff) {
+                       if (i > intermediate_len - 3) {
+                               return LZXPRESS_ERROR;
+                       }
+                       len = intermediate[i + 1];
+                       distance = intermediate[i + 2];
+                       i += 2;
+               } else {
+                       return LZXPRESS_ERROR;
+               }
+               /* len has already had 3 subtracted */
+               if (len >= 15) {
+                       /*
+                        * We are going to need to write extra length
+                        * bytes into the stream, but we don't do it
+                        * now, we do it after the code has been
+                        * written (and before the distance bits).
+                        */
+                       code_len = 15;
+               } else {
+                       code_len = len;
+               }
+               code_dist = bitlen_nonzero_16(distance);
+               c = 256 | (code_dist << 4) | code_len;
+               if (c > 511) {
+                       return LZXPRESS_ERROR;
+               }
+
+               ok = write_code(&wc, symbol_values[c]);
+               if (!ok) {
+                       return LZXPRESS_ERROR;
+               }
+
+               if (code_len == 15) {
+                       if (len >= 270) {
+                               ok = write_long_len(&wc, len);
+                       } else {
+                               ok = write_byte(&wc, len - 15);
+                       }
+                       if (! ok) {
+                               return LZXPRESS_ERROR;
+                       }
+               }
+               if (code_dist != 0) {
+                       uint16_t dist_bits = distance - (1 << code_dist);
+                       ok = write_bits(&wc, dist_bits, code_dist);
+                       if (!ok) {
+                               return LZXPRESS_ERROR;
+                       }
+               }
+       }
+       /*
+        * There are some intricacies around flushing the bits and returning
+        * the length.
+        *
+        * If the returned length is not exactly right and there is another
+        * block, that block will read its huffman table from the wrong place,
+        * and have all the symbol codes out by a multiple of 4.
+        */
+       end = wc.head;
+       if (wc.bit_len == 0) {
+               end -= 2;
+       }
+       ok = write_bits(&wc, 0, 16 - wc.bit_len);
+       if (!ok) {
+               return LZXPRESS_ERROR;
+       }
+       for (i = 0; i < 2; i++) {
+               /*
+                * Flush out the bits with zeroes. It doesn't matter if we do
+                * a round too many, as we have buffer space, and have already
+                * determined the returned length (end).
+                */
+               ok = write_bits(&wc, 0, 16);
+               if (!ok) {
+                       return LZXPRESS_ERROR;
+               }
+       }
+       return end;
+}
+
+
+static ssize_t lzx_huffman_compress_block(struct lzxhuff_compressor_context *cmp_ctx,
+                                         struct lzxhuff_compressor_mem *cmp_mem,
+                                         size_t block_no)
+{
+       ssize_t intermediate_size;
+       uint16_t *hash_table = NULL;
+       uint16_t *back_window_hash_table = NULL;
+       ssize_t bytes_written;
+
+       if (cmp_ctx->available_size - cmp_ctx->output_pos < 260) {
+               /* huffman block + 4 bytes */
+               return LZXPRESS_ERROR;
+       }
+
+       /*
+        * For LZ77 compression, we keep a hash table for the previous block,
+        * via alternation after the first block.
+        *
+        * LZ77 writes into the intermediate buffer in the cmp_mem context.
+        */
+       if (block_no == 0) {
+               hash_table = cmp_mem->hash_table1;
+               back_window_hash_table = NULL;
+       } else if (block_no & 1) {
+               hash_table = cmp_mem->hash_table2;
+               back_window_hash_table = cmp_mem->hash_table1;
+       } else {
+               hash_table = cmp_mem->hash_table1;
+               back_window_hash_table = cmp_mem->hash_table2;
+       }
+
+       intermediate_size = lz77_encode_block(cmp_ctx,
+                                             cmp_mem,
+                                             hash_table,
+                                             back_window_hash_table);
+
+       if (intermediate_size < 0) {
+               return intermediate_size;
+       }
+
+       /*
+        * Write the 256 byte Huffman table, based on the counts gained in
+        * LZ77 phase.
+        */
+       bytes_written = write_huffman_table(
+               cmp_mem->symbol_values,
+               cmp_ctx->output + cmp_ctx->output_pos,
+               cmp_ctx->available_size - cmp_ctx->output_pos);
+
+       if (bytes_written != 256) {
+               return LZXPRESS_ERROR;
+       }
+       cmp_ctx->output_pos += 256;
+
+       /*
+        * Write the compressed bytes using the LZ77 matches and Huffman codes
+        * worked out in the previous steps.
+        */
+       bytes_written = write_compressed_bytes(
+               cmp_mem->symbol_values,
+               cmp_mem->intermediate,
+               intermediate_size,
+               cmp_ctx->output + cmp_ctx->output_pos,
+               cmp_ctx->available_size - cmp_ctx->output_pos);
+
+       if (bytes_written < 0) {
+               return bytes_written;
+       }
+
+       cmp_ctx->output_pos += bytes_written;
+       return bytes_written;
+}
+
+
+/*
+ * lzxpress_huffman_compress_talloc()
+ *
+ * This is the convenience function that allocates the compressor context and
+ * output memory for you. The return value is the number of bytes written to
+ * the location indicated by the output pointer.
+ *
+ * The maximum input_size is effectively around 227MB due to the need to guess
+ * an upper bound on the output size that hits an internal limitation in
+ * talloc.
+ *
+ * @param mem_ctx      TALLOC_CTX parent for the compressed buffer.
+ * @param input_bytes  memory to be compressed.
+ * @param input_size   length of the input buffer.
+ * @param output       destination pointer for the compressed data.
+ *
+ * @return the number of bytes written or -1 on error.
+ */
+
+ssize_t lzxpress_huffman_compress_talloc(TALLOC_CTX *mem_ctx,
+                                        const uint8_t *input_bytes,
+                                        size_t input_size,
+                                        uint8_t **output)
+{
+       struct lzxhuff_compressor_mem *cmp = NULL;
+       /*
+        * In the worst case, the output size should be about the same as the
+        * input size, plus the 256 byte header per 64k block. We aim for
+        * ample, but within the order of magnitude.
+        */
+       size_t alloc_size = input_size + (input_size / 8) + 270;
+       ssize_t output_size;
+
+       *output = talloc_array(mem_ctx, uint8_t, alloc_size);
+       if (*output == NULL) {
+               return LZXPRESS_ERROR;
+       }
+
+       cmp = talloc(mem_ctx, struct lzxhuff_compressor_mem);
+       if (cmp == NULL) {
+               TALLOC_FREE(*output);
+               return LZXPRESS_ERROR;
+       }
+
+       output_size = lzxpress_huffman_compress(cmp,
+                                               input_bytes,
+                                               input_size,
+                                               *output,
+                                               alloc_size);
+
+       talloc_free(cmp);
+
+       if (output_size < 0) {
+               TALLOC_FREE(*output);
+               return LZXPRESS_ERROR;
+       }
+
+       *output = talloc_realloc(mem_ctx, *output, uint8_t, output_size);
+       if (*output == NULL) {
+               return LZXPRESS_ERROR;
+       }
+
+       return output_size;
+}
+
+/*
+ * lzxpress_huffman_compress()
+ *
+ * This is the inconvenience function, slightly faster and fiddlier than
+ * lzxpress_huffman_compress_talloc().
+ *
+ * To use this, you need to have allocated (but not initialised) a `struct
+ * lzxhuff_compressor_context`, and an output buffer. If the buffer is not big
+ * enough (per `output_size`), you'll get a negative return value, otherwise
+ * the number of bytes actually consumed, which will always be at least 260.
+ *
+ * The `struct lzxhuff_compressor_context` is reusable -- it is basically a
+ * collection of uninitialised memory buffers. The total size is less than
+ * 150k, so stack allocation is plausible.
+ *
+ * input_size and available_size are limited to the minimum of UINT32_MAX and
+ * SSIZE_MAX. On 64 bit machines that will be UINT32_MAX, or 4GB.
+ *
+ * @param cmp_mem         a struct lzxhuff_compressor_mem.
+ * @param input_bytes     memory to be compressed.
+ * @param input_size      length of the input buffer.
+ * @param output          destination for the compressed data.
+ * @param available_size  allocated output bytes.
+ *
+ * @return the number of bytes written or -1 on error.
+ */
+ssize_t lzxpress_huffman_compress(struct lzxhuff_compressor_mem *cmp_mem,
+                                 const uint8_t *input_bytes,
+                                 size_t input_size,
+                                 uint8_t *output,
+                                 size_t available_size)
+{
+       size_t i = 0;
+       struct lzxhuff_compressor_context cmp_ctx = {
+               .input_bytes = input_bytes,
+               .input_size = input_size,
+               .input_pos = 0,
+               .prev_block_pos = 0,
+               .output = output,
+               .available_size = available_size,
+               .output_pos = 0
+       };
+
+       if (input_size == 0) {
+               /*
+                * We can't deal with this for a number of reasons (e.g. it
+                * breaks the Huffman tree), and the output will be infinitely
+                * bigger than the input. The caller needs to go and think
+                * about what they're trying to do here.
+                */
+               return LZXPRESS_ERROR;
+       }
+
+       if (input_size > SSIZE_MAX ||
+           input_size > UINT32_MAX ||
+           available_size > SSIZE_MAX ||
+           available_size > UINT32_MAX ||
+           available_size == 0) {
+               /*
+                * We use negative ssize_t to return errors, which is limiting
+                * on 32 bit machines; otherwise we adhere to Microsoft's 4GB
+                * limit.
+                *
+                * lzxpress_huffman_compress_talloc() will not get this far,
+                * having already have failed on talloc's 256 MB limit.
+                */
+               return LZXPRESS_ERROR;
+       }
+
+       if (cmp_mem == NULL ||
+           output == NULL ||
+           input_bytes == NULL) {
+               return LZXPRESS_ERROR;
+       }
+
+       while (cmp_ctx.input_pos < cmp_ctx.input_size) {
+               ssize_t ret;
+               ret = lzx_huffman_compress_block(&cmp_ctx,
+                                                cmp_mem,
+                                                i);
+               if (ret < 0) {
+                       return ret;
+               }
+               i++;
+       }
+
+       return cmp_ctx.output_pos;
+}
+
+
 /**
  * Determines the sort order of one prefix_code_symbol relative to another
  */
index c484181d504e5c4ba0257af47a7b98dc8809d62d..04de448bcce7da22b4738618acabd5c347e7afbe 100644 (file)
 #define HAVE_LZXPRESS_HUFFMAN_H
 
 
+struct huffman_node {
+       struct huffman_node *left;
+       struct huffman_node *right;
+       uint32_t count;
+       uint16_t symbol;
+       int8_t depth;
+};
+
+
+/*
+ * LZX_HUFF_COMP_HASH_BITS is how big to make the hash tables
+ * (12 means 4096, etc).
+ *
+ * A larger number (up to 16) will be faster on long messages (fewer
+ * collisions), but probably slower on short ones (more prep).
+ */
+#define LZX_HUFF_COMP_HASH_BITS 14
+
+
+/*
+ * This struct just coalesces all the memory you need for LZ77 + Huffman
+ * compresssion together in one bundle.
+ *
+ * There are a few different things you want, you usually want them all, so
+ * this makes it easy to allocate them all at once.
+ */
+
+struct lzxhuff_compressor_mem {
+       struct huffman_node leaf_nodes[512];
+       struct huffman_node internal_nodes[512];
+       uint16_t symbol_values[512];
+       uint16_t intermediate[65536 + 6];
+       uint16_t hash_table1[1 << LZX_HUFF_COMP_HASH_BITS];
+       uint16_t hash_table2[1 << LZX_HUFF_COMP_HASH_BITS];
+};
+
+
+ssize_t lzxpress_huffman_compress(struct lzxhuff_compressor_mem *cmp,
+                                 const uint8_t *input_bytes,
+                                 size_t input_size,
+                                 uint8_t *output,
+                                 size_t available_size);
+
+
+ssize_t lzxpress_huffman_compress_talloc(TALLOC_CTX *mem_ctx,
+                                        const uint8_t *input_bytes,
+                                        size_t input_size,
+                                        uint8_t **output);
+
 ssize_t lzxpress_huffman_decompress(const uint8_t *input,
                                    size_t input_size,
                                    uint8_t *output,
index 43275d2d32fd68676b9531799aa8b165bdc82c10..250b121eed356bd76fe912e3f07e11b863f8772e 100644 (file)
@@ -322,6 +322,29 @@ static void test_lzxpress_huffman_decompress(void **state)
        }
 }
 
+static void test_lzxpress_huffman_compress(void **state)
+{
+       size_t i;
+       ssize_t written;
+       uint8_t *dest = NULL;
+       TALLOC_CTX *mem_ctx = talloc_new(NULL);
+       for (i = 0; bidirectional_pairs[i].name != NULL; i++) {
+               struct lzx_pair p = bidirectional_pairs[i];
+               debug_message("%s compressed %zu decomp %zu\n", p.name,
+                             p.compressed.length,
+                             p.decompressed.length);
+
+               written = lzxpress_huffman_compress_talloc(mem_ctx,
+                                                          p.decompressed.data,
+                                                          p.decompressed.length,
+                                                          &dest);
+
+               assert_int_equal(written, p.compressed.length);
+               assert_memory_equal(dest, p.compressed.data, p.compressed.length);
+               talloc_free(dest);
+       }
+}
+
 
 static DATA_BLOB datablob_from_file(TALLOC_CTX *mem_ctx,
                                    const char *filename)
@@ -356,6 +379,7 @@ static DATA_BLOB datablob_from_file(TALLOC_CTX *mem_ctx,
 }
 
 
+
 static void test_lzxpress_huffman_decompress_files(void **state)
 {
        size_t i;
@@ -473,6 +497,673 @@ static void test_lzxpress_huffman_decompress_more_compressed_files(void **state)
 }
 
 
+/*
+ * attempt_round_trip() tests whether a data blob can survive a compression
+ * and decompression cycle. If save_name is not NULL and LZXHUFF_DEBUG_FILES
+ * evals to true, the various stages are saved in files with that name and the
+ * '-original', '-compressed', and '-decompressed' suffixes. If ref_compressed
+ * has data, it'll print a message saying whether the compressed data matches
+ * that.
+ */
+
+static ssize_t attempt_round_trip(TALLOC_CTX *mem_ctx,
+                                 DATA_BLOB original,
+                                 const char *save_name,
+                                 DATA_BLOB ref_compressed)
+{
+       TALLOC_CTX *tmp_ctx = talloc_new(mem_ctx);
+       DATA_BLOB compressed = data_blob_talloc(tmp_ctx, NULL,
+                                               original.length * 4 / 3 + 260);
+       DATA_BLOB decompressed = data_blob_talloc(tmp_ctx, NULL,
+                                               original.length);
+       ssize_t comp_written, decomp_written;
+
+       comp_written = lzxpress_huffman_compress_talloc(tmp_ctx,
+                                                       original.data,
+                                                       original.length,
+                                                       &compressed.data);
+
+       if (comp_written <= 0) {
+               talloc_free(tmp_ctx);
+               return -1;
+       }
+
+       if (ref_compressed.data != NULL) {
+               /*
+                * This is informational, not an assertion; there are
+                * ~infinite legitimate ways to compress the data, many as
+                * good as each other (think of compression as a language, not
+                * a format).
+                */
+               debug_message("compressed size %zd vs reference %zu\n",
+                             comp_written, ref_compressed.length);
+
+               if (comp_written == compressed.length &&
+                   memcmp(compressed.data, ref_compressed.data, comp_written) == 0) {
+                       debug_message("\033[1;32mbyte identical!\033[0m\n");
+               }
+       }
+
+       decomp_written = lzxpress_huffman_decompress(compressed.data,
+                                                    comp_written,
+                                                    decompressed.data,
+                                                    original.length);
+       if (save_name != NULL && LZXHUFF_DEBUG_FILES) {
+               char s[300];
+               FILE *fh = NULL;
+
+               snprintf(s, sizeof(s), "%s-original", save_name);
+               fprintf(stderr, "Saving %zu bytes to %s\n", original.length, s);
+               fh = fopen(s, "w");
+               fwrite(original.data, 1, original.length, fh);
+               fclose(fh);
+
+               snprintf(s, sizeof(s), "%s-compressed", save_name);
+               fprintf(stderr, "Saving %zu bytes to %s\n", comp_written, s);
+               fh = fopen(s, "w");
+               fwrite(compressed.data, 1, comp_written, fh);
+               fclose(fh);
+               /*
+                * We save the decompressed file using original.length, not
+                * the returned size. If these differ, the returned size will
+                * be -1. By saving the whole buffer we can see at what point
+                * it went haywire.
+                */
+               snprintf(s, sizeof(s), "%s-decompressed", save_name);
+               fprintf(stderr, "Saving %zu bytes to %s\n", original.length, s);
+               fh = fopen(s, "w");
+               fwrite(decompressed.data, 1, original.length, fh);
+               fclose(fh);
+       }
+
+       if (original.length != decomp_written ||
+           memcmp(decompressed.data,
+                  original.data,
+                  original.length) != 0) {
+               debug_message("\033[1;31mgot %zd, expected %zu\033[0m\n",
+                             decomp_written,
+                             original.length);
+               talloc_free(tmp_ctx);
+               return -1;
+       }
+       talloc_free(tmp_ctx);
+       return comp_written;
+}
+
+
+static void test_lzxpress_huffman_round_trip(void **state)
+{
+       size_t i;
+       int score = 0;
+       ssize_t compressed_total = 0;
+       ssize_t reference_total = 0;
+       TALLOC_CTX *mem_ctx = talloc_new(NULL);
+       for (i = 0; file_names[i] != NULL; i++) {
+               char filename[200];
+               char *debug_files = NULL;
+               TALLOC_CTX *tmp_ctx = talloc_new(mem_ctx);
+               ssize_t comp_size;
+               struct lzx_pair p = {
+                       .name = file_names[i]
+               };
+               debug_message("-------------------\n");
+               debug_message("%s\n", p.name);
+
+               snprintf(filename, sizeof(filename),
+                        "%s/%s.decomp", DECOMP_DIR, p.name);
+
+               p.decompressed = datablob_from_file(tmp_ctx, filename);
+               assert_non_null(p.decompressed.data);
+
+               snprintf(filename, sizeof(filename),
+                        "%s/%s.lzhuff", COMP_DIR, p.name);
+
+               p.compressed = datablob_from_file(tmp_ctx, filename);
+               if (p.compressed.data == NULL) {
+                       debug_message(
+                               "Could not load %s reference file %s\n",
+                               p.name, filename);
+                       debug_message("%s decompressed %zu\n", p.name,
+                                     p.decompressed.length);
+               } else {
+                       debug_message("%s: reference compressed %zu decomp %zu\n",
+                                     p.name,
+                                     p.compressed.length,
+                                     p.decompressed.length);
+               }
+               if (1) {
+                       /*
+                        * We're going to save copies in /tmp.
+                        */
+                       snprintf(filename, sizeof(filename),
+                                "/tmp/lzxhuffman-%s", p.name);
+                       debug_files = filename;
+               }
+
+               comp_size = attempt_round_trip(mem_ctx, p.decompressed,
+                                              debug_files,
+                                              p.compressed);
+               if (comp_size > 0) {
+                       debug_message("\033[1;32mround trip!\033[0m\n");
+                       score++;
+                       if (p.compressed.length) {
+                               compressed_total += comp_size;
+                               reference_total += p.compressed.length;
+                       }
+               }
+               talloc_free(tmp_ctx);
+       }
+       debug_message("%d/%zu correct\n", score, i);
+       print_message("\033[1;34mtotal compressed size: %zu\033[0m\n",
+                     compressed_total);
+       print_message("total reference size:  %zd \n", reference_total);
+       print_message("diff:                  %7zd \n",
+                     reference_total - compressed_total);
+       print_message("ratio: \033[1;3%dm%.2f\033[0m \n",
+                     2 + (compressed_total >= reference_total),
+                     ((double)compressed_total) / reference_total);
+       /*
+        * Assert that the compression is *about* as good as Windows. Of course
+        * it doesn't matter if we do better, but mysteriously getting better
+        * is usually a sign that something is wrong.
+        *
+        * At the time of writing, compressed_total is 2674004, or 10686 more
+        * than the Windows reference total. That's < 0.5% difference, we're
+        * asserting at 2%.
+        */
+       assert_true(labs(compressed_total - reference_total) <
+                   compressed_total / 50);
+
+       assert_int_equal(score, i);
+       talloc_free(mem_ctx);
+}
+
+/*
+ * Bob Jenkins' Small Fast RNG.
+ *
+ * We don't need it to be this good, but we do need it to be reproduceable
+ * across platforms, which rand() etc aren't.
+ *
+ * http://burtleburtle.net/bob/rand/smallprng.html
+ */
+
+struct jsf_rng {
+       uint32_t a;
+       uint32_t b;
+       uint32_t c;
+       uint32_t d;
+};
+
+#define ROTATE32(x, k) (((x) << (k)) | ((x) >> (32 - (k))))
+
+static uint32_t jsf32(struct jsf_rng *x) {
+       uint32_t e = x->a - ROTATE32(x->b, 27);
+       x->a = x->b ^ ROTATE32(x->c, 17);
+       x->b = x->c + x->d;
+       x->c = x->d + e;
+       x->d = e + x->a;
+       return x->d;
+}
+
+static void jsf32_init(struct jsf_rng *x, uint32_t seed) {
+       size_t i;
+       x->a = 0xf1ea5eed;
+       x->b = x->c = x->d = seed;
+       for (i = 0; i < 20; ++i) {
+               jsf32(x);
+       }
+}
+
+
+static void test_lzxpress_huffman_long_gpl_round_trip(void **state)
+{
+       /*
+        * We use a kind of model-free Markov model to generate a massively
+        * extended pastiche of the GPLv3 (chosen because it is right there in
+        * "COPYING" and won't change often).
+        *
+        * The point is to check a round trip of a very long message with
+        * multiple repetitions on many scales, without having to add a very
+        * large file.
+        */
+       size_t i, j, k;
+       uint8_t c;
+       TALLOC_CTX *mem_ctx = talloc_new(NULL);
+       DATA_BLOB gpl = datablob_from_file(mem_ctx, "COPYING");
+       DATA_BLOB original = data_blob_talloc(mem_ctx, NULL, 5 * 1024 * 1024);
+       DATA_BLOB ref = {0};
+       ssize_t comp_size;
+       struct jsf_rng rng;
+
+       if (gpl.data == NULL) {
+               print_message("could not read COPYING\n");
+               fail();
+       }
+
+       jsf32_init(&rng, 1);
+
+       j = 1;
+       original.data[0] = gpl.data[0];
+       for (i = 1; i < original.length; i++) {
+               size_t m;
+               char p = original.data[i - 1];
+               c = gpl.data[j];
+               original.data[i] = c;
+               j++;
+               m = (j + jsf32(&rng)) % (gpl.length - 50);
+               for (k = m; k < m + 30; k++) {
+                       if (p == gpl.data[k] &&
+                           c == gpl.data[k + 1]) {
+                               j = k + 2;
+                               break;
+                       }
+               }
+               if (j == gpl.length) {
+                       j = 1;
+               }
+       }
+
+       comp_size = attempt_round_trip(mem_ctx, original, "/tmp/gpl", ref);
+       assert_true(comp_size > 0);
+       assert_true(comp_size < original.length);
+
+       talloc_free(mem_ctx);
+}
+
+
+static void test_lzxpress_huffman_long_random_graph_round_trip(void **state)
+{
+       size_t i;
+       TALLOC_CTX *mem_ctx = talloc_new(NULL);
+       DATA_BLOB original = data_blob_talloc(mem_ctx, NULL, 5 * 1024 * 1024);
+       DATA_BLOB ref = {0};
+       /*
+        * There's a random trigram graph, with each pair of sequential bytes
+        * pointing to a successor. This would probably fall into a fairly
+        * simple loop, but we introduce damage into the system, randomly
+        * flipping about 1 bit in 64.
+        *
+        * The result is semi-structured and compressable.
+        */
+       uint8_t *d = original.data;
+       uint8_t *table = talloc_array(mem_ctx, uint8_t, 65536);
+       uint32_t *table32 = (void*)table;
+       ssize_t comp_size;
+       struct jsf_rng rng;
+
+       jsf32_init(&rng, 1);
+       for (i = 0; i < (65536 / 4); i++) {
+               table32[i] = jsf32(&rng);
+       }
+
+       d[0] = 'a';
+       d[1] = 'b';
+
+       for (i = 2; i < original.length; i++) {
+               uint16_t k = (d[i - 2] << 8) | d[i - 1];
+               uint32_t damage = jsf32(&rng) & jsf32(&rng) & jsf32(&rng);
+               damage &= (damage >> 16);
+               k ^= damage & 0xffff;
+               d[i] = table[k];
+       }
+
+       comp_size = attempt_round_trip(mem_ctx, original, "/tmp/random-graph", ref);
+       assert_true(comp_size > 0);
+       assert_true(comp_size < original.length);
+
+       talloc_free(mem_ctx);
+}
+
+
+static void test_lzxpress_huffman_chaos_graph_round_trip(void **state)
+{
+       size_t i;
+       TALLOC_CTX *mem_ctx = talloc_new(NULL);
+       DATA_BLOB original = data_blob_talloc(mem_ctx, NULL, 5 * 1024 * 1024);
+       DATA_BLOB ref = {0};
+       /*
+        * There's a random trigram graph, with each pair of sequential bytes
+        * pointing to a successor. This would probably fall into a fairly
+        * simple loop, but we keep changing the graph. The result is long
+        * periods of stability separatd by bursts of noise.
+        */
+       uint8_t *d = original.data;
+       uint8_t *table = talloc_array(mem_ctx, uint8_t, 65536);
+       uint32_t *table32 = (void*)table;
+       ssize_t comp_size;
+       struct jsf_rng rng;
+
+       jsf32_init(&rng, 1);
+       for (i = 0; i < (65536 / 4); i++) {
+               table32[i] = jsf32(&rng);
+       }
+
+       d[0] = 'a';
+       d[1] = 'b';
+
+       for (i = 2; i < original.length; i++) {
+               uint16_t k = (d[i - 2] << 8) | d[i - 1];
+               uint32_t damage = jsf32(&rng);
+               d[i] = table[k];
+               if ((damage >> 29) == 0) {
+                       uint16_t index = damage & 0xffff;
+                       uint8_t value = (damage >> 16) & 0xff;
+                       table[index] = value;
+               }
+       }
+
+       comp_size = attempt_round_trip(mem_ctx, original, "/tmp/chaos-graph", ref);
+       assert_true(comp_size > 0);
+       assert_true(comp_size < original.length);
+
+       talloc_free(mem_ctx);
+}
+
+
+static void test_lzxpress_huffman_sparse_random_graph_round_trip(void **state)
+{
+       size_t i;
+       TALLOC_CTX *mem_ctx = talloc_new(NULL);
+       DATA_BLOB original = data_blob_talloc(mem_ctx, NULL, 5 * 1024 * 1024);
+       DATA_BLOB ref = {0};
+       /*
+        * There's a random trigram graph, with each pair of sequential bytes
+        * pointing to a successor. This will fall into a fairly simple loops,
+        * but we introduce damage into the system, randomly mangling about 1
+        * byte in 65536.
+        *
+        * The result has very long repetitive runs, which should lead to
+        * oversized blocks.
+        */
+       uint8_t *d = original.data;
+       uint8_t *table = talloc_array(mem_ctx, uint8_t, 65536);
+       uint32_t *table32 = (void*)table;
+       ssize_t comp_size;
+       struct jsf_rng rng;
+
+       jsf32_init(&rng, 3);
+       for (i = 0; i < (65536 / 4); i++) {
+               table32[i] = jsf32(&rng);
+       }
+
+       d[0] = 'a';
+       d[1] = 'b';
+
+       for (i = 2; i < original.length; i++) {
+               uint16_t k = (d[i - 2] << 8) | d[i - 1];
+               uint32_t damage = jsf32(&rng);
+               if ((damage & 0xffff0000) == 0) {
+                       k ^= damage & 0xffff;
+               }
+               d[i] = table[k];
+       }
+
+       comp_size = attempt_round_trip(mem_ctx, original, "/tmp/sparse-random-graph", ref);
+       assert_true(comp_size > 0);
+       assert_true(comp_size < original.length);
+
+       talloc_free(mem_ctx);
+}
+
+
+static void test_lzxpress_huffman_random_noise_round_trip(void **state)
+{
+       size_t i;
+       size_t len = 1024 * 1024;
+       TALLOC_CTX *mem_ctx = talloc_new(NULL);
+       DATA_BLOB original = data_blob_talloc(mem_ctx, NULL, len);
+       DATA_BLOB ref = {0};
+       ssize_t comp_size;
+       /*
+        * We are filling this up with incompressible noise, but we can assert
+        * quite tight bounds on how badly it will fail to compress.
+        *
+        * Specifically, with randomly distributed codes, the Huffman table
+        * should come out as roughly even, averaging 8 bit codes. Then there
+        * will be a 256 byte table every 64k, which is a 1/256 overhead (i.e.
+        * the compressed length will be 257/256 the original *on average*).
+        * We assert it is less than 1 in 200 but more than 1 in 300.
+        */
+       uint32_t *d32 = (uint32_t*)((void*)original.data);
+       struct jsf_rng rng;
+       jsf32_init(&rng, 2);
+
+       for (i = 0; i < (len / 4); i++) {
+               d32[i] = jsf32(&rng);
+       }
+
+       comp_size = attempt_round_trip(mem_ctx, original, "/tmp/random-noise", ref);
+       assert_true(comp_size > 0);
+       assert_true(comp_size > original.length + original.length / 300);
+       assert_true(comp_size < original.length + original.length / 200);
+       debug_message("original size %zu; compressed size %zd; ratio %.3f\n",
+                     len, comp_size, ((double)comp_size) / len);
+
+       talloc_free(mem_ctx);
+}
+
+
+static void test_lzxpress_huffman_overlong_matches(void **state)
+{
+       size_t i, j;
+       TALLOC_CTX *mem_ctx = talloc_new(NULL);
+       DATA_BLOB original = data_blob_talloc(mem_ctx, NULL, 1024 * 1024);
+       DATA_BLOB ref = {0};
+       uint8_t *d = original.data;
+       char filename[300];
+       /*
+        * We are testing with something like "aaaaaaaaaaaaaaaaaaaaaaabbbbb"
+        * where typically the number of "a"s is > 65536, and the number of
+        * "b"s is < 42.
+        */
+       ssize_t na[] = {65535, 65536, 65537, 65559, 65575, 200000, -1};
+       ssize_t nb[] = {1, 2, 20, 39, 40, 41, 42, -1};
+       int score = 0;
+       ssize_t comp_size;
+
+       for (i = 0; na[i] >= 0; i++) {
+               ssize_t a = na[i];
+               memset(d, 'a', a);
+               for (j = 0; nb[j] >= 0; j++) {
+                       ssize_t b = nb[j];
+                       memset(d + a, 'b', b);
+                       original.length = a + b;
+                       snprintf(filename, sizeof(filename),
+                                "/tmp/overlong-%zd-%zd", a, b);
+                       comp_size = attempt_round_trip(mem_ctx,
+                                                      original,
+                                                      filename, ref);
+                       if (comp_size > 0) {
+                               score++;
+                       }
+               }
+       }
+       debug_message("%d/%zu correct\n", score, i * j);
+       assert_int_equal(score, i * j);
+       talloc_free(mem_ctx);
+}
+
+
+static void test_lzxpress_huffman_overlong_matches_abc(void **state)
+{
+       size_t i, j, k;
+       TALLOC_CTX *mem_ctx = talloc_new(NULL);
+       DATA_BLOB original = data_blob_talloc(mem_ctx, NULL, 1024 * 1024);
+       DATA_BLOB ref = {0};
+       uint8_t *d = original.data;
+       char filename[300];
+       /*
+        * We are testing with something like "aaaabbbbcc" where typically
+        * the number of "a"s + "b"s is around 65536, and the number of "c"s
+        * is < 43.
+        */
+       ssize_t nab[] = {1, 21, 32767, 32768, 32769, -1};
+       ssize_t nc[] = {1, 2, 20, 39, 40, 41, 42, -1};
+       int score = 0;
+       ssize_t comp_size;
+
+       for (i = 0; nab[i] >= 0; i++) {
+               ssize_t a = nab[i];
+               memset(d, 'a', a);
+               for (j = 0; nab[j] >= 0; j++) {
+                       ssize_t b = nab[j];
+                       memset(d + a, 'b', b);
+                       for (k = 0; nc[k] >= 0; k++) {
+                               ssize_t c = nc[k];
+                               memset(d + a + b, 'c', c);
+                               original.length = a + b + c;
+                               snprintf(filename, sizeof(filename),
+                                        "/tmp/overlong-abc-%zd-%zd-%zd",
+                                        a, b, c);
+                               comp_size = attempt_round_trip(mem_ctx,
+                                                              original,
+                                                              filename, ref);
+                               if (comp_size > 0) {
+                                       score++;
+                               }
+                       }
+               }
+       }
+       debug_message("%d/%zu correct\n", score, i * j * k);
+       assert_int_equal(score, i * j * k);
+       talloc_free(mem_ctx);
+}
+
+
+static void test_lzxpress_huffman_extremely_compressible_middle(void **state)
+{
+       size_t len = 192 * 1024;
+       TALLOC_CTX *mem_ctx = talloc_new(NULL);
+       DATA_BLOB original = data_blob_talloc(mem_ctx, NULL, len);
+       DATA_BLOB ref = {0};
+       ssize_t comp_size;
+       /*
+        * When a middle block (i.e. not the first and not the last of >= 3),
+        * can be entirely expressed as a match starting in the previous
+        * block, the Huffman tree would end up with 1 element, which does not
+        * work for the code construction. It really wants to use both bits.
+        * So we need to ensure we have some way of dealing with this.
+        */
+       memset(original.data, 'a', 0x10000 - 1);
+       memset(original.data + 0x10000 - 1, 'b', 0x10000 + 1);
+       memset(original.data + 0x20000, 'a', 0x10000);
+       comp_size = attempt_round_trip(mem_ctx, original, "/tmp/compressible-middle", ref);
+       assert_true(comp_size > 0);
+       assert_true(comp_size < 1024);
+       debug_message("original size %zu; compressed size %zd; ratio %.3f\n",
+                     len, comp_size, ((double)comp_size) / len);
+
+       talloc_free(mem_ctx);
+}
+
+
+static void test_lzxpress_huffman_max_length_limit(void **state)
+{
+       size_t len = 65 * 1024 * 1024;
+       TALLOC_CTX *mem_ctx = talloc_new(NULL);
+       DATA_BLOB original = data_blob_talloc_zero(mem_ctx, len);
+       DATA_BLOB ref = {0};
+       ssize_t comp_size;
+       /*
+        * Reputedly Windows has a 64MB limit in the maximum match length it
+        * will encode. We follow this, and test that here with nearly 65 MB
+        * of zeros between two letters; this should be encoded in three
+        * blocks:
+        *
+        * 1. 'a', 64M × '\0'
+        * 2. (1M - 2) × '\0' -- finishing off what would have been the same match
+        * 3. 'b' EOF
+        *
+        * Which we can assert by saying the length is > 768, < 1024.
+        */
+       original.data[0] = 'a';
+       original.data[len - 1] = 'b';
+       comp_size = attempt_round_trip(mem_ctx, original, "/tmp/max-length-limit", ref);
+       assert_true(comp_size > 0x300);
+       assert_true(comp_size < 0x400);
+       debug_message("original size %zu; compressed size %zd; ratio %.3f\n",
+                     len, comp_size, ((double)comp_size) / len);
+
+       talloc_free(mem_ctx);
+}
+
+
+static void test_lzxpress_huffman_short_boring_strings(void **state)
+{
+       size_t len = 64 * 1024;
+       TALLOC_CTX *mem_ctx = talloc_new(NULL);
+       DATA_BLOB original = data_blob_talloc(mem_ctx, NULL, len);
+       DATA_BLOB ref = {0};
+       ssize_t comp_size;
+       ssize_t lengths[] = {
+               1, 2, 20, 39, 40, 41, 42, 256, 270, 273, 274, 1000, 64000, -1};
+       char filename[300];
+       size_t i;
+       /*
+        * How do short repetitive strings work? We're poking at the limit
+        * around which LZ77 comprssion is turned on.
+        *
+        * For this test we don't change the blob memory between runs, just
+        * the declared length.
+        */
+       memset(original.data, 'a', len);
+       for (i = 0; lengths[i] >= 0; i++) {
+               original.length = lengths[i];
+               snprintf(filename, sizeof(filename),
+                        "/tmp/short-boring-%zu",
+                        original.length);
+               comp_size = attempt_round_trip(mem_ctx, original, filename, ref);
+               if (original.length < 41) {
+                       assert_true(comp_size > 256 + original.length / 8);
+               } else if (original.length < 274) {
+                       assert_true(comp_size == 261);
+               } else {
+                       assert_true(comp_size == 263);
+               }
+               assert_true(comp_size < 261 + original.length / 8);
+       }
+       /* let's just show we didn't change the original */
+       for (i = 0; i < len; i++) {
+               if (original.data[i] != 'a') {
+                       fail_msg("input data[%zu] was changed! (%2x, expected %2x)\n",
+                                i, original.data[i], 'a');
+               }
+       }
+
+       talloc_free(mem_ctx);
+}
+
+
+static void test_lzxpress_huffman_compress_empty_or_null(void **state)
+{
+       /*
+        * We expect these to fail with a -1, except the last one, which does
+        * the real thing.
+        */
+       ssize_t ret;
+       const uint8_t *input = bidirectional_pairs[0].decompressed.data;
+       size_t ilen = bidirectional_pairs[0].decompressed.length;
+       size_t olen = bidirectional_pairs[0].compressed.length;
+       uint8_t output[olen];
+       struct lzxhuff_compressor_mem cmp_mem;
+
+       ret = lzxpress_huffman_compress(&cmp_mem, input, 0, output, olen);
+       assert_int_equal(ret, -1LL);
+       ret = lzxpress_huffman_compress(&cmp_mem, input, ilen, output, 0);
+       assert_int_equal(ret, -1LL);
+
+       ret = lzxpress_huffman_compress(&cmp_mem, NULL, ilen, output, olen);
+       assert_int_equal(ret, -1LL);
+       ret = lzxpress_huffman_compress(&cmp_mem, input, ilen, NULL, olen);
+       assert_int_equal(ret, -1LL);
+       ret = lzxpress_huffman_compress(NULL, input, ilen, output, olen);
+       assert_int_equal(ret, -1LL);
+
+       ret = lzxpress_huffman_compress(&cmp_mem, input, ilen, output, olen);
+       assert_int_equal(ret, olen);
+}
+
+
 static void test_lzxpress_huffman_decompress_empty_or_null(void **state)
 {
        /*
@@ -501,10 +1192,24 @@ static void test_lzxpress_huffman_decompress_empty_or_null(void **state)
 
 int main(void) {
        const struct CMUnitTest tests[] = {
+               cmocka_unit_test(test_lzxpress_huffman_short_boring_strings),
+               cmocka_unit_test(test_lzxpress_huffman_max_length_limit),
+               cmocka_unit_test(test_lzxpress_huffman_extremely_compressible_middle),
+               cmocka_unit_test(test_lzxpress_huffman_long_random_graph_round_trip),
+               cmocka_unit_test(test_lzxpress_huffman_chaos_graph_round_trip),
+               cmocka_unit_test(test_lzxpress_huffman_sparse_random_graph_round_trip),
+               cmocka_unit_test(test_lzxpress_huffman_round_trip),
                cmocka_unit_test(test_lzxpress_huffman_decompress_files),
                cmocka_unit_test(test_lzxpress_huffman_decompress_more_compressed_files),
+               cmocka_unit_test(test_lzxpress_huffman_compress),
                cmocka_unit_test(test_lzxpress_huffman_decompress),
+               cmocka_unit_test(test_lzxpress_huffman_long_gpl_round_trip),
+               cmocka_unit_test(test_lzxpress_huffman_long_random_graph_round_trip),
+               cmocka_unit_test(test_lzxpress_huffman_random_noise_round_trip),
+               cmocka_unit_test(test_lzxpress_huffman_overlong_matches_abc),
+               cmocka_unit_test(test_lzxpress_huffman_overlong_matches),
                cmocka_unit_test(test_lzxpress_huffman_decompress_empty_or_null),
+               cmocka_unit_test(test_lzxpress_huffman_compress_empty_or_null),
        };
        if (!isatty(1)) {
                cmocka_set_message_output(CM_OUTPUT_SUBUNIT);