]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-99631: Add custom `loads` and `dumps` support for the `shelve` module (#118065)
authorFurkan Onder <furkanonder@protonmail.com>
Sat, 12 Jul 2025 12:27:32 +0000 (15:27 +0300)
committerGitHub <noreply@github.com>
Sat, 12 Jul 2025 12:27:32 +0000 (14:27 +0200)
Co-authored-by: Pieter Eendebak <pieter.eendebak@gmail.com>
Co-authored-by: Petr Viktorin <encukou@gmail.com>
Co-authored-by: Bénédikt Tran <10796600+picnixz@users.noreply.github.com>
Doc/library/shelve.rst
Lib/shelve.py
Lib/test/test_shelve.py
Misc/NEWS.d/next/Library/2024-07-16-00-01-04.gh-issue-99631.GWD4fD.rst [new file with mode: 0644]

index 23a2e0c3d0c758373d29129456ac7e0b01079273..238086195240565e75115227c43f40c3e1595c8d 100644 (file)
@@ -17,7 +17,8 @@ This includes most class instances, recursive data types, and objects containing
 lots of shared  sub-objects.  The keys are ordinary strings.
 
 
-.. function:: open(filename, flag='c', protocol=None, writeback=False)
+.. function:: open(filename, flag='c', protocol=None, writeback=False, *, \
+                   serializer=None, deserializer=None)
 
    Open a persistent dictionary.  The filename specified is the base filename for
    the underlying database.  As a side-effect, an extension may be added to the
@@ -41,6 +42,21 @@ lots of shared  sub-objects.  The keys are ordinary strings.
    determine which accessed entries are mutable, nor which ones were actually
    mutated).
 
+   By default, :mod:`shelve` uses :func:`pickle.dumps` and :func:`pickle.loads`
+   for serializing and deserializing. This can be changed by supplying
+   *serializer* and *deserializer*, respectively.
+
+   The *serializer* argument must be a callable which takes an object ``obj``
+   and the *protocol* as inputs and returns the representation ``obj`` as a
+   :term:`bytes-like object`; the *protocol* value may be ignored by the
+   serializer.
+
+   The *deserializer* argument must be callable which takes a serialized object
+   given as a :class:`bytes` object and returns the corresponding object.
+
+   A :exc:`ShelveError` is raised if *serializer* is given but *deserializer*
+   is not, or vice-versa.
+
    .. versionchanged:: 3.10
       :const:`pickle.DEFAULT_PROTOCOL` is now used as the default pickle
       protocol.
@@ -48,6 +64,10 @@ lots of shared  sub-objects.  The keys are ordinary strings.
    .. versionchanged:: 3.11
       Accepts :term:`path-like object` for filename.
 
+   .. versionchanged:: next
+      Accepts custom *serializer* and *deserializer* functions in place of
+      :func:`pickle.dumps` and :func:`pickle.loads`.
+
    .. note::
 
       Do not rely on the shelf being closed automatically; always call
@@ -129,7 +149,8 @@ Restrictions
   explicitly.
 
 
-.. class:: Shelf(dict, protocol=None, writeback=False, keyencoding='utf-8')
+.. class:: Shelf(dict, protocol=None, writeback=False, \
+                 keyencoding='utf-8', *, serializer=None, deserializer=None)
 
    A subclass of :class:`collections.abc.MutableMapping` which stores pickled
    values in the *dict* object.
@@ -147,6 +168,9 @@ Restrictions
    The *keyencoding* parameter is the encoding used to encode keys before they
    are used with the underlying dict.
 
+   The *serializer* and *deserializer* parameters have the same interpretation
+   as in :func:`~shelve.open`.
+
    A :class:`Shelf` object can also be used as a context manager, in which
    case it will be automatically closed when the :keyword:`with` block ends.
 
@@ -161,8 +185,13 @@ Restrictions
       :const:`pickle.DEFAULT_PROTOCOL` is now used as the default pickle
       protocol.
 
+   .. versionchanged:: next
+      Added the *serializer* and *deserializer* parameters.
 
-.. class:: BsdDbShelf(dict, protocol=None, writeback=False, keyencoding='utf-8')
+
+.. class:: BsdDbShelf(dict, protocol=None, writeback=False, \
+                      keyencoding='utf-8', *, \
+                      serializer=None, deserializer=None)
 
    A subclass of :class:`Shelf` which exposes :meth:`!first`, :meth:`!next`,
    :meth:`!previous`, :meth:`!last` and :meth:`!set_location` methods.
@@ -172,18 +201,27 @@ Restrictions
    modules.  The *dict* object passed to the constructor must support those
    methods.  This is generally accomplished by calling one of
    :func:`!bsddb.hashopen`, :func:`!bsddb.btopen` or :func:`!bsddb.rnopen`.  The
-   optional *protocol*, *writeback*, and *keyencoding* parameters have the same
-   interpretation as for the :class:`Shelf` class.
+   optional *protocol*, *writeback*, *keyencoding*, *serializer* and *deserializer*
+   parameters have the same interpretation as in :func:`~shelve.open`.
+
+   .. versionchanged:: next
+      Added the *serializer* and *deserializer* parameters.
 
 
-.. class:: DbfilenameShelf(filename, flag='c', protocol=None, writeback=False)
+.. class:: DbfilenameShelf(filename, flag='c', protocol=None, \
+                           writeback=False, *, serializer=None, \
+                           deserializer=None)
 
    A subclass of :class:`Shelf` which accepts a *filename* instead of a dict-like
    object.  The underlying file will be opened using :func:`dbm.open`.  By
    default, the file will be created and opened for both read and write.  The
-   optional *flag* parameter has the same interpretation as for the :func:`.open`
-   function.  The optional *protocol* and *writeback* parameters have the same
-   interpretation as for the :class:`Shelf` class.
+   optional *flag* parameter has the same interpretation as for the
+   :func:`.open` function.  The optional *protocol*, *writeback*, *serializer*
+   and *deserializer* parameters have the same interpretation as in
+   :func:`~shelve.open`.
+
+   .. versionchanged:: next
+      Added the *serializer* and *deserializer* parameters.
 
 
 .. _shelve-example:
@@ -225,6 +263,20 @@ object)::
    d.close()                  # close it
 
 
+Exceptions
+----------
+
+.. exception:: ShelveError
+
+   Exception raised when one of the arguments *deserializer* and *serializer*
+   is missing in the :func:`~shelve.open`, :class:`Shelf`, :class:`BsdDbShelf`
+   and :class:`DbfilenameShelf`.
+
+   The *deserializer* and *serializer* arguments must be given together.
+
+   .. versionadded:: next
+
+
 .. seealso::
 
    Module :mod:`dbm`
index b53dc8b7a8ece914b43e782fa07400226edc3134..1010be1e09d702f98ea5d5d3280462ad076435d8 100644 (file)
@@ -56,12 +56,17 @@ entries in the cache, and empty the cache (d.sync() also synchronizes
 the persistent dictionary on disk, if feasible).
 """
 
-from pickle import DEFAULT_PROTOCOL, Pickler, Unpickler
+from pickle import DEFAULT_PROTOCOL, dumps, loads
 from io import BytesIO
 
 import collections.abc
 
-__all__ = ["Shelf", "BsdDbShelf", "DbfilenameShelf", "open"]
+__all__ = ["ShelveError", "Shelf", "BsdDbShelf", "DbfilenameShelf", "open"]
+
+
+class ShelveError(Exception):
+    pass
+
 
 class _ClosedDict(collections.abc.MutableMapping):
     'Marker for a closed dict.  Access attempts raise a ValueError.'
@@ -82,7 +87,7 @@ class Shelf(collections.abc.MutableMapping):
     """
 
     def __init__(self, dict, protocol=None, writeback=False,
-                 keyencoding="utf-8"):
+                 keyencoding="utf-8", *, serializer=None, deserializer=None):
         self.dict = dict
         if protocol is None:
             protocol = DEFAULT_PROTOCOL
@@ -91,6 +96,16 @@ class Shelf(collections.abc.MutableMapping):
         self.cache = {}
         self.keyencoding = keyencoding
 
+        if serializer is None and deserializer is None:
+            self.serializer = dumps
+            self.deserializer = loads
+        elif (serializer is None) ^ (deserializer is None):
+            raise ShelveError("serializer and deserializer must be "
+                              "defined together")
+        else:
+            self.serializer = serializer
+            self.deserializer = deserializer
+
     def __iter__(self):
         for k in self.dict.keys():
             yield k.decode(self.keyencoding)
@@ -110,8 +125,8 @@ class Shelf(collections.abc.MutableMapping):
         try:
             value = self.cache[key]
         except KeyError:
-            f = BytesIO(self.dict[key.encode(self.keyencoding)])
-            value = Unpickler(f).load()
+            f = self.dict[key.encode(self.keyencoding)]
+            value = self.deserializer(f)
             if self.writeback:
                 self.cache[key] = value
         return value
@@ -119,10 +134,8 @@ class Shelf(collections.abc.MutableMapping):
     def __setitem__(self, key, value):
         if self.writeback:
             self.cache[key] = value
-        f = BytesIO()
-        p = Pickler(f, self._protocol)
-        p.dump(value)
-        self.dict[key.encode(self.keyencoding)] = f.getvalue()
+        serialized_value = self.serializer(value, self._protocol)
+        self.dict[key.encode(self.keyencoding)] = serialized_value
 
     def __delitem__(self, key):
         del self.dict[key.encode(self.keyencoding)]
@@ -191,33 +204,29 @@ class BsdDbShelf(Shelf):
     """
 
     def __init__(self, dict, protocol=None, writeback=False,
-                 keyencoding="utf-8"):
-        Shelf.__init__(self, dict, protocol, writeback, keyencoding)
+                 keyencoding="utf-8", *, serializer=None, deserializer=None):
+        Shelf.__init__(self, dict, protocol, writeback, keyencoding,
+                       serializer=serializer, deserializer=deserializer)
 
     def set_location(self, key):
         (key, value) = self.dict.set_location(key)
-        f = BytesIO(value)
-        return (key.decode(self.keyencoding), Unpickler(f).load())
+        return (key.decode(self.keyencoding), self.deserializer(value))
 
     def next(self):
         (key, value) = next(self.dict)
-        f = BytesIO(value)
-        return (key.decode(self.keyencoding), Unpickler(f).load())
+        return (key.decode(self.keyencoding), self.deserializer(value))
 
     def previous(self):
         (key, value) = self.dict.previous()
-        f = BytesIO(value)
-        return (key.decode(self.keyencoding), Unpickler(f).load())
+        return (key.decode(self.keyencoding), self.deserializer(value))
 
     def first(self):
         (key, value) = self.dict.first()
-        f = BytesIO(value)
-        return (key.decode(self.keyencoding), Unpickler(f).load())
+        return (key.decode(self.keyencoding), self.deserializer(value))
 
     def last(self):
         (key, value) = self.dict.last()
-        f = BytesIO(value)
-        return (key.decode(self.keyencoding), Unpickler(f).load())
+        return (key.decode(self.keyencoding), self.deserializer(value))
 
 
 class DbfilenameShelf(Shelf):
@@ -227,9 +236,11 @@ class DbfilenameShelf(Shelf):
     See the module's __doc__ string for an overview of the interface.
     """
 
-    def __init__(self, filename, flag='c', protocol=None, writeback=False):
+    def __init__(self, filename, flag='c', protocol=None, writeback=False, *,
+                 serializer=None, deserializer=None):
         import dbm
-        Shelf.__init__(self, dbm.open(filename, flag), protocol, writeback)
+        Shelf.__init__(self, dbm.open(filename, flag), protocol, writeback,
+                       serializer=serializer, deserializer=deserializer)
 
     def clear(self):
         """Remove all items from the shelf."""
@@ -238,8 +249,8 @@ class DbfilenameShelf(Shelf):
         self.cache.clear()
         self.dict.clear()
 
-
-def open(filename, flag='c', protocol=None, writeback=False):
+def open(filename, flag='c', protocol=None, writeback=False, *,
+         serializer=None, deserializer=None):
     """Open a persistent dictionary for reading and writing.
 
     The filename parameter is the base filename for the underlying
@@ -252,4 +263,5 @@ def open(filename, flag='c', protocol=None, writeback=False):
     See the module's __doc__ string for an overview of the interface.
     """
 
-    return DbfilenameShelf(filename, flag, protocol, writeback)
+    return DbfilenameShelf(filename, flag, protocol, writeback,
+                           serializer=serializer, deserializer=deserializer)
index 08c6562f2a273e01f9e9a149dd5cf34ecea38a98..64609ab9dd9a626a51160f24fd9fa40aebbcb47e 100644 (file)
@@ -1,10 +1,11 @@
+import array
 import unittest
 import dbm
 import shelve
 import pickle
 import os
 
-from test.support import os_helper
+from test.support import import_helper, os_helper
 from collections.abc import MutableMapping
 from test.test_dbm import dbm_iterator
 
@@ -165,6 +166,239 @@ class TestCase(unittest.TestCase):
         with shelve.Shelf({}) as s:
             self.assertEqual(s._protocol, pickle.DEFAULT_PROTOCOL)
 
+    def test_custom_serializer_and_deserializer(self):
+        os.mkdir(self.dirname)
+        self.addCleanup(os_helper.rmtree, self.dirname)
+
+        def serializer(obj, protocol):
+            if isinstance(obj, (bytes, bytearray, str)):
+                if protocol == 5:
+                    return obj
+                return type(obj).__name__
+            elif isinstance(obj, array.array):
+                return obj.tobytes()
+            raise TypeError(f"Unsupported type for serialization: {type(obj)}")
+
+        def deserializer(data):
+            if isinstance(data, (bytes, bytearray, str)):
+                return data.decode("utf-8")
+            elif isinstance(data, array.array):
+                return array.array("b", data)
+            raise TypeError(
+                f"Unsupported type for deserialization: {type(data)}"
+            )
+
+        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+            with self.subTest(proto=proto), shelve.open(
+                self.fn,
+                protocol=proto,
+                serializer=serializer,
+                deserializer=deserializer
+            ) as s:
+                bar = "bar"
+                bytes_data = b"Hello, world!"
+                bytearray_data = bytearray(b"\x00\x01\x02\x03\x04")
+                array_data = array.array("i", [1, 2, 3, 4, 5])
+
+                s["foo"] = bar
+                s["bytes_data"] = bytes_data
+                s["bytearray_data"] = bytearray_data
+                s["array_data"] = array_data
+
+                if proto == 5:
+                    self.assertEqual(s["foo"], str(bar))
+                    self.assertEqual(s["bytes_data"], "Hello, world!")
+                    self.assertEqual(
+                        s["bytearray_data"], bytearray_data.decode()
+                    )
+                    self.assertEqual(
+                        s["array_data"], array_data.tobytes().decode()
+                    )
+                else:
+                    self.assertEqual(s["foo"], "str")
+                    self.assertEqual(s["bytes_data"], "bytes")
+                    self.assertEqual(s["bytearray_data"], "bytearray")
+                    self.assertEqual(
+                        s["array_data"], array_data.tobytes().decode()
+                    )
+
+    def test_custom_incomplete_serializer_and_deserializer(self):
+        dbm_sqlite3 = import_helper.import_module("dbm.sqlite3")
+        os.mkdir(self.dirname)
+        self.addCleanup(os_helper.rmtree, self.dirname)
+
+        with self.assertRaises(dbm_sqlite3.error):
+            def serializer(obj, protocol=None):
+                pass
+
+            def deserializer(data):
+                return data.decode("utf-8")
+
+            with shelve.open(self.fn, serializer=serializer,
+                             deserializer=deserializer) as s:
+                s["foo"] = "bar"
+
+        def serializer(obj, protocol=None):
+            return type(obj).__name__.encode("utf-8")
+
+        def deserializer(data):
+            pass
+
+        with shelve.open(self.fn, serializer=serializer,
+                         deserializer=deserializer) as s:
+            s["foo"] = "bar"
+            self.assertNotEqual(s["foo"], "bar")
+            self.assertIsNone(s["foo"])
+
+    def test_custom_serializer_and_deserializer_bsd_db_shelf(self):
+        berkeleydb = import_helper.import_module("berkeleydb")
+        os.mkdir(self.dirname)
+        self.addCleanup(os_helper.rmtree, self.dirname)
+
+        def serializer(obj, protocol=None):
+            data = obj.__class__.__name__
+            if protocol == 5:
+                data = str(len(data))
+            return data.encode("utf-8")
+
+        def deserializer(data):
+            return data.decode("utf-8")
+
+        def type_name_len(obj):
+            return f"{(len(type(obj).__name__))}"
+
+        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+            with self.subTest(proto=proto), shelve.BsdDbShelf(
+                berkeleydb.btopen(self.fn),
+                protocol=proto,
+                serializer=serializer,
+                deserializer=deserializer,
+            ) as s:
+                bar = "bar"
+                bytes_obj = b"Hello, world!"
+                bytearray_obj = bytearray(b"\x00\x01\x02\x03\x04")
+                arr_obj = array.array("i", [1, 2, 3, 4, 5])
+
+                s["foo"] = bar
+                s["bytes_data"] = bytes_obj
+                s["bytearray_data"] = bytearray_obj
+                s["array_data"] = arr_obj
+
+                if proto == 5:
+                    self.assertEqual(s["foo"], type_name_len(bar))
+                    self.assertEqual(s["bytes_data"], type_name_len(bytes_obj))
+                    self.assertEqual(s["bytearray_data"],
+                                     type_name_len(bytearray_obj))
+                    self.assertEqual(s["array_data"], type_name_len(arr_obj))
+
+                    k, v = s.set_location(b"foo")
+                    self.assertEqual(k, "foo")
+                    self.assertEqual(v, type_name_len(bar))
+
+                    k, v = s.previous()
+                    self.assertEqual(k, "bytes_data")
+                    self.assertEqual(v, type_name_len(bytes_obj))
+
+                    k, v = s.previous()
+                    self.assertEqual(k, "bytearray_data")
+                    self.assertEqual(v, type_name_len(bytearray_obj))
+
+                    k, v = s.previous()
+                    self.assertEqual(k, "array_data")
+                    self.assertEqual(v, type_name_len(arr_obj))
+
+                    k, v = s.next()
+                    self.assertEqual(k, "bytearray_data")
+                    self.assertEqual(v, type_name_len(bytearray_obj))
+
+                    k, v = s.next()
+                    self.assertEqual(k, "bytes_data")
+                    self.assertEqual(v, type_name_len(bytes_obj))
+
+                    k, v = s.first()
+                    self.assertEqual(k, "array_data")
+                    self.assertEqual(v, type_name_len(arr_obj))
+                else:
+                    k, v = s.set_location(b"foo")
+                    self.assertEqual(k, "foo")
+                    self.assertEqual(v, "str")
+
+                    k, v = s.previous()
+                    self.assertEqual(k, "bytes_data")
+                    self.assertEqual(v, "bytes")
+
+                    k, v = s.previous()
+                    self.assertEqual(k, "bytearray_data")
+                    self.assertEqual(v, "bytearray")
+
+                    k, v = s.previous()
+                    self.assertEqual(k, "array_data")
+                    self.assertEqual(v, "array")
+
+                    k, v = s.next()
+                    self.assertEqual(k, "bytearray_data")
+                    self.assertEqual(v, "bytearray")
+
+                    k, v = s.next()
+                    self.assertEqual(k, "bytes_data")
+                    self.assertEqual(v, "bytes")
+
+                    k, v = s.first()
+                    self.assertEqual(k, "array_data")
+                    self.assertEqual(v, "array")
+
+                    self.assertEqual(s["foo"], "str")
+                    self.assertEqual(s["bytes_data"], "bytes")
+                    self.assertEqual(s["bytearray_data"], "bytearray")
+                    self.assertEqual(s["array_data"], "array")
+
+    def test_custom_incomplete_serializer_and_deserializer_bsd_db_shelf(self):
+        berkeleydb = import_helper.import_module("berkeleydb")
+        os.mkdir(self.dirname)
+        self.addCleanup(os_helper.rmtree, self.dirname)
+
+        def serializer(obj, protocol=None):
+            return type(obj).__name__.encode("utf-8")
+
+        def deserializer(data):
+            pass
+
+        with shelve.BsdDbShelf(berkeleydb.btopen(self.fn),
+                               serializer=serializer,
+                               deserializer=deserializer) as s:
+            s["foo"] = "bar"
+            self.assertIsNone(s["foo"])
+            self.assertNotEqual(s["foo"], "bar")
+
+        def serializer(obj, protocol=None):
+            pass
+
+        def deserializer(data):
+            return data.decode("utf-8")
+
+        with shelve.BsdDbShelf(berkeleydb.btopen(self.fn),
+                               serializer=serializer,
+                               deserializer=deserializer) as s:
+            s["foo"] = "bar"
+            self.assertEqual(s["foo"], "")
+            self.assertNotEqual(s["foo"], "bar")
+
+    def test_missing_custom_deserializer(self):
+        def serializer(obj, protocol=None):
+            pass
+
+        kwargs = dict(protocol=2, writeback=False, serializer=serializer)
+        self.assertRaises(shelve.ShelveError, shelve.Shelf, {}, **kwargs)
+        self.assertRaises(shelve.ShelveError, shelve.BsdDbShelf, {}, **kwargs)
+
+    def test_missing_custom_serializer(self):
+        def deserializer(data):
+            pass
+
+        kwargs = dict(protocol=2, writeback=False, deserializer=deserializer)
+        self.assertRaises(shelve.ShelveError, shelve.Shelf, {}, **kwargs)
+        self.assertRaises(shelve.ShelveError, shelve.BsdDbShelf, {}, **kwargs)
+
 
 class TestShelveBase:
     type2test = shelve.Shelf
diff --git a/Misc/NEWS.d/next/Library/2024-07-16-00-01-04.gh-issue-99631.GWD4fD.rst b/Misc/NEWS.d/next/Library/2024-07-16-00-01-04.gh-issue-99631.GWD4fD.rst
new file mode 100644 (file)
index 0000000..735249b
--- /dev/null
@@ -0,0 +1,2 @@
+The :mod:`shelve` module now accepts custom serialization
+and deserialization functions.