]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Allow specifying a context to `set_json_loads/dumps()` functions
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 27 Jun 2021 00:51:28 +0000 (01:51 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 27 Jun 2021 00:52:57 +0000 (01:52 +0100)
Drop subclasses tests from the test suite because they are an
implementation detail: the real interface are the `set_json_*()` functions.

psycopg/psycopg/types/json.py
tests/types/test_json.py

index 7ecea76f44dcbb17baa1f265f7d72bbe5d35c772..74a2e142b4d6f8f63b41c661a9c52c7b85338c21 100644 (file)
@@ -5,7 +5,7 @@ Adapers for JSON types.
 # Copyright (C) 2020-2021 The Psycopg Team
 
 import json
-from typing import Any, Callable, Optional, Union
+from typing import Any, Callable, Optional, Type, Union
 
 from ..pq import Format
 from ..oids import postgres_types as builtins
@@ -17,24 +17,60 @@ JsonDumpsFunction = Callable[[Any], str]
 JsonLoadsFunction = Callable[[Union[str, bytes, bytearray]], Any]
 
 
-def set_json_dumps(dumps: JsonDumpsFunction) -> None:
+def set_json_dumps(
+    dumps: JsonDumpsFunction, context: Optional[AdaptContext] = None
+) -> None:
     """
     Set a global JSON serialisation function to use by default by JSON dumpers.
 
     By default dumping JSON uses the builtin `json.dumps()`. You can override
     it to use a different JSON library or to use customised arguments.
     """
-    _JsonDumper._dumps = dumps
-
-
-def set_json_loads(loads: JsonLoadsFunction) -> None:
+    if context is None:
+        # If changing load function globally, just change the default on the
+        # global class
+        _JsonDumper._dumps = dumps
+    else:
+        # If the scope is smaller than global, create subclassess and register
+        # them in the appropriate scope.
+        grid = [
+            (Json, JsonDumper),
+            (Json, JsonBinaryDumper),
+            (Jsonb, JsonbDumper),
+            (Jsonb, JsonbBinaryDumper),
+        ]
+        dumper: Type[_JsonDumper]
+        for wrapper, base in grid:
+            dumper = type(f"Custom{base.__name__}", (base,), {"_dumps": dumps})
+            dumper.register(wrapper, context=context)
+
+
+def set_json_loads(
+    loads: JsonLoadsFunction, context: Optional[AdaptContext] = None
+) -> None:
     """
     Set a global JSON parsing function to use by default by the JSON loaders.
 
     By default loading JSON uses the builtin `json.loads()`. You can override
     it to use a different JSON library or to use customised arguments.
     """
-    _JsonLoader._loads = loads
+    if context is None:
+        # If changing load function globally, just change the default on the
+        # global class
+        _JsonLoader._loads = loads
+    else:
+        # If the scope is smaller than global, create subclassess and register
+        # them in the appropriate scope.
+        grid = [
+            ("json", JsonLoader),
+            ("json", JsonBinaryLoader),
+            ("jsonb", JsonbLoader),
+            ("jsonb", JsonbBinaryLoader),
+        ]
+        loader: Type[_JsonLoader]
+        for tname, base in grid:
+            loader = type(f"Custom{base.__name__}", (base,), {"_loads": loads})
+            loader.register(tname, context=context)
 
 
 class _JsonWrapper:
index 20b7ff2121b2d1f27c2688c6a32a81b0870454fa..d273babbee915714508ab85ecb19c247a46d15da 100644 (file)
@@ -82,37 +82,32 @@ def test_json_dump_customise(conn, wrapper, fmt_in):
 
 @pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
 @pytest.mark.parametrize("wrapper", ["Json", "Jsonb"])
-def test_json_dump_customise_wrapper(conn, wrapper, fmt_in):
+def test_json_dump_customise_context(conn, wrapper, fmt_in):
     wrapper = getattr(psycopg.types.json, wrapper)
     obj = {"foo": "bar"}
-    cur = conn.cursor()
-    cur.execute(f"select %{fmt_in}->>'baz' = 'qux'", (wrapper(obj, my_dumps),))
-    assert cur.fetchone()[0] is True
+    cur1 = conn.cursor()
+    cur2 = conn.cursor()
+
+    set_json_dumps(my_dumps, cur2)
+    cur1.execute(f"select %{fmt_in}->>'baz'", (wrapper(obj),))
+    assert cur1.fetchone()[0] is None
+    cur2.execute(f"select %{fmt_in}->>'baz'", (wrapper(obj),))
+    assert cur2.fetchone()[0] == "qux"
 
 
 @pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
 @pytest.mark.parametrize("wrapper", ["Json", "Jsonb"])
-def test_json_dump_subclass(conn, wrapper, fmt_in):
-    JDumper = getattr(
-        psycopg.types.json,
-        f"{wrapper}{'Binary' if fmt_in != Format.TEXT else ''}Dumper",
-    )
+def test_json_dump_customise_wrapper(conn, wrapper, fmt_in):
     wrapper = getattr(psycopg.types.json, wrapper)
-
-    class MyJsonDumper(JDumper):
-        _dumps = my_dumps
-
     obj = {"foo": "bar"}
     cur = conn.cursor()
-    MyJsonDumper.register(wrapper, context=cur)
-    cur.execute(f"select %{fmt_in}->>'baz' = 'qux'", (wrapper(obj),))
+    cur.execute(f"select %{fmt_in}->>'baz' = 'qux'", (wrapper(obj, my_dumps),))
     assert cur.fetchone()[0] is True
 
 
 @pytest.mark.parametrize("binary", [True, False])
 @pytest.mark.parametrize("pgtype", ["json", "jsonb"])
 def test_json_load_customise(conn, binary, pgtype):
-    obj = {"foo": "bar"}
     cur = conn.cursor(binary=binary)
 
     set_json_loads(my_loads)
@@ -127,21 +122,20 @@ def test_json_load_customise(conn, binary, pgtype):
 
 @pytest.mark.parametrize("binary", [True, False])
 @pytest.mark.parametrize("pgtype", ["json", "jsonb"])
-def test_json_load_subclass(conn, binary, pgtype):
-    JLoader = getattr(
-        psycopg.types.json,
-        f"{pgtype.title()}{'Binary' if binary else ''}Loader",
-    )
-
-    class MyJsonLoader(JLoader):
-        _loads = my_loads
-
-    cur = conn.cursor(binary=binary)
-    MyJsonLoader.register(cur.adapters.types[pgtype].oid, context=cur)
-    cur.execute(f"""select '{{"foo": "bar"}}'::{pgtype}""")
-    obj = cur.fetchone()[0]
-    assert obj["foo"] == "bar"
-    assert obj["answer"] == 42
+def test_json_load_customise_context(conn, binary, pgtype):
+    cur1 = conn.cursor(binary=binary)
+    cur2 = conn.cursor(binary=binary)
+
+    set_json_loads(my_loads, cur2)
+    cur1.execute(f"""select '{{"foo": "bar"}}'::{pgtype}""")
+    got = cur1.fetchone()[0]
+    assert got["foo"] == "bar"
+    assert "answer" not in got
+
+    cur2.execute(f"""select '{{"foo": "bar"}}'::{pgtype}""")
+    got = cur2.fetchone()[0]
+    assert got["foo"] == "bar"
+    assert got["answer"] == 42
 
 
 def my_dumps(obj):