]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-88233: zipfile: refactor _strip_extra (#102084)
authorJason R. Coombs <jaraco@jaraco.com>
Mon, 25 Sep 2023 23:46:58 +0000 (19:46 -0400)
committerGitHub <noreply@github.com>
Mon, 25 Sep 2023 23:46:58 +0000 (19:46 -0400)
* Refactor zipfile._strip_extra to use higher level abstractions for extras instead of a heavy-state loop.

* Add blurb

* Remove _strip_extra and use _Extra.strip directly.

* Use memoryview to avoid unnecessary copies while splitting Extras.

Lib/test/test_zipfile/test_core.py
Lib/zipfile/__init__.py
Misc/NEWS.d/next/Library/2023-02-20-12-00-11.gh-issue-88233.o5Zb0t.rst [new file with mode: 0644]

index 9960259c4cde0c7afbd0504d80abc6ea6cd316ae..0f6c0f2107ce6b69ea190d3c77c788ffbd6f0f8e 100644 (file)
@@ -3203,14 +3203,14 @@ class StripExtraTests(unittest.TestCase):
         b = s.pack(2, 0)
         c = s.pack(3, 0)
 
-        self.assertEqual(b'', zipfile._strip_extra(a, (self.ZIP64_EXTRA,)))
-        self.assertEqual(b, zipfile._strip_extra(b, (self.ZIP64_EXTRA,)))
+        self.assertEqual(b'', zipfile._Extra.strip(a, (self.ZIP64_EXTRA,)))
+        self.assertEqual(b, zipfile._Extra.strip(b, (self.ZIP64_EXTRA,)))
         self.assertEqual(
-            b+b"z", zipfile._strip_extra(b+b"z", (self.ZIP64_EXTRA,)))
+            b+b"z", zipfile._Extra.strip(b+b"z", (self.ZIP64_EXTRA,)))
 
-        self.assertEqual(b+c, zipfile._strip_extra(a+b+c, (self.ZIP64_EXTRA,)))
-        self.assertEqual(b+c, zipfile._strip_extra(b+a+c, (self.ZIP64_EXTRA,)))
-        self.assertEqual(b+c, zipfile._strip_extra(b+c+a, (self.ZIP64_EXTRA,)))
+        self.assertEqual(b+c, zipfile._Extra.strip(a+b+c, (self.ZIP64_EXTRA,)))
+        self.assertEqual(b+c, zipfile._Extra.strip(b+a+c, (self.ZIP64_EXTRA,)))
+        self.assertEqual(b+c, zipfile._Extra.strip(b+c+a, (self.ZIP64_EXTRA,)))
 
     def test_with_data(self):
         s = struct.Struct("<HH")
@@ -3218,38 +3218,38 @@ class StripExtraTests(unittest.TestCase):
         b = s.pack(2, 2) + b"bb"
         c = s.pack(3, 3) + b"ccc"
 
-        self.assertEqual(b"", zipfile._strip_extra(a, (self.ZIP64_EXTRA,)))
-        self.assertEqual(b, zipfile._strip_extra(b, (self.ZIP64_EXTRA,)))
+        self.assertEqual(b"", zipfile._Extra.strip(a, (self.ZIP64_EXTRA,)))
+        self.assertEqual(b, zipfile._Extra.strip(b, (self.ZIP64_EXTRA,)))
         self.assertEqual(
-            b+b"z", zipfile._strip_extra(b+b"z", (self.ZIP64_EXTRA,)))
+            b+b"z", zipfile._Extra.strip(b+b"z", (self.ZIP64_EXTRA,)))
 
-        self.assertEqual(b+c, zipfile._strip_extra(a+b+c, (self.ZIP64_EXTRA,)))
-        self.assertEqual(b+c, zipfile._strip_extra(b+a+c, (self.ZIP64_EXTRA,)))
-        self.assertEqual(b+c, zipfile._strip_extra(b+c+a, (self.ZIP64_EXTRA,)))
+        self.assertEqual(b+c, zipfile._Extra.strip(a+b+c, (self.ZIP64_EXTRA,)))
+        self.assertEqual(b+c, zipfile._Extra.strip(b+a+c, (self.ZIP64_EXTRA,)))
+        self.assertEqual(b+c, zipfile._Extra.strip(b+c+a, (self.ZIP64_EXTRA,)))
 
     def test_multiples(self):
         s = struct.Struct("<HH")
         a = s.pack(self.ZIP64_EXTRA, 1) + b"a"
         b = s.pack(2, 2) + b"bb"
 
-        self.assertEqual(b"", zipfile._strip_extra(a+a, (self.ZIP64_EXTRA,)))
-        self.assertEqual(b"", zipfile._strip_extra(a+a+a, (self.ZIP64_EXTRA,)))
+        self.assertEqual(b"", zipfile._Extra.strip(a+a, (self.ZIP64_EXTRA,)))
+        self.assertEqual(b"", zipfile._Extra.strip(a+a+a, (self.ZIP64_EXTRA,)))
         self.assertEqual(
-            b"z", zipfile._strip_extra(a+a+b"z", (self.ZIP64_EXTRA,)))
+            b"z", zipfile._Extra.strip(a+a+b"z", (self.ZIP64_EXTRA,)))
         self.assertEqual(
-            b+b"z", zipfile._strip_extra(a+a+b+b"z", (self.ZIP64_EXTRA,)))
+            b+b"z", zipfile._Extra.strip(a+a+b+b"z", (self.ZIP64_EXTRA,)))
 
-        self.assertEqual(b, zipfile._strip_extra(a+a+b, (self.ZIP64_EXTRA,)))
-        self.assertEqual(b, zipfile._strip_extra(a+b+a, (self.ZIP64_EXTRA,)))
-        self.assertEqual(b, zipfile._strip_extra(b+a+a, (self.ZIP64_EXTRA,)))
+        self.assertEqual(b, zipfile._Extra.strip(a+a+b, (self.ZIP64_EXTRA,)))
+        self.assertEqual(b, zipfile._Extra.strip(a+b+a, (self.ZIP64_EXTRA,)))
+        self.assertEqual(b, zipfile._Extra.strip(b+a+a, (self.ZIP64_EXTRA,)))
 
     def test_too_short(self):
-        self.assertEqual(b"", zipfile._strip_extra(b"", (self.ZIP64_EXTRA,)))
-        self.assertEqual(b"z", zipfile._strip_extra(b"z", (self.ZIP64_EXTRA,)))
+        self.assertEqual(b"", zipfile._Extra.strip(b"", (self.ZIP64_EXTRA,)))
+        self.assertEqual(b"z", zipfile._Extra.strip(b"z", (self.ZIP64_EXTRA,)))
         self.assertEqual(
-            b"zz", zipfile._strip_extra(b"zz", (self.ZIP64_EXTRA,)))
+            b"zz", zipfile._Extra.strip(b"zz", (self.ZIP64_EXTRA,)))
         self.assertEqual(
-            b"zzz", zipfile._strip_extra(b"zzz", (self.ZIP64_EXTRA,)))
+            b"zzz", zipfile._Extra.strip(b"zzz", (self.ZIP64_EXTRA,)))
 
 
 if __name__ == "__main__":
index 9fc1840ba1e5340658832ddc30c8394f4c1ccda8..2c963de18e4f9531c495836f14b2fbf75284001f 100644 (file)
@@ -188,28 +188,42 @@ _CD64_OFFSET_START_CENTDIR = 9
 
 _DD_SIGNATURE = 0x08074b50
 
-_EXTRA_FIELD_STRUCT = struct.Struct('<HH')
-
-def _strip_extra(extra, xids):
-    # Remove Extra Fields with specified IDs.
-    unpack = _EXTRA_FIELD_STRUCT.unpack
-    modified = False
-    buffer = []
-    start = i = 0
-    while i + 4 <= len(extra):
-        xid, xlen = unpack(extra[i : i + 4])
-        j = i + 4 + xlen
-        if xid in xids:
-            if i != start:
-                buffer.append(extra[start : i])
-            start = j
-            modified = True
-        i = j
-    if not modified:
-        return extra
-    if start != len(extra):
-        buffer.append(extra[start:])
-    return b''.join(buffer)
+
+class _Extra(bytes):
+    FIELD_STRUCT = struct.Struct('<HH')
+
+    def __new__(cls, val, id=None):
+        return super().__new__(cls, val)
+
+    def __init__(self, val, id=None):
+        self.id = id
+
+    @classmethod
+    def read_one(cls, raw):
+        try:
+            xid, xlen = cls.FIELD_STRUCT.unpack(raw[:4])
+        except struct.error:
+            xid = None
+            xlen = 0
+        return cls(raw[:4+xlen], xid), raw[4+xlen:]
+
+    @classmethod
+    def split(cls, data):
+        # use memoryview for zero-copy slices
+        rest = memoryview(data)
+        while rest:
+            extra, rest = _Extra.read_one(rest)
+            yield extra
+
+    @classmethod
+    def strip(cls, data, xids):
+        """Remove Extra fields with specified IDs."""
+        return b''.join(
+            ex
+            for ex in cls.split(data)
+            if ex.id not in xids
+        )
+
 
 def _check_zipfile(fp):
     try:
@@ -1963,7 +1977,7 @@ class ZipFile:
             min_version = 0
             if extra:
                 # Append a ZIP64 field to the extra's
-                extra_data = _strip_extra(extra_data, (1,))
+                extra_data = _Extra.strip(extra_data, (1,))
                 extra_data = struct.pack(
                     '<HH' + 'Q'*len(extra),
                     1, 8*len(extra), *extra) + extra_data
diff --git a/Misc/NEWS.d/next/Library/2023-02-20-12-00-11.gh-issue-88233.o5Zb0t.rst b/Misc/NEWS.d/next/Library/2023-02-20-12-00-11.gh-issue-88233.o5Zb0t.rst
new file mode 100644 (file)
index 0000000..945f92d
--- /dev/null
@@ -0,0 +1,2 @@
+Refactored ``zipfile._strip_extra`` to use higher level abstactions for
+extras instead of a heavy-state loop.