]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix(crdb): fix json adaptation
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 22 May 2022 00:44:17 +0000 (02:44 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 12 Jul 2022 11:58:34 +0000 (12:58 +0100)
The function set_json_dumps() now adapts the currently registered
adapter, not the default one, to be able to customize the Json dumper on
CRDB, without resetting its oid to the PostgreSQL standard one, which
CRDB doesn't know.

psycopg/psycopg/crdb.py
psycopg/psycopg/types/json.py
tests/types/test_json.py

index acbd99a1717a840ea5577b5214c7648ed767060a..64f74c618be05f798dc1c6462faff3c655eaf0b7 100644 (file)
@@ -66,7 +66,7 @@ class CrdbEnumBinaryDumper(EnumBinaryDumper):
 
 
 def register_crdb_adapters(context: AdaptContext) -> None:
-    from .types import string
+    from .types import string, json
 
     adapters = context.adapters
 
@@ -76,6 +76,10 @@ def register_crdb_adapters(context: AdaptContext) -> None:
     adapters.register_dumper(Enum, CrdbEnumBinaryDumper)
     adapters.register_dumper(Enum, CrdbEnumDumper)
 
+    # CRDB doesn't have json/jsonb: both dump as the jsonb oid
+    adapters.register_dumper(json.Json, json.JsonbBinaryDumper)
+    adapters.register_dumper(json.Json, json.JsonbDumper)
+
 
 register_crdb_adapters(adapters)
 
index 2a5835a15de070b8d5f20a07f35cbf3e43ce016a..dc7d83f619d349fe9fc61a67b86ed1c1b5aabc37 100644 (file)
@@ -5,12 +5,13 @@ Adapers for JSON types.
 # Copyright (C) 2020 The Psycopg Team
 
 import json
-from typing import Any, Callable, Optional, Type, Union
+from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
 
+from .. import abc
+from .. import errors as e
 from .. import postgres
 from ..pq import Format
-from ..abc import AdaptContext
-from ..adapt import Buffer, Dumper, Loader
+from ..adapt import Buffer, Dumper, Loader, PyFormat, AdaptersMap
 from ..errors import DataError
 
 JsonDumpsFunction = Callable[[Any], str]
@@ -18,7 +19,7 @@ JsonLoadsFunction = Callable[[Union[str, bytes, bytearray]], Any]
 
 
 def set_json_dumps(
-    dumps: JsonDumpsFunction, context: Optional[AdaptContext] = None
+    dumps: JsonDumpsFunction, context: Optional[abc.AdaptContext] = None
 ) -> None:
     """
     Set the JSON serialisation function to store JSON objects in the database.
@@ -40,22 +41,28 @@ def set_json_dumps(
         # global class
         _JsonDumper._dumps = dumps
     else:
+        adapters = context.adapters
+
         # If the scope is smaller than global, create subclassess and register
         # them in the appropriate scope.
         grid = [
-            (Json, JsonDumper),
-            (Json, JsonBinaryDumper),
-            (Jsonb, JsonbDumper),
-            (Jsonb, JsonbBinaryDumper),
+            (Json, PyFormat.BINARY),
+            (Json, PyFormat.TEXT),
+            (Jsonb, PyFormat.BINARY),
+            (Jsonb, PyFormat.TEXT),
         ]
         dumper: Type[_JsonDumper]
-        for wrapper, base in grid:
-            dumper = type(f"Custom{base.__name__}", (base,), {"_dumps": dumps})
-            context.adapters.register_dumper(wrapper, dumper)
+        for wrapper, format in grid:
+            base = _get_current_dumper(adapters, wrapper, format)
+            name = base.__name__
+            if not base.__name__.startswith("Custom"):
+                name = f"Custom{name}"
+            dumper = type(name, (base,), {"_dumps": dumps})
+            adapters.register_dumper(wrapper, dumper)
 
 
 def set_json_loads(
-    loads: JsonLoadsFunction, context: Optional[AdaptContext] = None
+    loads: JsonLoadsFunction, context: Optional[abc.AdaptContext] = None
 ) -> None:
     """
     Set the JSON parsing function to fetch JSON objects from the database.
@@ -116,7 +123,7 @@ class _JsonDumper(Dumper):
     # set_json_dumps) or by a subclass.
     _dumps: JsonDumpsFunction = json.dumps
 
-    def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+    def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None):
         super().__init__(cls, context)
         self.dumps = self.__class__._dumps
 
@@ -157,7 +164,7 @@ class _JsonLoader(Loader):
     # set_json_loads) or by a subclass.
     _loads: JsonLoadsFunction = json.loads
 
-    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+    def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None):
         super().__init__(oid, context)
         self.loads = self.__class__._loads
 
@@ -193,7 +200,24 @@ class JsonbBinaryLoader(_JsonLoader):
         return self.loads(data)
 
 
-def register_default_adapters(context: AdaptContext) -> None:
+def _get_current_dumper(
+    adapters: AdaptersMap, cls: type, format: PyFormat
+) -> Type[abc.Dumper]:
+    try:
+        return adapters.get_dumper(cls, format)
+    except e.ProgrammingError:
+        return _default_dumpers[cls, format]
+
+
+_default_dumpers: Dict[Tuple[Type[_JsonWrapper], PyFormat], Type[Dumper]] = {
+    (Json, PyFormat.BINARY): JsonBinaryDumper,
+    (Json, PyFormat.TEXT): JsonDumper,
+    (Jsonb, PyFormat.BINARY): JsonbBinaryDumper,
+    (Jsonb, PyFormat.TEXT): JsonDumper,
+}
+
+
+def register_default_adapters(context: abc.AdaptContext) -> None:
     adapters = context.adapters
 
     # Currently json binary format is nothing different than text, maybe with
index 482bb3e3d1887356bde29bb606e9a241e1d7aaaf..9eef91270902bab8cb0679051cfc071faabf279c 100644 (file)
@@ -26,7 +26,9 @@ samples = [
 def test_json_dump(conn, val, fmt_in):
     obj = json.loads(val)
     cur = conn.cursor()
-    cur.execute(f"select pg_typeof(%{fmt_in.value}) = 'json'::regtype", (Json(obj),))
+    cur.execute(
+        f"select pg_typeof(%{fmt_in.value})::regtype = 'json'::regtype", (Json(obj),)
+    )
     assert cur.fetchone()[0] is True
     cur.execute(f"select %{fmt_in.value}::text = %s::json::text", (Json(obj), val))
     assert cur.fetchone()[0] is True
@@ -50,6 +52,7 @@ def test_json_load(conn, val, jtype, fmt_out):
     assert cur.fetchone()[0] == json.loads(val)
 
 
+@pytest.mark.crdb("skip", reason="copy")
 @pytest.mark.parametrize("val", samples)
 @pytest.mark.parametrize("jtype", ["json", "jsonb"])
 @pytest.mark.parametrize("fmt_out", pq.Format)