]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-126317: Simplify pickle code by using itertools.batched() (GH-126323)
authorLee Dong Wook <sh95119@gmail.com>
Sat, 2 Nov 2024 14:07:32 +0000 (23:07 +0900)
committerGitHub <noreply@github.com>
Sat, 2 Nov 2024 14:07:32 +0000 (16:07 +0200)
Lib/pickle.py

index ed8138beb908ee36c1c4065b3863b782b6ca4f9b..965e1952fb8c5ea9b99b704a050c8088a13fd725 100644 (file)
@@ -26,7 +26,7 @@ Misc variables:
 from types import FunctionType
 from copyreg import dispatch_table
 from copyreg import _extension_registry, _inverted_registry, _extension_cache
-from itertools import islice
+from itertools import batched
 from functools import partial
 import sys
 from sys import maxsize
@@ -1033,31 +1033,26 @@ class _Pickler:
                 write(APPEND)
             return
 
-        it = iter(items)
         start = 0
-        while True:
-            tmp = list(islice(it, self._BATCHSIZE))
-            n = len(tmp)
-            if n > 1:
+        for batch in batched(items, self._BATCHSIZE):
+            batch_len = len(batch)
+            if batch_len != 1:
                 write(MARK)
-                for i, x in enumerate(tmp, start):
+                for i, x in enumerate(batch, start):
                     try:
                         save(x)
                     except BaseException as exc:
                         exc.add_note(f'when serializing {_T(obj)} item {i}')
                         raise
                 write(APPENDS)
-            elif n:
+            else:
                 try:
-                    save(tmp[0])
+                    save(batch[0])
                 except BaseException as exc:
                     exc.add_note(f'when serializing {_T(obj)} item {start}')
                     raise
                 write(APPEND)
-            # else tmp is empty, and we're done
-            if n < self._BATCHSIZE:
-                return
-            start += n
+            start += batch_len
 
     def save_dict(self, obj):
         if self.bin:
@@ -1086,13 +1081,10 @@ class _Pickler:
                 write(SETITEM)
             return
 
-        it = iter(items)
-        while True:
-            tmp = list(islice(it, self._BATCHSIZE))
-            n = len(tmp)
-            if n > 1:
+        for batch in batched(items, self._BATCHSIZE):
+            if len(batch) != 1:
                 write(MARK)
-                for k, v in tmp:
+                for k, v in batch:
                     save(k)
                     try:
                         save(v)
@@ -1100,8 +1092,8 @@ class _Pickler:
                         exc.add_note(f'when serializing {_T(obj)} item {k!r}')
                         raise
                 write(SETITEMS)
-            elif n:
-                k, v = tmp[0]
+            else:
+                k, v = batch[0]
                 save(k)
                 try:
                     save(v)
@@ -1109,9 +1101,6 @@ class _Pickler:
                     exc.add_note(f'when serializing {_T(obj)} item {k!r}')
                     raise
                 write(SETITEM)
-            # else tmp is empty, and we're done
-            if n < self._BATCHSIZE:
-                return
 
     def save_set(self, obj):
         save = self.save
@@ -1124,21 +1113,15 @@ class _Pickler:
         write(EMPTY_SET)
         self.memoize(obj)
 
-        it = iter(obj)
-        while True:
-            batch = list(islice(it, self._BATCHSIZE))
-            n = len(batch)
-            if n > 0:
-                write(MARK)
-                try:
-                    for item in batch:
-                        save(item)
-                except BaseException as exc:
-                    exc.add_note(f'when serializing {_T(obj)} element')
-                    raise
-                write(ADDITEMS)
-            if n < self._BATCHSIZE:
-                return
+        for batch in batched(obj, self._BATCHSIZE):
+            write(MARK)
+            try:
+                for item in batch:
+                    save(item)
+            except BaseException as exc:
+                exc.add_note(f'when serializing {_T(obj)} element')
+                raise
+            write(ADDITEMS)
     dispatch[set] = save_set
 
     def save_frozenset(self, obj):