b = s.pack(2, 0)
c = s.pack(3, 0)
- self.assertEqual(b'', zipfile._strip_extra(a, (self.ZIP64_EXTRA,)))
- self.assertEqual(b, zipfile._strip_extra(b, (self.ZIP64_EXTRA,)))
+ self.assertEqual(b'', zipfile._Extra.strip(a, (self.ZIP64_EXTRA,)))
+ self.assertEqual(b, zipfile._Extra.strip(b, (self.ZIP64_EXTRA,)))
self.assertEqual(
- b+b"z", zipfile._strip_extra(b+b"z", (self.ZIP64_EXTRA,)))
+ b+b"z", zipfile._Extra.strip(b+b"z", (self.ZIP64_EXTRA,)))
- self.assertEqual(b+c, zipfile._strip_extra(a+b+c, (self.ZIP64_EXTRA,)))
- self.assertEqual(b+c, zipfile._strip_extra(b+a+c, (self.ZIP64_EXTRA,)))
- self.assertEqual(b+c, zipfile._strip_extra(b+c+a, (self.ZIP64_EXTRA,)))
+ self.assertEqual(b+c, zipfile._Extra.strip(a+b+c, (self.ZIP64_EXTRA,)))
+ self.assertEqual(b+c, zipfile._Extra.strip(b+a+c, (self.ZIP64_EXTRA,)))
+ self.assertEqual(b+c, zipfile._Extra.strip(b+c+a, (self.ZIP64_EXTRA,)))
def test_with_data(self):
s = struct.Struct("<HH")
b = s.pack(2, 2) + b"bb"
c = s.pack(3, 3) + b"ccc"
- self.assertEqual(b"", zipfile._strip_extra(a, (self.ZIP64_EXTRA,)))
- self.assertEqual(b, zipfile._strip_extra(b, (self.ZIP64_EXTRA,)))
+ self.assertEqual(b"", zipfile._Extra.strip(a, (self.ZIP64_EXTRA,)))
+ self.assertEqual(b, zipfile._Extra.strip(b, (self.ZIP64_EXTRA,)))
self.assertEqual(
- b+b"z", zipfile._strip_extra(b+b"z", (self.ZIP64_EXTRA,)))
+ b+b"z", zipfile._Extra.strip(b+b"z", (self.ZIP64_EXTRA,)))
- self.assertEqual(b+c, zipfile._strip_extra(a+b+c, (self.ZIP64_EXTRA,)))
- self.assertEqual(b+c, zipfile._strip_extra(b+a+c, (self.ZIP64_EXTRA,)))
- self.assertEqual(b+c, zipfile._strip_extra(b+c+a, (self.ZIP64_EXTRA,)))
+ self.assertEqual(b+c, zipfile._Extra.strip(a+b+c, (self.ZIP64_EXTRA,)))
+ self.assertEqual(b+c, zipfile._Extra.strip(b+a+c, (self.ZIP64_EXTRA,)))
+ self.assertEqual(b+c, zipfile._Extra.strip(b+c+a, (self.ZIP64_EXTRA,)))
def test_multiples(self):
s = struct.Struct("<HH")
a = s.pack(self.ZIP64_EXTRA, 1) + b"a"
b = s.pack(2, 2) + b"bb"
- self.assertEqual(b"", zipfile._strip_extra(a+a, (self.ZIP64_EXTRA,)))
- self.assertEqual(b"", zipfile._strip_extra(a+a+a, (self.ZIP64_EXTRA,)))
+ self.assertEqual(b"", zipfile._Extra.strip(a+a, (self.ZIP64_EXTRA,)))
+ self.assertEqual(b"", zipfile._Extra.strip(a+a+a, (self.ZIP64_EXTRA,)))
self.assertEqual(
- b"z", zipfile._strip_extra(a+a+b"z", (self.ZIP64_EXTRA,)))
+ b"z", zipfile._Extra.strip(a+a+b"z", (self.ZIP64_EXTRA,)))
self.assertEqual(
- b+b"z", zipfile._strip_extra(a+a+b+b"z", (self.ZIP64_EXTRA,)))
+ b+b"z", zipfile._Extra.strip(a+a+b+b"z", (self.ZIP64_EXTRA,)))
- self.assertEqual(b, zipfile._strip_extra(a+a+b, (self.ZIP64_EXTRA,)))
- self.assertEqual(b, zipfile._strip_extra(a+b+a, (self.ZIP64_EXTRA,)))
- self.assertEqual(b, zipfile._strip_extra(b+a+a, (self.ZIP64_EXTRA,)))
+ self.assertEqual(b, zipfile._Extra.strip(a+a+b, (self.ZIP64_EXTRA,)))
+ self.assertEqual(b, zipfile._Extra.strip(a+b+a, (self.ZIP64_EXTRA,)))
+ self.assertEqual(b, zipfile._Extra.strip(b+a+a, (self.ZIP64_EXTRA,)))
def test_too_short(self):
- self.assertEqual(b"", zipfile._strip_extra(b"", (self.ZIP64_EXTRA,)))
- self.assertEqual(b"z", zipfile._strip_extra(b"z", (self.ZIP64_EXTRA,)))
+ self.assertEqual(b"", zipfile._Extra.strip(b"", (self.ZIP64_EXTRA,)))
+ self.assertEqual(b"z", zipfile._Extra.strip(b"z", (self.ZIP64_EXTRA,)))
self.assertEqual(
- b"zz", zipfile._strip_extra(b"zz", (self.ZIP64_EXTRA,)))
+ b"zz", zipfile._Extra.strip(b"zz", (self.ZIP64_EXTRA,)))
self.assertEqual(
- b"zzz", zipfile._strip_extra(b"zzz", (self.ZIP64_EXTRA,)))
+ b"zzz", zipfile._Extra.strip(b"zzz", (self.ZIP64_EXTRA,)))
if __name__ == "__main__":
_DD_SIGNATURE = 0x08074b50
-_EXTRA_FIELD_STRUCT = struct.Struct('<HH')
-
-def _strip_extra(extra, xids):
- # Remove Extra Fields with specified IDs.
- unpack = _EXTRA_FIELD_STRUCT.unpack
- modified = False
- buffer = []
- start = i = 0
- while i + 4 <= len(extra):
- xid, xlen = unpack(extra[i : i + 4])
- j = i + 4 + xlen
- if xid in xids:
- if i != start:
- buffer.append(extra[start : i])
- start = j
- modified = True
- i = j
- if not modified:
- return extra
- if start != len(extra):
- buffer.append(extra[start:])
- return b''.join(buffer)
+
+class _Extra(bytes):
+ FIELD_STRUCT = struct.Struct('<HH')
+
+ def __new__(cls, val, id=None):
+ return super().__new__(cls, val)
+
+ def __init__(self, val, id=None):
+ self.id = id
+
+ @classmethod
+ def read_one(cls, raw):
+ try:
+ xid, xlen = cls.FIELD_STRUCT.unpack(raw[:4])
+ except struct.error:
+ xid = None
+ xlen = 0
+ return cls(raw[:4+xlen], xid), raw[4+xlen:]
+
+ @classmethod
+ def split(cls, data):
+ # use memoryview for zero-copy slices
+ rest = memoryview(data)
+ while rest:
+ extra, rest = _Extra.read_one(rest)
+ yield extra
+
+ @classmethod
+ def strip(cls, data, xids):
+ """Remove Extra fields with specified IDs."""
+ return b''.join(
+ ex
+ for ex in cls.split(data)
+ if ex.id not in xids
+ )
+
def _check_zipfile(fp):
try:
min_version = 0
if extra:
# Append a ZIP64 field to the extra's
- extra_data = _strip_extra(extra_data, (1,))
+ extra_data = _Extra.strip(extra_data, (1,))
extra_data = struct.pack(
'<HH' + 'Q'*len(extra),
1, 8*len(extra), *extra) + extra_data