From: Joseph Sutton Date: Fri, 4 Mar 2022 03:21:19 +0000 (+1300) Subject: tests/krb5: Add xpress (de)compression functions X-Git-Tag: talloc-2.4.0~1105 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=fa90633b8109696c923e4559a17b82761f4dc486;p=thirdparty%2Fsamba.git tests/krb5: Add xpress (de)compression functions Signed-off-by: Joseph Sutton Reviewed-by: Andrew Bartlett --- diff --git a/python/samba/tests/krb5/xpress.py b/python/samba/tests/krb5/xpress.py new file mode 100644 index 00000000000..b0fbe26fafb --- /dev/null +++ b/python/samba/tests/krb5/xpress.py @@ -0,0 +1,128 @@ +# Unix SMB/CIFS implementation. +# Copyright (C) Catalyst.Net Ltd 2022 +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# + +from samba.dcerpc import claims + + +def decompress(data, compression_type, uncompressed_size): + if compression_type == claims.CLAIMS_COMPRESSION_FORMAT_NONE: + return data + elif compression_type == claims.CLAIMS_COMPRESSION_FORMAT_XPRESS_HUFF: + return lz77_huffman_decompress(data, uncompressed_size) + else: + raise AssertionError(f'compression type {compression_type} ' + f'not supported') + + +def lz77_huffman_decompress(data, decompressed_size): + def get_16_bits(data, pos): + return data[pos] + (data[pos + 1] << 8) + + output = [] + + symbol_bit_lengths = [] + for pair in data[:256]: + symbol_bit_lengths.append(pair & 0xf) + symbol_bit_lengths.append(pair >> 4) + + # Loop until a decompression terminating condition. + while True: + # Build the decoding table. + decoding_table = [] + for bit_len in range(1, 16): + for symbol in range(0, 512): + encoded_bit_length = symbol_bit_lengths[symbol] + if encoded_bit_length == bit_len: + count = (1 << (15 - bit_len)) + decoding_table.extend([symbol] * count) + + if len(decoding_table) != 2 ** 15: + raise AssertionError(f'Error constructing decoding table (len = ' + f'{len(decoding_table)}') + + # Start at the end of the Huffman table. + current_pos = 256 + + next_bits = get_16_bits(data, current_pos) + current_pos += 2 + + next_bits <<= 16 + next_bits |= get_16_bits(data, current_pos) + current_pos += 2 + + extra_bit_count = 16 + block_end = len(output) + 65536 + + # Loop until a block terminating condition. + while len(output) < block_end: + huffman_symbol = decoding_table[next_bits >> (32 - 15)] + + huffman_symbol_bit_len = symbol_bit_lengths[huffman_symbol] + next_bits <<= huffman_symbol_bit_len + next_bits &= 0xffffffff + extra_bit_count -= huffman_symbol_bit_len + + if extra_bit_count < 0: + next_bits |= get_16_bits(data, current_pos) << -extra_bit_count + extra_bit_count += 16 + current_pos += 2 + + if huffman_symbol < 256: + output.append(huffman_symbol) + + elif (huffman_symbol == 256 and current_pos == len(data) + and len(output) == decompressed_size): + return bytes(output) + else: + huffman_symbol -= 256 + + match_len = huffman_symbol & 0xf + match_offset_bit_len = huffman_symbol >> 4 + + if match_len == 15: + match_len = data[current_pos] + current_pos += 1 + + if match_len == 255: + match_len = get_16_bits(data, current_pos) + current_pos += 2 + + if match_len < 15: + raise AssertionError(f'match_len is too small! ' + f'({match_len} < 15)') + match_len -= 15 + match_len += 15 + match_len += 3 + + match_offset = next_bits >> (32 - match_offset_bit_len) + match_offset |= 1 << match_offset_bit_len + + next_bits <<= match_offset_bit_len + next_bits &= 0xffffffff + + extra_bit_count -= match_offset_bit_len + if extra_bit_count < 0: + next_bits |= ( + get_16_bits(data, current_pos) << -extra_bit_count) + extra_bit_count += 16 + current_pos += 2 + + for i in range(len(output) - match_offset, + len(output) - match_offset + match_len): + output.append(output[i]) + + raise AssertionError('Should not get here')