]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
bpo-44439: BZ2File.write() / LZMAFile.write() handle buffer protocol correctly (GH...
authorMa Lin <animalize@users.noreply.github.com>
Tue, 22 Jun 2021 07:04:23 +0000 (15:04 +0800)
committerGitHub <noreply@github.com>
Tue, 22 Jun 2021 07:04:23 +0000 (10:04 +0300)
No longer use len() to get the length of the input data. For some buffer protocol objects,
the length obtained by using len() is wrong.

Lib/bz2.py
Lib/gzip.py
Lib/lzma.py
Lib/test/test_bz2.py
Lib/test/test_gzip.py
Lib/test/test_lzma.py
Misc/NEWS.d/next/Library/2021-06-17-15-01-51.bpo-44439.1S7QhT.rst [new file with mode: 0644]

index a2c588e7487f3d98d442c7d7c49ae5dc312f55c8..7f1d20632ef139887bfb497f0261dea83d068978 100644 (file)
@@ -219,14 +219,22 @@ class BZ2File(_compression.BaseStream):
         """Write a byte string to the file.
 
         Returns the number of uncompressed bytes written, which is
-        always len(data). Note that due to buffering, the file on disk
-        may not reflect the data written until close() is called.
+        always the length of data in bytes. Note that due to buffering,
+        the file on disk may not reflect the data written until close()
+        is called.
         """
         self._check_can_write()
+        if isinstance(data, (bytes, bytearray)):
+            length = len(data)
+        else:
+            # accept any data that supports the buffer protocol
+            data = memoryview(data)
+            length = data.nbytes
+
         compressed = self._compressor.compress(data)
         self._fp.write(compressed)
-        self._pos += len(data)
-        return len(data)
+        self._pos += length
+        return length
 
     def writelines(self, seq):
         """Write a sequence of byte strings to the file.
index 1c1e795e1715dba4dcc4716a3e7d24e47bfb32ae..3d837b744800eda3747e26cc8e2c9aa3530e5aaf 100644 (file)
@@ -278,7 +278,7 @@ class GzipFile(_compression.BaseStream):
         if self.fileobj is None:
             raise ValueError("write() on closed GzipFile object")
 
-        if isinstance(data, bytes):
+        if isinstance(data, (bytes, bytearray)):
             length = len(data)
         else:
             # accept any data that supports the buffer protocol
index 2ada7d81d3c813b85a6815285bcb23db6f8cd1da..9abf06d91db1848e1df95ecd29433f592098f514 100644 (file)
@@ -229,14 +229,22 @@ class LZMAFile(_compression.BaseStream):
         """Write a bytes object to the file.
 
         Returns the number of uncompressed bytes written, which is
-        always len(data). Note that due to buffering, the file on disk
-        may not reflect the data written until close() is called.
+        always the length of data in bytes. Note that due to buffering,
+        the file on disk may not reflect the data written until close()
+        is called.
         """
         self._check_can_write()
+        if isinstance(data, (bytes, bytearray)):
+            length = len(data)
+        else:
+            # accept any data that supports the buffer protocol
+            data = memoryview(data)
+            length = data.nbytes
+
         compressed = self._compressor.compress(data)
         self._fp.write(compressed)
-        self._pos += len(data)
-        return len(data)
+        self._pos += length
+        return length
 
     def seek(self, offset, whence=io.SEEK_SET):
         """Change the file position.
index efed3a859ba217ddc48a3adf1fd92bfc204714b7..7913beb87a352520aed648383d3fb1cd77432e3f 100644 (file)
@@ -1,6 +1,7 @@
 from test import support
 from test.support import bigmemtest, _4G
 
+import array
 import unittest
 from io import BytesIO, DEFAULT_BUFFER_SIZE
 import os
@@ -620,6 +621,14 @@ class BZ2FileTest(BaseTest):
             with BZ2File(BytesIO(truncated[:i])) as f:
                 self.assertRaises(EOFError, f.read, 1)
 
+    def test_issue44439(self):
+        q = array.array('Q', [1, 2, 3, 4, 5])
+        LENGTH = len(q) * q.itemsize
+
+        with BZ2File(BytesIO(), 'w') as f:
+            self.assertEqual(f.write(q), LENGTH)
+            self.assertEqual(f.tell(), LENGTH)
+
 
 class BZ2CompressorTest(BaseTest):
     def testCompress(self):
index 446b61ab439ffef7bdbdfac0a45d738dc2bfc882..7b51e45aad92bac75fde153f00f215f3a322e68d 100644 (file)
@@ -592,6 +592,15 @@ class TestGzip(BaseTest):
         with gzip.open(self.filename, "rb") as f:
             f._buffer.raw._fp.prepend()
 
+    def test_issue44439(self):
+        q = array.array('Q', [1, 2, 3, 4, 5])
+        LENGTH = len(q) * q.itemsize
+
+        with gzip.GzipFile(fileobj=io.BytesIO(), mode='w') as f:
+            self.assertEqual(f.write(q), LENGTH)
+            self.assertEqual(f.tell(), LENGTH)
+
+
 class TestOpen(BaseTest):
     def test_binary_modes(self):
         uncompressed = data1 * 50
index db20300056e489975285d5a089f68e858579f469..1e2066b89168f419ebcab08b705ce0231e479785 100644 (file)
@@ -1,4 +1,5 @@
 import _compression
+import array
 from io import BytesIO, UnsupportedOperation, DEFAULT_BUFFER_SIZE
 import os
 import pathlib
@@ -1231,6 +1232,14 @@ class FileTestCase(unittest.TestCase):
         self.assertTrue(d2.eof)
         self.assertEqual(out1 + out2, entire)
 
+    def test_issue44439(self):
+        q = array.array('Q', [1, 2, 3, 4, 5])
+        LENGTH = len(q) * q.itemsize
+
+        with LZMAFile(BytesIO(), 'w') as f:
+            self.assertEqual(f.write(q), LENGTH)
+            self.assertEqual(f.tell(), LENGTH)
+
 
 class OpenTestCase(unittest.TestCase):
 
diff --git a/Misc/NEWS.d/next/Library/2021-06-17-15-01-51.bpo-44439.1S7QhT.rst b/Misc/NEWS.d/next/Library/2021-06-17-15-01-51.bpo-44439.1S7QhT.rst
new file mode 100644 (file)
index 0000000..2739668
--- /dev/null
@@ -0,0 +1,3 @@
+Fix in :meth:`bz2.BZ2File.write` / :meth:`lzma.LZMAFile.write` methods, when
+the input data is an object that supports the buffer protocol, the file length
+may be wrong.