]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-132983: Add the `compression.zstd` pacakge and tests (#133365)
authorEmma Smith <emma@emmatyping.dev>
Tue, 6 May 2025 00:38:08 +0000 (17:38 -0700)
committerGitHub <noreply@github.com>
Tue, 6 May 2025 00:38:08 +0000 (01:38 +0100)
Co-authored-by: Adam Turner <9087854+AA-Turner@users.noreply.github.com>
Co-authored-by: Gregory P. Smith <greg@krypto.org>
Co-authored-by: Tomas R. <tomas.roun8@gmail.com>
Co-authored-by: Rogdham <contact@rogdham.net>
15 files changed:
Lib/compression/zstd/__init__.py [new file with mode: 0644]
Lib/compression/zstd/_zstdfile.py [new file with mode: 0644]
Lib/shutil.py
Lib/tarfile.py
Lib/test/support/__init__.py
Lib/test/test_shutil.py
Lib/test/test_tarfile.py
Lib/test/test_zipfile/test_core.py
Lib/test/test_zstd.py [new file with mode: 0644]
Lib/zipfile/__init__.py
Makefile.pre.in
Modules/_zstd/_zstdmodule.c
Modules/_zstd/clinic/_zstdmodule.c.h
Modules/_zstd/compressor.c
Modules/_zstd/decompressor.c

diff --git a/Lib/compression/zstd/__init__.py b/Lib/compression/zstd/__init__.py
new file mode 100644 (file)
index 0000000..4f734eb
--- /dev/null
@@ -0,0 +1,234 @@
+"""Python bindings to the Zstandard (zstd) compression library (RFC-8878)."""
+
+__all__ = (
+    # compression.zstd
+    "COMPRESSION_LEVEL_DEFAULT",
+    "compress",
+    "CompressionParameter",
+    "decompress",
+    "DecompressionParameter",
+    "finalize_dict",
+    "get_frame_info",
+    "Strategy",
+    "train_dict",
+
+    # compression.zstd._zstdfile
+    "open",
+    "ZstdFile",
+
+    # _zstd
+    "get_frame_size",
+    "zstd_version",
+    "zstd_version_info",
+    "ZstdCompressor",
+    "ZstdDecompressor",
+    "ZstdDict",
+    "ZstdError",
+)
+
+import _zstd
+import enum
+from _zstd import *
+from compression.zstd._zstdfile import ZstdFile, open, _nbytes
+
+COMPRESSION_LEVEL_DEFAULT = _zstd._compressionLevel_values[0]
+"""The default compression level for Zstandard, currently '3'."""
+
+
+class FrameInfo:
+    """Information about a Zstandard frame."""
+    __slots__ = 'decompressed_size', 'dictionary_id'
+
+    def __init__(self, decompressed_size, dictionary_id):
+        super().__setattr__('decompressed_size', decompressed_size)
+        super().__setattr__('dictionary_id', dictionary_id)
+
+    def __repr__(self):
+        return (f'FrameInfo(decompressed_size={self.decompressed_size}, '
+                f'dictionary_id={self.dictionary_id})')
+
+    def __setattr__(self, name, _):
+        raise AttributeError(f"can't set attribute {name!r}")
+
+
+def get_frame_info(frame_buffer):
+    """Get Zstandard frame information from a frame header.
+
+    *frame_buffer* is a bytes-like object. It should start from the beginning
+    of a frame, and needs to include at least the frame header (6 to 18 bytes).
+
+    The returned FrameInfo object has two attributes.
+    'decompressed_size' is the size in bytes of the data in the frame when
+    decompressed, or None when the decompressed size is unknown.
+    'dictionary_id' is an int in the range (0, 2**32). The special value 0
+    means that the dictionary ID was not recorded in the frame header,
+    the frame may or may not need a dictionary to be decoded,
+    and the ID of such a dictionary is not specified.
+    """
+    return FrameInfo(*_zstd._get_frame_info(frame_buffer))
+
+
+def train_dict(samples, dict_size):
+    """Return a ZstdDict representing a trained Zstandard dictionary.
+
+    *samples* is an iterable of samples, where a sample is a bytes-like
+    object representing a file.
+
+    *dict_size* is the dictionary's maximum size, in bytes.
+    """
+    if not isinstance(dict_size, int):
+        ds_cls = type(dict_size).__qualname__
+        raise TypeError(f'dict_size must be an int object, not {ds_cls!r}.')
+
+    samples = tuple(samples)
+    chunks = b''.join(samples)
+    chunk_sizes = tuple(_nbytes(sample) for sample in samples)
+    if not chunks:
+        raise ValueError("samples contained no data; can't train dictionary.")
+    dict_content = _zstd._train_dict(chunks, chunk_sizes, dict_size)
+    return ZstdDict(dict_content)
+
+
+def finalize_dict(zstd_dict, /, samples, dict_size, level):
+    """Return a ZstdDict representing a finalized Zstandard dictionary.
+
+    Given a custom content as a basis for dictionary, and a set of samples,
+    finalize *zstd_dict* by adding headers and statistics according to the
+    Zstandard dictionary format.
+
+    You may compose an effective dictionary content by hand, which is used as
+    basis dictionary, and use some samples to finalize a dictionary. The basis
+    dictionary may be a "raw content" dictionary. See *is_raw* in ZstdDict.
+
+    *samples* is an iterable of samples, where a sample is a bytes-like object
+    representing a file.
+    *dict_size* is the dictionary's maximum size, in bytes.
+    *level* is the expected compression level. The statistics for each
+    compression level differ, so tuning the dictionary to the compression level
+    can provide improvements.
+    """
+
+    if not isinstance(zstd_dict, ZstdDict):
+        raise TypeError('zstd_dict argument should be a ZstdDict object.')
+    if not isinstance(dict_size, int):
+        raise TypeError('dict_size argument should be an int object.')
+    if not isinstance(level, int):
+        raise TypeError('level argument should be an int object.')
+
+    samples = tuple(samples)
+    chunks = b''.join(samples)
+    chunk_sizes = tuple(_nbytes(sample) for sample in samples)
+    if not chunks:
+        raise ValueError("The samples are empty content, can't finalize the"
+                         "dictionary.")
+    dict_content = _zstd._finalize_dict(zstd_dict.dict_content,
+                                        chunks, chunk_sizes,
+                                        dict_size, level)
+    return ZstdDict(dict_content)
+
+def compress(data, level=None, options=None, zstd_dict=None):
+    """Return Zstandard compressed *data* as bytes.
+
+    *level* is an int specifying the compression level to use, defaulting to
+    COMPRESSION_LEVEL_DEFAULT ('3').
+    *options* is a dict object that contains advanced compression
+    parameters. See CompressionParameter for more on options.
+    *zstd_dict* is a ZstdDict object, a pre-trained Zstandard dictionary. See
+    the function train_dict for how to train a ZstdDict on sample data.
+
+    For incremental compression, use a ZstdCompressor instead.
+    """
+    comp = ZstdCompressor(level=level, options=options, zstd_dict=zstd_dict)
+    return comp.compress(data, mode=ZstdCompressor.FLUSH_FRAME)
+
+def decompress(data, zstd_dict=None, options=None):
+    """Decompress one or more frames of Zstandard compressed *data*.
+
+    *zstd_dict* is a ZstdDict object, a pre-trained Zstandard dictionary. See
+    the function train_dict for how to train a ZstdDict on sample data.
+    *options* is a dict object that contains advanced compression
+    parameters. See DecompressionParameter for more on options.
+
+    For incremental decompression, use a ZstdDecompressor instead.
+    """
+    results = []
+    while True:
+        decomp = ZstdDecompressor(options=options, zstd_dict=zstd_dict)
+        results.append(decomp.decompress(data))
+        if not decomp.eof:
+            raise ZstdError("Compressed data ended before the "
+                            "end-of-stream marker was reached")
+        data = decomp.unused_data
+        if not data:
+            break
+    return b"".join(results)
+
+
+class CompressionParameter(enum.IntEnum):
+    """Compression parameters."""
+
+    compression_level = _zstd._ZSTD_c_compressionLevel
+    window_log = _zstd._ZSTD_c_windowLog
+    hash_log = _zstd._ZSTD_c_hashLog
+    chain_log = _zstd._ZSTD_c_chainLog
+    search_log = _zstd._ZSTD_c_searchLog
+    min_match = _zstd._ZSTD_c_minMatch
+    target_length = _zstd._ZSTD_c_targetLength
+    strategy = _zstd._ZSTD_c_strategy
+
+    enable_long_distance_matching = _zstd._ZSTD_c_enableLongDistanceMatching
+    ldm_hash_log = _zstd._ZSTD_c_ldmHashLog
+    ldm_min_match = _zstd._ZSTD_c_ldmMinMatch
+    ldm_bucket_size_log = _zstd._ZSTD_c_ldmBucketSizeLog
+    ldm_hash_rate_log = _zstd._ZSTD_c_ldmHashRateLog
+
+    content_size_flag = _zstd._ZSTD_c_contentSizeFlag
+    checksum_flag = _zstd._ZSTD_c_checksumFlag
+    dict_id_flag = _zstd._ZSTD_c_dictIDFlag
+
+    nb_workers = _zstd._ZSTD_c_nbWorkers
+    job_size = _zstd._ZSTD_c_jobSize
+    overlap_log = _zstd._ZSTD_c_overlapLog
+
+    def bounds(self):
+        """Return the (lower, upper) int bounds of a compression parameter.
+
+        Both the lower and upper bounds are inclusive.
+        """
+        return _zstd._get_param_bounds(self.value, is_compress=True)
+
+
+class DecompressionParameter(enum.IntEnum):
+    """Decompression parameters."""
+
+    window_log_max = _zstd._ZSTD_d_windowLogMax
+
+    def bounds(self):
+        """Return the (lower, upper) int bounds of a decompression parameter.
+
+        Both the lower and upper bounds are inclusive.
+        """
+        return _zstd._get_param_bounds(self.value, is_compress=False)
+
+
+class Strategy(enum.IntEnum):
+    """Compression strategies, listed from fastest to strongest.
+
+    Note that new strategies might be added in the future.
+    Only the order (from fast to strong) is guaranteed,
+    the numeric value might change.
+    """
+
+    fast = _zstd._ZSTD_fast
+    dfast = _zstd._ZSTD_dfast
+    greedy = _zstd._ZSTD_greedy
+    lazy = _zstd._ZSTD_lazy
+    lazy2 = _zstd._ZSTD_lazy2
+    btlazy2 = _zstd._ZSTD_btlazy2
+    btopt = _zstd._ZSTD_btopt
+    btultra = _zstd._ZSTD_btultra
+    btultra2 = _zstd._ZSTD_btultra2
+
+
+# Check validity of the CompressionParameter & DecompressionParameter types
+_zstd._set_parameter_types(CompressionParameter, DecompressionParameter)
diff --git a/Lib/compression/zstd/_zstdfile.py b/Lib/compression/zstd/_zstdfile.py
new file mode 100644 (file)
index 0000000..fbc9e02
--- /dev/null
@@ -0,0 +1,349 @@
+import io
+from os import PathLike
+from _zstd import (ZstdCompressor, ZstdDecompressor, _ZSTD_DStreamSizes,
+                   ZstdError)
+from compression._common import _streams
+
+__all__ = ("ZstdFile", "open")
+
+_ZSTD_DStreamOutSize = _ZSTD_DStreamSizes[1]
+
+_MODE_CLOSED = 0
+_MODE_READ = 1
+_MODE_WRITE = 2
+
+
+def _nbytes(dat, /):
+    if isinstance(dat, (bytes, bytearray)):
+        return len(dat)
+    with memoryview(dat) as mv:
+        return mv.nbytes
+
+
+class ZstdFile(_streams.BaseStream):
+    """A file-like object providing transparent Zstandard (de)compression.
+
+    A ZstdFile can act as a wrapper for an existing file object, or refer
+    directly to a named file on disk.
+
+    ZstdFile provides a *binary* file interface. Data is read and returned as
+    bytes, and may only be written to objects that support the Buffer Protocol.
+    """
+
+    FLUSH_BLOCK = ZstdCompressor.FLUSH_BLOCK
+    FLUSH_FRAME = ZstdCompressor.FLUSH_FRAME
+
+    def __init__(self, file, /, mode="r", *,
+                 level=None, options=None, zstd_dict=None):
+        """Open a Zstandard compressed file in binary mode.
+
+        *file* can be either an file-like object, or a file name to open.
+
+        *mode* can be "r" for reading (default), "w" for (over)writing, "x" for
+        creating exclusively, or "a" for appending. These can equivalently be
+        given as "rb", "wb", "xb" and "ab" respectively.
+
+        *level* is an optional int specifying the compression level to use,
+        or COMPRESSION_LEVEL_DEFAULT if not given.
+
+        *options* is an optional dict for advanced compression parameters.
+        See CompressionParameter and DecompressionParameter for the possible
+        options.
+
+        *zstd_dict* is an optional ZstdDict object, a pre-trained Zstandard
+        dictionary. See train_dict() to train ZstdDict on sample data.
+        """
+        self._fp = None
+        self._close_fp = False
+        self._mode = _MODE_CLOSED
+        self._buffer = None
+
+        if not isinstance(mode, str):
+            raise ValueError("mode must be a str")
+        if options is not None and not isinstance(options, dict):
+            raise TypeError("options must be a dict or None")
+        mode = mode.removesuffix("b")  # handle rb, wb, xb, ab
+        if mode == "r":
+            if level is not None:
+                raise TypeError("level is illegal in read mode")
+            self._mode = _MODE_READ
+        elif mode in {"w", "a", "x"}:
+            if level is not None and not isinstance(level, int):
+                raise TypeError("level must be int or None")
+            self._mode = _MODE_WRITE
+            self._compressor = ZstdCompressor(level=level, options=options,
+                                              zstd_dict=zstd_dict)
+            self._pos = 0
+        else:
+            raise ValueError(f"Invalid mode: {mode!r}")
+
+        if isinstance(file, (str, bytes, PathLike)):
+            self._fp = io.open(file, f'{mode}b')
+            self._close_fp = True
+        elif ((mode == 'r' and hasattr(file, "read"))
+                or (mode != 'r' and hasattr(file, "write"))):
+            self._fp = file
+        else:
+            raise TypeError("file must be a file-like object "
+                            "or a str, bytes, or PathLike object")
+
+        if self._mode == _MODE_READ:
+            raw = _streams.DecompressReader(
+                self._fp,
+                ZstdDecompressor,
+                trailing_error=ZstdError,
+                zstd_dict=zstd_dict,
+                options=options,
+            )
+            self._buffer = io.BufferedReader(raw)
+
+    def close(self):
+        """Flush and close the file.
+
+        May be called multiple times. Once the file has been closed,
+        any other operation on it will raise ValueError.
+        """
+        if self._fp is None:
+            return
+        try:
+            if self._mode == _MODE_READ:
+                if getattr(self, '_buffer', None):
+                    self._buffer.close()
+                    self._buffer = None
+            elif self._mode == _MODE_WRITE:
+                self.flush(self.FLUSH_FRAME)
+                self._compressor = None
+        finally:
+            self._mode = _MODE_CLOSED
+            try:
+                if self._close_fp:
+                    self._fp.close()
+            finally:
+                self._fp = None
+                self._close_fp = False
+
+    def write(self, data, /):
+        """Write a bytes-like object *data* to the file.
+
+        Returns the number of uncompressed bytes written, which is
+        always the length of data in bytes. Note that due to buffering,
+        the file on disk may not reflect the data written until .flush()
+        or .close() is called.
+        """
+        self._check_can_write()
+
+        length = _nbytes(data)
+
+        compressed = self._compressor.compress(data)
+        self._fp.write(compressed)
+        self._pos += length
+        return length
+
+    def flush(self, mode=FLUSH_BLOCK):
+        """Flush remaining data to the underlying stream.
+
+        The mode argument can be FLUSH_BLOCK or FLUSH_FRAME. Abuse of this
+        method will reduce compression ratio, use it only when necessary.
+
+        If the program is interrupted afterwards, all data can be recovered.
+        To ensure saving to disk, also need to use os.fsync(fd).
+
+        This method does nothing in reading mode.
+        """
+        if self._mode == _MODE_READ:
+            return
+        self._check_not_closed()
+        if mode not in {self.FLUSH_BLOCK, self.FLUSH_FRAME}:
+            raise ValueError("Invalid mode argument, expected either "
+                             "ZstdFile.FLUSH_FRAME or "
+                             "ZstdFile.FLUSH_BLOCK")
+        if self._compressor.last_mode == mode:
+            return
+        # Flush zstd block/frame, and write.
+        data = self._compressor.flush(mode)
+        self._fp.write(data)
+        if hasattr(self._fp, "flush"):
+            self._fp.flush()
+
+    def read(self, size=-1):
+        """Read up to size uncompressed bytes from the file.
+
+        If size is negative or omitted, read until EOF is reached.
+        Returns b"" if the file is already at EOF.
+        """
+        if size is None:
+            size = -1
+        self._check_can_read()
+        return self._buffer.read(size)
+
+    def read1(self, size=-1):
+        """Read up to size uncompressed bytes, while trying to avoid
+        making multiple reads from the underlying stream. Reads up to a
+        buffer's worth of data if size is negative.
+
+        Returns b"" if the file is at EOF.
+        """
+        self._check_can_read()
+        if size < 0:
+            # Note this should *not* be io.DEFAULT_BUFFER_SIZE.
+            # ZSTD_DStreamOutSize is the minimum amount to read guaranteeing
+            # a full block is read.
+            size = _ZSTD_DStreamOutSize
+        return self._buffer.read1(size)
+
+    def readinto(self, b):
+        """Read bytes into b.
+
+        Returns the number of bytes read (0 for EOF).
+        """
+        self._check_can_read()
+        return self._buffer.readinto(b)
+
+    def readinto1(self, b):
+        """Read bytes into b, while trying to avoid making multiple reads
+        from the underlying stream.
+
+        Returns the number of bytes read (0 for EOF).
+        """
+        self._check_can_read()
+        return self._buffer.readinto1(b)
+
+    def readline(self, size=-1):
+        """Read a line of uncompressed bytes from the file.
+
+        The terminating newline (if present) is retained. If size is
+        non-negative, no more than size bytes will be read (in which
+        case the line may be incomplete). Returns b'' if already at EOF.
+        """
+        self._check_can_read()
+        return self._buffer.readline(size)
+
+    def seek(self, offset, whence=io.SEEK_SET):
+        """Change the file position.
+
+        The new position is specified by offset, relative to the
+        position indicated by whence. Possible values for whence are:
+
+            0: start of stream (default): offset must not be negative
+            1: current stream position
+            2: end of stream; offset must not be positive
+
+        Returns the new file position.
+
+        Note that seeking is emulated, so depending on the arguments,
+        this operation may be extremely slow.
+        """
+        self._check_can_read()
+
+        # BufferedReader.seek() checks seekable
+        return self._buffer.seek(offset, whence)
+
+    def peek(self, size=-1):
+        """Return buffered data without advancing the file position.
+
+        Always returns at least one byte of data, unless at EOF.
+        The exact number of bytes returned is unspecified.
+        """
+        # Relies on the undocumented fact that BufferedReader.peek() always
+        # returns at least one byte (except at EOF)
+        self._check_can_read()
+        return self._buffer.peek(size)
+
+    def __next__(self):
+        if ret := self._buffer.readline():
+            return ret
+        raise StopIteration
+
+    def tell(self):
+        """Return the current file position."""
+        self._check_not_closed()
+        if self._mode == _MODE_READ:
+            return self._buffer.tell()
+        elif self._mode == _MODE_WRITE:
+            return self._pos
+
+    def fileno(self):
+        """Return the file descriptor for the underlying file."""
+        self._check_not_closed()
+        return self._fp.fileno()
+
+    @property
+    def name(self):
+        self._check_not_closed()
+        return self._fp.name
+
+    @property
+    def mode(self):
+        return 'wb' if self._mode == _MODE_WRITE else 'rb'
+
+    @property
+    def closed(self):
+        """True if this file is closed."""
+        return self._mode == _MODE_CLOSED
+
+    def seekable(self):
+        """Return whether the file supports seeking."""
+        return self.readable() and self._buffer.seekable()
+
+    def readable(self):
+        """Return whether the file was opened for reading."""
+        self._check_not_closed()
+        return self._mode == _MODE_READ
+
+    def writable(self):
+        """Return whether the file was opened for writing."""
+        self._check_not_closed()
+        return self._mode == _MODE_WRITE
+
+
+def open(file, /, mode="rb", *, level=None, options=None, zstd_dict=None,
+         encoding=None, errors=None, newline=None):
+    """Open a Zstandard compressed file in binary or text mode.
+
+    file can be either a file name (given as a str, bytes, or PathLike object),
+    in which case the named file is opened, or it can be an existing file object
+    to read from or write to.
+
+    The mode parameter can be "r", "rb" (default), "w", "wb", "x", "xb", "a",
+    "ab" for binary mode, or "rt", "wt", "xt", "at" for text mode.
+
+    The level, options, and zstd_dict parameters specify the settings the same
+    as ZstdFile.
+
+    When using read mode (decompression), the options parameter is a dict
+    representing advanced decompression options. The level parameter is not
+    supported in this case. When using write mode (compression), only one of
+    level, an int representing the compression level, or options, a dict
+    representing advanced compression options, may be passed. In both modes,
+    zstd_dict is a ZstdDict instance containing a trained Zstandard dictionary.
+
+    For binary mode, this function is equivalent to the ZstdFile constructor:
+    ZstdFile(filename, mode, ...). In this case, the encoding, errors and
+    newline parameters must not be provided.
+
+    For text mode, an ZstdFile object is created, and wrapped in an
+    io.TextIOWrapper instance with the specified encoding, error handling
+    behavior, and line ending(s).
+    """
+
+    text_mode = "t" in mode
+    mode = mode.replace("t", "")
+
+    if text_mode:
+        if "b" in mode:
+            raise ValueError(f"Invalid mode: {mode!r}")
+    else:
+        if encoding is not None:
+            raise ValueError("Argument 'encoding' not supported in binary mode")
+        if errors is not None:
+            raise ValueError("Argument 'errors' not supported in binary mode")
+        if newline is not None:
+            raise ValueError("Argument 'newline' not supported in binary mode")
+
+    binary_file = ZstdFile(file, mode, level=level, options=options,
+                           zstd_dict=zstd_dict)
+
+    if text_mode:
+        return io.TextIOWrapper(binary_file, encoding, errors, newline)
+    else:
+        return binary_file
index 510ae8c6f22d59725a49a60856b6ff784afb6460..ca0a2ea2f7fa8a0e2e66a23a9cb5ba9a4331280b 100644 (file)
@@ -32,6 +32,13 @@ try:
 except ImportError:
     _LZMA_SUPPORTED = False
 
+try:
+    from compression import zstd
+    del zstd
+    _ZSTD_SUPPORTED = True
+except ImportError:
+    _ZSTD_SUPPORTED = False
+
 _WINDOWS = os.name == 'nt'
 posix = nt = None
 if os.name == 'posix':
@@ -1006,6 +1013,8 @@ def _make_tarball(base_name, base_dir, compress="gzip", verbose=0, dry_run=0,
         tar_compression = 'bz2'
     elif _LZMA_SUPPORTED and compress == 'xz':
         tar_compression = 'xz'
+    elif _ZSTD_SUPPORTED and compress == 'zst':
+        tar_compression = 'zst'
     else:
         raise ValueError("bad value for 'compress', or compression format not "
                          "supported : {0}".format(compress))
@@ -1134,6 +1143,10 @@ if _LZMA_SUPPORTED:
     _ARCHIVE_FORMATS['xztar'] = (_make_tarball, [('compress', 'xz')],
                                 "xz'ed tar-file")
 
+if _ZSTD_SUPPORTED:
+    _ARCHIVE_FORMATS['zstdtar'] = (_make_tarball, [('compress', 'zst')],
+                                  "zstd'ed tar-file")
+
 def get_archive_formats():
     """Returns a list of supported formats for archiving and unarchiving.
 
@@ -1174,7 +1187,7 @@ def make_archive(base_name, format, root_dir=None, base_dir=None, verbose=0,
 
     'base_name' is the name of the file to create, minus any format-specific
     extension; 'format' is the archive format: one of "zip", "tar", "gztar",
-    "bztar", or "xztar".  Or any other registered format.
+    "bztar", "zstdtar", or "xztar".  Or any other registered format.
 
     'root_dir' is a directory that will be the root directory of the
     archive; ie. we typically chdir into 'root_dir' before creating the
@@ -1359,6 +1372,10 @@ if _LZMA_SUPPORTED:
     _UNPACK_FORMATS['xztar'] = (['.tar.xz', '.txz'], _unpack_tarfile, [],
                                 "xz'ed tar-file")
 
+if _ZSTD_SUPPORTED:
+    _UNPACK_FORMATS['zstdtar'] = (['.tar.zst', '.tzst'], _unpack_tarfile, [],
+                                  "zstd'ed tar-file")
+
 def _find_unpack_format(filename):
     for name, info in _UNPACK_FORMATS.items():
         for extension in info[0]:
index 28581f3e7a2692dd5d8fbfe7e0116312565bde0b..c0f5a609b9f42f395df1b8a950249ee2b3100a03 100644 (file)
@@ -399,7 +399,17 @@ class _Stream:
                     self.exception = lzma.LZMAError
                 else:
                     self.cmp = lzma.LZMACompressor(preset=preset)
-
+            elif comptype == "zst":
+                try:
+                    from compression import zstd
+                except ImportError:
+                    raise CompressionError("compression.zstd module is not available") from None
+                if mode == "r":
+                    self.dbuf = b""
+                    self.cmp = zstd.ZstdDecompressor()
+                    self.exception = zstd.ZstdError
+                else:
+                    self.cmp = zstd.ZstdCompressor()
             elif comptype != "tar":
                 raise CompressionError("unknown compression type %r" % comptype)
 
@@ -591,6 +601,8 @@ class _StreamProxy(object):
             return "bz2"
         elif self.buf.startswith((b"\x5d\x00\x00\x80", b"\xfd7zXZ")):
             return "xz"
+        elif self.buf.startswith(b"\x28\xb5\x2f\xfd"):
+            return "zst"
         else:
             return "tar"
 
@@ -1817,11 +1829,13 @@ class TarFile(object):
            'r:gz'       open for reading with gzip compression
            'r:bz2'      open for reading with bzip2 compression
            'r:xz'       open for reading with lzma compression
+           'r:zst'      open for reading with zstd compression
            'a' or 'a:'  open for appending, creating the file if necessary
            'w' or 'w:'  open for writing without compression
            'w:gz'       open for writing with gzip compression
            'w:bz2'      open for writing with bzip2 compression
            'w:xz'       open for writing with lzma compression
+           'w:zst'      open for writing with zstd compression
 
            'x' or 'x:'  create a tarfile exclusively without compression, raise
                         an exception if the file is already created
@@ -1831,16 +1845,20 @@ class TarFile(object):
                         if the file is already created
            'x:xz'       create an lzma compressed tarfile, raise an exception
                         if the file is already created
+           'x:zst'      create a zstd compressed tarfile, raise an exception
+                        if the file is already created
 
            'r|*'        open a stream of tar blocks with transparent compression
            'r|'         open an uncompressed stream of tar blocks for reading
            'r|gz'       open a gzip compressed stream of tar blocks
            'r|bz2'      open a bzip2 compressed stream of tar blocks
            'r|xz'       open an lzma compressed stream of tar blocks
+           'r|zst'      open a zstd compressed stream of tar blocks
            'w|'         open an uncompressed stream for writing
            'w|gz'       open a gzip compressed stream for writing
            'w|bz2'      open a bzip2 compressed stream for writing
            'w|xz'       open an lzma compressed stream for writing
+           'w|zst'      open a zstd compressed stream for writing
         """
 
         if not name and not fileobj:
@@ -2006,12 +2024,48 @@ class TarFile(object):
         t._extfileobj = False
         return t
 
+    @classmethod
+    def zstopen(cls, name, mode="r", fileobj=None, level=None, options=None,
+                zstd_dict=None, **kwargs):
+        """Open zstd compressed tar archive name for reading or writing.
+           Appending is not allowed.
+        """
+        if mode not in ("r", "w", "x"):
+            raise ValueError("mode must be 'r', 'w' or 'x'")
+
+        try:
+            from compression.zstd import ZstdFile, ZstdError
+        except ImportError:
+            raise CompressionError("compression.zstd module is not available") from None
+
+        fileobj = ZstdFile(
+            fileobj or name,
+            mode,
+            level=level,
+            options=options,
+            zstd_dict=zstd_dict
+        )
+
+        try:
+            t = cls.taropen(name, mode, fileobj, **kwargs)
+        except (ZstdError, EOFError) as e:
+            fileobj.close()
+            if mode == 'r':
+                raise ReadError("not a zstd file") from e
+            raise
+        except Exception:
+            fileobj.close()
+            raise
+        t._extfileobj = False
+        return t
+
     # All *open() methods are registered here.
     OPEN_METH = {
         "tar": "taropen",   # uncompressed tar
         "gz":  "gzopen",    # gzip compressed tar
         "bz2": "bz2open",   # bzip2 compressed tar
-        "xz":  "xzopen"     # lzma compressed tar
+        "xz":  "xzopen",    # lzma compressed tar
+        "zst": "zstopen"    # zstd compressed tar
     }
 
     #--------------------------------------------------------------------------
@@ -2963,6 +3017,9 @@ def main():
             '.tbz': 'bz2',
             '.tbz2': 'bz2',
             '.tb2': 'bz2',
+            # zstd
+            '.zst': 'zst',
+            '.tzst': 'zst',
         }
         tar_mode = 'w:' + compressions[ext] if ext in compressions else 'w'
         tar_files = args.create
index 041f1250003b686bf9246adc0f3616748075f55a..c74c3a3190947ba4604c04e9c22a3ca10e2f3475 100644 (file)
@@ -33,7 +33,7 @@ __all__ = [
     "is_resource_enabled", "requires", "requires_freebsd_version",
     "requires_gil_enabled", "requires_linux_version", "requires_mac_ver",
     "check_syntax_error",
-    "requires_gzip", "requires_bz2", "requires_lzma",
+    "requires_gzip", "requires_bz2", "requires_lzma", "requires_zstd",
     "bigmemtest", "bigaddrspacetest", "cpython_only", "get_attribute",
     "requires_IEEE_754", "requires_zlib",
     "has_fork_support", "requires_fork",
@@ -527,6 +527,13 @@ def requires_lzma(reason='requires lzma'):
         lzma = None
     return unittest.skipUnless(lzma, reason)
 
+def requires_zstd(reason='requires zstd'):
+    try:
+        from compression import zstd
+    except ImportError:
+        zstd = None
+    return unittest.skipUnless(zstd, reason)
+
 def has_no_debug_ranges():
     try:
         import _testcapi
index ed01163074a507c32aa60a4031230be3f29c4633..87991fbda4c7df0568f671a001f0c797ada11213 100644 (file)
@@ -2153,6 +2153,10 @@ class TestArchives(BaseTest, unittest.TestCase):
     def test_unpack_archive_bztar(self):
         self.check_unpack_tarball('bztar')
 
+    @support.requires_zstd()
+    def test_unpack_archive_zstdtar(self):
+        self.check_unpack_tarball('zstdtar')
+
     @support.requires_lzma()
     @unittest.skipIf(AIX and not _maxdataOK(), "AIX MAXDATA must be 0x20000000 or larger")
     def test_unpack_archive_xztar(self):
index fcbaf854cc294f4ff787318d758171dfbb14e110..2d9649237a9382bdc8127f19b13abbd9a551af6b 100644 (file)
@@ -38,6 +38,10 @@ try:
     import lzma
 except ImportError:
     lzma = None
+try:
+    from compression import zstd
+except ImportError:
+    zstd = None
 
 def sha256sum(data):
     return sha256(data).hexdigest()
@@ -48,6 +52,7 @@ tarname = support.findfile("testtar.tar", subdir="archivetestdata")
 gzipname = os.path.join(TEMPDIR, "testtar.tar.gz")
 bz2name = os.path.join(TEMPDIR, "testtar.tar.bz2")
 xzname = os.path.join(TEMPDIR, "testtar.tar.xz")
+zstname = os.path.join(TEMPDIR, "testtar.tar.zst")
 tmpname = os.path.join(TEMPDIR, "tmp.tar")
 dotlessname = os.path.join(TEMPDIR, "testtar")
 
@@ -90,6 +95,12 @@ class LzmaTest:
     open = lzma.LZMAFile if lzma else None
     taropen = tarfile.TarFile.xzopen
 
+@support.requires_zstd()
+class ZstdTest:
+    tarname = zstname
+    suffix = 'zst'
+    open = zstd.ZstdFile if zstd else None
+    taropen = tarfile.TarFile.zstopen
 
 class ReadTest(TarTest):
 
@@ -271,6 +282,8 @@ class Bz2UstarReadTest(Bz2Test, UstarReadTest):
 class LzmaUstarReadTest(LzmaTest, UstarReadTest):
     pass
 
+class ZstdUstarReadTest(ZstdTest, UstarReadTest):
+    pass
 
 class ListTest(ReadTest, unittest.TestCase):
 
@@ -375,6 +388,8 @@ class Bz2ListTest(Bz2Test, ListTest):
 class LzmaListTest(LzmaTest, ListTest):
     pass
 
+class ZstdListTest(ZstdTest, ListTest):
+    pass
 
 class CommonReadTest(ReadTest):
 
@@ -837,6 +852,8 @@ class Bz2MiscReadTest(Bz2Test, MiscReadTestBase, unittest.TestCase):
 class LzmaMiscReadTest(LzmaTest, MiscReadTestBase, unittest.TestCase):
     pass
 
+class ZstdMiscReadTest(ZstdTest, MiscReadTestBase, unittest.TestCase):
+    pass
 
 class StreamReadTest(CommonReadTest, unittest.TestCase):
 
@@ -909,6 +926,9 @@ class Bz2StreamReadTest(Bz2Test, StreamReadTest):
 class LzmaStreamReadTest(LzmaTest, StreamReadTest):
     pass
 
+class ZstdStreamReadTest(ZstdTest, StreamReadTest):
+    pass
+
 class TarStreamModeReadTest(StreamModeTest, unittest.TestCase):
 
     def test_stream_mode_no_cache(self):
@@ -925,6 +945,9 @@ class Bz2StreamModeReadTest(Bz2Test, TarStreamModeReadTest):
 class LzmaStreamModeReadTest(LzmaTest, TarStreamModeReadTest):
     pass
 
+class ZstdStreamModeReadTest(ZstdTest, TarStreamModeReadTest):
+    pass
+
 class DetectReadTest(TarTest, unittest.TestCase):
     def _testfunc_file(self, name, mode):
         try:
@@ -986,6 +1009,8 @@ class Bz2DetectReadTest(Bz2Test, DetectReadTest):
 class LzmaDetectReadTest(LzmaTest, DetectReadTest):
     pass
 
+class ZstdDetectReadTest(ZstdTest, DetectReadTest):
+    pass
 
 class GzipBrokenHeaderCorrectException(GzipTest, unittest.TestCase):
     """
@@ -1666,6 +1691,8 @@ class Bz2WriteTest(Bz2Test, WriteTest):
 class LzmaWriteTest(LzmaTest, WriteTest):
     pass
 
+class ZstdWriteTest(ZstdTest, WriteTest):
+    pass
 
 class StreamWriteTest(WriteTestBase, unittest.TestCase):
 
@@ -1727,6 +1754,9 @@ class Bz2StreamWriteTest(Bz2Test, StreamWriteTest):
 class LzmaStreamWriteTest(LzmaTest, StreamWriteTest):
     decompressor = lzma.LZMADecompressor if lzma else None
 
+class ZstdStreamWriteTest(ZstdTest, StreamWriteTest):
+    decompressor = zstd.ZstdDecompressor if zstd else None
+
 class _CompressedWriteTest(TarTest):
     # This is not actually a standalone test.
     # It does not inherit WriteTest because it only makes sense with gz,bz2
@@ -2042,6 +2072,14 @@ class LzmaCreateTest(LzmaTest, CreateTest):
             tobj.add(self.file_path)
 
 
+class ZstdCreateTest(ZstdTest, CreateTest):
+
+    # Unlike gz and bz2, zstd uses the level keyword instead of compresslevel.
+    # It does not allow for level to be specified when reading.
+    def test_create_with_level(self):
+        with tarfile.open(tmpname, self.mode, level=1) as tobj:
+            tobj.add(self.file_path)
+
 class CreateWithXModeTest(CreateTest):
 
     prefix = "x"
@@ -2523,6 +2561,8 @@ class Bz2AppendTest(Bz2Test, AppendTestBase, unittest.TestCase):
 class LzmaAppendTest(LzmaTest, AppendTestBase, unittest.TestCase):
     pass
 
+class ZstdAppendTest(ZstdTest, AppendTestBase, unittest.TestCase):
+    pass
 
 class LimitsTest(unittest.TestCase):
 
@@ -2835,7 +2875,7 @@ class CommandLineTest(unittest.TestCase):
                  support.findfile('tokenize_tests-no-coding-cookie-'
                                   'and-utf8-bom-sig-only.txt',
                                   subdir='tokenizedata')]
-        for filetype in (GzipTest, Bz2Test, LzmaTest):
+        for filetype in (GzipTest, Bz2Test, LzmaTest, ZstdTest):
             if not filetype.open:
                 continue
             try:
@@ -4257,7 +4297,7 @@ def setUpModule():
         data = fobj.read()
 
     # Create compressed tarfiles.
-    for c in GzipTest, Bz2Test, LzmaTest:
+    for c in GzipTest, Bz2Test, LzmaTest, ZstdTest:
         if c.open:
             os_helper.unlink(c.tarname)
             testtarnames.append(c.tarname)
index 4c9d9f4b56235de9f8a00f05684cdaf518d7fe48..ae898150658565c042195be435fc7e9bd48f3857 100644 (file)
@@ -23,7 +23,7 @@ from test import archiver_tests
 from test.support import script_helper, os_helper
 from test.support import (
     findfile, requires_zlib, requires_bz2, requires_lzma,
-    captured_stdout, captured_stderr, requires_subprocess,
+    requires_zstd, captured_stdout, captured_stderr, requires_subprocess,
     cpython_only
 )
 from test.support.os_helper import (
@@ -702,6 +702,10 @@ class LzmaTestsWithSourceFile(AbstractTestsWithSourceFile,
                               unittest.TestCase):
     compression = zipfile.ZIP_LZMA
 
+@requires_zstd()
+class ZstdTestsWithSourceFile(AbstractTestsWithSourceFile,
+                              unittest.TestCase):
+    compression = zipfile.ZIP_ZSTANDARD
 
 class AbstractTestZip64InSmallFiles:
     # These tests test the ZIP64 functionality without using large files,
@@ -1279,6 +1283,10 @@ class LzmaTestZip64InSmallFiles(AbstractTestZip64InSmallFiles,
                                 unittest.TestCase):
     compression = zipfile.ZIP_LZMA
 
+@requires_zstd()
+class ZstdTestZip64InSmallFiles(AbstractTestZip64InSmallFiles,
+                                unittest.TestCase):
+    compression = zipfile.ZIP_ZSTANDARD
 
 class AbstractWriterTests:
 
@@ -1348,6 +1356,9 @@ class Bzip2WriterTests(AbstractWriterTests, unittest.TestCase):
 class LzmaWriterTests(AbstractWriterTests, unittest.TestCase):
     compression = zipfile.ZIP_LZMA
 
+@requires_zstd()
+class ZstdWriterTests(AbstractWriterTests, unittest.TestCase):
+    compression = zipfile.ZIP_ZSTANDARD
 
 class PyZipFileTests(unittest.TestCase):
     def assertCompiledIn(self, name, namelist):
@@ -2678,6 +2689,17 @@ class LzmaBadCrcTests(AbstractBadCrcTests, unittest.TestCase):
         b'ePK\x05\x06\x00\x00\x00\x00\x01\x00\x01\x003\x00\x00'
         b'\x00>\x00\x00\x00\x00\x00')
 
+@requires_zstd()
+class ZstdBadCrcTests(AbstractBadCrcTests, unittest.TestCase):
+    compression = zipfile.ZIP_ZSTANDARD
+    zip_with_bad_crc = (
+        b'PK\x03\x04?\x00\x00\x00]\x00\x00\x00!\x00V\xb1\x17J\x14\x00'
+        b'\x00\x00\x0b\x00\x00\x00\x05\x00\x00\x00afile(\xb5/\xfd\x00'
+        b'XY\x00\x00Hello WorldPK\x01\x02?\x03?\x00\x00\x00]\x00\x00\x00'
+        b'!\x00V\xb0\x17J\x14\x00\x00\x00\x0b\x00\x00\x00\x05\x00\x00\x00'
+        b'\x00\x00\x00\x00\x00\x00\x00\x00\x80\x01\x00\x00\x00\x00afilePK'
+        b'\x05\x06\x00\x00\x00\x00\x01\x00\x01\x003\x00\x00\x007\x00\x00\x00'
+        b'\x00\x00')
 
 class DecryptionTests(unittest.TestCase):
     """Check that ZIP decryption works. Since the library does not
@@ -2905,6 +2927,10 @@ class LzmaTestsWithRandomBinaryFiles(AbstractTestsWithRandomBinaryFiles,
                                      unittest.TestCase):
     compression = zipfile.ZIP_LZMA
 
+@requires_zstd()
+class ZstdTestsWithRandomBinaryFiles(AbstractTestsWithRandomBinaryFiles,
+                                     unittest.TestCase):
+    compression = zipfile.ZIP_ZSTANDARD
 
 # Provide the tell() method but not seek()
 class Tellable:
diff --git a/Lib/test/test_zstd.py b/Lib/test/test_zstd.py
new file mode 100644 (file)
index 0000000..f4a2537
--- /dev/null
@@ -0,0 +1,2507 @@
+import array
+import gc
+import io
+import pathlib
+import random
+import re
+import os
+import unittest
+import tempfile
+import threading
+
+from test.support.import_helper import import_module
+from test.support import threading_helper
+from test.support import _1M
+from test.support import Py_GIL_DISABLED
+
+_zstd = import_module("_zstd")
+zstd = import_module("compression.zstd")
+
+from compression.zstd import (
+    open,
+    compress,
+    decompress,
+    ZstdCompressor,
+    ZstdDecompressor,
+    ZstdDict,
+    ZstdError,
+    zstd_version,
+    zstd_version_info,
+    COMPRESSION_LEVEL_DEFAULT,
+    get_frame_info,
+    get_frame_size,
+    finalize_dict,
+    train_dict,
+    CompressionParameter,
+    DecompressionParameter,
+    Strategy,
+    ZstdFile,
+)
+
+_1K = 1024
+_130_1K = 130 * _1K
+DICT_SIZE1 = 3*_1K
+
+DAT_130K_D = None
+DAT_130K_C = None
+
+DECOMPRESSED_DAT = None
+COMPRESSED_DAT = None
+
+DECOMPRESSED_100_PLUS_32KB = None
+COMPRESSED_100_PLUS_32KB = None
+
+SKIPPABLE_FRAME = None
+
+THIS_FILE_BYTES = None
+THIS_FILE_STR = None
+COMPRESSED_THIS_FILE = None
+
+COMPRESSED_BOGUS = None
+
+SAMPLES = None
+
+TRAINED_DICT = None
+
+SUPPORT_MULTITHREADING = False
+
+def setUpModule():
+    global SUPPORT_MULTITHREADING
+    SUPPORT_MULTITHREADING = CompressionParameter.nb_workers.bounds() != (0, 0)
+    # uncompressed size 130KB, more than a zstd block.
+    # with a frame epilogue, 4 bytes checksum.
+    global DAT_130K_D
+    DAT_130K_D = bytes([random.randint(0, 127) for _ in range(130*_1K)])
+
+    global DAT_130K_C
+    DAT_130K_C = compress(DAT_130K_D, options={CompressionParameter.checksum_flag:1})
+
+    global DECOMPRESSED_DAT
+    DECOMPRESSED_DAT = b'abcdefg123456' * 1000
+
+    global COMPRESSED_DAT
+    COMPRESSED_DAT = compress(DECOMPRESSED_DAT)
+
+    global DECOMPRESSED_100_PLUS_32KB
+    DECOMPRESSED_100_PLUS_32KB = b'a' * (100 + 32*_1K)
+
+    global COMPRESSED_100_PLUS_32KB
+    COMPRESSED_100_PLUS_32KB = compress(DECOMPRESSED_100_PLUS_32KB)
+
+    global SKIPPABLE_FRAME
+    SKIPPABLE_FRAME = (0x184D2A50).to_bytes(4, byteorder='little') + \
+                      (32*_1K).to_bytes(4, byteorder='little') + \
+                      b'a' * (32*_1K)
+
+    global THIS_FILE_BYTES, THIS_FILE_STR
+    with io.open(os.path.abspath(__file__), 'rb') as f:
+        THIS_FILE_BYTES = f.read()
+        THIS_FILE_BYTES = re.sub(rb'\r?\n', rb'\n', THIS_FILE_BYTES)
+        THIS_FILE_STR = THIS_FILE_BYTES.decode('utf-8')
+
+    global COMPRESSED_THIS_FILE
+    COMPRESSED_THIS_FILE = compress(THIS_FILE_BYTES)
+
+    global COMPRESSED_BOGUS
+    COMPRESSED_BOGUS = DECOMPRESSED_DAT
+
+    # dict data
+    words = [b'red', b'green', b'yellow', b'black', b'withe', b'blue',
+             b'lilac', b'purple', b'navy', b'glod', b'silver', b'olive',
+             b'dog', b'cat', b'tiger', b'lion', b'fish', b'bird']
+    lst = []
+    for i in range(300):
+        sample = [b'%s = %d' % (random.choice(words), random.randrange(100))
+                  for j in range(20)]
+        sample = b'\n'.join(sample)
+
+        lst.append(sample)
+    global SAMPLES
+    SAMPLES = lst
+    assert len(SAMPLES) > 10
+
+    global TRAINED_DICT
+    TRAINED_DICT = train_dict(SAMPLES, 3*_1K)
+    assert len(TRAINED_DICT.dict_content) <= 3*_1K
+
+
+class FunctionsTestCase(unittest.TestCase):
+
+    def test_version(self):
+        s = ".".join((str(i) for i in zstd_version_info))
+        self.assertEqual(s, zstd_version)
+
+    def test_compressionLevel_values(self):
+        min, max = CompressionParameter.compression_level.bounds()
+        self.assertIs(type(COMPRESSION_LEVEL_DEFAULT), int)
+        self.assertIs(type(min), int)
+        self.assertIs(type(max), int)
+        self.assertLess(min, max)
+
+    def test_roundtrip_default(self):
+        raw_dat = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6]
+        dat1 = compress(raw_dat)
+        dat2 = decompress(dat1)
+        self.assertEqual(dat2, raw_dat)
+
+    def test_roundtrip_level(self):
+        raw_dat = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6]
+        level_min, level_max = CompressionParameter.compression_level.bounds()
+
+        for level in range(max(-20, level_min), level_max + 1):
+            dat1 = compress(raw_dat, level)
+            dat2 = decompress(dat1)
+            self.assertEqual(dat2, raw_dat)
+
+    def test_get_frame_info(self):
+        # no dict
+        info = get_frame_info(COMPRESSED_100_PLUS_32KB[:20])
+        self.assertEqual(info.decompressed_size, 32 * _1K + 100)
+        self.assertEqual(info.dictionary_id, 0)
+
+        # use dict
+        dat = compress(b"a" * 345, zstd_dict=TRAINED_DICT)
+        info = get_frame_info(dat)
+        self.assertEqual(info.decompressed_size, 345)
+        self.assertEqual(info.dictionary_id, TRAINED_DICT.dict_id)
+
+        with self.assertRaisesRegex(ZstdError, "not less than the frame header"):
+            get_frame_info(b"aaaaaaaaaaaaaa")
+
+    def test_get_frame_size(self):
+        size = get_frame_size(COMPRESSED_100_PLUS_32KB)
+        self.assertEqual(size, len(COMPRESSED_100_PLUS_32KB))
+
+        with self.assertRaisesRegex(ZstdError, "not less than this complete frame"):
+            get_frame_size(b"aaaaaaaaaaaaaa")
+
+    def test_decompress_2x130_1K(self):
+        decompressed_size = get_frame_info(DAT_130K_C).decompressed_size
+        self.assertEqual(decompressed_size, _130_1K)
+
+        dat = decompress(DAT_130K_C + DAT_130K_C)
+        self.assertEqual(len(dat), 2 * _130_1K)
+
+
+class CompressorTestCase(unittest.TestCase):
+
+    def test_simple_compress_bad_args(self):
+        # ZstdCompressor
+        self.assertRaises(TypeError, ZstdCompressor, [])
+        self.assertRaises(TypeError, ZstdCompressor, level=3.14)
+        self.assertRaises(TypeError, ZstdCompressor, level="abc")
+        self.assertRaises(TypeError, ZstdCompressor, options=b"abc")
+
+        self.assertRaises(TypeError, ZstdCompressor, zstd_dict=123)
+        self.assertRaises(TypeError, ZstdCompressor, zstd_dict=b"abcd1234")
+        self.assertRaises(TypeError, ZstdCompressor, zstd_dict={1: 2, 3: 4})
+
+        with self.assertRaises(ValueError):
+            ZstdCompressor(2**31)
+        with self.assertRaises(ValueError):
+            ZstdCompressor(options={2**31: 100})
+
+        with self.assertRaises(ZstdError):
+            ZstdCompressor(options={CompressionParameter.window_log: 100})
+        with self.assertRaises(ZstdError):
+            ZstdCompressor(options={3333: 100})
+
+        # Method bad arguments
+        zc = ZstdCompressor()
+        self.assertRaises(TypeError, zc.compress)
+        self.assertRaises((TypeError, ValueError), zc.compress, b"foo", b"bar")
+        self.assertRaises(TypeError, zc.compress, "str")
+        self.assertRaises((TypeError, ValueError), zc.flush, b"foo")
+        self.assertRaises(TypeError, zc.flush, b"blah", 1)
+
+        self.assertRaises(ValueError, zc.compress, b'', -1)
+        self.assertRaises(ValueError, zc.compress, b'', 3)
+        self.assertRaises(ValueError, zc.flush, zc.CONTINUE) # 0
+        self.assertRaises(ValueError, zc.flush, 3)
+
+        zc.compress(b'')
+        zc.compress(b'', zc.CONTINUE)
+        zc.compress(b'', zc.FLUSH_BLOCK)
+        zc.compress(b'', zc.FLUSH_FRAME)
+        empty = zc.flush()
+        zc.flush(zc.FLUSH_BLOCK)
+        zc.flush(zc.FLUSH_FRAME)
+
+    def test_compress_parameters(self):
+        d = {CompressionParameter.compression_level : 10,
+
+             CompressionParameter.window_log : 12,
+             CompressionParameter.hash_log : 10,
+             CompressionParameter.chain_log : 12,
+             CompressionParameter.search_log : 12,
+             CompressionParameter.min_match : 4,
+             CompressionParameter.target_length : 12,
+             CompressionParameter.strategy : Strategy.lazy,
+
+             CompressionParameter.enable_long_distance_matching : 1,
+             CompressionParameter.ldm_hash_log : 12,
+             CompressionParameter.ldm_min_match : 11,
+             CompressionParameter.ldm_bucket_size_log : 5,
+             CompressionParameter.ldm_hash_rate_log : 12,
+
+             CompressionParameter.content_size_flag : 1,
+             CompressionParameter.checksum_flag : 1,
+             CompressionParameter.dict_id_flag : 0,
+
+             CompressionParameter.nb_workers : 2 if SUPPORT_MULTITHREADING else 0,
+             CompressionParameter.job_size : 5*_1M if SUPPORT_MULTITHREADING else 0,
+             CompressionParameter.overlap_log : 9 if SUPPORT_MULTITHREADING else 0,
+             }
+        ZstdCompressor(options=d)
+
+        # larger than signed int, ValueError
+        d1 = d.copy()
+        d1[CompressionParameter.ldm_bucket_size_log] = 2**31
+        self.assertRaises(ValueError, ZstdCompressor, options=d1)
+
+        # clamp compressionLevel
+        level_min, level_max = CompressionParameter.compression_level.bounds()
+        compress(b'', level_max+1)
+        compress(b'', level_min-1)
+
+        compress(b'', options={CompressionParameter.compression_level:level_max+1})
+        compress(b'', options={CompressionParameter.compression_level:level_min-1})
+
+        # zstd lib doesn't support MT compression
+        if not SUPPORT_MULTITHREADING:
+            with self.assertRaises(ZstdError):
+                ZstdCompressor(options={CompressionParameter.nb_workers:4})
+            with self.assertRaises(ZstdError):
+                ZstdCompressor(options={CompressionParameter.job_size:4})
+            with self.assertRaises(ZstdError):
+                ZstdCompressor(options={CompressionParameter.overlap_log:4})
+
+        # out of bounds error msg
+        option = {CompressionParameter.window_log:100}
+        with self.assertRaisesRegex(ZstdError,
+                (r'Error when setting zstd compression parameter "window_log", '
+                 r'it should \d+ <= value <= \d+, provided value is 100\. '
+                 r'\(zstd v\d\.\d\.\d, (?:32|64)-bit build\)')):
+            compress(b'', options=option)
+
+    def test_unknown_compression_parameter(self):
+        KEY = 100001234
+        option = {CompressionParameter.compression_level: 10,
+                  KEY: 200000000}
+        pattern = r'Zstd compression parameter.*?"unknown parameter \(key %d\)"' \
+                  % KEY
+        with self.assertRaisesRegex(ZstdError, pattern):
+            ZstdCompressor(options=option)
+
+    @unittest.skipIf(not SUPPORT_MULTITHREADING,
+                     "zstd build doesn't support multi-threaded compression")
+    def test_zstd_multithread_compress(self):
+        size = 40*_1M
+        b = THIS_FILE_BYTES * (size // len(THIS_FILE_BYTES))
+
+        options = {CompressionParameter.compression_level : 4,
+                   CompressionParameter.nb_workers : 2}
+
+        # compress()
+        dat1 = compress(b, options=options)
+        dat2 = decompress(dat1)
+        self.assertEqual(dat2, b)
+
+        # ZstdCompressor
+        c = ZstdCompressor(options=options)
+        dat1 = c.compress(b, c.CONTINUE)
+        dat2 = c.compress(b, c.FLUSH_BLOCK)
+        dat3 = c.compress(b, c.FLUSH_FRAME)
+        dat4 = decompress(dat1+dat2+dat3)
+        self.assertEqual(dat4, b * 3)
+
+        # ZstdFile
+        with ZstdFile(io.BytesIO(), 'w', options=options) as f:
+            f.write(b)
+
+    def test_compress_flushblock(self):
+        point = len(THIS_FILE_BYTES) // 2
+
+        c = ZstdCompressor()
+        self.assertEqual(c.last_mode, c.FLUSH_FRAME)
+        dat1 = c.compress(THIS_FILE_BYTES[:point])
+        self.assertEqual(c.last_mode, c.CONTINUE)
+        dat1 += c.compress(THIS_FILE_BYTES[point:], c.FLUSH_BLOCK)
+        self.assertEqual(c.last_mode, c.FLUSH_BLOCK)
+        dat2 = c.flush()
+        pattern = "Compressed data ended before the end-of-stream marker"
+        with self.assertRaisesRegex(ZstdError, pattern):
+            decompress(dat1)
+
+        dat3 = decompress(dat1 + dat2)
+
+        self.assertEqual(dat3, THIS_FILE_BYTES)
+
+    def test_compress_flushframe(self):
+        # test compress & decompress
+        point = len(THIS_FILE_BYTES) // 2
+
+        c = ZstdCompressor()
+
+        dat1 = c.compress(THIS_FILE_BYTES[:point])
+        self.assertEqual(c.last_mode, c.CONTINUE)
+
+        dat1 += c.compress(THIS_FILE_BYTES[point:], c.FLUSH_FRAME)
+        self.assertEqual(c.last_mode, c.FLUSH_FRAME)
+
+        nt = get_frame_info(dat1)
+        self.assertEqual(nt.decompressed_size, None) # no content size
+
+        dat2 = decompress(dat1)
+
+        self.assertEqual(dat2, THIS_FILE_BYTES)
+
+        # single .FLUSH_FRAME mode has content size
+        c = ZstdCompressor()
+        dat = c.compress(THIS_FILE_BYTES, mode=c.FLUSH_FRAME)
+        self.assertEqual(c.last_mode, c.FLUSH_FRAME)
+
+        nt = get_frame_info(dat)
+        self.assertEqual(nt.decompressed_size, len(THIS_FILE_BYTES))
+
+    def test_compress_empty(self):
+        # output empty content frame
+        self.assertNotEqual(compress(b''), b'')
+
+        c = ZstdCompressor()
+        self.assertNotEqual(c.compress(b'', c.FLUSH_FRAME), b'')
+
+class DecompressorTestCase(unittest.TestCase):
+
+    def test_simple_decompress_bad_args(self):
+        # ZstdDecompressor
+        self.assertRaises(TypeError, ZstdDecompressor, ())
+        self.assertRaises(TypeError, ZstdDecompressor, zstd_dict=123)
+        self.assertRaises(TypeError, ZstdDecompressor, zstd_dict=b'abc')
+        self.assertRaises(TypeError, ZstdDecompressor, zstd_dict={1:2, 3:4})
+
+        self.assertRaises(TypeError, ZstdDecompressor, options=123)
+        self.assertRaises(TypeError, ZstdDecompressor, options='abc')
+        self.assertRaises(TypeError, ZstdDecompressor, options=b'abc')
+
+        with self.assertRaises(ValueError):
+            ZstdDecompressor(options={2**31 : 100})
+
+        with self.assertRaises(ZstdError):
+            ZstdDecompressor(options={DecompressionParameter.window_log_max:100})
+        with self.assertRaises(ZstdError):
+            ZstdDecompressor(options={3333 : 100})
+
+        empty = compress(b'')
+        lzd = ZstdDecompressor()
+        self.assertRaises(TypeError, lzd.decompress)
+        self.assertRaises(TypeError, lzd.decompress, b"foo", b"bar")
+        self.assertRaises(TypeError, lzd.decompress, "str")
+        lzd.decompress(empty)
+
+    def test_decompress_parameters(self):
+        d = {DecompressionParameter.window_log_max : 15}
+        ZstdDecompressor(options=d)
+
+        # larger than signed int, ValueError
+        d1 = d.copy()
+        d1[DecompressionParameter.window_log_max] = 2**31
+        self.assertRaises(ValueError, ZstdDecompressor, None, d1)
+
+        # out of bounds error msg
+        options = {DecompressionParameter.window_log_max:100}
+        with self.assertRaisesRegex(ZstdError,
+                (r'Error when setting zstd decompression parameter "window_log_max", '
+                 r'it should \d+ <= value <= \d+, provided value is 100\. '
+                 r'\(zstd v\d\.\d\.\d, (?:32|64)-bit build\)')):
+            decompress(b'', options=options)
+
+    def test_unknown_decompression_parameter(self):
+        KEY = 100001234
+        options = {DecompressionParameter.window_log_max: DecompressionParameter.window_log_max.bounds()[1],
+                  KEY: 200000000}
+        pattern = r'Zstd decompression parameter.*?"unknown parameter \(key %d\)"' \
+                  % KEY
+        with self.assertRaisesRegex(ZstdError, pattern):
+            ZstdDecompressor(options=options)
+
+    def test_decompress_epilogue_flags(self):
+        # DAT_130K_C has a 4 bytes checksum at frame epilogue
+
+        # full unlimited
+        d = ZstdDecompressor()
+        dat = d.decompress(DAT_130K_C)
+        self.assertEqual(len(dat), _130_1K)
+        self.assertFalse(d.needs_input)
+
+        with self.assertRaises(EOFError):
+            dat = d.decompress(b'')
+
+        # full limited
+        d = ZstdDecompressor()
+        dat = d.decompress(DAT_130K_C, _130_1K)
+        self.assertEqual(len(dat), _130_1K)
+        self.assertFalse(d.needs_input)
+
+        with self.assertRaises(EOFError):
+            dat = d.decompress(b'', 0)
+
+        # [:-4] unlimited
+        d = ZstdDecompressor()
+        dat = d.decompress(DAT_130K_C[:-4])
+        self.assertEqual(len(dat), _130_1K)
+        self.assertTrue(d.needs_input)
+
+        dat = d.decompress(b'')
+        self.assertEqual(len(dat), 0)
+        self.assertTrue(d.needs_input)
+
+        # [:-4] limited
+        d = ZstdDecompressor()
+        dat = d.decompress(DAT_130K_C[:-4], _130_1K)
+        self.assertEqual(len(dat), _130_1K)
+        self.assertFalse(d.needs_input)
+
+        dat = d.decompress(b'', 0)
+        self.assertEqual(len(dat), 0)
+        self.assertFalse(d.needs_input)
+
+        # [:-3] unlimited
+        d = ZstdDecompressor()
+        dat = d.decompress(DAT_130K_C[:-3])
+        self.assertEqual(len(dat), _130_1K)
+        self.assertTrue(d.needs_input)
+
+        dat = d.decompress(b'')
+        self.assertEqual(len(dat), 0)
+        self.assertTrue(d.needs_input)
+
+        # [:-3] limited
+        d = ZstdDecompressor()
+        dat = d.decompress(DAT_130K_C[:-3], _130_1K)
+        self.assertEqual(len(dat), _130_1K)
+        self.assertFalse(d.needs_input)
+
+        dat = d.decompress(b'', 0)
+        self.assertEqual(len(dat), 0)
+        self.assertFalse(d.needs_input)
+
+        # [:-1] unlimited
+        d = ZstdDecompressor()
+        dat = d.decompress(DAT_130K_C[:-1])
+        self.assertEqual(len(dat), _130_1K)
+        self.assertTrue(d.needs_input)
+
+        dat = d.decompress(b'')
+        self.assertEqual(len(dat), 0)
+        self.assertTrue(d.needs_input)
+
+        # [:-1] limited
+        d = ZstdDecompressor()
+        dat = d.decompress(DAT_130K_C[:-1], _130_1K)
+        self.assertEqual(len(dat), _130_1K)
+        self.assertFalse(d.needs_input)
+
+        dat = d.decompress(b'', 0)
+        self.assertEqual(len(dat), 0)
+        self.assertFalse(d.needs_input)
+
+    def test_decompressor_arg(self):
+        zd = ZstdDict(b'12345678', True)
+
+        with self.assertRaises(TypeError):
+            d = ZstdDecompressor(zstd_dict={})
+
+        with self.assertRaises(TypeError):
+            d = ZstdDecompressor(options=zd)
+
+        ZstdDecompressor()
+        ZstdDecompressor(zd, {})
+        ZstdDecompressor(zstd_dict=zd, options={DecompressionParameter.window_log_max:25})
+
+    def test_decompressor_1(self):
+        # empty
+        d = ZstdDecompressor()
+        dat = d.decompress(b'')
+
+        self.assertEqual(dat, b'')
+        self.assertFalse(d.eof)
+
+        # 130_1K full
+        d = ZstdDecompressor()
+        dat = d.decompress(DAT_130K_C)
+
+        self.assertEqual(len(dat), _130_1K)
+        self.assertTrue(d.eof)
+        self.assertFalse(d.needs_input)
+
+        # 130_1K full, limit output
+        d = ZstdDecompressor()
+        dat = d.decompress(DAT_130K_C, _130_1K)
+
+        self.assertEqual(len(dat), _130_1K)
+        self.assertTrue(d.eof)
+        self.assertFalse(d.needs_input)
+
+        # 130_1K, without 4 bytes checksum
+        d = ZstdDecompressor()
+        dat = d.decompress(DAT_130K_C[:-4])
+
+        self.assertEqual(len(dat), _130_1K)
+        self.assertFalse(d.eof)
+        self.assertTrue(d.needs_input)
+
+        # above, limit output
+        d = ZstdDecompressor()
+        dat = d.decompress(DAT_130K_C[:-4], _130_1K)
+
+        self.assertEqual(len(dat), _130_1K)
+        self.assertFalse(d.eof)
+        self.assertFalse(d.needs_input)
+
+        # full, unused_data
+        TRAIL = b'89234893abcd'
+        d = ZstdDecompressor()
+        dat = d.decompress(DAT_130K_C + TRAIL, _130_1K)
+
+        self.assertEqual(len(dat), _130_1K)
+        self.assertTrue(d.eof)
+        self.assertFalse(d.needs_input)
+        self.assertEqual(d.unused_data, TRAIL)
+
+    def test_decompressor_chunks_read_300(self):
+        TRAIL = b'89234893abcd'
+        DAT = DAT_130K_C + TRAIL
+        d = ZstdDecompressor()
+
+        bi = io.BytesIO(DAT)
+        lst = []
+        while True:
+            if d.needs_input:
+                dat = bi.read(300)
+                if not dat:
+                    break
+            else:
+                raise Exception('should not get here')
+
+            ret = d.decompress(dat)
+            lst.append(ret)
+            if d.eof:
+                break
+
+        ret = b''.join(lst)
+
+        self.assertEqual(len(ret), _130_1K)
+        self.assertTrue(d.eof)
+        self.assertFalse(d.needs_input)
+        self.assertEqual(d.unused_data + bi.read(), TRAIL)
+
+    def test_decompressor_chunks_read_3(self):
+        TRAIL = b'89234893'
+        DAT = DAT_130K_C + TRAIL
+        d = ZstdDecompressor()
+
+        bi = io.BytesIO(DAT)
+        lst = []
+        while True:
+            if d.needs_input:
+                dat = bi.read(3)
+                if not dat:
+                    break
+            else:
+                dat = b''
+
+            ret = d.decompress(dat, 1)
+            lst.append(ret)
+            if d.eof:
+                break
+
+        ret = b''.join(lst)
+
+        self.assertEqual(len(ret), _130_1K)
+        self.assertTrue(d.eof)
+        self.assertFalse(d.needs_input)
+        self.assertEqual(d.unused_data + bi.read(), TRAIL)
+
+
+    def test_decompress_empty(self):
+        with self.assertRaises(ZstdError):
+            decompress(b'')
+
+        d = ZstdDecompressor()
+        self.assertEqual(d.decompress(b''), b'')
+        self.assertFalse(d.eof)
+
+    def test_decompress_empty_content_frame(self):
+        DAT = compress(b'')
+        # decompress
+        self.assertGreaterEqual(len(DAT), 4)
+        self.assertEqual(decompress(DAT), b'')
+
+        with self.assertRaises(ZstdError):
+            decompress(DAT[:-1])
+
+        # ZstdDecompressor
+        d = ZstdDecompressor()
+        dat = d.decompress(DAT)
+        self.assertEqual(dat, b'')
+        self.assertTrue(d.eof)
+        self.assertFalse(d.needs_input)
+        self.assertEqual(d.unused_data, b'')
+        self.assertEqual(d.unused_data, b'') # twice
+
+        d = ZstdDecompressor()
+        dat = d.decompress(DAT[:-1])
+        self.assertEqual(dat, b'')
+        self.assertFalse(d.eof)
+        self.assertTrue(d.needs_input)
+        self.assertEqual(d.unused_data, b'')
+        self.assertEqual(d.unused_data, b'') # twice
+
+class DecompressorFlagsTestCase(unittest.TestCase):
+
+    @classmethod
+    def setUpClass(cls):
+        options = {CompressionParameter.checksum_flag:1}
+        c = ZstdCompressor(options=options)
+
+        cls.DECOMPRESSED_42 = b'a'*42
+        cls.FRAME_42 = c.compress(cls.DECOMPRESSED_42, c.FLUSH_FRAME)
+
+        cls.DECOMPRESSED_60 = b'a'*60
+        cls.FRAME_60 = c.compress(cls.DECOMPRESSED_60, c.FLUSH_FRAME)
+
+        cls.FRAME_42_60 = cls.FRAME_42 + cls.FRAME_60
+        cls.DECOMPRESSED_42_60 = cls.DECOMPRESSED_42 + cls.DECOMPRESSED_60
+
+        cls._130_1K = 130*_1K
+
+        c = ZstdCompressor()
+        cls.UNKNOWN_FRAME_42 = c.compress(cls.DECOMPRESSED_42) + c.flush()
+        cls.UNKNOWN_FRAME_60 = c.compress(cls.DECOMPRESSED_60) + c.flush()
+        cls.UNKNOWN_FRAME_42_60 = cls.UNKNOWN_FRAME_42 + cls.UNKNOWN_FRAME_60
+
+        cls.TRAIL = b'12345678abcdefg!@#$%^&*()_+|'
+
+    def test_function_decompress(self):
+
+        self.assertEqual(len(decompress(COMPRESSED_100_PLUS_32KB)), 100+32*_1K)
+
+        # 1 frame
+        self.assertEqual(decompress(self.FRAME_42), self.DECOMPRESSED_42)
+
+        self.assertEqual(decompress(self.UNKNOWN_FRAME_42), self.DECOMPRESSED_42)
+
+        pattern = r"Compressed data ended before the end-of-stream marker"
+        with self.assertRaisesRegex(ZstdError, pattern):
+            decompress(self.FRAME_42[:1])
+
+        with self.assertRaisesRegex(ZstdError, pattern):
+            decompress(self.FRAME_42[:-4])
+
+        with self.assertRaisesRegex(ZstdError, pattern):
+            decompress(self.FRAME_42[:-1])
+
+        # 2 frames
+        self.assertEqual(decompress(self.FRAME_42_60), self.DECOMPRESSED_42_60)
+
+        self.assertEqual(decompress(self.UNKNOWN_FRAME_42_60), self.DECOMPRESSED_42_60)
+
+        self.assertEqual(decompress(self.FRAME_42 + self.UNKNOWN_FRAME_60),
+                         self.DECOMPRESSED_42_60)
+
+        self.assertEqual(decompress(self.UNKNOWN_FRAME_42 + self.FRAME_60),
+                         self.DECOMPRESSED_42_60)
+
+        with self.assertRaisesRegex(ZstdError, pattern):
+            decompress(self.FRAME_42_60[:-4])
+
+        with self.assertRaisesRegex(ZstdError, pattern):
+            decompress(self.UNKNOWN_FRAME_42_60[:-1])
+
+        # 130_1K
+        self.assertEqual(decompress(DAT_130K_C), DAT_130K_D)
+
+        with self.assertRaisesRegex(ZstdError, pattern):
+            decompress(DAT_130K_C[:-4])
+
+        with self.assertRaisesRegex(ZstdError, pattern):
+            decompress(DAT_130K_C[:-1])
+
+        # Unknown frame descriptor
+        with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"):
+            decompress(b'aaaaaaaaa')
+
+        with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"):
+            decompress(self.FRAME_42 + b'aaaaaaaaa')
+
+        with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"):
+            decompress(self.UNKNOWN_FRAME_42_60 + b'aaaaaaaaa')
+
+        # doesn't match checksum
+        checksum = DAT_130K_C[-4:]
+        if checksum[0] == 255:
+            wrong_checksum = bytes([254]) + checksum[1:]
+        else:
+            wrong_checksum = bytes([checksum[0]+1]) + checksum[1:]
+
+        dat = DAT_130K_C[:-4] + wrong_checksum
+
+        with self.assertRaisesRegex(ZstdError, "doesn't match checksum"):
+            decompress(dat)
+
+    def test_function_skippable(self):
+        self.assertEqual(decompress(SKIPPABLE_FRAME), b'')
+        self.assertEqual(decompress(SKIPPABLE_FRAME + SKIPPABLE_FRAME), b'')
+
+        # 1 frame + 2 skippable
+        self.assertEqual(len(decompress(SKIPPABLE_FRAME + SKIPPABLE_FRAME + DAT_130K_C)),
+                         self._130_1K)
+
+        self.assertEqual(len(decompress(DAT_130K_C + SKIPPABLE_FRAME + SKIPPABLE_FRAME)),
+                         self._130_1K)
+
+        self.assertEqual(len(decompress(SKIPPABLE_FRAME + DAT_130K_C + SKIPPABLE_FRAME)),
+                         self._130_1K)
+
+        # unknown size
+        self.assertEqual(decompress(SKIPPABLE_FRAME + self.UNKNOWN_FRAME_60),
+                         self.DECOMPRESSED_60)
+
+        self.assertEqual(decompress(self.UNKNOWN_FRAME_60 + SKIPPABLE_FRAME),
+                         self.DECOMPRESSED_60)
+
+        # 2 frames + 1 skippable
+        self.assertEqual(decompress(self.FRAME_42 + SKIPPABLE_FRAME + self.FRAME_60),
+                         self.DECOMPRESSED_42_60)
+
+        self.assertEqual(decompress(SKIPPABLE_FRAME + self.FRAME_42_60),
+                         self.DECOMPRESSED_42_60)
+
+        self.assertEqual(decompress(self.UNKNOWN_FRAME_42_60 + SKIPPABLE_FRAME),
+                         self.DECOMPRESSED_42_60)
+
+        # incomplete
+        with self.assertRaises(ZstdError):
+            decompress(SKIPPABLE_FRAME[:1])
+
+        with self.assertRaises(ZstdError):
+            decompress(SKIPPABLE_FRAME[:-1])
+
+        with self.assertRaises(ZstdError):
+            decompress(self.FRAME_42 + SKIPPABLE_FRAME[:-1])
+
+        # Unknown frame descriptor
+        with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"):
+            decompress(b'aaaaaaaaa' + SKIPPABLE_FRAME)
+
+        with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"):
+            decompress(SKIPPABLE_FRAME + b'aaaaaaaaa')
+
+        with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"):
+            decompress(SKIPPABLE_FRAME + SKIPPABLE_FRAME + b'aaaaaaaaa')
+
+    def test_decompressor_1(self):
+        # empty 1
+        d = ZstdDecompressor()
+
+        dat = d.decompress(b'')
+        self.assertEqual(dat, b'')
+        self.assertFalse(d.eof)
+        self.assertTrue(d.needs_input)
+        self.assertEqual(d.unused_data, b'')
+        self.assertEqual(d.unused_data, b'') # twice
+
+        dat = d.decompress(b'', 0)
+        self.assertEqual(dat, b'')
+        self.assertFalse(d.eof)
+        self.assertFalse(d.needs_input)
+        self.assertEqual(d.unused_data, b'')
+        self.assertEqual(d.unused_data, b'') # twice
+
+        dat = d.decompress(COMPRESSED_100_PLUS_32KB + b'a')
+        self.assertEqual(dat, DECOMPRESSED_100_PLUS_32KB)
+        self.assertTrue(d.eof)
+        self.assertFalse(d.needs_input)
+        self.assertEqual(d.unused_data, b'a')
+        self.assertEqual(d.unused_data, b'a') # twice
+
+        # empty 2
+        d = ZstdDecompressor()
+
+        dat = d.decompress(b'', 0)
+        self.assertEqual(dat, b'')
+        self.assertFalse(d.eof)
+        self.assertFalse(d.needs_input)
+        self.assertEqual(d.unused_data, b'')
+        self.assertEqual(d.unused_data, b'') # twice
+
+        dat = d.decompress(b'')
+        self.assertEqual(dat, b'')
+        self.assertFalse(d.eof)
+        self.assertTrue(d.needs_input)
+        self.assertEqual(d.unused_data, b'')
+        self.assertEqual(d.unused_data, b'') # twice
+
+        dat = d.decompress(COMPRESSED_100_PLUS_32KB + b'a')
+        self.assertEqual(dat, DECOMPRESSED_100_PLUS_32KB)
+        self.assertTrue(d.eof)
+        self.assertFalse(d.needs_input)
+        self.assertEqual(d.unused_data, b'a')
+        self.assertEqual(d.unused_data, b'a') # twice
+
+        # 1 frame
+        d = ZstdDecompressor()
+        dat = d.decompress(self.FRAME_42)
+
+        self.assertEqual(dat, self.DECOMPRESSED_42)
+        self.assertTrue(d.eof)
+        self.assertFalse(d.needs_input)
+        self.assertEqual(d.unused_data, b'')
+        self.assertEqual(d.unused_data, b'') # twice
+
+        with self.assertRaises(EOFError):
+            d.decompress(b'')
+
+        # 1 frame, trail
+        d = ZstdDecompressor()
+        dat = d.decompress(self.FRAME_42 + self.TRAIL)
+
+        self.assertEqual(dat, self.DECOMPRESSED_42)
+        self.assertTrue(d.eof)
+        self.assertFalse(d.needs_input)
+        self.assertEqual(d.unused_data, self.TRAIL)
+        self.assertEqual(d.unused_data, self.TRAIL) # twice
+
+        # 1 frame, 32_1K
+        temp = compress(b'a'*(32*_1K))
+        d = ZstdDecompressor()
+        dat = d.decompress(temp, 32*_1K)
+
+        self.assertEqual(dat, b'a'*(32*_1K))
+        self.assertTrue(d.eof)
+        self.assertFalse(d.needs_input)
+        self.assertEqual(d.unused_data, b'')
+        self.assertEqual(d.unused_data, b'') # twice
+
+        with self.assertRaises(EOFError):
+            d.decompress(b'')
+
+        # 1 frame, 32_1K+100, trail
+        d = ZstdDecompressor()
+        dat = d.decompress(COMPRESSED_100_PLUS_32KB+self.TRAIL, 100) # 100 bytes
+
+        self.assertEqual(len(dat), 100)
+        self.assertFalse(d.eof)
+        self.assertFalse(d.needs_input)
+        self.assertEqual(d.unused_data, b'')
+
+        dat = d.decompress(b'') # 32_1K
+
+        self.assertEqual(len(dat), 32*_1K)
+        self.assertTrue(d.eof)
+        self.assertFalse(d.needs_input)
+        self.assertEqual(d.unused_data, self.TRAIL)
+        self.assertEqual(d.unused_data, self.TRAIL) # twice
+
+        with self.assertRaises(EOFError):
+            d.decompress(b'')
+
+        # incomplete 1
+        d = ZstdDecompressor()
+        dat = d.decompress(self.FRAME_60[:1])
+
+        self.assertFalse(d.eof)
+        self.assertTrue(d.needs_input)
+        self.assertEqual(d.unused_data, b'')
+        self.assertEqual(d.unused_data, b'') # twice
+
+        # incomplete 2
+        d = ZstdDecompressor()
+
+        dat = d.decompress(self.FRAME_60[:-4])
+        self.assertEqual(dat, self.DECOMPRESSED_60)
+        self.assertFalse(d.eof)
+        self.assertTrue(d.needs_input)
+        self.assertEqual(d.unused_data, b'')
+        self.assertEqual(d.unused_data, b'') # twice
+
+        # incomplete 3
+        d = ZstdDecompressor()
+
+        dat = d.decompress(self.FRAME_60[:-1])
+        self.assertEqual(dat, self.DECOMPRESSED_60)
+        self.assertFalse(d.eof)
+        self.assertTrue(d.needs_input)
+        self.assertEqual(d.unused_data, b'')
+
+        # incomplete 4
+        d = ZstdDecompressor()
+
+        dat = d.decompress(self.FRAME_60[:-4], 60)
+        self.assertEqual(dat, self.DECOMPRESSED_60)
+        self.assertFalse(d.eof)
+        self.assertFalse(d.needs_input)
+        self.assertEqual(d.unused_data, b'')
+        self.assertEqual(d.unused_data, b'') # twice
+
+        dat = d.decompress(b'')
+        self.assertEqual(dat, b'')
+        self.assertFalse(d.eof)
+        self.assertTrue(d.needs_input)
+        self.assertEqual(d.unused_data, b'')
+        self.assertEqual(d.unused_data, b'') # twice
+
+        # Unknown frame descriptor
+        d = ZstdDecompressor()
+        with self.assertRaisesRegex(ZstdError, "Unknown frame descriptor"):
+            d.decompress(b'aaaaaaaaa')
+
+    def test_decompressor_skippable(self):
+        # 1 skippable
+        d = ZstdDecompressor()
+        dat = d.decompress(SKIPPABLE_FRAME)
+
+        self.assertEqual(dat, b'')
+        self.assertTrue(d.eof)
+        self.assertFalse(d.needs_input)
+        self.assertEqual(d.unused_data, b'')
+        self.assertEqual(d.unused_data, b'') # twice
+
+        # 1 skippable, max_length=0
+        d = ZstdDecompressor()
+        dat = d.decompress(SKIPPABLE_FRAME, 0)
+
+        self.assertEqual(dat, b'')
+        self.assertTrue(d.eof)
+        self.assertFalse(d.needs_input)
+        self.assertEqual(d.unused_data, b'')
+        self.assertEqual(d.unused_data, b'') # twice
+
+        # 1 skippable, trail
+        d = ZstdDecompressor()
+        dat = d.decompress(SKIPPABLE_FRAME + self.TRAIL)
+
+        self.assertEqual(dat, b'')
+        self.assertTrue(d.eof)
+        self.assertFalse(d.needs_input)
+        self.assertEqual(d.unused_data, self.TRAIL)
+        self.assertEqual(d.unused_data, self.TRAIL) # twice
+
+        # incomplete
+        d = ZstdDecompressor()
+        dat = d.decompress(SKIPPABLE_FRAME[:-1])
+
+        self.assertEqual(dat, b'')
+        self.assertFalse(d.eof)
+        self.assertTrue(d.needs_input)
+        self.assertEqual(d.unused_data, b'')
+        self.assertEqual(d.unused_data, b'') # twice
+
+        # incomplete
+        d = ZstdDecompressor()
+        dat = d.decompress(SKIPPABLE_FRAME[:-1], 0)
+
+        self.assertEqual(dat, b'')
+        self.assertFalse(d.eof)
+        self.assertFalse(d.needs_input)
+        self.assertEqual(d.unused_data, b'')
+        self.assertEqual(d.unused_data, b'') # twice
+
+        dat = d.decompress(b'')
+
+        self.assertEqual(dat, b'')
+        self.assertFalse(d.eof)
+        self.assertTrue(d.needs_input)
+        self.assertEqual(d.unused_data, b'')
+        self.assertEqual(d.unused_data, b'') # twice
+
+
+
+class ZstdDictTestCase(unittest.TestCase):
+
+    def test_is_raw(self):
+        # content < 8
+        b = b'1234567'
+        with self.assertRaises(ValueError):
+            ZstdDict(b)
+
+        # content == 8
+        b = b'12345678'
+        zd = ZstdDict(b, is_raw=True)
+        self.assertEqual(zd.dict_id, 0)
+
+        temp = compress(b'aaa12345678', level=3, zstd_dict=zd)
+        self.assertEqual(b'aaa12345678', decompress(temp, zd))
+
+        # is_raw == False
+        b = b'12345678abcd'
+        with self.assertRaises(ValueError):
+            ZstdDict(b)
+
+        # read only attributes
+        with self.assertRaises(AttributeError):
+            zd.dict_content = b
+
+        with self.assertRaises(AttributeError):
+            zd.dict_id = 10000
+
+        # ZstdDict arguments
+        zd = ZstdDict(TRAINED_DICT.dict_content, is_raw=False)
+        self.assertNotEqual(zd.dict_id, 0)
+
+        zd = ZstdDict(TRAINED_DICT.dict_content, is_raw=True)
+        self.assertNotEqual(zd.dict_id, 0) # note this assertion
+
+        with self.assertRaises(TypeError):
+            ZstdDict("12345678abcdef", is_raw=True)
+        with self.assertRaises(TypeError):
+            ZstdDict(TRAINED_DICT)
+
+        # invalid parameter
+        with self.assertRaises(TypeError):
+            ZstdDict(desk333=345)
+
+    def test_invalid_dict(self):
+        DICT_MAGIC = 0xEC30A437.to_bytes(4, byteorder='little')
+        dict_content = DICT_MAGIC + b'abcdefghighlmnopqrstuvwxyz'
+
+        # corrupted
+        zd = ZstdDict(dict_content, is_raw=False)
+        with self.assertRaisesRegex(ZstdError, r'ZSTD_CDict.*?corrupted'):
+            ZstdCompressor(zstd_dict=zd.as_digested_dict)
+        with self.assertRaisesRegex(ZstdError, r'ZSTD_DDict.*?corrupted'):
+            ZstdDecompressor(zd)
+
+        # wrong type
+        with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
+            ZstdCompressor(zstd_dict=(zd, b'123'))
+        with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
+            ZstdCompressor(zstd_dict=(zd, 1, 2))
+        with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
+            ZstdCompressor(zstd_dict=(zd, -1))
+        with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
+            ZstdCompressor(zstd_dict=(zd, 3))
+
+        with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
+            ZstdDecompressor(zstd_dict=(zd, b'123'))
+        with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
+            ZstdDecompressor((zd, 1, 2))
+        with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
+            ZstdDecompressor((zd, -1))
+        with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
+            ZstdDecompressor((zd, 3))
+
+    def test_train_dict(self):
+
+
+        TRAINED_DICT = train_dict(SAMPLES, DICT_SIZE1)
+        ZstdDict(TRAINED_DICT.dict_content, False)
+
+        self.assertNotEqual(TRAINED_DICT.dict_id, 0)
+        self.assertGreater(len(TRAINED_DICT.dict_content), 0)
+        self.assertLessEqual(len(TRAINED_DICT.dict_content), DICT_SIZE1)
+        self.assertTrue(re.match(r'^<ZstdDict dict_id=\d+ dict_size=\d+>$', str(TRAINED_DICT)))
+
+        # compress/decompress
+        c = ZstdCompressor(zstd_dict=TRAINED_DICT)
+        for sample in SAMPLES:
+            dat1 = compress(sample, zstd_dict=TRAINED_DICT)
+            dat2 = decompress(dat1, TRAINED_DICT)
+            self.assertEqual(sample, dat2)
+
+            dat1 = c.compress(sample)
+            dat1 += c.flush()
+            dat2 = decompress(dat1, TRAINED_DICT)
+            self.assertEqual(sample, dat2)
+
+    def test_finalize_dict(self):
+        DICT_SIZE2 = 200*_1K
+        C_LEVEL = 6
+
+        try:
+            dic2 = finalize_dict(TRAINED_DICT, SAMPLES, DICT_SIZE2, C_LEVEL)
+        except NotImplementedError:
+            # < v1.4.5 at compile-time, >= v.1.4.5 at run-time
+            return
+
+        self.assertNotEqual(dic2.dict_id, 0)
+        self.assertGreater(len(dic2.dict_content), 0)
+        self.assertLessEqual(len(dic2.dict_content), DICT_SIZE2)
+
+        # compress/decompress
+        c = ZstdCompressor(C_LEVEL, zstd_dict=dic2)
+        for sample in SAMPLES:
+            dat1 = compress(sample, C_LEVEL, zstd_dict=dic2)
+            dat2 = decompress(dat1, dic2)
+            self.assertEqual(sample, dat2)
+
+            dat1 = c.compress(sample)
+            dat1 += c.flush()
+            dat2 = decompress(dat1, dic2)
+            self.assertEqual(sample, dat2)
+
+        # dict mismatch
+        self.assertNotEqual(TRAINED_DICT.dict_id, dic2.dict_id)
+
+        dat1 = compress(SAMPLES[0], zstd_dict=TRAINED_DICT)
+        with self.assertRaises(ZstdError):
+            decompress(dat1, dic2)
+
+    def test_train_dict_arguments(self):
+        with self.assertRaises(ValueError):
+            train_dict([], 100*_1K)
+
+        with self.assertRaises(ValueError):
+            train_dict(SAMPLES, -100)
+
+        with self.assertRaises(ValueError):
+            train_dict(SAMPLES, 0)
+
+    def test_finalize_dict_arguments(self):
+        with self.assertRaises(TypeError):
+            finalize_dict({1:2}, (b'aaa', b'bbb'), 100*_1K, 2)
+
+        with self.assertRaises(ValueError):
+            finalize_dict(TRAINED_DICT, [], 100*_1K, 2)
+
+        with self.assertRaises(ValueError):
+            finalize_dict(TRAINED_DICT, SAMPLES, -100, 2)
+
+        with self.assertRaises(ValueError):
+            finalize_dict(TRAINED_DICT, SAMPLES, 0, 2)
+
+    def test_train_dict_c(self):
+        # argument wrong type
+        with self.assertRaises(TypeError):
+            _zstd._train_dict({}, (), 100)
+        with self.assertRaises(TypeError):
+            _zstd._train_dict(b'', 99, 100)
+        with self.assertRaises(TypeError):
+            _zstd._train_dict(b'', (), 100.1)
+
+        # size > size_t
+        with self.assertRaises(ValueError):
+            _zstd._train_dict(b'', (2**64+1,), 100)
+
+        # dict_size <= 0
+        with self.assertRaises(ValueError):
+            _zstd._train_dict(b'', (), 0)
+
+    def test_finalize_dict_c(self):
+        with self.assertRaises(TypeError):
+            _zstd._finalize_dict(1, 2, 3, 4, 5)
+
+        # argument wrong type
+        with self.assertRaises(TypeError):
+            _zstd._finalize_dict({}, b'', (), 100, 5)
+        with self.assertRaises(TypeError):
+            _zstd._finalize_dict(TRAINED_DICT.dict_content, {}, (), 100, 5)
+        with self.assertRaises(TypeError):
+            _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', 99, 100, 5)
+        with self.assertRaises(TypeError):
+            _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', (), 100.1, 5)
+        with self.assertRaises(TypeError):
+            _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 5.1)
+
+        # size > size_t
+        with self.assertRaises(ValueError):
+            _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', (2**64+1,), 100, 5)
+
+        # dict_size <= 0
+        with self.assertRaises(ValueError):
+            _zstd._finalize_dict(TRAINED_DICT.dict_content, b'', (), 0, 5)
+
+    def test_train_buffer_protocol_samples(self):
+        def _nbytes(dat):
+            if isinstance(dat, (bytes, bytearray)):
+                return len(dat)
+            return memoryview(dat).nbytes
+
+        # prepare samples
+        chunk_lst = []
+        wrong_size_lst = []
+        correct_size_lst = []
+        for _ in range(300):
+            arr = array.array('Q', [random.randint(0, 20) for i in range(20)])
+            chunk_lst.append(arr)
+            correct_size_lst.append(_nbytes(arr))
+            wrong_size_lst.append(len(arr))
+        concatenation = b''.join(chunk_lst)
+
+        # wrong size list
+        with self.assertRaisesRegex(ValueError,
+                "The samples size tuple doesn't match the concatenation's size"):
+            _zstd._train_dict(concatenation, tuple(wrong_size_lst), 100*_1K)
+
+        # correct size list
+        _zstd._train_dict(concatenation, tuple(correct_size_lst), 3*_1K)
+
+        # wrong size list
+        with self.assertRaisesRegex(ValueError,
+                "The samples size tuple doesn't match the concatenation's size"):
+            _zstd._finalize_dict(TRAINED_DICT.dict_content,
+                                  concatenation, tuple(wrong_size_lst), 300*_1K, 5)
+
+        # correct size list
+        _zstd._finalize_dict(TRAINED_DICT.dict_content,
+                              concatenation, tuple(correct_size_lst), 300*_1K, 5)
+
+    def test_as_prefix(self):
+        # V1
+        V1 = THIS_FILE_BYTES
+        zd = ZstdDict(V1, True)
+
+        # V2
+        mid = len(V1) // 2
+        V2 = V1[:mid] + \
+             (b'a' if V1[mid] != int.from_bytes(b'a') else b'b') + \
+             V1[mid+1:]
+
+        # compress
+        dat = compress(V2, zstd_dict=zd.as_prefix)
+        self.assertEqual(get_frame_info(dat).dictionary_id, 0)
+
+        # decompress
+        self.assertEqual(decompress(dat, zd.as_prefix), V2)
+
+        # use wrong prefix
+        zd2 = ZstdDict(SAMPLES[0], True)
+        try:
+            decompressed = decompress(dat, zd2.as_prefix)
+        except ZstdError: # expected
+            pass
+        else:
+            self.assertNotEqual(decompressed, V2)
+
+        # read only attribute
+        with self.assertRaises(AttributeError):
+            zd.as_prefix = b'1234'
+
+    def test_as_digested_dict(self):
+        zd = TRAINED_DICT
+
+        # test .as_digested_dict
+        dat = compress(SAMPLES[0], zstd_dict=zd.as_digested_dict)
+        self.assertEqual(decompress(dat, zd.as_digested_dict), SAMPLES[0])
+        with self.assertRaises(AttributeError):
+            zd.as_digested_dict = b'1234'
+
+        # test .as_undigested_dict
+        dat = compress(SAMPLES[0], zstd_dict=zd.as_undigested_dict)
+        self.assertEqual(decompress(dat, zd.as_undigested_dict), SAMPLES[0])
+        with self.assertRaises(AttributeError):
+            zd.as_undigested_dict = b'1234'
+
+    def test_advanced_compression_parameters(self):
+        options = {CompressionParameter.compression_level: 6,
+                  CompressionParameter.window_log: 20,
+                  CompressionParameter.enable_long_distance_matching: 1}
+
+        # automatically select
+        dat = compress(SAMPLES[0], options=options, zstd_dict=TRAINED_DICT)
+        self.assertEqual(decompress(dat, TRAINED_DICT), SAMPLES[0])
+
+        # explicitly select
+        dat = compress(SAMPLES[0], options=options, zstd_dict=TRAINED_DICT.as_digested_dict)
+        self.assertEqual(decompress(dat, TRAINED_DICT), SAMPLES[0])
+
+    def test_len(self):
+        self.assertEqual(len(TRAINED_DICT), len(TRAINED_DICT.dict_content))
+        self.assertIn(str(len(TRAINED_DICT)), str(TRAINED_DICT))
+
+class FileTestCase(unittest.TestCase):
+    def setUp(self):
+        self.DECOMPRESSED_42 = b'a'*42
+        self.FRAME_42 = compress(self.DECOMPRESSED_42)
+
+    def test_init(self):
+        with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+            pass
+        with ZstdFile(io.BytesIO(), "w") as f:
+            pass
+        with ZstdFile(io.BytesIO(), "x") as f:
+            pass
+        with ZstdFile(io.BytesIO(), "a") as f:
+            pass
+
+        with ZstdFile(io.BytesIO(), "w", level=12) as f:
+            pass
+        with ZstdFile(io.BytesIO(), "w", options={CompressionParameter.checksum_flag:1}) as f:
+            pass
+        with ZstdFile(io.BytesIO(), "w", options={}) as f:
+            pass
+        with ZstdFile(io.BytesIO(), "w", level=20, zstd_dict=TRAINED_DICT) as f:
+            pass
+
+        with ZstdFile(io.BytesIO(), "r", options={DecompressionParameter.window_log_max:25}) as f:
+            pass
+        with ZstdFile(io.BytesIO(), "r", options={}, zstd_dict=TRAINED_DICT) as f:
+            pass
+
+    def test_init_with_PathLike_filename(self):
+        with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+            filename = pathlib.Path(tmp_f.name)
+
+        with ZstdFile(filename, "a") as f:
+            f.write(DECOMPRESSED_100_PLUS_32KB)
+        with ZstdFile(filename) as f:
+            self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB)
+
+        with ZstdFile(filename, "a") as f:
+            f.write(DECOMPRESSED_100_PLUS_32KB)
+        with ZstdFile(filename) as f:
+            self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB * 2)
+
+        os.remove(filename)
+
+    def test_init_with_filename(self):
+        with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+            filename = pathlib.Path(tmp_f.name)
+
+        with ZstdFile(filename) as f:
+            pass
+        with ZstdFile(filename, "w") as f:
+            pass
+        with ZstdFile(filename, "a") as f:
+            pass
+
+        os.remove(filename)
+
+    def test_init_mode(self):
+        bi = io.BytesIO()
+
+        with ZstdFile(bi, "r"):
+            pass
+        with ZstdFile(bi, "rb"):
+            pass
+        with ZstdFile(bi, "w"):
+            pass
+        with ZstdFile(bi, "wb"):
+            pass
+        with ZstdFile(bi, "a"):
+            pass
+        with ZstdFile(bi, "ab"):
+            pass
+
+    def test_init_with_x_mode(self):
+        with tempfile.NamedTemporaryFile() as tmp_f:
+            filename = pathlib.Path(tmp_f.name)
+
+        for mode in ("x", "xb"):
+            with ZstdFile(filename, mode):
+                pass
+            with self.assertRaises(FileExistsError):
+                with ZstdFile(filename, mode):
+                    pass
+            os.remove(filename)
+
+    def test_init_bad_mode(self):
+        with self.assertRaises(ValueError):
+            ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), (3, "x"))
+        with self.assertRaises(ValueError):
+            ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "")
+        with self.assertRaises(ValueError):
+            ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "xt")
+        with self.assertRaises(ValueError):
+            ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "x+")
+        with self.assertRaises(ValueError):
+            ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rx")
+        with self.assertRaises(ValueError):
+            ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "wx")
+        with self.assertRaises(ValueError):
+            ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rt")
+        with self.assertRaises(ValueError):
+            ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r+")
+        with self.assertRaises(ValueError):
+            ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "wt")
+        with self.assertRaises(ValueError):
+            ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "w+")
+        with self.assertRaises(ValueError):
+            ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rw")
+
+        with self.assertRaisesRegex(TypeError, r"NOT be CompressionParameter"):
+            ZstdFile(io.BytesIO(), 'rb',
+                     options={CompressionParameter.compression_level:5})
+        with self.assertRaisesRegex(TypeError,
+                                    r"NOT be DecompressionParameter"):
+            ZstdFile(io.BytesIO(), 'wb',
+                     options={DecompressionParameter.window_log_max:21})
+
+        with self.assertRaises(TypeError):
+            ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r", level=12)
+
+    def test_init_bad_check(self):
+        with self.assertRaises(TypeError):
+            ZstdFile(io.BytesIO(), "w", level='asd')
+        # CHECK_UNKNOWN and anything above CHECK_ID_MAX should be invalid.
+        with self.assertRaises(ZstdError):
+            ZstdFile(io.BytesIO(), "w", options={999:9999})
+        with self.assertRaises(ZstdError):
+            ZstdFile(io.BytesIO(), "w", options={CompressionParameter.window_log:99})
+
+        with self.assertRaises(TypeError):
+            ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r", options=33)
+
+        with self.assertRaises(ValueError):
+            ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB),
+                             options={DecompressionParameter.window_log_max:2**31})
+
+        with self.assertRaises(ZstdError):
+            ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB),
+                             options={444:333})
+
+        with self.assertRaises(TypeError):
+            ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), zstd_dict={1:2})
+
+        with self.assertRaises(TypeError):
+            ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), zstd_dict=b'dict123456')
+
+    def test_init_close_fp(self):
+        # get a temp file name
+        with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+            tmp_f.write(DAT_130K_C)
+            filename = tmp_f.name
+
+        with self.assertRaises(ValueError):
+            ZstdFile(filename, options={'a':'b'})
+
+        # for PyPy
+        gc.collect()
+
+        os.remove(filename)
+
+    def test_close(self):
+        with io.BytesIO(COMPRESSED_100_PLUS_32KB) as src:
+            f = ZstdFile(src)
+            f.close()
+            # ZstdFile.close() should not close the underlying file object.
+            self.assertFalse(src.closed)
+            # Try closing an already-closed ZstdFile.
+            f.close()
+            self.assertFalse(src.closed)
+
+        # Test with a real file on disk, opened directly by ZstdFile.
+        with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+            filename = pathlib.Path(tmp_f.name)
+
+        f = ZstdFile(filename)
+        fp = f._fp
+        f.close()
+        # Here, ZstdFile.close() *should* close the underlying file object.
+        self.assertTrue(fp.closed)
+        # Try closing an already-closed ZstdFile.
+        f.close()
+
+        os.remove(filename)
+
+    def test_closed(self):
+        f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+        try:
+            self.assertFalse(f.closed)
+            f.read()
+            self.assertFalse(f.closed)
+        finally:
+            f.close()
+        self.assertTrue(f.closed)
+
+        f = ZstdFile(io.BytesIO(), "w")
+        try:
+            self.assertFalse(f.closed)
+        finally:
+            f.close()
+        self.assertTrue(f.closed)
+
+    def test_fileno(self):
+        # 1
+        f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+        try:
+            self.assertRaises(io.UnsupportedOperation, f.fileno)
+        finally:
+            f.close()
+        self.assertRaises(ValueError, f.fileno)
+
+        # 2
+        with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+            filename = pathlib.Path(tmp_f.name)
+
+        f = ZstdFile(filename)
+        try:
+            self.assertEqual(f.fileno(), f._fp.fileno())
+            self.assertIsInstance(f.fileno(), int)
+        finally:
+            f.close()
+        self.assertRaises(ValueError, f.fileno)
+
+        os.remove(filename)
+
+        # 3, no .fileno() method
+        class C:
+            def read(self, size=-1):
+                return b'123'
+        with ZstdFile(C(), 'rb') as f:
+            with self.assertRaisesRegex(AttributeError, r'fileno'):
+                f.fileno()
+
+    def test_name(self):
+        # 1
+        f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+        try:
+            with self.assertRaises(AttributeError):
+                f.name
+        finally:
+            f.close()
+        with self.assertRaises(ValueError):
+            f.name
+
+        # 2
+        with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+            filename = pathlib.Path(tmp_f.name)
+
+        f = ZstdFile(filename)
+        try:
+            self.assertEqual(f.name, f._fp.name)
+            self.assertIsInstance(f.name, str)
+        finally:
+            f.close()
+        with self.assertRaises(ValueError):
+            f.name
+
+        os.remove(filename)
+
+        # 3, no .filename property
+        class C:
+            def read(self, size=-1):
+                return b'123'
+        with ZstdFile(C(), 'rb') as f:
+            with self.assertRaisesRegex(AttributeError, r'name'):
+                f.name
+
+    def test_seekable(self):
+        f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+        try:
+            self.assertTrue(f.seekable())
+            f.read()
+            self.assertTrue(f.seekable())
+        finally:
+            f.close()
+        self.assertRaises(ValueError, f.seekable)
+
+        f = ZstdFile(io.BytesIO(), "w")
+        try:
+            self.assertFalse(f.seekable())
+        finally:
+            f.close()
+        self.assertRaises(ValueError, f.seekable)
+
+        src = io.BytesIO(COMPRESSED_100_PLUS_32KB)
+        src.seekable = lambda: False
+        f = ZstdFile(src)
+        try:
+            self.assertFalse(f.seekable())
+        finally:
+            f.close()
+        self.assertRaises(ValueError, f.seekable)
+
+    def test_readable(self):
+        f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+        try:
+            self.assertTrue(f.readable())
+            f.read()
+            self.assertTrue(f.readable())
+        finally:
+            f.close()
+        self.assertRaises(ValueError, f.readable)
+
+        f = ZstdFile(io.BytesIO(), "w")
+        try:
+            self.assertFalse(f.readable())
+        finally:
+            f.close()
+        self.assertRaises(ValueError, f.readable)
+
+    def test_writable(self):
+        f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+        try:
+            self.assertFalse(f.writable())
+            f.read()
+            self.assertFalse(f.writable())
+        finally:
+            f.close()
+        self.assertRaises(ValueError, f.writable)
+
+        f = ZstdFile(io.BytesIO(), "w")
+        try:
+            self.assertTrue(f.writable())
+        finally:
+            f.close()
+        self.assertRaises(ValueError, f.writable)
+
+    def test_read_0(self):
+        with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+            self.assertEqual(f.read(0), b"")
+            self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB)
+        with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB),
+                              options={DecompressionParameter.window_log_max:20}) as f:
+            self.assertEqual(f.read(0), b"")
+
+        # empty file
+        with ZstdFile(io.BytesIO(b'')) as f:
+            self.assertEqual(f.read(0), b"")
+            with self.assertRaises(EOFError):
+                f.read(10)
+
+        with ZstdFile(io.BytesIO(b'')) as f:
+            with self.assertRaises(EOFError):
+                f.read(10)
+
+    def test_read_10(self):
+        with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+            chunks = []
+            while True:
+                result = f.read(10)
+                if not result:
+                    break
+                self.assertLessEqual(len(result), 10)
+                chunks.append(result)
+            self.assertEqual(b"".join(chunks), DECOMPRESSED_100_PLUS_32KB)
+
+    def test_read_multistream(self):
+        with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 5)) as f:
+            self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB * 5)
+
+        with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB + SKIPPABLE_FRAME)) as f:
+            self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB)
+
+        with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB + COMPRESSED_DAT)) as f:
+            self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB + DECOMPRESSED_DAT)
+
+    def test_read_incomplete(self):
+        with ZstdFile(io.BytesIO(DAT_130K_C[:-200])) as f:
+            self.assertRaises(EOFError, f.read)
+
+        # Trailing data isn't a valid compressed stream
+        with ZstdFile(io.BytesIO(self.FRAME_42 + b'12345')) as f:
+            self.assertEqual(f.read(), self.DECOMPRESSED_42)
+
+        with ZstdFile(io.BytesIO(SKIPPABLE_FRAME + b'12345')) as f:
+            self.assertEqual(f.read(), b'')
+
+    def test_read_truncated(self):
+        # Drop stream epilogue: 4 bytes checksum
+        truncated = DAT_130K_C[:-4]
+        with ZstdFile(io.BytesIO(truncated)) as f:
+            self.assertRaises(EOFError, f.read)
+
+        with ZstdFile(io.BytesIO(truncated)) as f:
+            # this is an important test, make sure it doesn't raise EOFError.
+            self.assertEqual(f.read(130*_1K), DAT_130K_D)
+            with self.assertRaises(EOFError):
+                f.read(1)
+
+        # Incomplete header
+        for i in range(1, 20):
+            with ZstdFile(io.BytesIO(truncated[:i])) as f:
+                self.assertRaises(EOFError, f.read, 1)
+
+    def test_read_bad_args(self):
+        f = ZstdFile(io.BytesIO(COMPRESSED_DAT))
+        f.close()
+        self.assertRaises(ValueError, f.read)
+        with ZstdFile(io.BytesIO(), "w") as f:
+            self.assertRaises(ValueError, f.read)
+        with ZstdFile(io.BytesIO(COMPRESSED_DAT)) as f:
+            self.assertRaises(TypeError, f.read, float())
+
+    def test_read_bad_data(self):
+        with ZstdFile(io.BytesIO(COMPRESSED_BOGUS)) as f:
+            self.assertRaises(ZstdError, f.read)
+
+    def test_read_exception(self):
+        class C:
+            def read(self, size=-1):
+                raise OSError
+        with ZstdFile(C()) as f:
+            with self.assertRaises(OSError):
+                f.read(10)
+
+    def test_read1(self):
+        with ZstdFile(io.BytesIO(DAT_130K_C)) as f:
+            blocks = []
+            while True:
+                result = f.read1()
+                if not result:
+                    break
+                blocks.append(result)
+            self.assertEqual(b"".join(blocks), DAT_130K_D)
+            self.assertEqual(f.read1(), b"")
+
+    def test_read1_0(self):
+        with ZstdFile(io.BytesIO(COMPRESSED_DAT)) as f:
+            self.assertEqual(f.read1(0), b"")
+
+    def test_read1_10(self):
+        with ZstdFile(io.BytesIO(COMPRESSED_DAT)) as f:
+            blocks = []
+            while True:
+                result = f.read1(10)
+                if not result:
+                    break
+                blocks.append(result)
+            self.assertEqual(b"".join(blocks), DECOMPRESSED_DAT)
+            self.assertEqual(f.read1(), b"")
+
+    def test_read1_multistream(self):
+        with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 5)) as f:
+            blocks = []
+            while True:
+                result = f.read1()
+                if not result:
+                    break
+                blocks.append(result)
+            self.assertEqual(b"".join(blocks), DECOMPRESSED_100_PLUS_32KB * 5)
+            self.assertEqual(f.read1(), b"")
+
+    def test_read1_bad_args(self):
+        f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+        f.close()
+        self.assertRaises(ValueError, f.read1)
+        with ZstdFile(io.BytesIO(), "w") as f:
+            self.assertRaises(ValueError, f.read1)
+        with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+            self.assertRaises(TypeError, f.read1, None)
+
+    def test_readinto(self):
+        arr = array.array("I", range(100))
+        self.assertEqual(len(arr), 100)
+        self.assertEqual(len(arr) * arr.itemsize, 400)
+        ba = bytearray(300)
+        with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+            # 0 length output buffer
+            self.assertEqual(f.readinto(ba[0:0]), 0)
+
+            # use correct length for buffer protocol object
+            self.assertEqual(f.readinto(arr), 400)
+            self.assertEqual(arr.tobytes(), DECOMPRESSED_100_PLUS_32KB[:400])
+
+            # normal readinto
+            self.assertEqual(f.readinto(ba), 300)
+            self.assertEqual(ba, DECOMPRESSED_100_PLUS_32KB[400:700])
+
+    def test_peek(self):
+        with ZstdFile(io.BytesIO(DAT_130K_C)) as f:
+            result = f.peek()
+            self.assertGreater(len(result), 0)
+            self.assertTrue(DAT_130K_D.startswith(result))
+            self.assertEqual(f.read(), DAT_130K_D)
+        with ZstdFile(io.BytesIO(DAT_130K_C)) as f:
+            result = f.peek(10)
+            self.assertGreater(len(result), 0)
+            self.assertTrue(DAT_130K_D.startswith(result))
+            self.assertEqual(f.read(), DAT_130K_D)
+
+    def test_peek_bad_args(self):
+        with ZstdFile(io.BytesIO(), "w") as f:
+            self.assertRaises(ValueError, f.peek)
+
+    def test_iterator(self):
+        with io.BytesIO(THIS_FILE_BYTES) as f:
+            lines = f.readlines()
+        compressed = compress(THIS_FILE_BYTES)
+
+        # iter
+        with ZstdFile(io.BytesIO(compressed)) as f:
+            self.assertListEqual(list(iter(f)), lines)
+
+        # readline
+        with ZstdFile(io.BytesIO(compressed)) as f:
+            for line in lines:
+                self.assertEqual(f.readline(), line)
+            self.assertEqual(f.readline(), b'')
+            self.assertEqual(f.readline(), b'')
+
+        # readlines
+        with ZstdFile(io.BytesIO(compressed)) as f:
+            self.assertListEqual(f.readlines(), lines)
+
+    def test_decompress_limited(self):
+        _ZSTD_DStreamInSize = 128*_1K + 3
+
+        bomb = compress(b'\0' * int(2e6), level=10)
+        self.assertLess(len(bomb), _ZSTD_DStreamInSize)
+
+        decomp = ZstdFile(io.BytesIO(bomb))
+        self.assertEqual(decomp.read(1), b'\0')
+
+        # BufferedReader uses 128 KiB buffer in __init__.py
+        max_decomp = 128*_1K
+        self.assertLessEqual(decomp._buffer.raw.tell(), max_decomp,
+            "Excessive amount of data was decompressed")
+
+    def test_write(self):
+        raw_data = THIS_FILE_BYTES[: len(THIS_FILE_BYTES) // 6]
+        with io.BytesIO() as dst:
+            with ZstdFile(dst, "w") as f:
+                f.write(raw_data)
+
+            comp = ZstdCompressor()
+            expected = comp.compress(raw_data) + comp.flush()
+            self.assertEqual(dst.getvalue(), expected)
+
+        with io.BytesIO() as dst:
+            with ZstdFile(dst, "w", level=12) as f:
+                f.write(raw_data)
+
+            comp = ZstdCompressor(12)
+            expected = comp.compress(raw_data) + comp.flush()
+            self.assertEqual(dst.getvalue(), expected)
+
+        with io.BytesIO() as dst:
+            with ZstdFile(dst, "w", options={CompressionParameter.checksum_flag:1}) as f:
+                f.write(raw_data)
+
+            comp = ZstdCompressor(options={CompressionParameter.checksum_flag:1})
+            expected = comp.compress(raw_data) + comp.flush()
+            self.assertEqual(dst.getvalue(), expected)
+
+        with io.BytesIO() as dst:
+            options = {CompressionParameter.compression_level:-5,
+                      CompressionParameter.checksum_flag:1}
+            with ZstdFile(dst, "w",
+                          options=options) as f:
+                f.write(raw_data)
+
+            comp = ZstdCompressor(options=options)
+            expected = comp.compress(raw_data) + comp.flush()
+            self.assertEqual(dst.getvalue(), expected)
+
+    def test_write_empty_frame(self):
+        # .FLUSH_FRAME generates an empty content frame
+        c = ZstdCompressor()
+        self.assertNotEqual(c.flush(c.FLUSH_FRAME), b'')
+        self.assertNotEqual(c.flush(c.FLUSH_FRAME), b'')
+
+        # don't generate empty content frame
+        bo = io.BytesIO()
+        with ZstdFile(bo, 'w') as f:
+            pass
+        self.assertEqual(bo.getvalue(), b'')
+
+        bo = io.BytesIO()
+        with ZstdFile(bo, 'w') as f:
+            f.flush(f.FLUSH_FRAME)
+        self.assertEqual(bo.getvalue(), b'')
+
+        # if .write(b''), generate empty content frame
+        bo = io.BytesIO()
+        with ZstdFile(bo, 'w') as f:
+            f.write(b'')
+        self.assertNotEqual(bo.getvalue(), b'')
+
+        # has an empty content frame
+        bo = io.BytesIO()
+        with ZstdFile(bo, 'w') as f:
+            f.flush(f.FLUSH_BLOCK)
+        self.assertNotEqual(bo.getvalue(), b'')
+
+    def test_write_empty_block(self):
+        # If no internal data, .FLUSH_BLOCK return b''.
+        c = ZstdCompressor()
+        self.assertEqual(c.flush(c.FLUSH_BLOCK), b'')
+        self.assertNotEqual(c.compress(b'123', c.FLUSH_BLOCK),
+                            b'')
+        self.assertEqual(c.flush(c.FLUSH_BLOCK), b'')
+        self.assertEqual(c.compress(b''), b'')
+        self.assertEqual(c.compress(b''), b'')
+        self.assertEqual(c.flush(c.FLUSH_BLOCK), b'')
+
+        # mode = .last_mode
+        bo = io.BytesIO()
+        with ZstdFile(bo, 'w') as f:
+            f.write(b'123')
+            f.flush(f.FLUSH_BLOCK)
+            fp_pos = f._fp.tell()
+            self.assertNotEqual(fp_pos, 0)
+            f.flush(f.FLUSH_BLOCK)
+            self.assertEqual(f._fp.tell(), fp_pos)
+
+        # mode != .last_mode
+        bo = io.BytesIO()
+        with ZstdFile(bo, 'w') as f:
+            f.flush(f.FLUSH_BLOCK)
+            self.assertEqual(f._fp.tell(), 0)
+            f.write(b'')
+            f.flush(f.FLUSH_BLOCK)
+            self.assertEqual(f._fp.tell(), 0)
+
+    def test_write_101(self):
+        with io.BytesIO() as dst:
+            with ZstdFile(dst, "w") as f:
+                for start in range(0, len(THIS_FILE_BYTES), 101):
+                    f.write(THIS_FILE_BYTES[start:start+101])
+
+            comp = ZstdCompressor()
+            expected = comp.compress(THIS_FILE_BYTES) + comp.flush()
+            self.assertEqual(dst.getvalue(), expected)
+
+    def test_write_append(self):
+        def comp(data):
+            comp = ZstdCompressor()
+            return comp.compress(data) + comp.flush()
+
+        part1 = THIS_FILE_BYTES[:_1K]
+        part2 = THIS_FILE_BYTES[_1K:1536]
+        part3 = THIS_FILE_BYTES[1536:]
+        expected = b"".join(comp(x) for x in (part1, part2, part3))
+        with io.BytesIO() as dst:
+            with ZstdFile(dst, "w") as f:
+                f.write(part1)
+            with ZstdFile(dst, "a") as f:
+                f.write(part2)
+            with ZstdFile(dst, "a") as f:
+                f.write(part3)
+            self.assertEqual(dst.getvalue(), expected)
+
+    def test_write_bad_args(self):
+        f = ZstdFile(io.BytesIO(), "w")
+        f.close()
+        self.assertRaises(ValueError, f.write, b"foo")
+        with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB), "r") as f:
+            self.assertRaises(ValueError, f.write, b"bar")
+        with ZstdFile(io.BytesIO(), "w") as f:
+            self.assertRaises(TypeError, f.write, None)
+            self.assertRaises(TypeError, f.write, "text")
+            self.assertRaises(TypeError, f.write, 789)
+
+    def test_writelines(self):
+        def comp(data):
+            comp = ZstdCompressor()
+            return comp.compress(data) + comp.flush()
+
+        with io.BytesIO(THIS_FILE_BYTES) as f:
+            lines = f.readlines()
+        with io.BytesIO() as dst:
+            with ZstdFile(dst, "w") as f:
+                f.writelines(lines)
+            expected = comp(THIS_FILE_BYTES)
+            self.assertEqual(dst.getvalue(), expected)
+
+    def test_seek_forward(self):
+        with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+            f.seek(555)
+            self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[555:])
+
+    def test_seek_forward_across_streams(self):
+        with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 2)) as f:
+            f.seek(len(DECOMPRESSED_100_PLUS_32KB) + 123)
+            self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[123:])
+
+    def test_seek_forward_relative_to_current(self):
+        with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+            f.read(100)
+            f.seek(1236, 1)
+            self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[1336:])
+
+    def test_seek_forward_relative_to_end(self):
+        with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+            f.seek(-555, 2)
+            self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[-555:])
+
+    def test_seek_backward(self):
+        with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+            f.read(1001)
+            f.seek(211)
+            self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[211:])
+
+    def test_seek_backward_across_streams(self):
+        with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB * 2)) as f:
+            f.read(len(DECOMPRESSED_100_PLUS_32KB) + 333)
+            f.seek(737)
+            self.assertEqual(f.read(),
+              DECOMPRESSED_100_PLUS_32KB[737:] + DECOMPRESSED_100_PLUS_32KB)
+
+    def test_seek_backward_relative_to_end(self):
+        with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+            f.seek(-150, 2)
+            self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB[-150:])
+
+    def test_seek_past_end(self):
+        with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+            f.seek(len(DECOMPRESSED_100_PLUS_32KB) + 9001)
+            self.assertEqual(f.tell(), len(DECOMPRESSED_100_PLUS_32KB))
+            self.assertEqual(f.read(), b"")
+
+    def test_seek_past_start(self):
+        with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+            f.seek(-88)
+            self.assertEqual(f.tell(), 0)
+            self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB)
+
+    def test_seek_bad_args(self):
+        f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+        f.close()
+        self.assertRaises(ValueError, f.seek, 0)
+        with ZstdFile(io.BytesIO(), "w") as f:
+            self.assertRaises(ValueError, f.seek, 0)
+        with ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB)) as f:
+            self.assertRaises(ValueError, f.seek, 0, 3)
+            # io.BufferedReader raises TypeError instead of ValueError
+            self.assertRaises((TypeError, ValueError), f.seek, 9, ())
+            self.assertRaises(TypeError, f.seek, None)
+            self.assertRaises(TypeError, f.seek, b"derp")
+
+    def test_seek_not_seekable(self):
+        class C(io.BytesIO):
+            def seekable(self):
+                return False
+        obj = C(COMPRESSED_100_PLUS_32KB)
+        with ZstdFile(obj, 'r') as f:
+            d = f.read(1)
+            self.assertFalse(f.seekable())
+            with self.assertRaisesRegex(io.UnsupportedOperation,
+                                        'File or stream is not seekable'):
+                f.seek(0)
+            d += f.read()
+            self.assertEqual(d, DECOMPRESSED_100_PLUS_32KB)
+
+    def test_tell(self):
+        with ZstdFile(io.BytesIO(DAT_130K_C)) as f:
+            pos = 0
+            while True:
+                self.assertEqual(f.tell(), pos)
+                result = f.read(random.randint(171, 189))
+                if not result:
+                    break
+                pos += len(result)
+            self.assertEqual(f.tell(), len(DAT_130K_D))
+        with ZstdFile(io.BytesIO(), "w") as f:
+            for pos in range(0, len(DAT_130K_D), 143):
+                self.assertEqual(f.tell(), pos)
+                f.write(DAT_130K_D[pos:pos+143])
+            self.assertEqual(f.tell(), len(DAT_130K_D))
+
+    def test_tell_bad_args(self):
+        f = ZstdFile(io.BytesIO(COMPRESSED_100_PLUS_32KB))
+        f.close()
+        self.assertRaises(ValueError, f.tell)
+
+    def test_file_dict(self):
+        # default
+        bi = io.BytesIO()
+        with ZstdFile(bi, 'w', zstd_dict=TRAINED_DICT) as f:
+            f.write(SAMPLES[0])
+        bi.seek(0)
+        with ZstdFile(bi, zstd_dict=TRAINED_DICT) as f:
+            dat = f.read()
+        self.assertEqual(dat, SAMPLES[0])
+
+        # .as_(un)digested_dict
+        bi = io.BytesIO()
+        with ZstdFile(bi, 'w', zstd_dict=TRAINED_DICT.as_digested_dict) as f:
+            f.write(SAMPLES[0])
+        bi.seek(0)
+        with ZstdFile(bi, zstd_dict=TRAINED_DICT.as_undigested_dict) as f:
+            dat = f.read()
+        self.assertEqual(dat, SAMPLES[0])
+
+    def test_file_prefix(self):
+        bi = io.BytesIO()
+        with ZstdFile(bi, 'w', zstd_dict=TRAINED_DICT.as_prefix) as f:
+            f.write(SAMPLES[0])
+        bi.seek(0)
+        with ZstdFile(bi, zstd_dict=TRAINED_DICT.as_prefix) as f:
+            dat = f.read()
+        self.assertEqual(dat, SAMPLES[0])
+
+    def test_UnsupportedOperation(self):
+        # 1
+        with ZstdFile(io.BytesIO(), 'r') as f:
+            with self.assertRaises(io.UnsupportedOperation):
+                f.write(b'1234')
+
+        # 2
+        class T:
+            def read(self, size):
+                return b'a' * size
+
+        with self.assertRaises(TypeError): # on creation
+            with ZstdFile(T(), 'w') as f:
+                pass
+
+        # 3
+        with ZstdFile(io.BytesIO(), 'w') as f:
+            with self.assertRaises(io.UnsupportedOperation):
+                f.read(100)
+            with self.assertRaises(io.UnsupportedOperation):
+                f.seek(100)
+        self.assertEqual(f.closed, True)
+        with self.assertRaises(ValueError):
+            f.readable()
+        with self.assertRaises(ValueError):
+            f.tell()
+        with self.assertRaises(ValueError):
+            f.read(100)
+
+    def test_read_readinto_readinto1(self):
+        lst = []
+        with ZstdFile(io.BytesIO(COMPRESSED_THIS_FILE*5)) as f:
+            while True:
+                method = random.randint(0, 2)
+                size = random.randint(0, 300)
+
+                if method == 0:
+                    dat = f.read(size)
+                    if not dat and size:
+                        break
+                    lst.append(dat)
+                elif method == 1:
+                    ba = bytearray(size)
+                    read_size = f.readinto(ba)
+                    if read_size == 0 and size:
+                        break
+                    lst.append(bytes(ba[:read_size]))
+                elif method == 2:
+                    ba = bytearray(size)
+                    read_size = f.readinto1(ba)
+                    if read_size == 0 and size:
+                        break
+                    lst.append(bytes(ba[:read_size]))
+        self.assertEqual(b''.join(lst), THIS_FILE_BYTES*5)
+
+    def test_zstdfile_flush(self):
+        # closed
+        f = ZstdFile(io.BytesIO(), 'w')
+        f.close()
+        with self.assertRaises(ValueError):
+            f.flush()
+
+        # read
+        with ZstdFile(io.BytesIO(), 'r') as f:
+            # does nothing for read-only stream
+            f.flush()
+
+        # write
+        DAT = b'abcd'
+        bi = io.BytesIO()
+        with ZstdFile(bi, 'w') as f:
+            self.assertEqual(f.write(DAT), len(DAT))
+            self.assertEqual(f.tell(), len(DAT))
+            self.assertEqual(bi.tell(), 0) # not enough for a block
+
+            self.assertEqual(f.flush(), None)
+            self.assertEqual(f.tell(), len(DAT))
+            self.assertGreater(bi.tell(), 0) # flushed
+
+        # write, no .flush() method
+        class C:
+            def write(self, b):
+                return len(b)
+        with ZstdFile(C(), 'w') as f:
+            self.assertEqual(f.write(DAT), len(DAT))
+            self.assertEqual(f.tell(), len(DAT))
+
+            self.assertEqual(f.flush(), None)
+            self.assertEqual(f.tell(), len(DAT))
+
+    def test_zstdfile_flush_mode(self):
+        self.assertEqual(ZstdFile.FLUSH_BLOCK, ZstdCompressor.FLUSH_BLOCK)
+        self.assertEqual(ZstdFile.FLUSH_FRAME, ZstdCompressor.FLUSH_FRAME)
+        with self.assertRaises(AttributeError):
+            ZstdFile.CONTINUE
+
+        bo = io.BytesIO()
+        with ZstdFile(bo, 'w') as f:
+            # flush block
+            self.assertEqual(f.write(b'123'), 3)
+            self.assertIsNone(f.flush(f.FLUSH_BLOCK))
+            p1 = bo.tell()
+            # mode == .last_mode, should return
+            self.assertIsNone(f.flush())
+            p2 = bo.tell()
+            self.assertEqual(p1, p2)
+            # flush frame
+            self.assertEqual(f.write(b'456'), 3)
+            self.assertIsNone(f.flush(mode=f.FLUSH_FRAME))
+            # flush frame
+            self.assertEqual(f.write(b'789'), 3)
+            self.assertIsNone(f.flush(f.FLUSH_FRAME))
+            p1 = bo.tell()
+            # mode == .last_mode, should return
+            self.assertIsNone(f.flush(f.FLUSH_FRAME))
+            p2 = bo.tell()
+            self.assertEqual(p1, p2)
+        self.assertEqual(decompress(bo.getvalue()), b'123456789')
+
+        bo = io.BytesIO()
+        with ZstdFile(bo, 'w') as f:
+            f.write(b'123')
+            with self.assertRaisesRegex(ValueError, r'\.FLUSH_.*?\.FLUSH_'):
+                f.flush(ZstdCompressor.CONTINUE)
+            with self.assertRaises(ValueError):
+                f.flush(-1)
+            with self.assertRaises(ValueError):
+                f.flush(123456)
+            with self.assertRaises(TypeError):
+                f.flush(node=ZstdCompressor.CONTINUE)
+            with self.assertRaises((TypeError, ValueError)):
+                f.flush('FLUSH_FRAME')
+            with self.assertRaises(TypeError):
+                f.flush(b'456', f.FLUSH_BLOCK)
+
+    def test_zstdfile_truncate(self):
+        with ZstdFile(io.BytesIO(), 'w') as f:
+            with self.assertRaises(io.UnsupportedOperation):
+                f.truncate(200)
+
+    def test_zstdfile_iter_issue45475(self):
+        lines = [l for l in ZstdFile(io.BytesIO(COMPRESSED_THIS_FILE))]
+        self.assertGreater(len(lines), 0)
+
+    def test_append_new_file(self):
+        with tempfile.NamedTemporaryFile(delete=True) as tmp_f:
+            filename = tmp_f.name
+
+        with ZstdFile(filename, 'a') as f:
+            pass
+        self.assertTrue(os.path.isfile(filename))
+
+        os.remove(filename)
+
+class OpenTestCase(unittest.TestCase):
+
+    def test_binary_modes(self):
+        with open(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rb") as f:
+            self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB)
+        with io.BytesIO() as bio:
+            with open(bio, "wb") as f:
+                f.write(DECOMPRESSED_100_PLUS_32KB)
+            file_data = decompress(bio.getvalue())
+            self.assertEqual(file_data, DECOMPRESSED_100_PLUS_32KB)
+            with open(bio, "ab") as f:
+                f.write(DECOMPRESSED_100_PLUS_32KB)
+            file_data = decompress(bio.getvalue())
+            self.assertEqual(file_data, DECOMPRESSED_100_PLUS_32KB * 2)
+
+    def test_text_modes(self):
+        # empty input
+        with self.assertRaises(EOFError):
+            with open(io.BytesIO(b''), "rt", encoding="utf-8", newline='\n') as reader:
+                for _ in reader:
+                    pass
+
+        # read
+        uncompressed = THIS_FILE_STR.replace(os.linesep, "\n")
+        with open(io.BytesIO(COMPRESSED_THIS_FILE), "rt", encoding="utf-8") as f:
+            self.assertEqual(f.read(), uncompressed)
+
+        with io.BytesIO() as bio:
+            # write
+            with open(bio, "wt", encoding="utf-8") as f:
+                f.write(uncompressed)
+            file_data = decompress(bio.getvalue()).decode("utf-8")
+            self.assertEqual(file_data.replace(os.linesep, "\n"), uncompressed)
+            # append
+            with open(bio, "at", encoding="utf-8") as f:
+                f.write(uncompressed)
+            file_data = decompress(bio.getvalue()).decode("utf-8")
+            self.assertEqual(file_data.replace(os.linesep, "\n"), uncompressed * 2)
+
+    def test_bad_params(self):
+        with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+            TESTFN = pathlib.Path(tmp_f.name)
+
+        with self.assertRaises(ValueError):
+            open(TESTFN, "")
+        with self.assertRaises(ValueError):
+            open(TESTFN, "rbt")
+        with self.assertRaises(ValueError):
+            open(TESTFN, "rb", encoding="utf-8")
+        with self.assertRaises(ValueError):
+            open(TESTFN, "rb", errors="ignore")
+        with self.assertRaises(ValueError):
+            open(TESTFN, "rb", newline="\n")
+
+        os.remove(TESTFN)
+
+    def test_option(self):
+        options = {DecompressionParameter.window_log_max:25}
+        with open(io.BytesIO(COMPRESSED_100_PLUS_32KB), "rb", options=options) as f:
+            self.assertEqual(f.read(), DECOMPRESSED_100_PLUS_32KB)
+
+        options = {CompressionParameter.compression_level:12}
+        with io.BytesIO() as bio:
+            with open(bio, "wb", options=options) as f:
+                f.write(DECOMPRESSED_100_PLUS_32KB)
+            file_data = decompress(bio.getvalue())
+            self.assertEqual(file_data, DECOMPRESSED_100_PLUS_32KB)
+
+    def test_encoding(self):
+        uncompressed = THIS_FILE_STR.replace(os.linesep, "\n")
+
+        with io.BytesIO() as bio:
+            with open(bio, "wt", encoding="utf-16-le") as f:
+                f.write(uncompressed)
+            file_data = decompress(bio.getvalue()).decode("utf-16-le")
+            self.assertEqual(file_data.replace(os.linesep, "\n"), uncompressed)
+            bio.seek(0)
+            with open(bio, "rt", encoding="utf-16-le") as f:
+                self.assertEqual(f.read().replace(os.linesep, "\n"), uncompressed)
+
+    def test_encoding_error_handler(self):
+        with io.BytesIO(compress(b"foo\xffbar")) as bio:
+            with open(bio, "rt", encoding="ascii", errors="ignore") as f:
+                self.assertEqual(f.read(), "foobar")
+
+    def test_newline(self):
+        # Test with explicit newline (universal newline mode disabled).
+        text = THIS_FILE_STR.replace(os.linesep, "\n")
+        with io.BytesIO() as bio:
+            with open(bio, "wt", encoding="utf-8", newline="\n") as f:
+                f.write(text)
+            bio.seek(0)
+            with open(bio, "rt", encoding="utf-8", newline="\r") as f:
+                self.assertEqual(f.readlines(), [text])
+
+    def test_x_mode(self):
+        with tempfile.NamedTemporaryFile(delete=False) as tmp_f:
+            TESTFN = pathlib.Path(tmp_f.name)
+
+        for mode in ("x", "xb", "xt"):
+            os.remove(TESTFN)
+
+            if mode == "xt":
+                encoding = "utf-8"
+            else:
+                encoding = None
+            with open(TESTFN, mode, encoding=encoding):
+                pass
+            with self.assertRaises(FileExistsError):
+                with open(TESTFN, mode):
+                    pass
+
+        os.remove(TESTFN)
+
+    def test_open_dict(self):
+        # default
+        bi = io.BytesIO()
+        with open(bi, 'w', zstd_dict=TRAINED_DICT) as f:
+            f.write(SAMPLES[0])
+        bi.seek(0)
+        with open(bi, zstd_dict=TRAINED_DICT) as f:
+            dat = f.read()
+        self.assertEqual(dat, SAMPLES[0])
+
+        # .as_(un)digested_dict
+        bi = io.BytesIO()
+        with open(bi, 'w', zstd_dict=TRAINED_DICT.as_digested_dict) as f:
+            f.write(SAMPLES[0])
+        bi.seek(0)
+        with open(bi, zstd_dict=TRAINED_DICT.as_undigested_dict) as f:
+            dat = f.read()
+        self.assertEqual(dat, SAMPLES[0])
+
+        # invalid dictionary
+        bi = io.BytesIO()
+        with self.assertRaisesRegex(TypeError, 'zstd_dict'):
+            open(bi, 'w', zstd_dict={1:2, 2:3})
+
+        with self.assertRaisesRegex(TypeError, 'zstd_dict'):
+            open(bi, 'w', zstd_dict=b'1234567890')
+
+    def test_open_prefix(self):
+        bi = io.BytesIO()
+        with open(bi, 'w', zstd_dict=TRAINED_DICT.as_prefix) as f:
+            f.write(SAMPLES[0])
+        bi.seek(0)
+        with open(bi, zstd_dict=TRAINED_DICT.as_prefix) as f:
+            dat = f.read()
+        self.assertEqual(dat, SAMPLES[0])
+
+    def test_buffer_protocol(self):
+        # don't use len() for buffer protocol objects
+        arr = array.array("i", range(1000))
+        LENGTH = len(arr) * arr.itemsize
+
+        with open(io.BytesIO(), "wb") as f:
+            self.assertEqual(f.write(arr), LENGTH)
+            self.assertEqual(f.tell(), LENGTH)
+
+class FreeThreadingMethodTests(unittest.TestCase):
+
+    @unittest.skipUnless(Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled')
+    @threading_helper.reap_threads
+    @threading_helper.requires_working_threading()
+    def test_compress_locking(self):
+        input = b'a'* (16*_1K)
+        num_threads = 8
+
+        comp = ZstdCompressor()
+        parts = []
+        for _ in range(num_threads):
+            res = comp.compress(input, ZstdCompressor.FLUSH_BLOCK)
+            if res:
+                parts.append(res)
+        rest1 = comp.flush()
+        expected = b''.join(parts) + rest1
+
+        comp = ZstdCompressor()
+        output = []
+        def run_method(method, input_data, output_data):
+            res = method(input_data, ZstdCompressor.FLUSH_BLOCK)
+            if res:
+                output_data.append(res)
+        threads = []
+
+        for i in range(num_threads):
+            thread = threading.Thread(target=run_method, args=(comp.compress, input, output))
+
+            threads.append(thread)
+
+        with threading_helper.start_threads(threads):
+            pass
+
+        rest2 = comp.flush()
+        self.assertEqual(rest1, rest2)
+        actual = b''.join(output) + rest2
+        self.assertEqual(expected, actual)
+
+    @unittest.skipUnless(Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled')
+    @threading_helper.reap_threads
+    @threading_helper.requires_working_threading()
+    def test_decompress_locking(self):
+        input = compress(b'a'* (16*_1K))
+        num_threads = 8
+        # to ensure we decompress over multiple calls, set maxsize
+        window_size = _1K * 16//num_threads
+
+        decomp = ZstdDecompressor()
+        parts = []
+        for _ in range(num_threads):
+            res = decomp.decompress(input, window_size)
+            if res:
+                parts.append(res)
+        expected = b''.join(parts)
+
+        comp = ZstdDecompressor()
+        output = []
+        def run_method(method, input_data, output_data):
+            res = method(input_data, window_size)
+            if res:
+                output_data.append(res)
+        threads = []
+
+        for i in range(num_threads):
+            thread = threading.Thread(target=run_method, args=(comp.decompress, input, output))
+
+            threads.append(thread)
+
+        with threading_helper.start_threads(threads):
+            pass
+
+        actual = b''.join(output)
+        self.assertEqual(expected, actual)
+
+
+
+if __name__ == "__main__":
+    unittest.main()
index cfb44f3ed970eeae1358231c3b353a84ee26d9c1..88356abe8cbaebbb64ffe27ed22cd39f097edb23 100644 (file)
@@ -31,6 +31,11 @@ try:
 except ImportError:
     lzma = None
 
+try:
+    from compression import zstd # We may need its compression method
+except ImportError:
+    zstd = None
+
 __all__ = ["BadZipFile", "BadZipfile", "error",
            "ZIP_STORED", "ZIP_DEFLATED", "ZIP_BZIP2", "ZIP_LZMA",
            "is_zipfile", "ZipInfo", "ZipFile", "PyZipFile", "LargeZipFile",
@@ -58,12 +63,14 @@ ZIP_STORED = 0
 ZIP_DEFLATED = 8
 ZIP_BZIP2 = 12
 ZIP_LZMA = 14
+ZIP_ZSTANDARD = 93
 # Other ZIP compression methods not supported
 
 DEFAULT_VERSION = 20
 ZIP64_VERSION = 45
 BZIP2_VERSION = 46
 LZMA_VERSION = 63
+ZSTANDARD_VERSION = 63
 # we recognize (but not necessarily support) all features up to that version
 MAX_EXTRACT_VERSION = 63
 
@@ -505,6 +512,8 @@ class ZipInfo:
             min_version = max(BZIP2_VERSION, min_version)
         elif self.compress_type == ZIP_LZMA:
             min_version = max(LZMA_VERSION, min_version)
+        elif self.compress_type == ZIP_ZSTANDARD:
+            min_version = max(ZSTANDARD_VERSION, min_version)
 
         self.extract_version = max(min_version, self.extract_version)
         self.create_version = max(min_version, self.create_version)
@@ -766,6 +775,7 @@ compressor_names = {
     14: 'lzma',
     18: 'terse',
     19: 'lz77',
+    93: 'zstd',
     97: 'wavpack',
     98: 'ppmd',
 }
@@ -785,6 +795,10 @@ def _check_compression(compression):
         if not lzma:
             raise RuntimeError(
                 "Compression requires the (missing) lzma module")
+    elif compression == ZIP_ZSTANDARD:
+        if not zstd:
+            raise RuntimeError(
+                "Compression requires the (missing) compression.zstd module")
     else:
         raise NotImplementedError("That compression method is not supported")
 
@@ -798,9 +812,11 @@ def _get_compressor(compress_type, compresslevel=None):
         if compresslevel is not None:
             return bz2.BZ2Compressor(compresslevel)
         return bz2.BZ2Compressor()
-    # compresslevel is ignored for ZIP_LZMA
+    # compresslevel is ignored for ZIP_LZMA and ZIP_ZSTANDARD
     elif compress_type == ZIP_LZMA:
         return LZMACompressor()
+    elif compress_type == ZIP_ZSTANDARD:
+        return zstd.ZstdCompressor()
     else:
         return None
 
@@ -815,6 +831,8 @@ def _get_decompressor(compress_type):
         return bz2.BZ2Decompressor()
     elif compress_type == ZIP_LZMA:
         return LZMADecompressor()
+    elif compress_type == ZIP_ZSTANDARD:
+        return zstd.ZstdDecompressor()
     else:
         descr = compressor_names.get(compress_type)
         if descr:
index 3fda59cdcec71b77d23b42fbd28e7791f9a5b803..17e0c9904cc3aa2d000b588ef8ef66b428b6e6bb 100644 (file)
@@ -2507,7 +2507,7 @@ maninstall:       altmaninstall
 XMLLIBSUBDIRS=  xml xml/dom xml/etree xml/parsers xml/sax
 LIBSUBDIRS=    asyncio \
                collections \
-               compression compression/bz2 compression/gzip \
+               compression compression/bz2 compression/gzip compression/zstd \
                compression/lzma compression/zlib compression/_common \
                concurrent concurrent/futures \
                csv \
index 18dc13b3fd16f05e5093076003cb6c37ab43d880..4d046859a1540efde31758c6b0961179887ea5b4 100644 (file)
@@ -74,33 +74,33 @@ typedef struct {
 
 static const ParameterInfo cp_list[] =
 {
-    {ZSTD_c_compressionLevel, "compressionLevel"},
-    {ZSTD_c_windowLog,        "windowLog"},
-    {ZSTD_c_hashLog,          "hashLog"},
-    {ZSTD_c_chainLog,         "chainLog"},
-    {ZSTD_c_searchLog,        "searchLog"},
-    {ZSTD_c_minMatch,         "minMatch"},
-    {ZSTD_c_targetLength,     "targetLength"},
+    {ZSTD_c_compressionLevel, "compression_level"},
+    {ZSTD_c_windowLog,        "window_log"},
+    {ZSTD_c_hashLog,          "hash_log"},
+    {ZSTD_c_chainLog,         "chain_log"},
+    {ZSTD_c_searchLog,        "search_log"},
+    {ZSTD_c_minMatch,         "min_match"},
+    {ZSTD_c_targetLength,     "target_length"},
     {ZSTD_c_strategy,         "strategy"},
 
-    {ZSTD_c_enableLongDistanceMatching, "enableLongDistanceMatching"},
-    {ZSTD_c_ldmHashLog,       "ldmHashLog"},
-    {ZSTD_c_ldmMinMatch,      "ldmMinMatch"},
-    {ZSTD_c_ldmBucketSizeLog, "ldmBucketSizeLog"},
-    {ZSTD_c_ldmHashRateLog,   "ldmHashRateLog"},
+    {ZSTD_c_enableLongDistanceMatching, "enable_long_distance_matching"},
+    {ZSTD_c_ldmHashLog,       "ldm_hash_log"},
+    {ZSTD_c_ldmMinMatch,      "ldm_min_match"},
+    {ZSTD_c_ldmBucketSizeLog, "ldm_bucket_size_log"},
+    {ZSTD_c_ldmHashRateLog,   "ldm_hash_rate_log"},
 
-    {ZSTD_c_contentSizeFlag,  "contentSizeFlag"},
-    {ZSTD_c_checksumFlag,     "checksumFlag"},
-    {ZSTD_c_dictIDFlag,       "dictIDFlag"},
+    {ZSTD_c_contentSizeFlag,  "content_size_flag"},
+    {ZSTD_c_checksumFlag,     "checksum_flag"},
+    {ZSTD_c_dictIDFlag,       "dict_id_flag"},
 
-    {ZSTD_c_nbWorkers,        "nbWorkers"},
-    {ZSTD_c_jobSize,          "jobSize"},
-    {ZSTD_c_overlapLog,       "overlapLog"}
+    {ZSTD_c_nbWorkers,        "nb_workers"},
+    {ZSTD_c_jobSize,          "job_size"},
+    {ZSTD_c_overlapLog,       "overlap_log"}
 };
 
 static const ParameterInfo dp_list[] =
 {
-    {ZSTD_d_windowLogMax, "windowLogMax"}
+    {ZSTD_d_windowLogMax, "window_log_max"}
 };
 
 void
@@ -180,8 +180,8 @@ _zstd._train_dict
 
     samples_bytes: PyBytesObject
         Concatenation of samples.
-    samples_size_list: object(subclass_of='&PyList_Type')
-        List of samples' sizes.
+    samples_sizes: object(subclass_of='&PyTuple_Type')
+        Tuple of samples' sizes.
     dict_size: Py_ssize_t
         The size of the dictionary.
     /
@@ -191,8 +191,8 @@ Internal function, train a zstd dictionary on sample data.
 
 static PyObject *
 _zstd__train_dict_impl(PyObject *module, PyBytesObject *samples_bytes,
-                       PyObject *samples_size_list, Py_ssize_t dict_size)
-/*[clinic end generated code: output=ee53c34c8f77886b input=b21d092c695a3a81]*/
+                       PyObject *samples_sizes, Py_ssize_t dict_size)
+/*[clinic end generated code: output=b5b4f36347c0addd input=2dce5b57d63923e2]*/
 {
     // TODO(emmatyping): The preamble and suffix to this function and _finalize_dict
     // are pretty similar. We should see if we can refactor them to share that code.
@@ -209,7 +209,7 @@ _zstd__train_dict_impl(PyObject *module, PyBytesObject *samples_bytes,
         return NULL;
     }
 
-    chunks_number = Py_SIZE(samples_size_list);
+    chunks_number = Py_SIZE(samples_sizes);
     if ((size_t) chunks_number > UINT32_MAX) {
         PyErr_Format(PyExc_ValueError,
                         "The number of samples should be <= %u.", UINT32_MAX);
@@ -225,12 +225,11 @@ _zstd__train_dict_impl(PyObject *module, PyBytesObject *samples_bytes,
 
     sizes_sum = 0;
     for (i = 0; i < chunks_number; i++) {
-        PyObject *size = PyList_GetItemRef(samples_size_list, i);
+        PyObject *size = PyTuple_GetItem(samples_sizes, i);
         chunk_sizes[i] = PyLong_AsSize_t(size);
-        Py_DECREF(size);
         if (chunk_sizes[i] == (size_t)-1 && PyErr_Occurred()) {
             PyErr_Format(PyExc_ValueError,
-                            "Items in samples_size_list should be an int "
+                            "Items in samples_sizes should be an int "
                             "object, with a value between 0 and %u.", SIZE_MAX);
             goto error;
         }
@@ -239,7 +238,7 @@ _zstd__train_dict_impl(PyObject *module, PyBytesObject *samples_bytes,
 
     if (sizes_sum != Py_SIZE(samples_bytes)) {
         PyErr_SetString(PyExc_ValueError,
-                        "The samples size list doesn't match the concatenation's size.");
+                        "The samples size tuple doesn't match the concatenation's size.");
         goto error;
     }
 
@@ -287,8 +286,8 @@ _zstd._finalize_dict
         Custom dictionary content.
     samples_bytes: PyBytesObject
         Concatenation of samples.
-    samples_size_list: object(subclass_of='&PyList_Type')
-        List of samples' sizes.
+    samples_sizes: object(subclass_of='&PyTuple_Type')
+        Tuple of samples' sizes.
     dict_size: Py_ssize_t
         The size of the dictionary.
     compression_level: int
@@ -301,9 +300,9 @@ Internal function, finalize a zstd dictionary.
 static PyObject *
 _zstd__finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes,
                           PyBytesObject *samples_bytes,
-                          PyObject *samples_size_list, Py_ssize_t dict_size,
+                          PyObject *samples_sizes, Py_ssize_t dict_size,
                           int compression_level)
-/*[clinic end generated code: output=9c2a7d8c845cee93 input=08531a803d87c56f]*/
+/*[clinic end generated code: output=5dc5b520fddba37f input=8afd42a249078460]*/
 {
     Py_ssize_t chunks_number;
     size_t *chunk_sizes = NULL;
@@ -319,7 +318,7 @@ _zstd__finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes,
         return NULL;
     }
 
-    chunks_number = Py_SIZE(samples_size_list);
+    chunks_number = Py_SIZE(samples_sizes);
     if ((size_t) chunks_number > UINT32_MAX) {
         PyErr_Format(PyExc_ValueError,
                         "The number of samples should be <= %u.", UINT32_MAX);
@@ -335,11 +334,11 @@ _zstd__finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes,
 
     sizes_sum = 0;
     for (i = 0; i < chunks_number; i++) {
-        PyObject *size = PyList_GET_ITEM(samples_size_list, i);
+        PyObject *size = PyTuple_GetItem(samples_sizes, i);
         chunk_sizes[i] = PyLong_AsSize_t(size);
         if (chunk_sizes[i] == (size_t)-1 && PyErr_Occurred()) {
             PyErr_Format(PyExc_ValueError,
-                            "Items in samples_size_list should be an int "
+                            "Items in samples_sizes should be an int "
                             "object, with a value between 0 and %u.", SIZE_MAX);
             goto error;
         }
@@ -348,7 +347,7 @@ _zstd__finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes,
 
     if (sizes_sum != Py_SIZE(samples_bytes)) {
         PyErr_SetString(PyExc_ValueError,
-                        "The samples size list doesn't match the concatenation's size.");
+                        "The samples size tuple doesn't match the concatenation's size.");
         goto error;
     }
 
@@ -402,18 +401,18 @@ success:
 /*[clinic input]
 _zstd._get_param_bounds
 
-    is_compress: bool
-        True for CParameter, False for DParameter.
     parameter: int
         The parameter to get bounds.
+    is_compress: bool
+        True for CompressionParameter, False for DecompressionParameter.
 
-Internal function, get CParameter/DParameter bounds.
+Internal function, get CompressionParameter/DecompressionParameter bounds.
 [clinic start generated code]*/
 
 static PyObject *
-_zstd__get_param_bounds_impl(PyObject *module, int is_compress,
-                             int parameter)
-/*[clinic end generated code: output=b751dc710f89ef55 input=fb21ff96aff65df1]*/
+_zstd__get_param_bounds_impl(PyObject *module, int parameter,
+                             int is_compress)
+/*[clinic end generated code: output=9892cd822f937e79 input=884cd1a01125267d]*/
 {
     ZSTD_bounds bound;
     if (is_compress) {
@@ -515,30 +514,30 @@ _zstd__get_frame_info_impl(PyObject *module, Py_buffer *frame_buffer)
 _zstd._set_parameter_types
 
     c_parameter_type: object(subclass_of='&PyType_Type')
-        CParameter IntEnum type object
+        CompressionParameter IntEnum type object
     d_parameter_type: object(subclass_of='&PyType_Type')
-        DParameter IntEnum type object
+        DecompressionParameter IntEnum type object
 
-Internal function, set CParameter/DParameter types for validity check.
+Internal function, set CompressionParameter/DecompressionParameter types for validity check.
 [clinic start generated code]*/
 
 static PyObject *
 _zstd__set_parameter_types_impl(PyObject *module, PyObject *c_parameter_type,
                                 PyObject *d_parameter_type)
-/*[clinic end generated code: output=a13d4890ccbd2873 input=3e7d0d37c3a1045a]*/
+/*[clinic end generated code: output=a13d4890ccbd2873 input=4535545d903853d3]*/
 {
     _zstd_state* const mod_state = get_zstd_state(module);
 
     if (!PyType_Check(c_parameter_type) || !PyType_Check(d_parameter_type)) {
         PyErr_SetString(PyExc_ValueError,
-                        "The two arguments should be CParameter and "
-                        "DParameter types.");
+                        "The two arguments should be CompressionParameter and "
+                        "DecompressionParameter types.");
         return NULL;
     }
 
     Py_XDECREF(mod_state->CParameter_type);
     Py_INCREF(c_parameter_type);
-    mod_state->CParameter_type = (PyTypeObject*) c_parameter_type;
+    mod_state->CParameter_type = (PyTypeObject*)c_parameter_type;
 
     Py_XDECREF(mod_state->DParameter_type);
     Py_INCREF(d_parameter_type);
index 4b78bded67bca7ffa278d5ad0e0c6ed95aef8297..2f8225389b7aea3b16622a7b978647b55c9eabd8 100644 (file)
@@ -10,15 +10,15 @@ preserve
 #include "pycore_modsupport.h"    // _PyArg_CheckPositional()
 
 PyDoc_STRVAR(_zstd__train_dict__doc__,
-"_train_dict($module, samples_bytes, samples_size_list, dict_size, /)\n"
+"_train_dict($module, samples_bytes, samples_sizes, dict_size, /)\n"
 "--\n"
 "\n"
 "Internal function, train a zstd dictionary on sample data.\n"
 "\n"
 "  samples_bytes\n"
 "    Concatenation of samples.\n"
-"  samples_size_list\n"
-"    List of samples\' sizes.\n"
+"  samples_sizes\n"
+"    Tuple of samples\' sizes.\n"
 "  dict_size\n"
 "    The size of the dictionary.");
 
@@ -27,14 +27,14 @@ PyDoc_STRVAR(_zstd__train_dict__doc__,
 
 static PyObject *
 _zstd__train_dict_impl(PyObject *module, PyBytesObject *samples_bytes,
-                       PyObject *samples_size_list, Py_ssize_t dict_size);
+                       PyObject *samples_sizes, Py_ssize_t dict_size);
 
 static PyObject *
 _zstd__train_dict(PyObject *module, PyObject *const *args, Py_ssize_t nargs)
 {
     PyObject *return_value = NULL;
     PyBytesObject *samples_bytes;
-    PyObject *samples_size_list;
+    PyObject *samples_sizes;
     Py_ssize_t dict_size;
 
     if (!_PyArg_CheckPositional("_train_dict", nargs, 3, 3)) {
@@ -45,11 +45,11 @@ _zstd__train_dict(PyObject *module, PyObject *const *args, Py_ssize_t nargs)
         goto exit;
     }
     samples_bytes = (PyBytesObject *)args[0];
-    if (!PyList_Check(args[1])) {
-        _PyArg_BadArgument("_train_dict", "argument 2", "list", args[1]);
+    if (!PyTuple_Check(args[1])) {
+        _PyArg_BadArgument("_train_dict", "argument 2", "tuple", args[1]);
         goto exit;
     }
-    samples_size_list = args[1];
+    samples_sizes = args[1];
     {
         Py_ssize_t ival = -1;
         PyObject *iobj = _PyNumber_Index(args[2]);
@@ -62,7 +62,7 @@ _zstd__train_dict(PyObject *module, PyObject *const *args, Py_ssize_t nargs)
         }
         dict_size = ival;
     }
-    return_value = _zstd__train_dict_impl(module, samples_bytes, samples_size_list, dict_size);
+    return_value = _zstd__train_dict_impl(module, samples_bytes, samples_sizes, dict_size);
 
 exit:
     return return_value;
@@ -70,7 +70,7 @@ exit:
 
 PyDoc_STRVAR(_zstd__finalize_dict__doc__,
 "_finalize_dict($module, custom_dict_bytes, samples_bytes,\n"
-"               samples_size_list, dict_size, compression_level, /)\n"
+"               samples_sizes, dict_size, compression_level, /)\n"
 "--\n"
 "\n"
 "Internal function, finalize a zstd dictionary.\n"
@@ -79,8 +79,8 @@ PyDoc_STRVAR(_zstd__finalize_dict__doc__,
 "    Custom dictionary content.\n"
 "  samples_bytes\n"
 "    Concatenation of samples.\n"
-"  samples_size_list\n"
-"    List of samples\' sizes.\n"
+"  samples_sizes\n"
+"    Tuple of samples\' sizes.\n"
 "  dict_size\n"
 "    The size of the dictionary.\n"
 "  compression_level\n"
@@ -92,7 +92,7 @@ PyDoc_STRVAR(_zstd__finalize_dict__doc__,
 static PyObject *
 _zstd__finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes,
                           PyBytesObject *samples_bytes,
-                          PyObject *samples_size_list, Py_ssize_t dict_size,
+                          PyObject *samples_sizes, Py_ssize_t dict_size,
                           int compression_level);
 
 static PyObject *
@@ -101,7 +101,7 @@ _zstd__finalize_dict(PyObject *module, PyObject *const *args, Py_ssize_t nargs)
     PyObject *return_value = NULL;
     PyBytesObject *custom_dict_bytes;
     PyBytesObject *samples_bytes;
-    PyObject *samples_size_list;
+    PyObject *samples_sizes;
     Py_ssize_t dict_size;
     int compression_level;
 
@@ -118,11 +118,11 @@ _zstd__finalize_dict(PyObject *module, PyObject *const *args, Py_ssize_t nargs)
         goto exit;
     }
     samples_bytes = (PyBytesObject *)args[1];
-    if (!PyList_Check(args[2])) {
-        _PyArg_BadArgument("_finalize_dict", "argument 3", "list", args[2]);
+    if (!PyTuple_Check(args[2])) {
+        _PyArg_BadArgument("_finalize_dict", "argument 3", "tuple", args[2]);
         goto exit;
     }
-    samples_size_list = args[2];
+    samples_sizes = args[2];
     {
         Py_ssize_t ival = -1;
         PyObject *iobj = _PyNumber_Index(args[3]);
@@ -139,29 +139,29 @@ _zstd__finalize_dict(PyObject *module, PyObject *const *args, Py_ssize_t nargs)
     if (compression_level == -1 && PyErr_Occurred()) {
         goto exit;
     }
-    return_value = _zstd__finalize_dict_impl(module, custom_dict_bytes, samples_bytes, samples_size_list, dict_size, compression_level);
+    return_value = _zstd__finalize_dict_impl(module, custom_dict_bytes, samples_bytes, samples_sizes, dict_size, compression_level);
 
 exit:
     return return_value;
 }
 
 PyDoc_STRVAR(_zstd__get_param_bounds__doc__,
-"_get_param_bounds($module, /, is_compress, parameter)\n"
+"_get_param_bounds($module, /, parameter, is_compress)\n"
 "--\n"
 "\n"
-"Internal function, get CParameter/DParameter bounds.\n"
+"Internal function, get CompressionParameter/DecompressionParameter bounds.\n"
 "\n"
-"  is_compress\n"
-"    True for CParameter, False for DParameter.\n"
 "  parameter\n"
-"    The parameter to get bounds.");
+"    The parameter to get bounds.\n"
+"  is_compress\n"
+"    True for CompressionParameter, False for DecompressionParameter.");
 
 #define _ZSTD__GET_PARAM_BOUNDS_METHODDEF    \
     {"_get_param_bounds", _PyCFunction_CAST(_zstd__get_param_bounds), METH_FASTCALL|METH_KEYWORDS, _zstd__get_param_bounds__doc__},
 
 static PyObject *
-_zstd__get_param_bounds_impl(PyObject *module, int is_compress,
-                             int parameter);
+_zstd__get_param_bounds_impl(PyObject *module, int parameter,
+                             int is_compress);
 
 static PyObject *
 _zstd__get_param_bounds(PyObject *module, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames)
@@ -178,7 +178,7 @@ _zstd__get_param_bounds(PyObject *module, PyObject *const *args, Py_ssize_t narg
     } _kwtuple = {
         .ob_base = PyVarObject_HEAD_INIT(&PyTuple_Type, NUM_KEYWORDS)
         .ob_hash = -1,
-        .ob_item = { &_Py_ID(is_compress), &_Py_ID(parameter), },
+        .ob_item = { &_Py_ID(parameter), &_Py_ID(is_compress), },
     };
     #undef NUM_KEYWORDS
     #define KWTUPLE (&_kwtuple.ob_base.ob_base)
@@ -187,7 +187,7 @@ _zstd__get_param_bounds(PyObject *module, PyObject *const *args, Py_ssize_t narg
     #  define KWTUPLE NULL
     #endif  // !Py_BUILD_CORE
 
-    static const char * const _keywords[] = {"is_compress", "parameter", NULL};
+    static const char * const _keywords[] = {"parameter", "is_compress", NULL};
     static _PyArg_Parser _parser = {
         .keywords = _keywords,
         .fname = "_get_param_bounds",
@@ -195,23 +195,23 @@ _zstd__get_param_bounds(PyObject *module, PyObject *const *args, Py_ssize_t narg
     };
     #undef KWTUPLE
     PyObject *argsbuf[2];
-    int is_compress;
     int parameter;
+    int is_compress;
 
     args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser,
             /*minpos*/ 2, /*maxpos*/ 2, /*minkw*/ 0, /*varpos*/ 0, argsbuf);
     if (!args) {
         goto exit;
     }
-    is_compress = PyObject_IsTrue(args[0]);
-    if (is_compress < 0) {
+    parameter = PyLong_AsInt(args[0]);
+    if (parameter == -1 && PyErr_Occurred()) {
         goto exit;
     }
-    parameter = PyLong_AsInt(args[1]);
-    if (parameter == -1 && PyErr_Occurred()) {
+    is_compress = PyObject_IsTrue(args[1]);
+    if (is_compress < 0) {
         goto exit;
     }
-    return_value = _zstd__get_param_bounds_impl(module, is_compress, parameter);
+    return_value = _zstd__get_param_bounds_impl(module, parameter, is_compress);
 
 exit:
     return return_value;
@@ -360,12 +360,12 @@ PyDoc_STRVAR(_zstd__set_parameter_types__doc__,
 "_set_parameter_types($module, /, c_parameter_type, d_parameter_type)\n"
 "--\n"
 "\n"
-"Internal function, set CParameter/DParameter types for validity check.\n"
+"Internal function, set CompressionParameter/DecompressionParameter types for validity check.\n"
 "\n"
 "  c_parameter_type\n"
-"    CParameter IntEnum type object\n"
+"    CompressionParameter IntEnum type object\n"
 "  d_parameter_type\n"
-"    DParameter IntEnum type object");
+"    DecompressionParameter IntEnum type object");
 
 #define _ZSTD__SET_PARAMETER_TYPES_METHODDEF    \
     {"_set_parameter_types", _PyCFunction_CAST(_zstd__set_parameter_types), METH_FASTCALL|METH_KEYWORDS, _zstd__set_parameter_types__doc__},
@@ -429,4 +429,4 @@ _zstd__set_parameter_types(PyObject *module, PyObject *const *args, Py_ssize_t n
 exit:
     return return_value;
 }
-/*[clinic end generated code: output=077c8ea2b11fb188 input=a9049054013a1b77]*/
+/*[clinic end generated code: output=189c462236a7096c input=a9049054013a1b77]*/
index d0f677be821572ccb780d24ea4b99826215a5e3d..b735981e7476d5521a57aa47f7510ad96b3a0016 100644 (file)
@@ -71,14 +71,14 @@ _PyZstd_set_c_parameters(ZstdCompressor *self, PyObject *level_or_options,
             if (Py_TYPE(key) == mod_state->DParameter_type) {
                 PyErr_SetString(PyExc_TypeError,
                                 "Key of compression option dict should "
-                                "NOT be DParameter.");
+                                "NOT be DecompressionParameter.");
                 return -1;
             }
 
             int key_v = PyLong_AsInt(key);
             if (key_v == -1 && PyErr_Occurred()) {
                 PyErr_SetString(PyExc_ValueError,
-                                "Key of options dict should be a CParameter attribute.");
+                                "Key of options dict should be a CompressionParameter attribute.");
                 return -1;
             }
 
index 4e3a28068be13065c7097ada50108dd0ef6699db..a4be180c0088fce26ff4297ea7c3b74be26bebf1 100644 (file)
@@ -84,7 +84,7 @@ _PyZstd_set_d_parameters(ZstdDecompressor *self, PyObject *options)
         if (Py_TYPE(key) == mod_state->CParameter_type) {
             PyErr_SetString(PyExc_TypeError,
                             "Key of decompression options dict should "
-                            "NOT be CParameter.");
+                            "NOT be CompressionParameter.");
             return -1;
         }
 
@@ -92,7 +92,7 @@ _PyZstd_set_d_parameters(ZstdDecompressor *self, PyObject *options)
         int key_v = PyLong_AsInt(key);
         if (key_v == -1 && PyErr_Occurred()) {
             PyErr_SetString(PyExc_ValueError,
-                            "Key of options dict should be a DParameter attribute.");
+                            "Key of options dict should be a DecompressionParameter attribute.");
             return -1;
         }