]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Customise json dumpers and loaders using a class attribute
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 27 Jun 2021 00:17:02 +0000 (01:17 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 27 Jun 2021 00:22:10 +0000 (01:22 +0100)
This is an intermediary step: the class attribute can be set by
subclassing but the idea is to extend set_json_dumps/loads functions to
create these subclasses automatically.

psycopg/psycopg/types/json.py
tests/types/test_json.py

index 2efa147a2b86dc5c3d033c6a4f786ad313a46561..d8e9fa319952904c45222c413e80b4eac9554d3c 100644 (file)
@@ -16,47 +16,33 @@ from ..errors import DataError
 JsonDumpsFunction = Callable[[Any], str]
 JsonLoadsFunction = Callable[[Union[str, bytes, bytearray]], Any]
 
-# Global load/dump functions, used by default.
-_loads: JsonLoadsFunction = json.loads
-_dumps: JsonDumpsFunction = json.dumps
-
 
 def set_json_dumps(dumps: JsonDumpsFunction) -> None:
     """
     Set a global JSON serialisation function to use by default by JSON dumpers.
 
-    Defaults to the builtin `json.dumps()`. You can override it to use a
-    different JSON library or to use customised arguments.
-
-    If you need a non-global customisation you can subclass the `!JsonDumper`
-    family of classes, overriding the `!get_loads()` method, and register
-    your class in the context required.
+    By default dumping JSON uses the builtin `json.dumps()`. You can override
+    it to use a different JSON library or to use customised arguments.
     """
-    global _dumps
-    _dumps = dumps
+    _JsonDumper._dumps = dumps
 
 
 def set_json_loads(loads: JsonLoadsFunction) -> None:
     """
     Set a global JSON parsing function to use by default by the JSON loaders.
 
-    Defaults to the builtin `json.loads()`. You can override it to use a
-    different JSON library or to use customised arguments.
-
-    If you need a non-global customisation you can subclass the `!JsonLoader`
-    family of classes, overriding the `!get_loads()` method, and register
-    your class in the context required.
+    By default loading JSON uses the builtin `json.loads()`. You can override
+    it to use a different JSON library or to use customised arguments.
     """
-    global _loads
-    _loads = loads
+    _JsonLoader._loads = loads
 
 
 class _JsonWrapper:
-    __slots__ = ("obj", "_dumps")
+    __slots__ = ("obj", "dumps")
 
     def __init__(self, obj: Any, dumps: Optional[JsonDumpsFunction] = None):
         self.obj = obj
-        self._dumps = dumps or _dumps
+        self.dumps = dumps
 
     def __repr__(self) -> str:
         sobj = repr(self.obj)
@@ -64,9 +50,6 @@ class _JsonWrapper:
             sobj = f"{sobj[:35]} ... ({len(sobj)} chars)"
         return f"{self.__class__.__name__}({sobj})"
 
-    def dumps(self) -> str:
-        return self._dumps(self.obj)
-
 
 class Json(_JsonWrapper):
     __slots__ = ()
@@ -80,8 +63,17 @@ class _JsonDumper(Dumper):
 
     format = Format.TEXT
 
+    # The globally used JSON dumps() function. It can be changed globally (by
+    # set_json_dumps) or by a subclass.
+    _dumps: JsonDumpsFunction = json.dumps
+
+    def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+        super().__init__(cls, context)
+        self.dumps = self.__class__._dumps
+
     def dump(self, obj: _JsonWrapper) -> bytes:
-        return obj.dumps().encode("utf-8")
+        dumps = obj.dumps or self.dumps
+        return dumps(obj.obj).encode("utf-8")
 
 
 class JsonDumper(_JsonDumper):
@@ -106,28 +98,25 @@ class JsonbBinaryDumper(JsonbDumper):
     format = Format.BINARY
 
     def dump(self, obj: _JsonWrapper) -> bytes:
-        return b"\x01" + obj.dumps().encode("utf-8")
+        dumps = obj.dumps or self.dumps
+        return b"\x01" + dumps(obj.obj).encode("utf-8")
 
 
 class _JsonLoader(Loader):
-    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
-        super().__init__(oid, context)
-        self._loads = self.get_loads()
 
-    def get_loads(self) -> JsonLoadsFunction:
-        r"""
-        Return a `json.loads()`\-compatible function to de-serialize the value.
+    # The globally used JSON loads() function. It can be changed globally (by
+    # set_json_loads) or by a subclass.
+    _loads: JsonLoadsFunction = json.loads
 
-        Subclasses can override this function to specify custom JSON
-        de-serialization per context.
-        """
-        return _loads
+    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+        super().__init__(oid, context)
+        self.loads = self.__class__._loads
 
     def load(self, data: Buffer) -> Any:
         # json.loads() cannot work on memoryview.
         if isinstance(data, memoryview):
             data = bytes(data)
-        return self._loads(data)
+        return self.loads(data)
 
 
 class JsonLoader(_JsonLoader):
@@ -152,7 +141,7 @@ class JsonbBinaryLoader(_JsonLoader):
         data = data[1:]
         if isinstance(data, memoryview):
             data = bytes(data)
-        return self._loads(data)
+        return self.loads(data)
 
 
 def register_default_globals(ctx: AdaptContext) -> None:
index 95e1983f1b77a2eb9c4bb7ec2dff3eb2b510f3f0..20b7ff2121b2d1f27c2688c6a32a81b0870454fa 100644 (file)
@@ -93,15 +93,19 @@ def test_json_dump_customise_wrapper(conn, wrapper, fmt_in):
 @pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
 @pytest.mark.parametrize("wrapper", ["Json", "Jsonb"])
 def test_json_dump_subclass(conn, wrapper, fmt_in):
+    JDumper = getattr(
+        psycopg.types.json,
+        f"{wrapper}{'Binary' if fmt_in != Format.TEXT else ''}Dumper",
+    )
     wrapper = getattr(psycopg.types.json, wrapper)
 
-    class MyWrapper(wrapper):
-        def dumps(self):
-            return my_dumps(self.obj)
+    class MyJsonDumper(JDumper):
+        _dumps = my_dumps
 
     obj = {"foo": "bar"}
     cur = conn.cursor()
-    cur.execute(f"select %{fmt_in}->>'baz' = 'qux'", (MyWrapper(obj),))
+    MyJsonDumper.register(wrapper, context=cur)
+    cur.execute(f"select %{fmt_in}->>'baz' = 'qux'", (wrapper(obj),))
     assert cur.fetchone()[0] is True
 
 
@@ -130,8 +134,7 @@ def test_json_load_subclass(conn, binary, pgtype):
     )
 
     class MyJsonLoader(JLoader):
-        def get_loads(self):
-            return my_loads
+        _loads = my_loads
 
     cur = conn.cursor(binary=binary)
     MyJsonLoader.register(cur.adapters.types[pgtype].oid, context=cur)