class _JsonWrapper:
- __slots__ = ("obj",)
+ __slots__ = ("obj", "_dumps")
- def __init__(self, obj: Any):
+ def __init__(self, obj: Any, dumps: Optional[JsonDumpsFunction] = None):
self.obj = obj
+ self._dumps = dumps or _dumps
def __repr__(self) -> str:
sobj = repr(self.obj)
sobj = f"{sobj[:35]} ... ({len(sobj)} chars)"
return f"{self.__class__.__name__}({sobj})"
+ def dumps(self) -> str:
+ return self._dumps(self.obj)
+
class Json(_JsonWrapper):
__slots__ = ()
format = Format.TEXT
- def __init__(self, cls: type, context: Optional[AdaptContext] = None):
- super().__init__(cls, context)
- self._dumps = self.get_dumps()
-
- def get_dumps(self) -> JsonDumpsFunction:
- r"""
- Return a `json.dumps()`\-compatible function to serialize the object.
-
- Subclasses can override this function to specify custom JSON
- serialization per context.
- """
- return _dumps
-
def dump(self, obj: _JsonWrapper) -> bytes:
- return self._dumps(obj.obj).encode("utf-8")
+ return obj.dumps().encode("utf-8")
class JsonDumper(_JsonDumper):
format = Format.BINARY
def dump(self, obj: _JsonWrapper) -> bytes:
- return b"\x01" + self._dumps(obj.obj).encode("utf-8")
+ return b"\x01" + obj.dumps().encode("utf-8")
class _JsonLoader(Loader):
set_json_dumps(json.dumps)
+@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize("wrapper", ["Json", "Jsonb"])
+def test_json_dump_customise_wrapper(conn, wrapper, fmt_in):
+ wrapper = getattr(psycopg.types.json, wrapper)
+ obj = {"foo": "bar"}
+ cur = conn.cursor()
+ cur.execute(f"select %{fmt_in}->>'baz' = 'qux'", (wrapper(obj, my_dumps),))
+ assert cur.fetchone()[0] is True
+
+
@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
@pytest.mark.parametrize("wrapper", ["Json", "Jsonb"])
def test_json_dump_subclass(conn, wrapper, fmt_in):
- JDumper = getattr(
- psycopg.types.json,
- f"{wrapper}{'Binary' if fmt_in != Format.TEXT else ''}Dumper",
- )
wrapper = getattr(psycopg.types.json, wrapper)
- class MyJsonDumper(JDumper):
- def get_dumps(self):
- return my_dumps
+ class MyWrapper(wrapper):
+ def dumps(self):
+ return my_dumps(self.obj)
obj = {"foo": "bar"}
cur = conn.cursor()
- MyJsonDumper.register(wrapper, context=cur)
- cur.execute(f"select %{fmt_in}->>'baz' = 'qux'", (wrapper(obj),))
+ cur.execute(f"select %{fmt_in}->>'baz' = 'qux'", (MyWrapper(obj),))
assert cur.fetchone()[0] is True